import gc

import numpy as np
import PIL.Image
import torch
import torchvision
from controlnet_aux import (
    CannyDetector,
    ContentShuffleDetector,
    HEDdetector,
    LineartAnimeDetector,
    LineartDetector,
    MidasDetector,
    MLSDdetector,
    NormalBaeDetector,
    OpenposeDetector,
    PidiNetDetector,
)
from controlnet_aux.util import HWC3

from cv_utils import resize_image
from depth_estimator import DepthEstimator
from image_segmentor import ImageSegmentor

from kornia.core import Tensor

# load preprocessor

# HED = HEDdetector.from_pretrained("lllyasviel/Annotators")
Midas = MidasDetector.from_pretrained("lllyasviel/Annotators")
MLSD = MLSDdetector.from_pretrained("lllyasviel/Annotators")
Canny = CannyDetector()
OPENPOSE =  OpenposeDetector.from_pretrained("lllyasviel/Annotators")


class Preprocessor:
    MODEL_ID = "lllyasviel/Annotators"

    def __init__(self):
        self.model = None
        self.name = ""

    def load(self, name: str) -> None:
        if name == self.name:
            return
        
        if name == "Midas":
            self.model = Midas
        elif name == "MLSD":
            self.model =MLSD
        elif name == "Openpose":
            self.model = OPENPOSE
        elif name == "Canny":
            self.model = Canny
        else:
            raise ValueError
        torch.cuda.empty_cache()
        gc.collect()
        self.name = name

    def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image:
        if self.name == "Canny" or self.name == "MLSD":
            detect_resolution = kwargs.pop("detect_resolution")
            image_resolution = kwargs.pop("image_resolution", 512)
            image = np.array(image)
            image = HWC3(image)
            image = resize_image(image, resolution=detect_resolution)
            image = self.model(image, **kwargs)
            image = np.array(image)
            image = HWC3(image)
            image = resize_image(image, resolution=image_resolution)
            return PIL.Image.fromarray(image).convert('RGB')

        else:
            detect_resolution = kwargs.pop("detect_resolution", 512)
            image_resolution = kwargs.pop("image_resolution", 512)
            image = np.array(image)
            image = HWC3(image)
            image = resize_image(image, resolution=detect_resolution)
            image = self.model(image, **kwargs)
            image = np.array(image)
            image = HWC3(image)
            image = resize_image(image, resolution=image_resolution)
            return PIL.Image.fromarray(image)