13章 PyTorch LightningでMNIST手書き数字の分類タスクを訓練できない!!
はじめに
シリーズ「Python機械学習プログラミング」の紹介
本シリーズは書籍「Python機械学習プログラミング PyTorch & scikit-learn編」(初版第1刷)に関する記事を取り扱います。
この書籍のよいところは、Pythonのコードを動かしたり、アルゴリズムの説明を読み、ときに数式を確認して、包括的に機械学習を学ぶことができることです。
Pythonで機械学習を学びたい方におすすめです!
この記事では、この書籍のことを「テキスト」と呼びます。
記事の内容
この記事は「第13章 PyTorchのメカニズム」の「13.8.3 PyTorch LightningのTrainerクラスを使ってモデルを訓練する」のコード実行時に発生するエラーの対処方法を取り上げています。
13章のダイジェスト
13章はPyTorchメニューがてんこ盛りの贅沢な章です。
まず、PyTorchの計算グラフと自動微分を学び、torch.nnモジュールを使ってXOR分類問題に挑みます。
続いて、実用的な応用例として、Auto MPGデータセットを用いた車の燃費の予測(回帰タスク)、MNIST手書き数字データセットを用いた数字の分類(分類タスク)を実施します。
最後に、PyTorch LightningとTensorBoardの活用に進みます。
件の問題は、最後のPyTorch LigntningでMNIST手書き数字データセットを用いたモデルの訓練時に起きました。
PyTorch Lightningで手書き文字分類の訓練を実行できるようにする
エラー発生の概要
MNIST手書き数字データセットの分類タスクにおいて、テキストに記載された訓練コードの実行時にエラーが発生しました。
直接のエラー発生箇所は、MultiLayerPerceptronクラスのインスタンス化の部分です。
# エラー発生箇所
mnistclassifier = MultiLayerPerceptron()
エラー情報を読み解くと、MultiLayerPerceptronクラスの__init__コンストラクタの部分に問題があるような感じでした。
# エラー発生箇所
class MultiLayerPerceptron(pl.LightningModule):
def __init__(self, image_shape=(1, 28,28), hidden_units=(32, 16)):
super().__init__()
# Lightningの新しい属性
self.train_acc = Accuracy() # <--- この部分でエラーが発生
Accuracyクラスのインスタンス化の部分にエラーの原因がありそうです。
このAccuracyはtorchmetricsモジュールのクラスです。
ちなみに、Accuracyとは正解率のことです。
エラーメッセージは次の内容でした。
TypeError: __new__() missing 1 required positional argument: 'task'
型エラー:__new__() に 1 つの必須の位置引数'task'がありません
エラー発生原因を探る
エラーメッセージの意味がよくわからなかったのでネットで調べてみましたが、直球の回答は得られませんでした。
どうやら task という引数が足りない、というところまではたどり着けました。
以下のサイトでヒントをいただきました。ありがとうございます!
続いて、Accuracyクラスの引数を調べるために、PyTorch Metricsのサイトを訪れて、パラメータの内容を確認することにしました。
Accuracyの引数には、たしかにtask引数がありました。
エラー箇所の変更を実施
手書き数字の分類タスクは 'multiclass' に該当するので、task='multiclass'を指定してみます。
# コードの変更
class MultiLayerPerceptron(pl.LightningModule):
def __init__(self, image_shape=(1, 28,28), hidden_units=(32, 16)):
super().__init__()
# Lightningの新しい属性
self.train_acc = Accuracy(task='multiclass') # <--- この部分を変更
self.valid_acc = Accuracy(task= 'multiclass') # <--- この部分を変更
self.test_acc = Accuracy(task= 'multiclass') # <--- この部分を変更
またしてもエラーが、、、
コードの変更後に、訓練を実施したところ、またエラー発生です。。。
AssertionError:
アサーション例外エラー
そっけないエラーメッセージです。。。
エラーメッセージの上部をみると、359行、360行にヒントが記載されていました。
task引数がmulticlassの場合に、アサーションチェックを実施していて、360行目の部分で num_classes 引数 の存在チェックをしているようです。
358 if task == "multiclass":
359 assert isinstance(num_classes, int)
--> 360 assert isinstance(top_k, int)
361 return MulticlassAccuracy(num_classes, top_k, average, **kwargs)
もう一度、Accuracyクラスの引数を確認してみます。
引数に num_classes がありました。
おそらくこの引数は「クラスの数」でしょう、という見立てで進むことにします。
エラー箇所を変更して再実行!
MNIST手書き数字は、0~9までの「10個」の数字に分類するので、num_class=10 と設定してみましょう。
# コードの変更(2回目)
class MultiLayerPerceptron(pl.LightningModule):
def __init__(self, image_shape=(1, 28,28), hidden_units=(32, 16)):
super().__init__()
# Lightningの新しい属性
self.train_acc = Accuracy(task='multiclass', num_classes=10) # <--- この部分を変更
self.valid_acc = Accuracy(task='multiclass', num_classes=10) # <--- この部分を変更
self.test_acc = Accuracy(task='multiclass', num_classes=10) # <--- この部分を変更
この変更により、13.8.3 のコードを動かせるようになりました!
ひとまず動くようになったのですが、もっとスマートなコードの書き方があるかもしません。これは宿題にしておきます。
テキストのコードでモデルを訓練した結果が出ました!
性能指標の値がテキストと異なっていますが、ひとまず良しとします。
しかし、なぜテキストは引数を指定しないコードを掲載しているのでしょう?
ちなみに、Accuracyを含む torchmetrics モジュールのバージョンは以下のとおりです。
- 自環境:0.11.0
- テキストの環境:0.6.2
まとめ
今回は、PyTorch MetricsのAccuracyクラスの引数に関して、エラーの解消に取り組みました。
世の中にはPyTorchをベースにした多種多様な周辺ライブラリが多くリリースされているそうです。
これらの周辺ライブラリに慣れるには、まだまだ時間がかかりそうです。日々進化するPyTorchに追随することの難しさを実感しました。
# 今日の一句
print('原書にとって正解率Accuracyの確からしさとは?')
楽しくPython機械学習プログラミングを学びましょう!
おまけ数式
noteでは数式記法を利用できます。
今回はディープラーニングの重みの初期化に用いられるXavier/Glorot初期化の一様分布の区間を紹介します。
$$
W \sim Uniform \left( - \cfrac{\sqrt6}{\sqrt{n_{in}+n_{out}}} \ ,\ \cfrac{\sqrt6}{\sqrt{n_{in}+n_{out}}} \right)
$$
$${n_{in}}$$は重みと掛け合わせる入力ニューロンの個数、$${n_{out}}$$は次の層に与える出力ニューロンの個数です。
おわりに
AI・機械学習の学習でおすすめの書籍を紹介いたします。
「最短コースでわかる ディープラーニングの数学」
機械学習やディープラーニングなどの手法を理解する際に、数学的な知識があると、いっそう深い理解につながると思います。
でも、難しい数式がびっしりと並んでいる書面を想像すると、なんだかゾッとします。
そんな数式にアレルギーのある方にとって、この「ディープラーニングの数学」は優しく寄り添ってくれて、「数学的」な見方を広げてくれるのではないでしょうか。
この書籍は、機械学習/深層学習の基礎的なテーマを、Pythonのコードを動かしながら、そして数学的な見解も実感しながら、楽しく学ぶことができると思います。
機械学習/深層学習の入門者にとって、次のようなトピックの理解を深くするチャンスとなるでしょう。
損失関数とその微分
活性化関数
交差エントロピー関数
誤差逆伝播
勾配降下法(最急降下法)
ちなみに、私が初めて手にした機械学習/深層学習の書籍が、このディープラーニングと数学でした。思い出深い一冊です。
今もときどき、自分の理解を整理する際に、ページを捲ります。
最後まで読んでくださり、ありがとうございました。
この記事が参加している募集
この記事が気に入ったらサポートをしてみませんか?