jayparmr's picture
update : inference
35575bb verified
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from diffusers import (
AutoencoderKL,
StableDiffusionImg2ImgPipeline,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLPipeline,
)
from internals.data.result import Result
from internals.pipelines.twoStepPipeline import two_step_pipeline
from internals.util import get_generators
from internals.util.commons import disable_safety_checker, download_image
from internals.util.config import (
get_base_model_revision,
get_base_model_variant,
get_hf_token,
get_is_sdxl,
get_low_gpu_mem,
get_num_return_sequences,
)
class AbstractPipeline:
def load(self, model_dir: str):
pass
def create(self, pipe):
pass
class Text2Img(AbstractPipeline):
@dataclass
class Params:
prompt: List[str] = None
modified_prompt: List[str] = None
prompt_left: List[str] = None
prompt_right: List[str] = None
def load(self, model_dir: str):
if get_is_sdxl():
print(
f"Loading model {model_dir} - {get_base_model_variant()}, {get_base_model_revision()}"
)
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
)
pipe = StableDiffusionXLPipeline.from_pretrained(
model_dir,
torch_dtype=torch.float16,
token=get_hf_token(),
use_safetensors=True,
variant=get_base_model_variant(),
revision=get_base_model_revision(),
)
pipe.vae = vae
pipe.to("cuda")
self.pipe = pipe
else:
self.pipe = two_step_pipeline.from_pretrained(
model_dir, torch_dtype=torch.float16, token=get_hf_token()
).to("cuda")
self.__patch()
def is_loaded(self):
if hasattr(self, "pipe"):
return True
return False
def create(self, pipeline: AbstractPipeline):
if get_is_sdxl():
self.pipe = StableDiffusionXLPipeline(**pipeline.pipe.components).to("cuda")
else:
self.pipe = two_step_pipeline(**pipeline.pipe.components).to("cuda")
self.__patch()
def __patch(self):
if get_is_sdxl() or get_low_gpu_mem():
self.pipe.vae.enable_tiling()
self.pipe.vae.enable_slicing()
self.pipe.enable_xformers_memory_efficient_attention()
@torch.inference_mode()
def process(
self,
params: Params,
num_inference_steps: int,
height: int,
width: int,
seed: int,
negative_prompt: str,
iteration: float = 3.0,
**kwargs,
):
prompt = params.prompt
generator = get_generators(seed, get_num_return_sequences())
if params.prompt_left and params.prompt_right:
# multi-character pipelines
prompt = [params.prompt[0], params.prompt_left[0], params.prompt_right[0]]
kwargs = {
"prompt": prompt,
"pos": ["1:1-0:0", "1:2-0:0", "1:2-0:1"],
"mix_val": [0.2, 0.8, 0.8],
"height": height,
"width": width,
"num_inference_steps": num_inference_steps,
"negative_prompt": [negative_prompt or ""] * len(prompt),
"generator": generator,
**kwargs,
}
result = self.pipe.multi_character_diffusion(**kwargs)
else:
# two step pipeline
modified_prompt = params.modified_prompt
if get_is_sdxl():
print("Warning: Two step pipeline is not supported on SDXL")
kwargs = {
"prompt": modified_prompt,
**kwargs,
}
else:
kwargs = {
"prompt": prompt,
"modified_prompts": modified_prompt,
"iteration": iteration,
**kwargs,
}
kwargs = {
"height": height,
"width": width,
"negative_prompt": [negative_prompt or ""] * get_num_return_sequences(),
"num_inference_steps": num_inference_steps,
"guidance_scale": 7.5,
"generator": generator,
**kwargs,
}
print(kwargs)
result = self.pipe.__call__(**kwargs)
return Result.from_result(result)
class Img2Img(AbstractPipeline):
__loaded = False
def load(self, model_dir: str):
if self.__loaded:
return
if get_is_sdxl():
self.pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
model_dir,
torch_dtype=torch.float16,
token=get_hf_token(),
variant=get_base_model_variant(),
revision=get_base_model_revision(),
use_safetensors=True,
).to("cuda")
else:
self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
model_dir, torch_dtype=torch.float16, token=get_hf_token()
).to("cuda")
self.__patch()
self.__loaded = True
def create(self, pipeline: AbstractPipeline):
if get_is_sdxl():
self.pipe = StableDiffusionXLImg2ImgPipeline(**pipeline.pipe.components).to(
"cuda"
)
else:
self.pipe = StableDiffusionImg2ImgPipeline(**pipeline.pipe.components).to(
"cuda"
)
self.__patch()
self.__loaded = True
def __patch(self):
if get_is_sdxl():
self.pipe.enable_vae_tiling()
self.pipe.enable_vae_slicing()
self.pipe.enable_xformers_memory_efficient_attention()
@torch.inference_mode()
def process(
self,
prompt: List[str],
imageUrl: str,
negative_prompt: List[str],
num_inference_steps: int,
width: int,
height: int,
seed: int,
strength: float = 0.75,
guidance_scale: float = 7.5,
**kwargs,
):
image = download_image(imageUrl).resize((width, height))
generator = get_generators(seed, get_num_return_sequences())
kwargs = {
"prompt": prompt,
"image": [image] * get_num_return_sequences(),
"strength": strength,
"negative_prompt": negative_prompt,
"guidance_scale": guidance_scale,
"num_images_per_prompt": 1,
"num_inference_steps": num_inference_steps,
"generator": generator,
**kwargs,
}
result = self.pipe.__call__(**kwargs)
return Result.from_result(result)