jayparmr's picture
Upload folder using huggingface_hub
2c6c92a
raw
history blame
6.02 kB
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from diffusers import (
AutoencoderKL,
StableDiffusionImg2ImgPipeline,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLPipeline,
)
from internals.data.result import Result
from internals.pipelines.twoStepPipeline import two_step_pipeline
from internals.util.commons import disable_safety_checker, download_image
from internals.util.config import get_hf_token, get_is_sdxl, num_return_sequences
class AbstractPipeline:
def load(self, model_dir: str):
pass
def create(self, pipe):
pass
class Text2Img(AbstractPipeline):
@dataclass
class Params:
prompt: List[str] = None
modified_prompt: List[str] = None
prompt_left: List[str] = None
prompt_right: List[str] = None
def load(self, model_dir: str):
if get_is_sdxl():
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
)
pipe = StableDiffusionXLPipeline.from_pretrained(
model_dir,
torch_dtype=torch.float16,
use_auth_token=get_hf_token(),
use_safetensors=True,
)
pipe.vae = vae
pipe.to("cuda")
self.pipe = pipe
else:
self.pipe = two_step_pipeline.from_pretrained(
model_dir, torch_dtype=torch.float16, use_auth_token=get_hf_token()
).to("cuda")
self.__patch()
def is_loaded(self):
if hasattr(self, "pipe"):
return True
return False
def create(self, pipeline: AbstractPipeline):
if get_is_sdxl():
self.pipe = StableDiffusionXLPipeline(**pipeline.pipe.components).to("cuda")
else:
self.pipe = two_step_pipeline(**pipeline.pipe.components).to("cuda")
self.__patch()
def __patch(self):
if get_is_sdxl():
self.pipe.enable_vae_tiling()
self.pipe.enable_vae_slicing()
self.pipe.enable_xformers_memory_efficient_attention()
@torch.inference_mode()
def process(
self,
params: Params,
num_inference_steps: int,
height: int,
width: int,
negative_prompt: str,
iteration: float = 3.0,
**kwargs,
):
prompt = params.prompt
if params.prompt_left and params.prompt_right:
# multi-character pipelines
prompt = [params.prompt[0], params.prompt_left[0], params.prompt_right[0]]
kwargs = {
"prompt": prompt,
"pos": ["1:1-0:0", "1:2-0:0", "1:2-0:1"],
"mix_val": [0.2, 0.8, 0.8],
"height": height,
"width": width,
"num_inference_steps": num_inference_steps,
"negative_prompt": [negative_prompt or ""] * len(prompt),
**kwargs,
}
result = self.pipe.multi_character_diffusion(**kwargs)
else:
# two step pipeline
modified_prompt = params.modified_prompt
if get_is_sdxl():
print("Warning: Two step pipeline is not supported on SDXL")
kwargs = {
"prompt": modified_prompt,
}
else:
kwargs = {
"prompt": prompt,
"modified_prompts": modified_prompt,
"iteration": iteration,
}
kwargs = {
"height": height,
"width": width,
"negative_prompt": [negative_prompt or ""] * num_return_sequences,
"num_inference_steps": num_inference_steps,
**kwargs,
}
result = self.pipe.__call__(**kwargs)
return Result.from_result(result)
class Img2Img(AbstractPipeline):
__loaded = False
def load(self, model_dir: str):
if self.__loaded:
return
if get_is_sdxl():
self.pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
model_dir,
torch_dtype=torch.float16,
use_auth_token=get_hf_token(),
use_safetensors=True,
).to("cuda")
else:
self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
model_dir, torch_dtype=torch.float16, use_auth_token=get_hf_token()
).to("cuda")
self.__patch()
self.__loaded = True
def create(self, pipeline: AbstractPipeline):
if get_is_sdxl():
self.pipe = StableDiffusionXLImg2ImgPipeline(**pipeline.pipe.components).to(
"cuda"
)
else:
self.pipe = StableDiffusionImg2ImgPipeline(**pipeline.pipe.components).to(
"cuda"
)
self.__patch()
self.__loaded = True
def __patch(self):
if get_is_sdxl():
self.pipe.enable_vae_tiling()
self.pipe.enable_vae_slicing()
self.pipe.enable_xformers_memory_efficient_attention()
@torch.inference_mode()
def process(
self,
prompt: List[str],
imageUrl: str,
negative_prompt: List[str],
num_inference_steps: int,
width: int,
height: int,
strength: float = 0.75,
guidance_scale: float = 7.5,
**kwargs,
):
image = download_image(imageUrl).resize((width, height))
kwargs = {
"prompt": prompt,
"image": image,
"strength": strength,
"negative_prompt": negative_prompt,
"guidance_scale": guidance_scale,
"num_images_per_prompt": 1,
"num_inference_steps": num_inference_steps,
**kwargs,
}
result = self.pipe.__call__(**kwargs)
return Result.from_result(result)