PyTorchを使ってみよ! - MNISTで数字分類 -準備編
PyTorchで数字を分類、手書きの数字が何であるかを推論するモデルを一から作っていきます。MNISTの画像を使って学習、テストします。詳しく紹介されている動画をみながら大事なところを記録しておきたいと思います。
まずコードを書く環境はGoogleColaboratory(Colab)を使います。最初からライブラリが入っていますので手間いらずです。
GPUをみる方法
!nvidia-smi
torchvisionを使った画像データベースMNIST、画像を変換するToTensorのインポート。
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
train_dataとtest_dataという変数を作りそこにMNISTのデータを入れて学習、テストをしてきます。
train_data = MNIST(root='data',train=True,download=True,transform=ToTensor())
test_data = MNIST(root='data',train=False,download=True,transform=ToTensor())
Colabでは変数名を入れて実行するとその中身を見ることができます。例えば
train_data
として実行してやると
Dataset MNIST
Number of datapoints: 60000
Root location: data
Split: Train
StandardTransform
Transform: ToTensor()
と何が入ってるかを表示してくれます。
len(train_data),len(test_data)
としてやることでどれぐらいのデータ量かも把握ができます。この場合
(60000, 10000)
と出てくるので、画像がそれぞれ60000,枚 10000枚入ってることが確認できます。
train_data[0][0].shape
と.shapeを使い配列の次元数や大きさを調べることができます。以下出力結果です。
torch.Size([1, 28, 28])
像一枚あたり、1チャネル(白黒)、28 x 28(784画素)
画像描画をしてみます。準備としてデータの整形をします
train_data[0][0].squeeze().shape
してやると最初の1が取り除かれる。
torch.Size([28, 28])
となります。これで画像を読み込む準備ができました。
import matplotlib.pyplot as plt
plt.imshow(train_data[0][0].squeeze(),cmap='gray')
plt.colorbar()
plt.show
と画像が表示されます。
続いてデータを訓練する方法ですが、ミニバッチと呼ばれる操作により行います。その準備をやっていきます。実際にはDataLoaderと呼ばれるものを作っていきます。
from torch.utils.data import DataLoader
train_loader = DataLoader(train_data,batch_size=64,shuffle=True)
test_loader = DataLoader(test_data,batch_size=64,shuffle=False)
実際に実行して長さをみてみます。
len(train_loader),len(test_loader)
(938, 157)
と出てきます。64ひとまとめで処理していますので、train_loaderは934*64枚、1test_loaderは57*64枚となります。元の枚数が
(60000, 10000)
ということだったにで、大体数字は同じになっています。
このDataLoaderですがデータを取り出すのに一手間必要です。
batch = iter(train_loader).next()
として、それぞれを変数に入れてやります。
image,label = batch
ラベル。
label
tensor([7, 2, 2, 1, 4, 7, 1, 1, 5, 8, 0, 7, 9, 6, 7, 1, 4, 8, 0, 1, 9, 3, 4, 4,
4, 2, 4, 4, 4, 1, 1, 1, 3, 3, 2, 9, 1, 7, 8, 5, 2, 1, 1, 9, 5, 6, 1, 9,
2, 5, 2, 0, 9, 0, 1, 6, 9, 8, 7, 6, 8, 3, 6, 0])
image。
plt.imshow(image[0][0],cmap='gray')
plt.colorbar
plt.show
image[0][0]はlabelの表示から"7"が読み取れます。そしてimageの表示でも"7"が読み取れます。