https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#lightningdatamodule

 

LightningDataModule — PyTorch Lightning 1.6.0dev documentation

LightningDataModule A datamodule is a shareable, reusable class that encapsulates all the steps needed to process data: A datamodule encapsulates the five steps involved in data processing in PyTorch: Download / tokenize / process. Clean and (maybe) save t

pytorch-lightning.readthedocs.io

 

1. DataModule이란?

train_dataloader, val_dataloder, test_dataloader, predict_dataloader의 집합체

# PyTorch Example
test_data = MNIST(my_path, train=False, download=True)
predict_data = MNIST(my_path, train=False, download=True)
train_data = MNIST(my_path, train=True, download=True)
train_data, val_data = random_split(train_data, [55000, 5000])

train_loader = DataLoader(train_data, batch_size=32)
val_loader = DataLoader(val_data, batch_size=32)
test_loader = DataLoader(test_data, batch_size=32)
predict_loader = DataLoader(predict_data, batch_size=32)

 

2. LightningDataModule API

 

1) prepare_data (how to download, tokenize, etc…)

multiple processes를 이용해서 data를 download

class MNISTDataModule(pl.LightningDataModule):
    def prepare_data(self):
        # download
        MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
        MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())

 

2) setup (stage = None)

Stage에 따라 수행하는 게 다름 (fit (train + validate) / validate / test / predict) - none의 경우 모두 수행함

- train/val/test split

- dataset 생성

class MNISTDataModule(pl.LightningDataModule):
    def setup(self, stage: Optional[str] = None):

        # Assign Train/val split(s) for use in Dataloaders
        if stage in (None, "fit"):
            mnist_full = MNIST(self.data_dir, train=True, download=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign Test split(s) for use in Dataloaders
        if stage in (None, "test"):
            self.mnist_test = MNIST(self.data_dir, train=False, download=True, transform=self.transform)

 

3) train_dataloader

- 하나 이상의 training dataloader 구현

torch.utils.data.DataLoader 의 collection 을 반환함

class MNISTDataModule(pl.LightningDataModule):
    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=64)

 

4) val_dataloader

class MNISTDataModule(pl.LightningDataModule):
    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=64)

 

5) test_dataloader

class MNISTDataModule(pl.LightningDataModule):
    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=64)

 

6) predict_dataloader

class MNISTDataModule(pl.LightningDataModule):
    def predict_dataloader(self):
        return DataLoader(self.mnist_predict, batch_size=64)

 

 

<참고 링크>

+ Recent posts