from io import BytesIO
from typing import List, Optional, Union

import torch
from cv2 import inpaint
from diffusers import (
    ControlNetModel,
    StableDiffusionControlNetInpaintPipeline,
    StableDiffusionControlNetPipeline,
    StableDiffusionInpaintPipeline,
    UniPCMultistepScheduler,
)
from PIL import Image, ImageFilter, ImageOps

import internals.util.image as ImageUtil
from internals.data.result import Result
from internals.data.task import ModelType
from internals.pipelines.commons import AbstractPipeline
from internals.pipelines.controlnets import ControlNet
from internals.pipelines.high_res import HighRes
from internals.pipelines.inpainter import InPainter
from internals.pipelines.remove_background import RemoveBackgroundV2
from internals.pipelines.upscaler import Upscaler
from internals.util.cache import clear_cuda_and_gc
from internals.util.commons import download_image
from internals.util.config import (
    get_hf_cache_dir,
    get_hf_token,
    get_inpaint_model_path,
    get_model_dir,
)


class ReplaceBackground(AbstractPipeline):
    __loaded = False

    def load(
        self,
        upscaler: Optional[Upscaler] = None,
        remove_background: Optional[RemoveBackgroundV2] = None,
        base: Optional[AbstractPipeline] = None,
        high_res: Optional[HighRes] = None,
    ):
        if self.__loaded:
            return
        controlnet_model = ControlNetModel.from_pretrained(
            "lllyasviel/control_v11p_sd15_canny",
            torch_dtype=torch.float16,
            cache_dir=get_hf_cache_dir(),
        ).to("cuda")
        if base:
            pipe = StableDiffusionControlNetPipeline(
                **base.pipe.components,
                controlnet=controlnet_model,
            )
        else:
            pipe = StableDiffusionControlNetPipeline.from_pretrained(
                get_model_dir(),
                controlnet=controlnet_model,
                torch_dtype=torch.float16,
                cache_dir=get_hf_cache_dir(),
                use_auth_token=get_hf_token(),
            )
        pipe.enable_xformers_memory_efficient_attention()
        pipe.enable_vae_slicing()
        pipe.to("cuda")

        self.pipe = pipe

        if not high_res:
            high_res = HighRes()
        high_res.load()
        self.high_res = high_res

        if not upscaler:
            upscaler = Upscaler()
        upscaler.load()
        self.upscaler = upscaler

        if not remove_background:
            remove_background = RemoveBackgroundV2()
        self.remove_background = remove_background

        self.__loaded = True

    def unload(self):
        self.__loaded = False
        self.pipe = None
        self.high_res = None
        self.upscaler = None
        self.remove_background = None

        clear_cuda_and_gc()

    @torch.inference_mode()
    def replace(
        self,
        image: Union[str, Image.Image],
        width: int,
        height: int,
        prompt: List[str],
        negative_prompt: List[str],
        conditioning_scale: float,
        seed: int,
        steps: int,
        apply_high_res: bool = False,
        model_type: ModelType = ModelType.REAL,
    ):
        if type(image) is str:
            image = download_image(image)

        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

        image = image.convert("RGB")
        if max(image.size) > 1024:
            image = ImageUtil.resize_image(image, dimension=1024)
        image = self.remove_background.remove(image, model_type=model_type)

        width = int(width)
        height = int(height)

        resolution = max(width, height)

        image = ImageUtil.resize_image(image, resolution)
        image = ImageUtil.padd_image(image, width, height)

        mask = image.copy()
        pixdata = mask.load()

        w, h = mask.size
        for y in range(h):
            for x in range(w):
                item = pixdata[x, y]
                if item[3] == 0:
                    pixdata[x, y] = (255, 255, 255, 255)
                else:
                    pixdata[x, y] = (0, 0, 0, 255)

        condition_image = ControlNet.canny_detect_edge(image)
        mask = mask.convert("RGB")

        result = self.pipe.__call__(
            prompt=prompt,
            negative_prompt=negative_prompt,
            image=condition_image,
            controlnet_conditioning_scale=conditioning_scale,
            guidance_scale=9,
            height=height,
            num_inference_steps=steps,
            width=width,
        )
        result = Result.from_result(result)

        images, has_nsfw = result

        if not has_nsfw:
            for i in range(len(images)):
                images[i].paste(image, (0, 0), image)

        return (images, has_nsfw)