Spaces:
Runtime error
Runtime error
from pathlib import Path | |
from PIL import Image | |
import torchvision | |
import random | |
from torch.utils.data import Dataset, DataLoader | |
from functools import partial | |
from multiprocessing import cpu_count | |
from datasets import load_dataset | |
import cv2 | |
import numpy as np | |
import torch | |
class PNGDataset(Dataset): | |
def __init__( | |
self, | |
data_dir, | |
tokenizer, | |
from_hf_hub=False, | |
ucg=0.10, | |
resolution=(512, 512), | |
prompt_key="tags", | |
cond_key="cond", | |
target_key="image", | |
controlnet_hint_key=None, | |
file_extension="png", | |
): | |
super().__init__() | |
vars(self).update(locals()) | |
if from_hf_hub: | |
self.img_paths = load_dataset(data_dir)["train"] | |
else: | |
self.img_paths = list(Path(data_dir).glob(f"*.{file_extension}")) | |
self.ucg = ucg | |
self.flip_transform = torchvision.transforms.RandomHorizontalFlip(p=0.5) | |
self.transforms = torchvision.transforms.Compose( | |
[ | |
torchvision.transforms.Resize(resolution), | |
torchvision.transforms.ToTensor(), | |
] | |
) | |
self.normalize = torchvision.transforms.Normalize([0.5], [0.5]) | |
def process_canny(self, image): | |
# code from https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/controlnet | |
image = np.array(image) | |
low_threshold, high_threshold = (100, 200) | |
image = cv2.Canny(image, low_threshold, high_threshold) | |
image = image[:, :, None] | |
image = np.concatenate([image, image, image], axis=2) | |
canny_image = Image.fromarray(image) | |
return canny_image | |
def __len__(self): | |
return len(self.img_paths) | |
def __getitem__(self, idx): | |
if self.from_hf_hub: | |
image = self.img_paths[idx]["image"] | |
else: | |
image = Image.open(self.img_paths[idx]) | |
if self.prompt_key not in image.info: | |
print(f"Image {idx} lacks {self.prompt_key}, skipping to next image") | |
return self.__getitem__(idx + 1 % len(self)) | |
if random.random() < self.ucg: | |
tags = "" | |
else: | |
tags = image.info[self.prompt_key] | |
# randomly flip image here so input image to canny has matching flip | |
image = self.flip_transform(image) | |
target = self.normalize(self.transforms(image)) | |
output_dict = {self.target_key: target, self.cond_key: tags} | |
if self.controlnet_hint_key == "canny": | |
canny_image = self.transforms(self.process_canny(image)) | |
output_dict[self.controlnet_hint_key] = canny_image | |
return output_dict | |
def collate_fn(self, samples): | |
prompts = torch.tensor( | |
[ | |
self.tokenizer( | |
sample[self.cond_key], | |
padding="max_length", | |
truncation=True, | |
).input_ids | |
for sample in samples | |
] | |
) | |
images = torch.stack( | |
[sample[self.target_key] for sample in samples] | |
).contiguous() | |
batch = { | |
self.cond_key: prompts, | |
self.target_key: images, | |
} | |
if self.controlnet_hint_key is not None: | |
hint = torch.stack( | |
[sample[self.controlnet_hint_key] for sample in samples] | |
).contiguous() | |
batch[self.controlnet_hint_key] = hint | |
return batch | |
class PNGDataModule: | |
def __init__( | |
self, | |
batch_size=1, | |
num_workers=None, | |
persistent_workers=True, | |
**kwargs, # passed to dataset class | |
): | |
super().__init__() | |
vars(self).update(locals()) | |
if num_workers is None: | |
num_workers = cpu_count() // 2 | |
self.ds_wrapper = partial(PNGDataset, **kwargs) | |
self.dl_wrapper = partial( | |
DataLoader, | |
batch_size=batch_size, | |
num_workers=num_workers, | |
persistent_workers=persistent_workers, | |
) | |
def get_dataloader(self, data_dir, shuffle=False): | |
dataset = self.ds_wrapper(data_dir=data_dir) | |
dataloader = self.dl_wrapper( | |
dataset, shuffle=shuffle, collate_fn=dataset.collate_fn | |
) | |
return dataloader | |