|
from typing import Optional |
|
|
|
import torch |
|
from diffusers import ControlNetModel, StableDiffusionControlNetImg2ImgPipeline |
|
from PIL import Image |
|
|
|
import internals.util.image as ImageUtil |
|
from internals.pipelines.commons import AbstractPipeline |
|
from internals.pipelines.controlnets import ControlNet |
|
from internals.pipelines.high_res import HighRes |
|
from internals.pipelines.sdxl_llite_pipeline import SDXLLLiteImg2ImgPipeline |
|
from internals.util import get_generators |
|
from internals.util.config import ( |
|
get_base_dimension, |
|
get_hf_cache_dir, |
|
get_is_sdxl, |
|
get_num_return_sequences, |
|
) |
|
|
|
|
|
class RealtimeDraw(AbstractPipeline): |
|
def load(self, pipeline: AbstractPipeline): |
|
if hasattr(self, "pipe"): |
|
return |
|
|
|
if get_is_sdxl(): |
|
lite_pipe = SDXLLLiteImg2ImgPipeline() |
|
lite_pipe.load( |
|
pipeline, |
|
[ |
|
"https://s3.ap-south-1.amazonaws.com/autodraft.model.assets/models/replicate-xl-llite.safetensors" |
|
], |
|
) |
|
self.pipe = lite_pipe |
|
else: |
|
self.__controlnet_scribble = ControlNetModel.from_pretrained( |
|
"lllyasviel/control_v11p_sd15_scribble", |
|
torch_dtype=torch.float16, |
|
cache_dir=get_hf_cache_dir(), |
|
) |
|
|
|
self.__controlnet_seg = ControlNetModel.from_pretrained( |
|
"lllyasviel/control_v11p_sd15_seg", |
|
torch_dtype=torch.float16, |
|
cache_dir=get_hf_cache_dir(), |
|
) |
|
|
|
kwargs = {**pipeline.pipe.components} |
|
kwargs.pop("image_encoder", None) |
|
self.pipe = StableDiffusionControlNetImg2ImgPipeline( |
|
**kwargs, controlnet=self.__controlnet_seg |
|
).to("cuda") |
|
self.pipe.safety_checker = None |
|
self.pipe2 = StableDiffusionControlNetImg2ImgPipeline( |
|
**kwargs, controlnet=[self.__controlnet_scribble, self.__controlnet_seg] |
|
).to("cuda") |
|
self.pipe2.safety_checker = None |
|
|
|
def process_seg( |
|
self, |
|
image: Image.Image, |
|
prompt: str, |
|
negative_prompt: str, |
|
seed: int, |
|
): |
|
if get_is_sdxl(): |
|
raise Exception("SDXL is not supported for this method") |
|
|
|
generator = get_generators(seed, get_num_return_sequences()) |
|
|
|
image = ImageUtil.resize_image(image, 512) |
|
|
|
img = self.pipe.__call__( |
|
image=image, |
|
control_image=image, |
|
prompt=prompt, |
|
num_inference_steps=15, |
|
negative_prompt=negative_prompt, |
|
generator=generator, |
|
guidance_scale=10, |
|
strength=0.8, |
|
).images[0] |
|
|
|
return img |
|
|
|
def process_img( |
|
self, |
|
prompt: str, |
|
negative_prompt: str, |
|
seed: int, |
|
image: Optional[Image.Image] = None, |
|
image2: Optional[Image.Image] = None, |
|
): |
|
generator = get_generators(seed, get_num_return_sequences()) |
|
|
|
b_dimen = get_base_dimension() |
|
|
|
if not image: |
|
size = (b_dimen, b_dimen) |
|
if image2: |
|
size = image2.size |
|
image = Image.new("RGB", size, color=0) |
|
|
|
if not image2: |
|
size = (b_dimen, b_dimen) |
|
if image: |
|
size = image.size |
|
image2 = Image.new("RGB", size, color=0) |
|
|
|
if get_is_sdxl(): |
|
size = HighRes.find_closest_sdxl_aspect_ratio(image.size[0], image.size[1]) |
|
image = image.resize(size) |
|
|
|
torch.manual_seed(seed) |
|
|
|
images = self.pipe.__call__( |
|
image=image, |
|
condition_image=image, |
|
negative_prompt=negative_prompt, |
|
prompt=prompt, |
|
seed=seed, |
|
num_inference_steps=10, |
|
width=image.size[0], |
|
height=image.size[1], |
|
) |
|
img = images[0] |
|
else: |
|
image = ImageUtil.resize_image(image, b_dimen) |
|
|
|
scribble = ControlNet.scribble_image(image) |
|
|
|
image2 = ImageUtil.resize_image(image2, b_dimen) |
|
|
|
img = self.pipe2.__call__( |
|
image=image, |
|
control_image=[scribble, image2], |
|
prompt=prompt, |
|
num_inference_steps=15, |
|
negative_prompt=negative_prompt, |
|
guidance_scale=10, |
|
generator=generator, |
|
strength=0.9, |
|
width=image.size[0], |
|
height=image.size[1], |
|
controlnet_conditioning_scale=[1.0, 0.8], |
|
).images[0] |
|
|
|
img = ImageUtil.resize_image(img, 512) |
|
|
|
return img |
|
|