from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Union import torch from diffusers import ( AutoencoderKL, StableDiffusionImg2ImgPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLPipeline, ) from internals.data.result import Result from internals.pipelines.twoStepPipeline import two_step_pipeline from internals.util.commons import disable_safety_checker, download_image from internals.util.config import ( get_base_model_variant, get_hf_token, get_is_sdxl, get_num_return_sequences, ) class AbstractPipeline: def load(self, model_dir: str): pass def create(self, pipe): pass class Text2Img(AbstractPipeline): @dataclass class Params: prompt: List[str] = None modified_prompt: List[str] = None prompt_left: List[str] = None prompt_right: List[str] = None def load(self, model_dir: str): if get_is_sdxl(): vae = AutoencoderKL.from_pretrained( "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16 ) pipe = StableDiffusionXLPipeline.from_pretrained( model_dir, torch_dtype=torch.float16, use_auth_token=get_hf_token(), use_safetensors=True, variant=get_base_model_variant(), ) pipe.vae = vae pipe.to("cuda") self.pipe = pipe else: self.pipe = two_step_pipeline.from_pretrained( model_dir, torch_dtype=torch.float16, use_auth_token=get_hf_token() ).to("cuda") self.__patch() def is_loaded(self): if hasattr(self, "pipe"): return True return False def create(self, pipeline: AbstractPipeline): if get_is_sdxl(): self.pipe = StableDiffusionXLPipeline(**pipeline.pipe.components).to("cuda") else: self.pipe = two_step_pipeline(**pipeline.pipe.components).to("cuda") self.__patch() def __patch(self): if get_is_sdxl(): self.pipe.enable_vae_tiling() self.pipe.enable_vae_slicing() self.pipe.enable_xformers_memory_efficient_attention() @torch.inference_mode() def process( self, params: Params, num_inference_steps: int, height: int, width: int, negative_prompt: str, iteration: float = 3.0, **kwargs, ): prompt = params.prompt if params.prompt_left and params.prompt_right: # multi-character pipelines prompt = [params.prompt[0], params.prompt_left[0], params.prompt_right[0]] kwargs = { "prompt": prompt, "pos": ["1:1-0:0", "1:2-0:0", "1:2-0:1"], "mix_val": [0.2, 0.8, 0.8], "height": height, "width": width, "num_inference_steps": num_inference_steps, "negative_prompt": [negative_prompt or ""] * len(prompt), **kwargs, } result = self.pipe.multi_character_diffusion(**kwargs) else: # two step pipeline modified_prompt = params.modified_prompt if get_is_sdxl(): print("Warning: Two step pipeline is not supported on SDXL") kwargs = { "prompt": modified_prompt, **kwargs, } else: kwargs = { "prompt": prompt, "modified_prompts": modified_prompt, "iteration": iteration, **kwargs, } kwargs = { "height": height, "width": width, "negative_prompt": [negative_prompt or ""] * get_num_return_sequences(), "num_inference_steps": num_inference_steps, **kwargs, } result = self.pipe.__call__(**kwargs) return Result.from_result(result) class Img2Img(AbstractPipeline): __loaded = False def load(self, model_dir: str): if self.__loaded: return if get_is_sdxl(): self.pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( model_dir, torch_dtype=torch.float16, use_auth_token=get_hf_token(), variant=get_base_model_variant(), use_safetensors=True, ).to("cuda") else: self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained( model_dir, torch_dtype=torch.float16, use_auth_token=get_hf_token() ).to("cuda") self.__patch() self.__loaded = True def create(self, pipeline: AbstractPipeline): if get_is_sdxl(): self.pipe = StableDiffusionXLImg2ImgPipeline(**pipeline.pipe.components).to( "cuda" ) else: self.pipe = StableDiffusionImg2ImgPipeline(**pipeline.pipe.components).to( "cuda" ) self.__patch() self.__loaded = True def __patch(self): if get_is_sdxl(): self.pipe.enable_vae_tiling() self.pipe.enable_vae_slicing() self.pipe.enable_xformers_memory_efficient_attention() @torch.inference_mode() def process( self, prompt: List[str], imageUrl: str, negative_prompt: List[str], num_inference_steps: int, width: int, height: int, strength: float = 0.75, guidance_scale: float = 7.5, **kwargs, ): image = download_image(imageUrl).resize((width, height)) kwargs = { "prompt": prompt, "image": image, "strength": strength, "negative_prompt": negative_prompt, "guidance_scale": guidance_scale, "num_images_per_prompt": 1, "num_inference_steps": num_inference_steps, **kwargs, } result = self.pipe.__call__(**kwargs) return Result.from_result(result)