1lint
init commit
6230dda
from pathlib import Path
from PIL import Image
import torchvision
import random
from torch.utils.data import Dataset, DataLoader
from functools import partial
from multiprocessing import cpu_count
from datasets import load_dataset
import cv2
import numpy as np
import torch
class PNGDataset(Dataset):
def __init__(
self,
data_dir,
tokenizer,
from_hf_hub=False,
ucg=0.10,
resolution=(512, 512),
prompt_key="tags",
cond_key="cond",
target_key="image",
controlnet_hint_key=None,
file_extension="png",
):
super().__init__()
vars(self).update(locals())
if from_hf_hub:
self.img_paths = load_dataset(data_dir)["train"]
else:
self.img_paths = list(Path(data_dir).glob(f"*.{file_extension}"))
self.ucg = ucg
self.flip_transform = torchvision.transforms.RandomHorizontalFlip(p=0.5)
self.transforms = torchvision.transforms.Compose(
[
torchvision.transforms.Resize(resolution),
torchvision.transforms.ToTensor(),
]
)
self.normalize = torchvision.transforms.Normalize([0.5], [0.5])
def process_canny(self, image):
# code from https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/controlnet
image = np.array(image)
low_threshold, high_threshold = (100, 200)
image = cv2.Canny(image, low_threshold, high_threshold)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
canny_image = Image.fromarray(image)
return canny_image
def __len__(self):
return len(self.img_paths)
def __getitem__(self, idx):
if self.from_hf_hub:
image = self.img_paths[idx]["image"]
else:
image = Image.open(self.img_paths[idx])
if self.prompt_key not in image.info:
print(f"Image {idx} lacks {self.prompt_key}, skipping to next image")
return self.__getitem__(idx + 1 % len(self))
if random.random() < self.ucg:
tags = ""
else:
tags = image.info[self.prompt_key]
# randomly flip image here so input image to canny has matching flip
image = self.flip_transform(image)
target = self.normalize(self.transforms(image))
output_dict = {self.target_key: target, self.cond_key: tags}
if self.controlnet_hint_key == "canny":
canny_image = self.transforms(self.process_canny(image))
output_dict[self.controlnet_hint_key] = canny_image
return output_dict
def collate_fn(self, samples):
prompts = torch.tensor(
[
self.tokenizer(
sample[self.cond_key],
padding="max_length",
truncation=True,
).input_ids
for sample in samples
]
)
images = torch.stack(
[sample[self.target_key] for sample in samples]
).contiguous()
batch = {
self.cond_key: prompts,
self.target_key: images,
}
if self.controlnet_hint_key is not None:
hint = torch.stack(
[sample[self.controlnet_hint_key] for sample in samples]
).contiguous()
batch[self.controlnet_hint_key] = hint
return batch
class PNGDataModule:
def __init__(
self,
batch_size=1,
num_workers=None,
persistent_workers=True,
**kwargs, # passed to dataset class
):
super().__init__()
vars(self).update(locals())
if num_workers is None:
num_workers = cpu_count() // 2
self.ds_wrapper = partial(PNGDataset, **kwargs)
self.dl_wrapper = partial(
DataLoader,
batch_size=batch_size,
num_workers=num_workers,
persistent_workers=persistent_workers,
)
def get_dataloader(self, data_dir, shuffle=False):
dataset = self.ds_wrapper(data_dir=data_dir)
dataloader = self.dl_wrapper(
dataset, shuffle=shuffle, collate_fn=dataset.collate_fn
)
return dataloader