|
from dataclasses import dataclass |
|
from typing import Any, Callable, Dict, List, Optional, Union |
|
|
|
import torch |
|
from diffusers import ( |
|
AutoencoderKL, |
|
StableDiffusionImg2ImgPipeline, |
|
StableDiffusionXLImg2ImgPipeline, |
|
StableDiffusionXLPipeline, |
|
) |
|
|
|
from internals.data.result import Result |
|
from internals.pipelines.twoStepPipeline import two_step_pipeline |
|
from internals.util.commons import disable_safety_checker, download_image |
|
from internals.util.config import get_hf_token, get_is_sdxl, num_return_sequences |
|
|
|
|
|
class AbstractPipeline: |
|
def load(self, model_dir: str): |
|
pass |
|
|
|
def create(self, pipe): |
|
pass |
|
|
|
|
|
class Text2Img(AbstractPipeline): |
|
@dataclass |
|
class Params: |
|
prompt: List[str] = None |
|
modified_prompt: List[str] = None |
|
prompt_left: List[str] = None |
|
prompt_right: List[str] = None |
|
|
|
def load(self, model_dir: str): |
|
if get_is_sdxl(): |
|
vae = AutoencoderKL.from_pretrained( |
|
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16 |
|
) |
|
pipe = StableDiffusionXLPipeline.from_pretrained( |
|
model_dir, |
|
torch_dtype=torch.float16, |
|
use_auth_token=get_hf_token(), |
|
use_safetensors=True, |
|
) |
|
pipe.vae = vae |
|
pipe.to("cuda") |
|
self.pipe = pipe |
|
else: |
|
self.pipe = two_step_pipeline.from_pretrained( |
|
model_dir, torch_dtype=torch.float16, use_auth_token=get_hf_token() |
|
).to("cuda") |
|
self.__patch() |
|
|
|
def is_loaded(self): |
|
if hasattr(self, "pipe"): |
|
return True |
|
return False |
|
|
|
def create(self, pipeline: AbstractPipeline): |
|
if get_is_sdxl(): |
|
self.pipe = StableDiffusionXLPipeline(**pipeline.pipe.components).to("cuda") |
|
else: |
|
self.pipe = two_step_pipeline(**pipeline.pipe.components).to("cuda") |
|
self.__patch() |
|
|
|
def __patch(self): |
|
if get_is_sdxl(): |
|
self.pipe.enable_vae_tiling() |
|
self.pipe.enable_vae_slicing() |
|
self.pipe.enable_xformers_memory_efficient_attention() |
|
|
|
@torch.inference_mode() |
|
def process( |
|
self, |
|
params: Params, |
|
num_inference_steps: int, |
|
height: int, |
|
width: int, |
|
negative_prompt: str, |
|
iteration: float = 3.0, |
|
**kwargs, |
|
): |
|
prompt = params.prompt |
|
|
|
if params.prompt_left and params.prompt_right: |
|
|
|
prompt = [params.prompt[0], params.prompt_left[0], params.prompt_right[0]] |
|
kwargs = { |
|
"prompt": prompt, |
|
"pos": ["1:1-0:0", "1:2-0:0", "1:2-0:1"], |
|
"mix_val": [0.2, 0.8, 0.8], |
|
"height": height, |
|
"width": width, |
|
"num_inference_steps": num_inference_steps, |
|
"negative_prompt": [negative_prompt or ""] * len(prompt), |
|
**kwargs, |
|
} |
|
result = self.pipe.multi_character_diffusion(**kwargs) |
|
else: |
|
|
|
modified_prompt = params.modified_prompt |
|
|
|
if get_is_sdxl(): |
|
print("Warning: Two step pipeline is not supported on SDXL") |
|
kwargs = { |
|
"prompt": modified_prompt, |
|
} |
|
else: |
|
kwargs = { |
|
"prompt": prompt, |
|
"modified_prompts": modified_prompt, |
|
"iteration": iteration, |
|
} |
|
|
|
kwargs = { |
|
"height": height, |
|
"width": width, |
|
"negative_prompt": [negative_prompt or ""] * num_return_sequences, |
|
"num_inference_steps": num_inference_steps, |
|
**kwargs, |
|
} |
|
result = self.pipe.__call__(**kwargs) |
|
|
|
return Result.from_result(result) |
|
|
|
|
|
class Img2Img(AbstractPipeline): |
|
__loaded = False |
|
|
|
def load(self, model_dir: str): |
|
if self.__loaded: |
|
return |
|
|
|
if get_is_sdxl(): |
|
self.pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( |
|
model_dir, |
|
torch_dtype=torch.float16, |
|
use_auth_token=get_hf_token(), |
|
use_safetensors=True, |
|
).to("cuda") |
|
else: |
|
self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained( |
|
model_dir, torch_dtype=torch.float16, use_auth_token=get_hf_token() |
|
).to("cuda") |
|
self.__patch() |
|
|
|
self.__loaded = True |
|
|
|
def create(self, pipeline: AbstractPipeline): |
|
if get_is_sdxl(): |
|
self.pipe = StableDiffusionXLImg2ImgPipeline(**pipeline.pipe.components).to( |
|
"cuda" |
|
) |
|
else: |
|
self.pipe = StableDiffusionImg2ImgPipeline(**pipeline.pipe.components).to( |
|
"cuda" |
|
) |
|
self.__patch() |
|
|
|
self.__loaded = True |
|
|
|
def __patch(self): |
|
if get_is_sdxl(): |
|
self.pipe.enable_vae_tiling() |
|
self.pipe.enable_vae_slicing() |
|
self.pipe.enable_xformers_memory_efficient_attention() |
|
|
|
@torch.inference_mode() |
|
def process( |
|
self, |
|
prompt: List[str], |
|
imageUrl: str, |
|
negative_prompt: List[str], |
|
num_inference_steps: int, |
|
width: int, |
|
height: int, |
|
strength: float = 0.75, |
|
guidance_scale: float = 7.5, |
|
**kwargs, |
|
): |
|
image = download_image(imageUrl).resize((width, height)) |
|
|
|
kwargs = { |
|
"prompt": prompt, |
|
"image": image, |
|
"strength": strength, |
|
"negative_prompt": negative_prompt, |
|
"guidance_scale": guidance_scale, |
|
"num_images_per_prompt": 1, |
|
"num_inference_steps": num_inference_steps, |
|
**kwargs, |
|
} |
|
result = self.pipe.__call__(**kwargs) |
|
return Result.from_result(result) |
|
|