Spaces:
Runtime error
Runtime error
import torchvision | |
import lightning as L | |
from torch.utils.data import DataLoader | |
from utils.transforms import train_transform, test_transform | |
class Cifar10SearchDataset(torchvision.datasets.CIFAR10): | |
def __init__(self, root="~/data", train=True, download=True, transform=None): | |
super().__init__(root=root, train=train, download=download, transform=transform) | |
def __getitem__(self, index): | |
image, label = self.data[index], self.targets[index] | |
if self.transform is not None: | |
transformed = self.transform(image=image) | |
image = transformed["image"] | |
return image, label | |
class CIFARDataModule(L.LightningDataModule): | |
def __init__( | |
self, data_dir="data", batch_size=512, shuffle=True, num_workers=4 | |
) -> None: | |
super().__init__() | |
self.data_dir = data_dir | |
self.batch_size = batch_size | |
self.shuffle = shuffle | |
self.num_workers = num_workers | |
def prepare_data(self) -> None: | |
pass | |
def setup(self, stage=None): | |
self.train_dataset = Cifar10SearchDataset( | |
root=self.data_dir, train=True, transform=train_transform | |
) | |
self.val_dataset = Cifar10SearchDataset( | |
root=self.data_dir, train=False, transform=test_transform | |
) | |
self.test_dataset = Cifar10SearchDataset( | |
root=self.data_dir, train=False, transform=test_transform | |
) | |
def train_dataloader(self): | |
return DataLoader( | |
dataset=self.train_dataset, | |
batch_size=self.batch_size, | |
shuffle=self.shuffle, | |
num_workers=self.num_workers, | |
) | |
def val_dataloader(self): | |
return DataLoader( | |
dataset=self.val_dataset, | |
batch_size=self.batch_size, | |
shuffle=self.shuffle, | |
num_workers=self.num_workers, | |
) | |
def test_dataloader(self): | |
return DataLoader( | |
dataset=self.test_dataset, | |
batch_size=self.batch_size, | |
shuffle=self.shuffle, | |
num_workers=self.num_workers, | |
) | |