torchvisionで学習済みモデルをweightsで指定する
PyTorchのtorchvisionで学習済みモデルをダウンロードする際に、pretrainedを指定すると警告が出るようになりました。
torchvision 0.13からは非推奨、0.15で削除予定なので、代わりにweightsを使ってとのこと(torchvisionのバージョンはこちらをご参照ください)。
非推奨になってしまったことですし、次の公式サイトを参考にしてweightsを使ってみます。なお、公式以上の説明はありません。悪しからず。
まずは従来のコード。モデルは何でもよいですが、ここではEfficientNet B1を指定しています。
from torchvision import models
model = models.efficientnet_b1(pretrained=True)
では、weightsを使います。指定できるweightsは、先のページ内のTable of all available classification weightsにあります。
weights = models.EfficientNet_B1_Weights.IMAGENET1K_V1
model = models.efficientnet_b1(weights=weights)
weightsをわざわざ変数にしているのは、その後の前処理やクラスラベルの読み込みのときにも利用するためです。
# 前処理の内容
print(weights.transforms())
# クラスラベルの内容
print(weights.meta["categories"])
以前は学習時の前処理の内容を自分で定義したり、クラスラベル用のファイルを読み込んでいたりしたので、この辺はとてもありがたいです。
以上、weightsの使い方でした。どなたかお困りの方のご参考になれば幸いです。