見出し画像

Pytorch Lightning トライアル

以下のページを参考にトライアル。

インストールコマンドは公式サイトと違っていたので、公式サイトのコマンドを実行。
PyTorch Lightning へようこそ ⚡ — PyTorch Lightning 2.4.0 ドキュメント

pip install lightning

トライアルコードとして、サイトのサンプルコードを微調整して使用。(validation_stepと、推論結果を確認するtest_stepを追加。)

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import pytorch_lightning as pl
import matplotlib.pyplot as plt
from pytorch_lightning.loggers import TensorBoardLogger

# モデルの定義
class SimpleCNN(pl.LightningModule):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool1(torch.relu(self.conv1(x)))
        x = self.pool2(torch.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.functional.cross_entropy(y_hat, y)
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=0.001)

    # validationステップを追加
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.functional.cross_entropy(y_hat, y)
        self.log('val_loss', loss)
        # 正解率を計算
        acc = (y_hat.argmax(dim=1) == y).float().mean()
        self.log('val_acc', acc)

    # テストステップを追加
    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.functional.cross_entropy(y_hat, y)
        self.log('test_loss', loss)
        # 正解率を計算
        acc = (y_hat.argmax(dim=1) == y).float().mean()
        self.log('test_acc', acc)

        # 予測と正解ラベルを保存
        self.test_predictions = torch.cat((getattr(self, 'test_predictions', torch.tensor([])), y_hat.argmax(dim=1)), dim=0)
        self.test_true_labels = torch.cat((getattr(self, 'test_true_labels', torch.tensor([])), y), dim=0)

    # テストエポック終了時に画像を表示
    def on_test_epoch_end(self):
        # ランダムに10個の画像を選択
        num_images = 10
        indices = torch.randint(0, len(self.test_predictions), (num_images,))

        fig, axes = plt.subplots(1, num_images, figsize=(20, 4))
        for i, idx in enumerate(indices):
            ax = axes[i]
            # テストデータセットから画像を取得
            image, _ = self.trainer.datamodule.mnist_val[idx]
            # 画像を表示
            ax.imshow(image.squeeze().numpy(), cmap='gray')
            ax.set_title(f"Pred: {self.test_predictions[idx]}, True: {self.test_true_labels[idx]}")
            ax.axis('off')
        plt.show()

# データローダーの設定
class DataModule(pl.LightningDataModule):
    def __init__(self, batch_size=64):
        super().__init__()
        self.batch_size = batch_size

    def prepare_data(self):
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
        self.mnist_train = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
        self.mnist_val = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True, num_workers=7)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=7)

    # テストデータローダーを追加
    def test_dataloader(self):
        return torch.utils.data.DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=7)

logger = TensorBoardLogger("/workspaces/output/tensorboard/src/sandbox", name="sandbox")

data_module = DataModule(batch_size=64)
model = SimpleCNN()

# CUDAが利用可能なら使う
if torch.cuda.is_available():
    model = model.cuda()

# モデルの訓練とテスト
trainer = pl.Trainer(max_epochs=1, logger=logger)
trainer.fit(model, datamodule=data_module)
trainer.test(model, datamodule=data_module)

サイトの解説にもあるように、訓練ループなどが省略されて、すっきり。

PyTorchとPyTorch Lightningのコード比較イメージ。

from https://pytorch-lightning.readthedocs.io/en/0.7.1/introduction_guide.html
(古いバージョンのものなので、若干変わっているところもあるかもしれません。)

この比較イメージでも、訓練ループがごっそり省略されているのが、見て取れます。

逆に省略されているせいで、何をやっているかわかりにくくなっているので、とりあえず、素のPyTorchを試したあと、PyTorch Lightningを試すのがいいかもしれません。

素のPyTorchが必要になった際も、pl.LightningModuleやpl.LightningDataModuleのインターフェースが参考になりそうな気がします。

ログ出力の処理もすっきりして、見やすく。
出力はデフォルトだとtensorboard形式のようで、tensorboard環境をつくれば、いい感じに結果が表示されます。

参照サイト

いいなと思ったら応援しよう!