見出し画像

機械学習用Template準備-3_1(補足)

はじめに

 以前の投稿でDatamoduleの実装を紹介しました。CIFAR10をDatasetとして使用する実装です。ですが実装に関して汎用性がなく、CIFAR100やMNISTの時に個別にModuleを実装する必要がありました。良い実装方法が思いついたのでその部分のみUpdateします。

何が変わったか

 TorchvisionでSupportしているDatasetsは呼び出すClass名を変えるだけで、Interface等は大きな差異がないということを利用し、Datasetsの呼び出し部分をgetattrで書けばいいじゃないかと気げ尽きました。またtransformもload時に適応すれば処理時間が短くなることがわかり、その部分も実装しなおしました。

import multiprocessing as mp
from typing import List, Union, Dict
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import torchvision
from pytorch_lightning import LightningDataModule as LitDataModule
class DataModule(LitDataModule):
   # default transform
   transform = transforms.Compose([
       transforms.ToTensor(),
   ])
   # sampleing for validataion
   val_data_num: int = 5000

   def __init__(self, data_dir: str, batch_size: int, dataset: str, download: bool=True):
       super().__init__()
       self.data_dir: str = data_dir
       self.batch_size: int = batch_size
       self.dataset: str = dataset
       self.dowload: bool = download
       # count of # of CPUs
       self._num_workers = mp.cpu_count()
   @classmethod
   def update_transform(cls, transform):
       cls.transform = transform
   @classmethod
   def change_num_of_val_data(cls, number: int):
       cls.val_data_num = number
   def dataset_setup(self, dataset_name: str): 
       ds = getattr(torchvision.datasets, dataset_name)(
           self.data_dir, 
           download=self.dowload, 
           transform=self.transform, 
           train=True
       )
       test_ds = getattr(torchvision.datasets, dataset_name)(
           self.data_dir, 
           download=self.dowload, 
           transform=self.transform, 
           train=False
       )
       num_items = ds.__len__()
       return ds, test_ds, num_items

   def prepare_data(self) -> None:
       if self.data_dir[-1] != '/':
           self.data_dir = self.data_dir + '/'
       if self.dataset != 'MNIST':
           self.data_dir = self.data_dir + self.dataset
       self.ds, self.test_ds, self._num_of_train_items = self.dataset_setup(self.dataset)
   def setup(self, stage: str='fit') -> None:
       if stage == 'fit' or stage is None:
           train_data_num = self._num_of_train_items - self.val_data_num
           self.train_ds, self.val_ds = random_split(self.ds, [train_data_num, self.val_data_num])
       if stage == 'test' or stage is None:
           pass

他の部分は依然と変わりません。これでスッキリです。



アメリカSilicon Valley在住のエンジニアです。日本企業から突然アメリカ企業に転職して気が付いた事や知って役に立った事を書いています。