from typing import List, Union

import torch
from diffusers import (
    StableDiffusionInpaintPipeline,
    StableDiffusionXLInpaintPipeline,
    UNet2DConditionModel,
)

from internals.pipelines.commons import AbstractPipeline
from internals.pipelines.high_res import HighRes
from internals.pipelines.inpaint_imageprocessor import VaeImageProcessor
from internals.util import get_generators
from internals.util.cache import clear_cuda_and_gc
from internals.util.commons import disable_safety_checker, download_image
from internals.util.config import (
    get_base_inpaint_model_revision,
    get_base_inpaint_model_variant,
    get_hf_cache_dir,
    get_hf_token,
    get_inpaint_model_path,
    get_is_sdxl,
    get_model_dir,
    get_num_return_sequences,
)


class InPainter(AbstractPipeline):
    __loaded = False

    def init(self, pipeline: AbstractPipeline):
        self.__base = pipeline

    def load(self):
        if self.__loaded:
            return

        if hasattr(self, "__base") and get_inpaint_model_path() == get_model_dir():
            self.create(self.__base)
            self.__loaded = True
            return

        if get_is_sdxl():
            # only take UNet from the repo
            unet = UNet2DConditionModel.from_pretrained(
                get_inpaint_model_path(),
                torch_dtype=torch.float16,
                cache_dir=get_hf_cache_dir(),
                token=get_hf_token(),
                subfolder="unet",
                variant=get_base_inpaint_model_variant(),
                revision=get_base_inpaint_model_revision(),
            ).to("cuda")
            kwargs = {**self.__base.pipe.components, "unet": unet}
            self.pipe = StableDiffusionXLInpaintPipeline(**kwargs).to("cuda")
            self.pipe.mask_processor = VaeImageProcessor(
                vae_scale_factor=self.pipe.vae_scale_factor,
                do_normalize=False,
                do_binarize=True,
                do_convert_grayscale=True,
            )
            self.pipe.image_processor = VaeImageProcessor(
                vae_scale_factor=self.pipe.vae_scale_factor
            )
        else:
            self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
                get_inpaint_model_path(),
                torch_dtype=torch.float16,
                cache_dir=get_hf_cache_dir(),
                token=get_hf_token(),
            ).to("cuda")

        disable_safety_checker(self.pipe)

        self.__patch()

        self.__loaded = True

    def create(self, pipeline: AbstractPipeline):
        if get_is_sdxl():
            self.pipe = StableDiffusionXLInpaintPipeline(**pipeline.pipe.components).to(
                "cuda"
            )
        else:
            self.pipe = StableDiffusionInpaintPipeline(**pipeline.pipe.components).to(
                "cuda"
            )
        disable_safety_checker(self.pipe)

        self.__patch()

    def __patch(self):
        if get_is_sdxl():
            self.pipe.enable_vae_tiling()
            self.pipe.enable_vae_slicing()
        self.pipe.enable_xformers_memory_efficient_attention()

    def unload(self):
        self.__loaded = False
        self.pipe = None
        clear_cuda_and_gc()

    @torch.inference_mode()
    def process(
        self,
        image_url: str,
        mask_image_url: str,
        width: int,
        height: int,
        seed: int,
        prompt: Union[str, List[str]],
        negative_prompt: Union[str, List[str]],
        num_inference_steps: int,
        **kwargs,
    ):
        generator = get_generators(seed, get_num_return_sequences())

        input_img = download_image(image_url).resize((width, height))
        mask_img = download_image(mask_image_url).resize((width, height))

        if get_is_sdxl():
            width, height = HighRes.find_closest_sdxl_aspect_ratio(width, height)
            mask_img = self.pipe.mask_processor.blur(mask_img, blur_factor=33)

            kwargs["strength"] = 0.999
            kwargs["padding_mask_crop"] = 1000

        kwargs = {
            "prompt": prompt,
            "image": input_img,
            "mask_image": mask_img,
            "height": height,
            "width": width,
            "negative_prompt": negative_prompt,
            "num_inference_steps": num_inference_steps,
            "strength": 1.0,
            "generator": generator,
            **kwargs,
        }
        return self.pipe.__call__(**kwargs).images, mask_img