File size: 6,946 Bytes
86248f3 19b3da3 10230ea 2c6c92a 10230ea 2c6c92a 10230ea 19b3da3 35575bb 19b3da3 22df957 35575bb 22df957 35575bb 61309b7 22df957 19b3da3 86248f3 19b3da3 10230ea 35575bb 2c6c92a 10230ea c95142c 10230ea 22df957 35575bb 2c6c92a 10230ea c95142c 10230ea 19b3da3 b71808f 19b3da3 10230ea 19b3da3 35575bb 19b3da3 86248f3 f70725b 35575bb f70725b 19b3da3 f70725b 19b3da3 86248f3 35575bb 86248f3 f70725b 35575bb f70725b 86248f3 10230ea 22df957 10230ea 22df957 10230ea f70725b 22df957 f70725b 35575bb f70725b 35575bb f70725b 86248f3 19b3da3 1bc457e 19b3da3 1bc457e 10230ea c95142c 22df957 35575bb 10230ea c95142c 10230ea 19b3da3 1bc457e 19b3da3 10230ea 19b3da3 1bc457e 19b3da3 10230ea 19b3da3 f70725b 19b3da3 35575bb f70725b 19b3da3 35575bb f70725b 35575bb f70725b 35575bb f70725b 19b3da3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from diffusers import (
AutoencoderKL,
StableDiffusionImg2ImgPipeline,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLPipeline,
)
from internals.data.result import Result
from internals.pipelines.twoStepPipeline import two_step_pipeline
from internals.util import get_generators
from internals.util.commons import disable_safety_checker, download_image
from internals.util.config import (
get_base_model_revision,
get_base_model_variant,
get_hf_token,
get_is_sdxl,
get_low_gpu_mem,
get_num_return_sequences,
)
class AbstractPipeline:
def load(self, model_dir: str):
pass
def create(self, pipe):
pass
class Text2Img(AbstractPipeline):
@dataclass
class Params:
prompt: List[str] = None
modified_prompt: List[str] = None
prompt_left: List[str] = None
prompt_right: List[str] = None
def load(self, model_dir: str):
if get_is_sdxl():
print(
f"Loading model {model_dir} - {get_base_model_variant()}, {get_base_model_revision()}"
)
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
)
pipe = StableDiffusionXLPipeline.from_pretrained(
model_dir,
torch_dtype=torch.float16,
token=get_hf_token(),
use_safetensors=True,
variant=get_base_model_variant(),
revision=get_base_model_revision(),
)
pipe.vae = vae
pipe.to("cuda")
self.pipe = pipe
else:
self.pipe = two_step_pipeline.from_pretrained(
model_dir, torch_dtype=torch.float16, token=get_hf_token()
).to("cuda")
self.__patch()
def is_loaded(self):
if hasattr(self, "pipe"):
return True
return False
def create(self, pipeline: AbstractPipeline):
if get_is_sdxl():
self.pipe = StableDiffusionXLPipeline(**pipeline.pipe.components).to("cuda")
else:
self.pipe = two_step_pipeline(**pipeline.pipe.components).to("cuda")
self.__patch()
def __patch(self):
if get_is_sdxl() or get_low_gpu_mem():
self.pipe.vae.enable_tiling()
self.pipe.vae.enable_slicing()
self.pipe.enable_xformers_memory_efficient_attention()
@torch.inference_mode()
def process(
self,
params: Params,
num_inference_steps: int,
height: int,
width: int,
seed: int,
negative_prompt: str,
iteration: float = 3.0,
**kwargs,
):
prompt = params.prompt
generator = get_generators(seed, get_num_return_sequences())
if params.prompt_left and params.prompt_right:
# multi-character pipelines
prompt = [params.prompt[0], params.prompt_left[0], params.prompt_right[0]]
kwargs = {
"prompt": prompt,
"pos": ["1:1-0:0", "1:2-0:0", "1:2-0:1"],
"mix_val": [0.2, 0.8, 0.8],
"height": height,
"width": width,
"num_inference_steps": num_inference_steps,
"negative_prompt": [negative_prompt or ""] * len(prompt),
"generator": generator,
**kwargs,
}
result = self.pipe.multi_character_diffusion(**kwargs)
else:
# two step pipeline
modified_prompt = params.modified_prompt
if get_is_sdxl():
print("Warning: Two step pipeline is not supported on SDXL")
kwargs = {
"prompt": modified_prompt,
**kwargs,
}
else:
kwargs = {
"prompt": prompt,
"modified_prompts": modified_prompt,
"iteration": iteration,
**kwargs,
}
kwargs = {
"height": height,
"width": width,
"negative_prompt": [negative_prompt or ""] * get_num_return_sequences(),
"num_inference_steps": num_inference_steps,
"guidance_scale": 7.5,
"generator": generator,
**kwargs,
}
print(kwargs)
result = self.pipe.__call__(**kwargs)
return Result.from_result(result)
class Img2Img(AbstractPipeline):
__loaded = False
def load(self, model_dir: str):
if self.__loaded:
return
if get_is_sdxl():
self.pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
model_dir,
torch_dtype=torch.float16,
token=get_hf_token(),
variant=get_base_model_variant(),
revision=get_base_model_revision(),
use_safetensors=True,
).to("cuda")
else:
self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
model_dir, torch_dtype=torch.float16, token=get_hf_token()
).to("cuda")
self.__patch()
self.__loaded = True
def create(self, pipeline: AbstractPipeline):
if get_is_sdxl():
self.pipe = StableDiffusionXLImg2ImgPipeline(**pipeline.pipe.components).to(
"cuda"
)
else:
self.pipe = StableDiffusionImg2ImgPipeline(**pipeline.pipe.components).to(
"cuda"
)
self.__patch()
self.__loaded = True
def __patch(self):
if get_is_sdxl():
self.pipe.enable_vae_tiling()
self.pipe.enable_vae_slicing()
self.pipe.enable_xformers_memory_efficient_attention()
@torch.inference_mode()
def process(
self,
prompt: List[str],
imageUrl: str,
negative_prompt: List[str],
num_inference_steps: int,
width: int,
height: int,
seed: int,
strength: float = 0.75,
guidance_scale: float = 7.5,
**kwargs,
):
image = download_image(imageUrl).resize((width, height))
generator = get_generators(seed, get_num_return_sequences())
kwargs = {
"prompt": prompt,
"image": [image] * get_num_return_sequences(),
"strength": strength,
"negative_prompt": negative_prompt,
"guidance_scale": guidance_scale,
"num_images_per_prompt": 1,
"num_inference_steps": num_inference_steps,
"generator": generator,
**kwargs,
}
result = self.pipe.__call__(**kwargs)
return Result.from_result(result)
|