https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#lightningdatamodule
1. DataModule이란?
train_dataloader, val_dataloder, test_dataloader, predict_dataloader의 집합체
# PyTorch Example
test_data = MNIST(my_path, train=False, download=True)
predict_data = MNIST(my_path, train=False, download=True)
train_data = MNIST(my_path, train=True, download=True)
train_data, val_data = random_split(train_data, [55000, 5000])
train_loader = DataLoader(train_data, batch_size=32)
val_loader = DataLoader(val_data, batch_size=32)
test_loader = DataLoader(test_data, batch_size=32)
predict_loader = DataLoader(predict_data, batch_size=32)
2. LightningDataModule API
1) prepare_data (how to download, tokenize, etc…)
multiple processes를 이용해서 data를 download
class MNISTDataModule(pl.LightningDataModule):
def prepare_data(self):
# download
MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())
2) setup (stage = None)
Stage에 따라 수행하는 게 다름 (fit (train + validate) / validate / test / predict) - none의 경우 모두 수행함
- train/val/test split
- dataset 생성
class MNISTDataModule(pl.LightningDataModule):
def setup(self, stage: Optional[str] = None):
# Assign Train/val split(s) for use in Dataloaders
if stage in (None, "fit"):
mnist_full = MNIST(self.data_dir, train=True, download=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
# Assign Test split(s) for use in Dataloaders
if stage in (None, "test"):
self.mnist_test = MNIST(self.data_dir, train=False, download=True, transform=self.transform)
3) train_dataloader
- 하나 이상의 training dataloader 구현
torch.utils.data.DataLoader 의 collection 을 반환함
class MNISTDataModule(pl.LightningDataModule):
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=64)
4) val_dataloader
class MNISTDataModule(pl.LightningDataModule):
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=64)
5) test_dataloader
class MNISTDataModule(pl.LightningDataModule):
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=64)
6) predict_dataloader
class MNISTDataModule(pl.LightningDataModule):
def predict_dataloader(self):
return DataLoader(self.mnist_predict, batch_size=64)
<참고 링크>