CM2000112 / internals /pipelines /controlnets.py
jayparmr's picture
Upload folder using huggingface_hub
a3d6c18
raw
history blame
10.8 kB
from typing import List, Union
import cv2
import numpy as np
import torch
from controlnet_aux import HEDdetector, LineartDetector, OpenposeDetector
from diffusers import (
ControlNetModel,
DiffusionPipeline,
StableDiffusionControlNetPipeline,
UniPCMultistepScheduler,
)
from PIL import Image
from torch.nn import Linear
from tqdm import gui
from internals.data.result import Result
from internals.pipelines.commons import AbstractPipeline
from internals.pipelines.tileUpscalePipeline import (
StableDiffusionControlNetImg2ImgPipeline,
)
from internals.util.cache import clear_cuda_and_gc
from internals.util.commons import download_image
class ControlNet(AbstractPipeline):
__current_task_name = ""
def load(self, model_dir: str):
# we will load canny by default
self.load_scribble()
# controlnet pipeline for tile upscaler
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
model_dir,
controlnet=self.controlnet,
torch_dtype=torch.float16,
).to("cuda")
# pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
pipe.enable_xformers_memory_efficient_attention()
self.pipe = pipe
# controlnet pipeline for canny and pose
pipe2 = StableDiffusionControlNetPipeline(**pipe.components).to("cuda")
pipe2.scheduler = UniPCMultistepScheduler.from_config(pipe2.scheduler.config)
pipe2.enable_xformers_memory_efficient_attention()
self.pipe2 = pipe2
def load_canny(self):
if self.__current_task_name == "canny":
return
canny = ControlNetModel.from_pretrained(
"lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16
).to("cuda")
self.__current_task_name = "canny"
self.controlnet = canny
if hasattr(self, "pipe"):
self.pipe.controlnet = canny
if hasattr(self, "pipe2"):
self.pipe2.controlnet = canny
clear_cuda_and_gc()
def load_pose(self):
if self.__current_task_name == "pose":
return
pose = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-openpose", torch_dtype=torch.float16
).to("cuda")
self.__current_task_name = "pose"
self.controlnet = pose
if hasattr(self, "pipe"):
self.pipe.controlnet = pose
if hasattr(self, "pipe2"):
self.pipe2.controlnet = pose
clear_cuda_and_gc()
def load_tile_upscaler(self):
if self.__current_task_name == "tile_upscaler":
return
tile_upscaler = ControlNetModel.from_pretrained(
"lllyasviel/control_v11f1e_sd15_tile", torch_dtype=torch.float16
).to("cuda")
self.__current_task_name = "tile_upscaler"
self.controlnet = tile_upscaler
if hasattr(self, "pipe"):
self.pipe.controlnet = tile_upscaler
if hasattr(self, "pipe2"):
self.pipe2.controlnet = tile_upscaler
clear_cuda_and_gc()
def load_scribble(self):
if self.__current_task_name == "scribble":
return
scribble = ControlNetModel.from_pretrained(
"lllyasviel/control_v11p_sd15_scribble", torch_dtype=torch.float16
).to("cuda")
self.__current_task_name = "scribble"
self.controlnet = scribble
if hasattr(self, "pipe"):
self.pipe.controlnet = scribble
if hasattr(self, "pipe2"):
self.pipe2.controlnet = scribble
clear_cuda_and_gc()
def load_linearart(self):
if self.__current_task_name == "linearart":
return
linearart = ControlNetModel.from_pretrained(
"ControlNet-1-1-preview/control_v11p_sd15_lineart",
torch_dtype=torch.float16,
).to("cuda")
self.__current_task_name = "linearart"
self.controlnet = linearart
if hasattr(self, "pipe"):
self.pipe.controlnet = linearart
if hasattr(self, "pipe2"):
self.pipe2.controlnet = linearart
clear_cuda_and_gc()
def cleanup(self):
self.pipe.controlnet = None
self.pipe2.controlnet = None
self.controlnet = None
self.__current_task_name = ""
clear_cuda_and_gc()
@torch.inference_mode()
def process_canny(
self,
prompt: List[str],
imageUrl: str,
seed: int,
steps: int,
negative_prompt: List[str],
guidance_scale: float,
height: int,
width: int,
):
if self.__current_task_name != "canny":
raise Exception("ControlNet is not loaded with canny model")
torch.manual_seed(seed)
init_image = download_image(imageUrl).resize((width, height))
init_image = self.__canny_detect_edge(init_image)
result = self.pipe2.__call__(
prompt=prompt,
image=init_image,
guidance_scale=guidance_scale,
num_images_per_prompt=1,
negative_prompt=negative_prompt,
num_inference_steps=steps,
height=height,
width=width,
)
return Result.from_result(result)
@torch.inference_mode()
def process_pose(
self,
prompt: List[str],
image: List[Image.Image],
seed: int,
steps: int,
guidance_scale: float,
negative_prompt: List[str],
height: int,
width: int,
):
if self.__current_task_name != "pose":
raise Exception("ControlNet is not loaded with pose model")
torch.manual_seed(seed)
result = self.pipe2.__call__(
prompt=prompt,
image=image,
num_images_per_prompt=1,
num_inference_steps=steps,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
height=height,
width=width,
)
return Result.from_result(result)
@torch.inference_mode()
def process_tile_upscaler(
self,
imageUrl: str,
prompt: str,
negative_prompt: str,
steps: int,
seed: int,
height: int,
width: int,
resize_dimension: int,
guidance_scale: float,
):
if self.__current_task_name != "tile_upscaler":
raise Exception("ControlNet is not loaded with tile_upscaler model")
torch.manual_seed(seed)
init_image = download_image(imageUrl).resize((width, height))
condition_image = self.__resize_for_condition_image(
init_image, resize_dimension
)
result = self.pipe.__call__(
image=condition_image,
prompt=prompt,
controlnet_conditioning_image=condition_image,
num_inference_steps=steps,
negative_prompt=negative_prompt,
height=condition_image.size[1],
width=condition_image.size[0],
guidance_scale=guidance_scale,
)
return Result.from_result(result)
@torch.inference_mode()
def process_scribble(
self,
imageUrl: Union[str, Image.Image],
prompt: Union[str, List[str]],
negative_prompt: Union[str, List[str]],
steps: int,
seed: int,
height: int,
width: int,
guidance_scale: float = 7.5,
):
if self.__current_task_name != "scribble":
raise Exception("ControlNet is not loaded with scribble model")
torch.manual_seed(seed)
if isinstance(imageUrl, Image.Image):
init_image = imageUrl.resize((width, height))
else:
init_image = download_image(imageUrl).resize((width, height))
condition_image = self.__scribble_condition_image(init_image)
result = self.pipe2.__call__(
image=condition_image,
prompt=prompt,
num_inference_steps=steps,
negative_prompt=negative_prompt,
height=height,
width=width,
guidance_scale=guidance_scale,
)
return Result.from_result(result)
@torch.inference_mode()
def process_linearart(
self,
imageUrl: str,
prompt: Union[str, List[str]],
negative_prompt: Union[str, List[str]],
steps: int,
seed: int,
height: int,
width: int,
guidance_scale: float = 7.5,
):
if self.__current_task_name != "linearart":
raise Exception("ControlNet is not loaded with linearart model")
torch.manual_seed(seed)
init_image = download_image(imageUrl).resize((width, height))
condition_image = ControlNet.linearart_condition_image(init_image)
result = self.pipe2.__call__(
image=condition_image,
prompt=prompt,
num_inference_steps=steps,
negative_prompt=negative_prompt,
height=height,
width=width,
guidance_scale=guidance_scale,
)
return Result.from_result(result)
def detect_pose(self, imageUrl: str) -> Image.Image:
detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
image = download_image(imageUrl)
image = detector.__call__(image, hand_and_face=True)
return image
def __scribble_condition_image(self, image: Image.Image) -> Image.Image:
processor = HEDdetector.from_pretrained("lllyasviel/Annotators")
image = processor.__call__(input_image=image, scribble=True)
return image
@staticmethod
def linearart_condition_image(image: Image.Image) -> Image.Image:
processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
image = processor.__call__(input_image=image)
return image
def __canny_detect_edge(self, image: Image.Image) -> Image.Image:
image_array = np.array(image)
low_threshold = 100
high_threshold = 200
image_array = cv2.Canny(image_array, low_threshold, high_threshold)
image_array = image_array[:, :, None]
image_array = np.concatenate([image_array, image_array, image_array], axis=2)
canny_image = Image.fromarray(image_array)
return canny_image
def __resize_for_condition_image(self, image: Image.Image, resolution: int):
input_image = image.convert("RGB")
W, H = input_image.size
k = float(resolution) / min(W, H)
H *= k
W *= k
H = int(round(H / 64.0)) * 64
W = int(round(W / 64.0)) * 64
img = input_image.resize((W, H), resample=Image.LANCZOS)
return img