Spaces:
Runtime error
Runtime error
import os | |
from torch.utils.data import Dataset | |
from PIL import Image | |
from PTI.utils.data_utils import make_dataset | |
from torchvision import transforms | |
class Image2Dataset(Dataset): | |
def __init__(self, image) -> None: | |
super().__init__() | |
self.image = image | |
self.transform = transforms.Compose( | |
[ | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | |
] | |
) | |
def __len__(self): | |
return 1 | |
def __getitem__(self, index): | |
return "customIMG", self.transform(self.image) | |
class ImagesDataset(Dataset): | |
def __init__(self, source_root, source_transform=None): | |
self.source_paths = sorted(make_dataset(source_root)) | |
self.source_transform = source_transform | |
def __len__(self): | |
return len(self.source_paths) | |
def __getitem__(self, index): | |
fname, from_path = self.source_paths[index] | |
from_im = Image.open(from_path).convert("RGB").resize([1024, 1024]) | |
if self.source_transform: | |
from_im = self.source_transform(from_im) | |
return fname, from_im | |