|
from typing import List, Union |
|
|
|
import cv2 |
|
import numpy as np |
|
import torch |
|
from controlnet_aux import HEDdetector, LineartDetector, OpenposeDetector |
|
from diffusers import ( |
|
ControlNetModel, |
|
DiffusionPipeline, |
|
StableDiffusionControlNetPipeline, |
|
UniPCMultistepScheduler, |
|
) |
|
from PIL import Image |
|
from torch.nn import Linear |
|
from tqdm import gui |
|
|
|
from internals.data.result import Result |
|
from internals.pipelines.commons import AbstractPipeline |
|
from internals.pipelines.tileUpscalePipeline import ( |
|
StableDiffusionControlNetImg2ImgPipeline, |
|
) |
|
from internals.util.cache import clear_cuda_and_gc |
|
from internals.util.commons import download_image |
|
|
|
|
|
class ControlNet(AbstractPipeline): |
|
__current_task_name = "" |
|
|
|
def load(self, model_dir: str): |
|
|
|
self.load_scribble() |
|
|
|
|
|
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( |
|
model_dir, |
|
controlnet=self.controlnet, |
|
torch_dtype=torch.float16, |
|
).to("cuda") |
|
|
|
pipe.enable_model_cpu_offload() |
|
pipe.enable_xformers_memory_efficient_attention() |
|
self.pipe = pipe |
|
|
|
|
|
pipe2 = StableDiffusionControlNetPipeline(**pipe.components).to("cuda") |
|
pipe2.scheduler = UniPCMultistepScheduler.from_config(pipe2.scheduler.config) |
|
pipe2.enable_xformers_memory_efficient_attention() |
|
self.pipe2 = pipe2 |
|
|
|
def load_canny(self): |
|
if self.__current_task_name == "canny": |
|
return |
|
canny = ControlNetModel.from_pretrained( |
|
"lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16 |
|
).to("cuda") |
|
self.__current_task_name = "canny" |
|
self.controlnet = canny |
|
if hasattr(self, "pipe"): |
|
self.pipe.controlnet = canny |
|
if hasattr(self, "pipe2"): |
|
self.pipe2.controlnet = canny |
|
clear_cuda_and_gc() |
|
|
|
def load_pose(self): |
|
if self.__current_task_name == "pose": |
|
return |
|
pose = ControlNetModel.from_pretrained( |
|
"lllyasviel/sd-controlnet-openpose", torch_dtype=torch.float16 |
|
).to("cuda") |
|
self.__current_task_name = "pose" |
|
self.controlnet = pose |
|
if hasattr(self, "pipe"): |
|
self.pipe.controlnet = pose |
|
if hasattr(self, "pipe2"): |
|
self.pipe2.controlnet = pose |
|
clear_cuda_and_gc() |
|
|
|
def load_tile_upscaler(self): |
|
if self.__current_task_name == "tile_upscaler": |
|
return |
|
tile_upscaler = ControlNetModel.from_pretrained( |
|
"lllyasviel/control_v11f1e_sd15_tile", torch_dtype=torch.float16 |
|
).to("cuda") |
|
self.__current_task_name = "tile_upscaler" |
|
self.controlnet = tile_upscaler |
|
if hasattr(self, "pipe"): |
|
self.pipe.controlnet = tile_upscaler |
|
if hasattr(self, "pipe2"): |
|
self.pipe2.controlnet = tile_upscaler |
|
clear_cuda_and_gc() |
|
|
|
def load_scribble(self): |
|
if self.__current_task_name == "scribble": |
|
return |
|
scribble = ControlNetModel.from_pretrained( |
|
"lllyasviel/control_v11p_sd15_scribble", torch_dtype=torch.float16 |
|
).to("cuda") |
|
self.__current_task_name = "scribble" |
|
self.controlnet = scribble |
|
if hasattr(self, "pipe"): |
|
self.pipe.controlnet = scribble |
|
if hasattr(self, "pipe2"): |
|
self.pipe2.controlnet = scribble |
|
clear_cuda_and_gc() |
|
|
|
def load_linearart(self): |
|
if self.__current_task_name == "linearart": |
|
return |
|
linearart = ControlNetModel.from_pretrained( |
|
"ControlNet-1-1-preview/control_v11p_sd15_lineart", |
|
torch_dtype=torch.float16, |
|
).to("cuda") |
|
self.__current_task_name = "linearart" |
|
self.controlnet = linearart |
|
if hasattr(self, "pipe"): |
|
self.pipe.controlnet = linearart |
|
if hasattr(self, "pipe2"): |
|
self.pipe2.controlnet = linearart |
|
clear_cuda_and_gc() |
|
|
|
def cleanup(self): |
|
self.pipe.controlnet = None |
|
self.pipe2.controlnet = None |
|
self.controlnet = None |
|
self.__current_task_name = "" |
|
|
|
clear_cuda_and_gc() |
|
|
|
@torch.inference_mode() |
|
def process_canny( |
|
self, |
|
prompt: List[str], |
|
imageUrl: str, |
|
seed: int, |
|
steps: int, |
|
negative_prompt: List[str], |
|
guidance_scale: float, |
|
height: int, |
|
width: int, |
|
): |
|
if self.__current_task_name != "canny": |
|
raise Exception("ControlNet is not loaded with canny model") |
|
|
|
torch.manual_seed(seed) |
|
|
|
init_image = download_image(imageUrl).resize((width, height)) |
|
init_image = self.__canny_detect_edge(init_image) |
|
|
|
result = self.pipe2.__call__( |
|
prompt=prompt, |
|
image=init_image, |
|
guidance_scale=guidance_scale, |
|
num_images_per_prompt=1, |
|
negative_prompt=negative_prompt, |
|
num_inference_steps=steps, |
|
height=height, |
|
width=width, |
|
) |
|
return Result.from_result(result) |
|
|
|
@torch.inference_mode() |
|
def process_pose( |
|
self, |
|
prompt: List[str], |
|
image: List[Image.Image], |
|
seed: int, |
|
steps: int, |
|
guidance_scale: float, |
|
negative_prompt: List[str], |
|
height: int, |
|
width: int, |
|
): |
|
if self.__current_task_name != "pose": |
|
raise Exception("ControlNet is not loaded with pose model") |
|
|
|
torch.manual_seed(seed) |
|
|
|
result = self.pipe2.__call__( |
|
prompt=prompt, |
|
image=image, |
|
num_images_per_prompt=1, |
|
num_inference_steps=steps, |
|
negative_prompt=negative_prompt, |
|
guidance_scale=guidance_scale, |
|
height=height, |
|
width=width, |
|
) |
|
return Result.from_result(result) |
|
|
|
@torch.inference_mode() |
|
def process_tile_upscaler( |
|
self, |
|
imageUrl: str, |
|
prompt: str, |
|
negative_prompt: str, |
|
steps: int, |
|
seed: int, |
|
height: int, |
|
width: int, |
|
resize_dimension: int, |
|
guidance_scale: float, |
|
): |
|
if self.__current_task_name != "tile_upscaler": |
|
raise Exception("ControlNet is not loaded with tile_upscaler model") |
|
|
|
torch.manual_seed(seed) |
|
|
|
init_image = download_image(imageUrl).resize((width, height)) |
|
condition_image = self.__resize_for_condition_image( |
|
init_image, resize_dimension |
|
) |
|
|
|
result = self.pipe.__call__( |
|
image=condition_image, |
|
prompt=prompt, |
|
controlnet_conditioning_image=condition_image, |
|
num_inference_steps=steps, |
|
negative_prompt=negative_prompt, |
|
height=condition_image.size[1], |
|
width=condition_image.size[0], |
|
guidance_scale=guidance_scale, |
|
) |
|
return Result.from_result(result) |
|
|
|
@torch.inference_mode() |
|
def process_scribble( |
|
self, |
|
imageUrl: Union[str, Image.Image], |
|
prompt: Union[str, List[str]], |
|
negative_prompt: Union[str, List[str]], |
|
steps: int, |
|
seed: int, |
|
height: int, |
|
width: int, |
|
guidance_scale: float = 7.5, |
|
): |
|
if self.__current_task_name != "scribble": |
|
raise Exception("ControlNet is not loaded with scribble model") |
|
|
|
torch.manual_seed(seed) |
|
|
|
if isinstance(imageUrl, Image.Image): |
|
init_image = imageUrl.resize((width, height)) |
|
else: |
|
init_image = download_image(imageUrl).resize((width, height)) |
|
|
|
condition_image = self.__scribble_condition_image(init_image) |
|
|
|
result = self.pipe2.__call__( |
|
image=condition_image, |
|
prompt=prompt, |
|
num_inference_steps=steps, |
|
negative_prompt=negative_prompt, |
|
height=height, |
|
width=width, |
|
guidance_scale=guidance_scale, |
|
) |
|
return Result.from_result(result) |
|
|
|
@torch.inference_mode() |
|
def process_linearart( |
|
self, |
|
imageUrl: str, |
|
prompt: Union[str, List[str]], |
|
negative_prompt: Union[str, List[str]], |
|
steps: int, |
|
seed: int, |
|
height: int, |
|
width: int, |
|
guidance_scale: float = 7.5, |
|
): |
|
if self.__current_task_name != "linearart": |
|
raise Exception("ControlNet is not loaded with linearart model") |
|
|
|
torch.manual_seed(seed) |
|
|
|
init_image = download_image(imageUrl).resize((width, height)) |
|
condition_image = ControlNet.linearart_condition_image(init_image) |
|
|
|
result = self.pipe2.__call__( |
|
image=condition_image, |
|
prompt=prompt, |
|
num_inference_steps=steps, |
|
negative_prompt=negative_prompt, |
|
height=height, |
|
width=width, |
|
guidance_scale=guidance_scale, |
|
) |
|
return Result.from_result(result) |
|
|
|
def detect_pose(self, imageUrl: str) -> Image.Image: |
|
detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet") |
|
image = download_image(imageUrl) |
|
image = detector.__call__(image, hand_and_face=True) |
|
return image |
|
|
|
def __scribble_condition_image(self, image: Image.Image) -> Image.Image: |
|
processor = HEDdetector.from_pretrained("lllyasviel/Annotators") |
|
image = processor.__call__(input_image=image, scribble=True) |
|
return image |
|
|
|
@staticmethod |
|
def linearart_condition_image(image: Image.Image) -> Image.Image: |
|
processor = LineartDetector.from_pretrained("lllyasviel/Annotators") |
|
image = processor.__call__(input_image=image) |
|
return image |
|
|
|
def __canny_detect_edge(self, image: Image.Image) -> Image.Image: |
|
image_array = np.array(image) |
|
|
|
low_threshold = 100 |
|
high_threshold = 200 |
|
|
|
image_array = cv2.Canny(image_array, low_threshold, high_threshold) |
|
image_array = image_array[:, :, None] |
|
image_array = np.concatenate([image_array, image_array, image_array], axis=2) |
|
canny_image = Image.fromarray(image_array) |
|
return canny_image |
|
|
|
def __resize_for_condition_image(self, image: Image.Image, resolution: int): |
|
input_image = image.convert("RGB") |
|
W, H = input_image.size |
|
k = float(resolution) / min(W, H) |
|
H *= k |
|
W *= k |
|
H = int(round(H / 64.0)) * 64 |
|
W = int(round(W / 64.0)) * 64 |
|
img = input_image.resize((W, H), resample=Image.LANCZOS) |
|
return img |
|
|