CM2000112 / internals /pipelines /controlnets.py
jayparmr's picture
Upload folder using huggingface_hub
1bc457e
raw
history blame
11.5 kB
from typing import List, Union
import cv2
import numpy as np
import torch
from controlnet_aux import HEDdetector, LineartDetector, OpenposeDetector
from diffusers import (ControlNetModel, DiffusionPipeline,
StableDiffusionControlNetPipeline,
UniPCMultistepScheduler)
from PIL import Image
from torch.nn import Linear
from tqdm import gui
from internals.data.result import Result
from internals.pipelines.commons import AbstractPipeline
from internals.pipelines.tileUpscalePipeline import \
StableDiffusionControlNetImg2ImgPipeline
from internals.util.cache import clear_cuda_and_gc
from internals.util.commons import download_image
from internals.util.config import get_hf_cache_dir, get_hf_token, get_model_dir
class ControlNet(AbstractPipeline):
__current_task_name = ""
__loaded = False
def load(self):
if self.__loaded:
return
if not self.controlnet:
self.load_pose()
# controlnet pipeline for tile upscaler
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
get_model_dir(),
controlnet=self.controlnet,
torch_dtype=torch.float16,
use_auth_token=get_hf_token(),
cache_dir=get_hf_cache_dir(),
).to("cuda")
# pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
pipe.enable_xformers_memory_efficient_attention()
self.pipe = pipe
# controlnet pipeline for canny and pose
pipe2 = StableDiffusionControlNetPipeline(**pipe.components).to("cuda")
pipe2.scheduler = UniPCMultistepScheduler.from_config(pipe2.scheduler.config)
pipe2.enable_xformers_memory_efficient_attention()
self.pipe2 = pipe2
self.__loaded = True
def load_canny(self):
if self.__current_task_name == "canny":
return
canny = ControlNetModel.from_pretrained(
"lllyasviel/control_v11p_sd15_canny",
torch_dtype=torch.float16,
cache_dir=get_hf_cache_dir(),
).to("cuda")
self.__current_task_name = "canny"
self.controlnet = canny
self.load()
if hasattr(self, "pipe"):
self.pipe.controlnet = canny
if hasattr(self, "pipe2"):
self.pipe2.controlnet = canny
clear_cuda_and_gc()
def load_pose(self):
if self.__current_task_name == "pose":
return
pose = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-openpose",
torch_dtype=torch.float16,
cache_dir=get_hf_cache_dir(),
).to("cuda")
self.__current_task_name = "pose"
self.controlnet = pose
self.load()
if hasattr(self, "pipe"):
self.pipe.controlnet = pose
if hasattr(self, "pipe2"):
self.pipe2.controlnet = pose
clear_cuda_and_gc()
def load_tile_upscaler(self):
if self.__current_task_name == "tile_upscaler":
return
tile_upscaler = ControlNetModel.from_pretrained(
"lllyasviel/control_v11f1e_sd15_tile",
torch_dtype=torch.float16,
cache_dir=get_hf_cache_dir(),
).to("cuda")
self.__current_task_name = "tile_upscaler"
self.controlnet = tile_upscaler
self.load()
if hasattr(self, "pipe"):
self.pipe.controlnet = tile_upscaler
if hasattr(self, "pipe2"):
self.pipe2.controlnet = tile_upscaler
clear_cuda_and_gc()
def load_scribble(self):
if self.__current_task_name == "scribble":
return
scribble = ControlNetModel.from_pretrained(
"lllyasviel/control_v11p_sd15_scribble",
torch_dtype=torch.float16,
cache_dir=get_hf_cache_dir(),
).to("cuda")
self.__current_task_name = "scribble"
self.controlnet = scribble
self.load()
if hasattr(self, "pipe"):
self.pipe.controlnet = scribble
if hasattr(self, "pipe2"):
self.pipe2.controlnet = scribble
clear_cuda_and_gc()
def load_linearart(self):
if self.__current_task_name == "linearart":
return
linearart = ControlNetModel.from_pretrained(
"ControlNet-1-1-preview/control_v11p_sd15_lineart",
torch_dtype=torch.float16,
cache_dir=get_hf_cache_dir(),
).to("cuda")
self.__current_task_name = "linearart"
self.controlnet = linearart
self.load()
if hasattr(self, "pipe"):
self.pipe.controlnet = linearart
if hasattr(self, "pipe2"):
self.pipe2.controlnet = linearart
clear_cuda_and_gc()
def cleanup(self):
if hasattr(self, "pipe"):
self.pipe.controlnet = None
if hasattr(self, "pipe2"):
self.pipe2.controlnet = None
self.controlnet = None
del self.controlnet
self.__current_task_name = ""
clear_cuda_and_gc()
@torch.inference_mode()
def process_canny(
self,
prompt: List[str],
imageUrl: str,
seed: int,
steps: int,
negative_prompt: List[str],
guidance_scale: float,
height: int,
width: int,
):
if self.__current_task_name != "canny":
raise Exception("ControlNet is not loaded with canny model")
torch.manual_seed(seed)
init_image = download_image(imageUrl).resize((width, height))
init_image = self.__canny_detect_edge(init_image)
result = self.pipe2.__call__(
prompt=prompt,
image=init_image,
guidance_scale=guidance_scale,
num_images_per_prompt=1,
negative_prompt=negative_prompt,
num_inference_steps=steps,
height=height,
width=width,
)
return Result.from_result(result)
@torch.inference_mode()
def process_pose(
self,
prompt: List[str],
image: List[Image.Image],
seed: int,
steps: int,
guidance_scale: float,
negative_prompt: List[str],
height: int,
width: int,
):
if self.__current_task_name != "pose":
raise Exception("ControlNet is not loaded with pose model")
torch.manual_seed(seed)
result = self.pipe2.__call__(
prompt=prompt,
image=image,
num_images_per_prompt=1,
num_inference_steps=steps,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
height=height,
width=width,
)
return Result.from_result(result)
@torch.inference_mode()
def process_tile_upscaler(
self,
imageUrl: str,
prompt: str,
negative_prompt: str,
steps: int,
seed: int,
height: int,
width: int,
resize_dimension: int,
guidance_scale: float,
):
if self.__current_task_name != "tile_upscaler":
raise Exception("ControlNet is not loaded with tile_upscaler model")
torch.manual_seed(seed)
init_image = download_image(imageUrl).resize((width, height))
condition_image = self.__resize_for_condition_image(
init_image, resize_dimension
)
result = self.pipe.__call__(
image=condition_image,
prompt=prompt,
controlnet_conditioning_image=condition_image,
num_inference_steps=steps,
negative_prompt=negative_prompt,
height=condition_image.size[1],
width=condition_image.size[0],
guidance_scale=guidance_scale,
)
return Result.from_result(result)
@torch.inference_mode()
def process_scribble(
self,
imageUrl: Union[str, Image.Image],
prompt: Union[str, List[str]],
negative_prompt: Union[str, List[str]],
steps: int,
seed: int,
height: int,
width: int,
guidance_scale: float = 7.5,
):
if self.__current_task_name != "scribble":
raise Exception("ControlNet is not loaded with scribble model")
torch.manual_seed(seed)
if isinstance(imageUrl, Image.Image):
init_image = imageUrl.resize((width, height))
else:
init_image = download_image(imageUrl).resize((width, height))
condition_image = self.__scribble_condition_image(init_image)
result = self.pipe2.__call__(
image=condition_image,
prompt=prompt,
num_inference_steps=steps,
negative_prompt=negative_prompt,
height=height,
width=width,
guidance_scale=guidance_scale,
)
return Result.from_result(result)
@torch.inference_mode()
def process_linearart(
self,
imageUrl: str,
prompt: Union[str, List[str]],
negative_prompt: Union[str, List[str]],
steps: int,
seed: int,
height: int,
width: int,
guidance_scale: float = 7.5,
):
if self.__current_task_name != "linearart":
raise Exception("ControlNet is not loaded with linearart model")
torch.manual_seed(seed)
init_image = download_image(imageUrl).resize((width, height))
condition_image = ControlNet.linearart_condition_image(init_image)
result = self.pipe2.__call__(
image=condition_image,
prompt=prompt,
num_inference_steps=steps,
negative_prompt=negative_prompt,
height=height,
width=width,
guidance_scale=guidance_scale,
)
return Result.from_result(result)
def detect_pose(self, imageUrl: str) -> Image.Image:
detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
image = download_image(imageUrl)
image = detector.__call__(image)
return image
def __scribble_condition_image(self, image: Image.Image) -> Image.Image:
processor = HEDdetector.from_pretrained("lllyasviel/Annotators")
image = processor.__call__(input_image=image, scribble=True)
return image
@staticmethod
def linearart_condition_image(image: Image.Image) -> Image.Image:
processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
image = processor.__call__(input_image=image)
return image
def __canny_detect_edge(self, image: Image.Image) -> Image.Image:
image_array = np.array(image)
low_threshold = 100
high_threshold = 200
image_array = cv2.Canny(image_array, low_threshold, high_threshold)
image_array = image_array[:, :, None]
image_array = np.concatenate([image_array, image_array, image_array], axis=2)
canny_image = Image.fromarray(image_array)
return canny_image
def __resize_for_condition_image(self, image: Image.Image, resolution: int):
input_image = image.convert("RGB")
W, H = input_image.size
k = float(resolution) / max(W, H)
H *= k
W *= k
H = int(round(H / 64.0)) * 64
W = int(round(W / 64.0)) * 64
img = input_image.resize((W, H), resample=Image.LANCZOS)
return img