Unico3D / scripts /sd_model_zoo.py
cavargas10's picture
Upload 56 files
1f30907 verified
raw
history blame
No virus
5.23 kB
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, EulerAncestralDiscreteScheduler, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionPipeline
from transformers import CLIPVisionModelWithProjection
import torch
from copy import deepcopy
ENABLE_CPU_CACHE = False
DEFAULT_BASE_MODEL = "runwayml/stable-diffusion-v1-5"
cached_models = {} # cache for models to avoid repeated loading, key is model name
def cache_model(func):
def wrapper(*args, **kwargs):
if ENABLE_CPU_CACHE:
model_name = func.__name__ + str(args) + str(kwargs)
if model_name not in cached_models:
cached_models[model_name] = func(*args, **kwargs)
return cached_models[model_name]
else:
return func(*args, **kwargs)
return wrapper
def copied_cache_model(func):
def wrapper(*args, **kwargs):
if ENABLE_CPU_CACHE:
model_name = func.__name__ + str(args) + str(kwargs)
if model_name not in cached_models:
cached_models[model_name] = func(*args, **kwargs)
return deepcopy(cached_models[model_name])
else:
return func(*args, **kwargs)
return wrapper
def model_from_ckpt_or_pretrained(ckpt_or_pretrained, model_cls, original_config_file='ckpt/v1-inference.yaml', torch_dtype=torch.float16, **kwargs):
if ckpt_or_pretrained.endswith(".safetensors"):
pipe = model_cls.from_single_file(ckpt_or_pretrained, original_config_file=original_config_file, torch_dtype=torch_dtype, **kwargs)
else:
pipe = model_cls.from_pretrained(ckpt_or_pretrained, torch_dtype=torch_dtype, **kwargs)
return pipe
@copied_cache_model
def load_base_model_components(base_model=DEFAULT_BASE_MODEL, torch_dtype=torch.float16):
model_kwargs = dict(
torch_dtype=torch_dtype,
requires_safety_checker=False,
safety_checker=None,
)
pipe: StableDiffusionPipeline = model_from_ckpt_or_pretrained(
base_model,
StableDiffusionPipeline,
**model_kwargs
)
pipe.to("cpu")
return pipe.components
@cache_model
def load_controlnet(controlnet_path, torch_dtype=torch.float16):
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch_dtype)
return controlnet
@cache_model
def load_image_encoder():
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"h94/IP-Adapter",
subfolder="models/image_encoder",
torch_dtype=torch.float16,
)
return image_encoder
def load_common_sd15_pipe(base_model=DEFAULT_BASE_MODEL, device="balanced", controlnet=None, ip_adapter=False, plus_model=True, torch_dtype=torch.float16, model_cpu_offload_seq=None, enable_sequential_cpu_offload=False, vae_slicing=False, pipeline_class=None, **kwargs):
model_kwargs = dict(
torch_dtype=torch_dtype,
# device_map=device,
requires_safety_checker=False,
safety_checker=None,
)
components = load_base_model_components(base_model=base_model, torch_dtype=torch_dtype)
model_kwargs.update(components)
model_kwargs.update(kwargs)
if controlnet is not None:
if isinstance(controlnet, list):
controlnet = [load_controlnet(controlnet_path, torch_dtype=torch_dtype) for controlnet_path in controlnet]
else:
controlnet = load_controlnet(controlnet, torch_dtype=torch_dtype)
model_kwargs.update(controlnet=controlnet)
if pipeline_class is None:
if controlnet is not None:
pipeline_class = StableDiffusionControlNetPipeline
else:
pipeline_class = StableDiffusionPipeline
pipe: StableDiffusionPipeline = model_from_ckpt_or_pretrained(
base_model,
pipeline_class,
**model_kwargs
)
if ip_adapter:
image_encoder = load_image_encoder()
pipe.image_encoder = image_encoder
if plus_model:
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.safetensors")
else:
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.safetensors")
pipe.set_ip_adapter_scale(1.0)
else:
pipe.unload_ip_adapter()
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
if model_cpu_offload_seq is None:
if isinstance(pipe, StableDiffusionControlNetPipeline):
pipe.model_cpu_offload_seq = "text_encoder->controlnet->unet->vae"
elif isinstance(pipe, StableDiffusionControlNetImg2ImgPipeline):
pipe.model_cpu_offload_seq = "text_encoder->controlnet->vae->unet->vae"
else:
pipe.model_cpu_offload_seq = model_cpu_offload_seq
if enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload()
else:
pass
pipe.enable_model_cpu_offload()
if vae_slicing:
pipe.enable_vae_slicing()
import gc
gc.collect()
return pipe