import gc import numpy as np import PIL.Image import torch import torchvision from controlnet_aux import ( CannyDetector, ContentShuffleDetector, HEDdetector, LineartAnimeDetector, LineartDetector, MidasDetector, MLSDdetector, NormalBaeDetector, OpenposeDetector, PidiNetDetector, ) from controlnet_aux.util import HWC3 from cv_utils import resize_image from depth_estimator import DepthEstimator from image_segmentor import ImageSegmentor from kornia.core import Tensor from kornia.filters import canny class Canny: def __call__( self, images: np.array, low_threshold: float = 0.1, high_threshold: float = 0.2, kernel_size: tuple[int, int] | int = (5, 5), sigma: tuple[float, float] | Tensor = (1, 1), hysteresis: bool = True, eps: float = 1e-6 ) -> torch.Tensor: assert low_threshold is not None, "low_threshold must be provided" assert high_threshold is not None, "high_threshold must be provided" images = torch.from_numpy(images).permute(2, 0, 1).unsqueeze(0) / 255.0 images_tensor = canny(images, low_threshold, high_threshold, kernel_size, sigma, hysteresis, eps)[1] images_tensor = (images_tensor[0][0].numpy() * 255).astype(np.uint8) return images_tensor class Preprocessor: MODEL_ID = "lllyasviel/Annotators" def __init__(self): self.model = None self.name = "" def load(self, name: str) -> None: if name == self.name: return if name == "Canny": self.model = Canny() elif name == "DPT": self.model = DepthEstimator() else: raise ValueError torch.cuda.empty_cache() gc.collect() self.name = name def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image: if self.name == "Canny": if "detect_resolution" in kwargs: detect_resolution = kwargs.pop("detect_resolution") image = np.array(image) image = HWC3(image) image = resize_image(image, resolution=detect_resolution) image = self.model(image, **kwargs) return PIL.Image.fromarray(image).convert('RGB') elif self.name == "Midas": detect_resolution = kwargs.pop("detect_resolution", 512) image_resolution = kwargs.pop("image_resolution", 512) image = np.array(image) image = HWC3(image) image = resize_image(image, resolution=detect_resolution) image = self.model(image, **kwargs) image = HWC3(image) image = resize_image(image, resolution=image_resolution) return PIL.Image.fromarray(image) else: return self.model(image, **kwargs)