from pathlib import Path
from PIL import Image
import torchvision
import random

from torch.utils.data import Dataset, DataLoader
from functools import partial
from multiprocessing import cpu_count
from datasets import load_dataset

import cv2
import numpy as np
import torch


class PNGDataset(Dataset):
    def __init__(
        self,
        data_dir,
        tokenizer,
        from_hf_hub=False,
        ucg=0.10,
        resolution=(512, 512),
        prompt_key="tags",
        cond_key="cond",
        target_key="image",
        controlnet_hint_key=None,
        file_extension="png",
    ):
        super().__init__()
        vars(self).update(locals())

        if from_hf_hub:
            self.img_paths = load_dataset(data_dir)["train"]
        else:
            self.img_paths = list(Path(data_dir).glob(f"*.{file_extension}"))

        self.ucg = ucg

        self.flip_transform = torchvision.transforms.RandomHorizontalFlip(p=0.5)
        self.transforms = torchvision.transforms.Compose(
            [
                torchvision.transforms.Resize(resolution),
                torchvision.transforms.ToTensor(),
            ]
        )
        self.normalize = torchvision.transforms.Normalize([0.5], [0.5])

    def process_canny(self, image):
        # code from https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/controlnet
        image = np.array(image)
        low_threshold, high_threshold = (100, 200)
        image = cv2.Canny(image, low_threshold, high_threshold)
        image = image[:, :, None]
        image = np.concatenate([image, image, image], axis=2)
        canny_image = Image.fromarray(image)

        return canny_image

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        if self.from_hf_hub:
            image = self.img_paths[idx]["image"]
        else:
            image = Image.open(self.img_paths[idx])

        if self.prompt_key not in image.info:
            print(f"Image {idx} lacks {self.prompt_key}, skipping to next image")
            return self.__getitem__(idx + 1 % len(self))

        if random.random() < self.ucg:
            tags = ""
        else:
            tags = image.info[self.prompt_key]

        # randomly flip image here so input image to canny has matching flip
        image = self.flip_transform(image)

        target = self.normalize(self.transforms(image))

        output_dict = {self.target_key: target, self.cond_key: tags}

        if self.controlnet_hint_key == "canny":
            canny_image = self.transforms(self.process_canny(image))
            output_dict[self.controlnet_hint_key] = canny_image

        return output_dict

    def collate_fn(self, samples):
        prompts = torch.tensor(
            [
                self.tokenizer(
                    sample[self.cond_key],
                    padding="max_length",
                    truncation=True,
                ).input_ids
                for sample in samples
            ]
        )

        images = torch.stack(
            [sample[self.target_key] for sample in samples]
        ).contiguous()

        batch = {
            self.cond_key: prompts,
            self.target_key: images,
        }

        if self.controlnet_hint_key is not None:
            hint = torch.stack(
                [sample[self.controlnet_hint_key] for sample in samples]
            ).contiguous()
            batch[self.controlnet_hint_key] = hint

        return batch


class PNGDataModule:
    def __init__(
        self,
        batch_size=1,
        num_workers=None,
        persistent_workers=True,
        **kwargs,  # passed to dataset class
    ):
        super().__init__()
        vars(self).update(locals())

        if num_workers is None:
            num_workers = cpu_count() // 2

        self.ds_wrapper = partial(PNGDataset, **kwargs)

        self.dl_wrapper = partial(
            DataLoader,
            batch_size=batch_size,
            num_workers=num_workers,
            persistent_workers=persistent_workers,
        )

    def get_dataloader(self, data_dir, shuffle=False):
        dataset = self.ds_wrapper(data_dir=data_dir)
        dataloader = self.dl_wrapper(
            dataset, shuffle=shuffle, collate_fn=dataset.collate_fn
        )
        return dataloader