機械学習用Template準備-3_1(補足)
はじめに
以前の投稿でDatamoduleの実装を紹介しました。CIFAR10をDatasetとして使用する実装です。ですが実装に関して汎用性がなく、CIFAR100やMNISTの時に個別にModuleを実装する必要がありました。良い実装方法が思いついたのでその部分のみUpdateします。
何が変わったか
TorchvisionでSupportしているDatasetsは呼び出すClass名を変えるだけで、Interface等は大きな差異がないということを利用し、Datasetsの呼び出し部分をgetattrで書けばいいじゃないかと気げ尽きました。またtransformもload時に適応すれば処理時間が短くなることがわかり、その部分も実装しなおしました。
import multiprocessing as mp
from typing import List, Union, Dict
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import torchvision
from pytorch_lightning import LightningDataModule as LitDataModule
class DataModule(LitDataModule):
# default transform
transform = transforms.Compose([
transforms.ToTensor(),
])
# sampleing for validataion
val_data_num: int = 5000
def __init__(self, data_dir: str, batch_size: int, dataset: str, download: bool=True):
super().__init__()
self.data_dir: str = data_dir
self.batch_size: int = batch_size
self.dataset: str = dataset
self.dowload: bool = download
# count of # of CPUs
self._num_workers = mp.cpu_count()
@classmethod
def update_transform(cls, transform):
cls.transform = transform
@classmethod
def change_num_of_val_data(cls, number: int):
cls.val_data_num = number
def dataset_setup(self, dataset_name: str):
ds = getattr(torchvision.datasets, dataset_name)(
self.data_dir,
download=self.dowload,
transform=self.transform,
train=True
)
test_ds = getattr(torchvision.datasets, dataset_name)(
self.data_dir,
download=self.dowload,
transform=self.transform,
train=False
)
num_items = ds.__len__()
return ds, test_ds, num_items
def prepare_data(self) -> None:
if self.data_dir[-1] != '/':
self.data_dir = self.data_dir + '/'
if self.dataset != 'MNIST':
self.data_dir = self.data_dir + self.dataset
self.ds, self.test_ds, self._num_of_train_items = self.dataset_setup(self.dataset)
def setup(self, stage: str='fit') -> None:
if stage == 'fit' or stage is None:
train_data_num = self._num_of_train_items - self.val_data_num
self.train_ds, self.val_ds = random_split(self.ds, [train_data_num, self.val_data_num])
if stage == 'test' or stage is None:
pass
他の部分は依然と変わりません。これでスッキリです。
アメリカSilicon Valley在住のエンジニアです。日本企業から突然アメリカ企業に転職して気が付いた事や知って役に立った事を書いています。