from pathlib import Path from PIL import Image import torchvision import random from torch.utils.data import Dataset, DataLoader from functools import partial from multiprocessing import cpu_count from datasets import load_dataset import cv2 import numpy as np import torch class PNGDataset(Dataset): def __init__( self, data_dir, tokenizer, from_hf_hub=False, ucg=0.10, resolution=(512, 512), prompt_key="tags", cond_key="cond", target_key="image", controlnet_hint_key=None, file_extension="png", ): super().__init__() vars(self).update(locals()) if from_hf_hub: self.img_paths = load_dataset(data_dir)["train"] else: self.img_paths = list(Path(data_dir).glob(f"*.{file_extension}")) self.ucg = ucg self.flip_transform = torchvision.transforms.RandomHorizontalFlip(p=0.5) self.transforms = torchvision.transforms.Compose( [ torchvision.transforms.Resize(resolution), torchvision.transforms.ToTensor(), ] ) self.normalize = torchvision.transforms.Normalize([0.5], [0.5]) def process_canny(self, image): # code from https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/controlnet image = np.array(image) low_threshold, high_threshold = (100, 200) image = cv2.Canny(image, low_threshold, high_threshold) image = image[:, :, None] image = np.concatenate([image, image, image], axis=2) canny_image = Image.fromarray(image) return canny_image def __len__(self): return len(self.img_paths) def __getitem__(self, idx): if self.from_hf_hub: image = self.img_paths[idx]["image"] else: image = Image.open(self.img_paths[idx]) if self.prompt_key not in image.info: print(f"Image {idx} lacks {self.prompt_key}, skipping to next image") return self.__getitem__(idx + 1 % len(self)) if random.random() < self.ucg: tags = "" else: tags = image.info[self.prompt_key] # randomly flip image here so input image to canny has matching flip image = self.flip_transform(image) target = self.normalize(self.transforms(image)) output_dict = {self.target_key: target, self.cond_key: tags} if self.controlnet_hint_key == "canny": canny_image = self.transforms(self.process_canny(image)) output_dict[self.controlnet_hint_key] = canny_image return output_dict def collate_fn(self, samples): prompts = torch.tensor( [ self.tokenizer( sample[self.cond_key], padding="max_length", truncation=True, ).input_ids for sample in samples ] ) images = torch.stack( [sample[self.target_key] for sample in samples] ).contiguous() batch = { self.cond_key: prompts, self.target_key: images, } if self.controlnet_hint_key is not None: hint = torch.stack( [sample[self.controlnet_hint_key] for sample in samples] ).contiguous() batch[self.controlnet_hint_key] = hint return batch class PNGDataModule: def __init__( self, batch_size=1, num_workers=None, persistent_workers=True, **kwargs, # passed to dataset class ): super().__init__() vars(self).update(locals()) if num_workers is None: num_workers = cpu_count() // 2 self.ds_wrapper = partial(PNGDataset, **kwargs) self.dl_wrapper = partial( DataLoader, batch_size=batch_size, num_workers=num_workers, persistent_workers=persistent_workers, ) def get_dataloader(self, data_dir, shuffle=False): dataset = self.ds_wrapper(data_dir=data_dir) dataloader = self.dl_wrapper( dataset, shuffle=shuffle, collate_fn=dataset.collate_fn ) return dataloader