Spaces:
Runtime error
Runtime error
from typing import Any, Dict, Optional, Tuple | |
import torch | |
from lightning import LightningDataModule | |
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split | |
from torchvision.datasets import MNIST | |
from torchvision.transforms import transforms | |
class MNISTDataModule(LightningDataModule): | |
"""Example of LightningDataModule for MNIST dataset. | |
A DataModule implements 6 key methods: | |
def prepare_data(self): | |
# things to do on 1 GPU/TPU (not on every GPU/TPU in DDP) | |
# download data, pre-process, split, save to disk, etc... | |
def setup(self, stage): | |
# things to do on every process in DDP | |
# load data, set variables, etc... | |
def train_dataloader(self): | |
# return train dataloader | |
def val_dataloader(self): | |
# return validation dataloader | |
def test_dataloader(self): | |
# return test dataloader | |
def teardown(self): | |
# called on every process in DDP | |
# clean up after fit or test | |
This allows you to share a full dataset without explaining how to download, | |
split, transform and process the data. | |
Read the docs: | |
https://lightning.ai/docs/pytorch/latest/data/datamodule.html | |
""" | |
def __init__( | |
self, | |
data_dir: str = "data/", | |
train_val_test_split: Tuple[int, int, int] = (55_000, 5_000, 10_000), | |
batch_size: int = 64, | |
num_workers: int = 0, | |
pin_memory: bool = False, | |
): | |
super().__init__() | |
# this line allows to access init params with 'self.hparams' attribute | |
# also ensures init params will be stored in ckpt | |
self.save_hyperparameters(logger=False) | |
# data transformations | |
self.transforms = transforms.Compose( | |
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] | |
) | |
self.data_train: Optional[Dataset] = None | |
self.data_val: Optional[Dataset] = None | |
self.data_test: Optional[Dataset] = None | |
def num_classes(self): | |
return 10 | |
def prepare_data(self): | |
"""Download data if needed. | |
Do not use it to assign state (self.x = y). | |
""" | |
MNIST(self.hparams.data_dir, train=True, download=True) | |
MNIST(self.hparams.data_dir, train=False, download=True) | |
def setup(self, stage: Optional[str] = None): | |
"""Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. | |
This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be | |
careful not to execute things like random split twice! | |
""" | |
# load and split datasets only if not loaded already | |
if not self.data_train and not self.data_val and not self.data_test: | |
trainset = MNIST(self.hparams.data_dir, train=True, transform=self.transforms) | |
testset = MNIST(self.hparams.data_dir, train=False, transform=self.transforms) | |
dataset = ConcatDataset(datasets=[trainset, testset]) | |
self.data_train, self.data_val, self.data_test = random_split( | |
dataset=dataset, | |
lengths=self.hparams.train_val_test_split, | |
generator=torch.Generator().manual_seed(42), | |
) | |
def train_dataloader(self): | |
return DataLoader( | |
dataset=self.data_train, | |
batch_size=self.hparams.batch_size, | |
num_workers=self.hparams.num_workers, | |
pin_memory=self.hparams.pin_memory, | |
shuffle=True, | |
) | |
def val_dataloader(self): | |
return DataLoader( | |
dataset=self.data_val, | |
batch_size=self.hparams.batch_size, | |
num_workers=self.hparams.num_workers, | |
pin_memory=self.hparams.pin_memory, | |
shuffle=False, | |
) | |
def test_dataloader(self): | |
return DataLoader( | |
dataset=self.data_test, | |
batch_size=self.hparams.batch_size, | |
num_workers=self.hparams.num_workers, | |
pin_memory=self.hparams.pin_memory, | |
shuffle=False, | |
) | |
def teardown(self, stage: Optional[str] = None): | |
"""Clean up after fit or test.""" | |
pass | |
def state_dict(self): | |
"""Extra things to save to checkpoint.""" | |
return {} | |
def load_state_dict(self, state_dict: Dict[str, Any]): | |
"""Things to do when loading checkpoint.""" | |
pass | |
if __name__ == "__main__": | |
_ = MNISTDataModule() | |