
Photo by
ia19200102
Pytorch Lightning トライアル
以下のページを参考にトライアル。
インストールコマンドは公式サイトと違っていたので、公式サイトのコマンドを実行。
PyTorch Lightning へようこそ ⚡ — PyTorch Lightning 2.4.0 ドキュメント
pip install lightning
トライアルコードとして、サイトのサンプルコードを微調整して使用。(validation_stepと、推論結果を確認するtest_stepを追加。)
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import pytorch_lightning as pl
import matplotlib.pyplot as plt
from pytorch_lightning.loggers import TensorBoardLogger
# モデルの定義
class SimpleCNN(pl.LightningModule):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.pool1(torch.relu(self.conv1(x)))
x = self.pool2(torch.relu(self.conv2(x)))
x = x.view(x.size(0), -1)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.functional.cross_entropy(y_hat, y)
self.log('train_loss', loss)
return loss
def configure_optimizers(self):
return optim.Adam(self.parameters(), lr=0.001)
# validationステップを追加
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.functional.cross_entropy(y_hat, y)
self.log('val_loss', loss)
# 正解率を計算
acc = (y_hat.argmax(dim=1) == y).float().mean()
self.log('val_acc', acc)
# テストステップを追加
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.functional.cross_entropy(y_hat, y)
self.log('test_loss', loss)
# 正解率を計算
acc = (y_hat.argmax(dim=1) == y).float().mean()
self.log('test_acc', acc)
# 予測と正解ラベルを保存
self.test_predictions = torch.cat((getattr(self, 'test_predictions', torch.tensor([])), y_hat.argmax(dim=1)), dim=0)
self.test_true_labels = torch.cat((getattr(self, 'test_true_labels', torch.tensor([])), y), dim=0)
# テストエポック終了時に画像を表示
def on_test_epoch_end(self):
# ランダムに10個の画像を選択
num_images = 10
indices = torch.randint(0, len(self.test_predictions), (num_images,))
fig, axes = plt.subplots(1, num_images, figsize=(20, 4))
for i, idx in enumerate(indices):
ax = axes[i]
# テストデータセットから画像を取得
image, _ = self.trainer.datamodule.mnist_val[idx]
# 画像を表示
ax.imshow(image.squeeze().numpy(), cmap='gray')
ax.set_title(f"Pred: {self.test_predictions[idx]}, True: {self.test_true_labels[idx]}")
ax.axis('off')
plt.show()
# データローダーの設定
class DataModule(pl.LightningDataModule):
def __init__(self, batch_size=64):
super().__init__()
self.batch_size = batch_size
def prepare_data(self):
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
self.mnist_train = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
self.mnist_val = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)
def train_dataloader(self):
return torch.utils.data.DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True, num_workers=7)
def val_dataloader(self):
return torch.utils.data.DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=7)
# テストデータローダーを追加
def test_dataloader(self):
return torch.utils.data.DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=7)
logger = TensorBoardLogger("/workspaces/output/tensorboard/src/sandbox", name="sandbox")
data_module = DataModule(batch_size=64)
model = SimpleCNN()
# CUDAが利用可能なら使う
if torch.cuda.is_available():
model = model.cuda()
# モデルの訓練とテスト
trainer = pl.Trainer(max_epochs=1, logger=logger)
trainer.fit(model, datamodule=data_module)
trainer.test(model, datamodule=data_module)
サイトの解説にもあるように、訓練ループなどが省略されて、すっきり。
PyTorchとPyTorch Lightningのコード比較イメージ。

(古いバージョンのものなので、若干変わっているところもあるかもしれません。)
この比較イメージでも、訓練ループがごっそり省略されているのが、見て取れます。
逆に省略されているせいで、何をやっているかわかりにくくなっているので、とりあえず、素のPyTorchを試したあと、PyTorch Lightningを試すのがいいかもしれません。
素のPyTorchが必要になった際も、pl.LightningModuleやpl.LightningDataModuleのインターフェースが参考になりそうな気がします。
ログ出力の処理もすっきりして、見やすく。
出力はデフォルトだとtensorboard形式のようで、tensorboard環境をつくれば、いい感じに結果が表示されます。