+ There are {currentQueueSize} + user(s) sharing the same GPU, affecting real-time performance. Maximum queue size is {maxQueueSize}. + Duplicate and run it on your own GPU. +
+ {/if} +Loading...
+diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000000000000000000000000000000000000..eddf290e7e469beceab4cb4b833d2f08cd2c9936 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "MiDaS"] + path = live2diff/MiDaS + url = git@github.com:lewiji/MiDaS.git diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..f35b5779d8bbf315a38d0b780dc7e5aa4d80c315 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,62 @@ +FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04 + +ARG DEBIAN_FRONTEND=noninteractive + +ENV PYTHONUNBUFFERED=1 +ENV NODE_MAJOR=20 + +RUN apt-get update && apt-get install --no-install-recommends -y \ + build-essential \ + python3.9 \ + python3-pip \ + python3-dev \ + git \ + ffmpeg \ + google-perftools \ + ca-certificates curl gnupg \ + && apt-get clean && rm -rf /var/lib/apt/lists/* + +WORKDIR /code + +RUN mkdir -p /etc/apt/keyrings +RUN curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key | gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg +RUN echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_${NODE_MAJOR}.x nodistro main" | tee /etc/apt/sources.list.d/nodesource.list > /dev/null +RUN apt-get update && apt-get install nodejs -y + +COPY ./requirements.txt /code/requirements.txt + +# Set up a new user named "user" with user ID 1000 +RUN useradd -m -u 1000 user +# Switch to the "user" user +USER user +# Set home to the user's home directory +ENV HOME=/home/user \ + PATH=/home/user/.local/bin:$PATH \ + PYTHONPATH=$HOME/app \ + PYTHONUNBUFFERED=1 \ + SYSTEM=spaces + +# Set the working directory to the user's home directory +WORKDIR $HOME/app + +# Copy the current directory contents into the container at $HOME/app setting the owner to the user +COPY --chown=user . $HOME/app + +ENV LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4 + +# install dependencies +RUN git submodule update --init --recursive +RUN pip install -e ."[tensorrt]" + +# download models` +RUN mkdir models +RUN huggingface-cli download Leoxing/Live2Diff live2diff.ckpt --local-dir ./models +RUN huggingface-cli download runwayml/stable-diffusion-v1-5 \ + --local-dir ./models/Model/stable-diffusion-v1-5 +RUN bash scripts/download.sh felted +RUN wget https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt -P models + +RUN cd demo +RUN pip install -r requirements.txt +CMD ["bash ./start.sh"] + diff --git a/demo/.gitattributes b/demo/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..a6344aac8c09253b3b630fb776ae94478aa0275b --- /dev/null +++ b/demo/.gitattributes @@ -0,0 +1,35 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text diff --git a/demo/.gitignore b/demo/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..3f6ac0a6477df8a139d73e921ba3078bd9fadd45 --- /dev/null +++ b/demo/.gitignore @@ -0,0 +1,6 @@ +__pycache__/ +venv/ +public/ +*.pem +!lib/ +!static/ diff --git a/demo/README.md b/demo/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f504977e6657cb71588aeb19ad741ea8d2b601ba --- /dev/null +++ b/demo/README.md @@ -0,0 +1,70 @@ +# Video2Video Example + +
+ Human Face (Web Camera Input) + |
+
+ Anime Character (Screen Video Input) + |
+
+ | ++ | +
+ There are {currentQueueSize} + user(s) sharing the same GPU, affecting real-time performance. Maximum queue size is {maxQueueSize}. + Duplicate and run it on your own GPU. +
+ {/if} +Loading...
++ This demo showcases + Live2Diff + +pipeline using + LCM-LoRA with a MJPEG stream server. +
+""" + + +WARMUP_FRAMES = 8 +WINDOW_SIZE = 16 + + +class Pipeline: + class Info(BaseModel): + name: str = "Live2Diff" + input_mode: str = "image" + page_content: str = page_content + + def build_input_params(self, default_prompt: str = default_prompt, width=512, height=512): + class InputParams(BaseModel): + prompt: str = Field( + default_prompt, + title="Prompt", + field="textarea", + id="prompt", + ) + width: int = Field( + 512, + min=2, + max=15, + title="Width", + disabled=True, + hide=True, + id="width", + ) + height: int = Field( + 512, + min=2, + max=15, + title="Height", + disabled=True, + hide=True, + id="height", + ) + + return InputParams + + def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype): + config_path = args.config + + cfg = load_config(config_path) + prompt = args.prompt or cfg.prompt or default_prompt + + self.InputParams = self.build_input_params(default_prompt=prompt) + params = self.InputParams() + + num_inference_steps = args.num_inference_steps or cfg.get("num_inference_steps", None) + strength = args.strength or cfg.get("strength", None) + t_index_list = args.t_index_list or cfg.get("t_index_list", None) + + self.stream = StreamAnimateDiffusionDepthWrapper( + few_step_model_type="lcm", + config_path=config_path, + cfg_type="none", + strength=strength, + num_inference_steps=num_inference_steps, + t_index_list=t_index_list, + frame_buffer_size=1, + width=params.width, + height=params.height, + acceleration=args.acceleration, + do_add_noise=True, + output_type="pil", + enable_similar_image_filter=True, + similar_image_filter_threshold=0.98, + use_denoising_batch=True, + use_tiny_vae=True, + seed=args.seed, + engine_dir=args.engine_dir, + ) + + self.last_prompt = prompt + + self.warmup_frame_list = [] + self.has_prepared = False + + def predict(self, params: "Pipeline.InputParams") -> Image.Image: + prompt = params.prompt + if prompt != self.last_prompt: + self.last_prompt = prompt + self.warmup_frame_list.clear() + + if len(self.warmup_frame_list) < WARMUP_FRAMES: + # from PIL import Image + self.warmup_frame_list.append(self.stream.preprocess_image(params.image)) + + elif len(self.warmup_frame_list) == WARMUP_FRAMES and not self.has_prepared: + warmup_frames = torch.stack(self.warmup_frame_list) + self.stream.prepare( + warmup_frames=warmup_frames, + prompt=prompt, + guidance_scale=1, + ) + self.has_prepared = True + + if self.has_prepared: + image_tensor = self.stream.preprocess_image(params.image) + output_image = self.stream(image=image_tensor) + return output_image + else: + return Image.new("RGB", (params.width, params.height)) diff --git a/live2diff/__init__.py b/live2diff/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..58caedb0ce3c3bd5aac9143a392b231a5c77fde4 --- /dev/null +++ b/live2diff/__init__.py @@ -0,0 +1,4 @@ +from .pipeline_stream_animation_depth import StreamAnimateDiffusionDepth + + +__all__ = ["StreamAnimateDiffusionDepth"] diff --git a/live2diff/acceleration/__init__.py b/live2diff/acceleration/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/live2diff/acceleration/tensorrt/__init__.py b/live2diff/acceleration/tensorrt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e5c718c06b1d60ee041a975ea92eb47ecd69ab36 --- /dev/null +++ b/live2diff/acceleration/tensorrt/__init__.py @@ -0,0 +1,45 @@ +import torch +import torch.nn as nn +from diffusers import AutoencoderKL +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import ( + retrieve_latents, +) + +from .builder import EngineBuilder +from .models import BaseModel + + +class TorchVAEEncoder(torch.nn.Module): + def __init__(self, vae: AutoencoderKL): + super().__init__() + self.vae = vae + + def forward(self, x: torch.Tensor): + return retrieve_latents(self.vae.encode(x)) + + +def compile_engine( + torch_model: nn.Module, + model_data: BaseModel, + onnx_path: str, + onnx_opt_path: str, + engine_path: str, + opt_image_height: int = 512, + opt_image_width: int = 512, + opt_batch_size: int = 1, + engine_build_options: dict = {}, +): + builder = EngineBuilder( + model_data, + torch_model, + device=torch.device("cuda"), + ) + builder.build( + onnx_path, + onnx_opt_path, + engine_path, + opt_image_height=opt_image_height, + opt_image_width=opt_image_width, + opt_batch_size=opt_batch_size, + **engine_build_options, + ) diff --git a/live2diff/acceleration/tensorrt/builder.py b/live2diff/acceleration/tensorrt/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..c2c83fb403f6ea521bfa555c63635bdad1882946 --- /dev/null +++ b/live2diff/acceleration/tensorrt/builder.py @@ -0,0 +1,103 @@ +import gc +import os +from typing import * + +import torch + +from .models import BaseModel +from .utilities import ( + build_engine, + export_onnx, + handle_onnx_batch_norm, + optimize_onnx, +) + + +class EngineBuilder: + def __init__( + self, + model: BaseModel, + network: Any, + device=torch.device("cuda"), + ): + self.device = device + + self.model = model + self.network = network + + def build( + self, + onnx_path: str, + onnx_opt_path: str, + engine_path: str, + opt_image_height: int = 512, + opt_image_width: int = 512, + opt_batch_size: int = 1, + min_image_resolution: int = 256, + max_image_resolution: int = 1024, + build_enable_refit: bool = False, + build_static_batch: bool = False, + build_dynamic_shape: bool = False, + build_all_tactics: bool = False, + onnx_opset: int = 17, + force_engine_build: bool = False, + force_onnx_export: bool = False, + force_onnx_optimize: bool = False, + ignore_onnx_optimize: bool = False, + auto_cast: bool = True, + handle_batch_norm: bool = False, + ): + if not force_onnx_export and os.path.exists(onnx_path): + print(f"Found cached model: {onnx_path}") + else: + print(f"Exporting model: {onnx_path}") + export_onnx( + self.network, + onnx_path=onnx_path, + model_data=self.model, + opt_image_height=opt_image_height, + opt_image_width=opt_image_width, + opt_batch_size=opt_batch_size, + onnx_opset=onnx_opset, + auto_cast=auto_cast, + ) + del self.network + gc.collect() + torch.cuda.empty_cache() + + if handle_batch_norm: + print(f"Handle Batch Norm for {onnx_path}") + handle_onnx_batch_norm(onnx_path) + + if ignore_onnx_optimize: + print(f"Ignore onnx optimize for {onnx_path}.") + onnx_opt_path = onnx_path + elif not force_onnx_optimize and os.path.exists(onnx_opt_path): + print(f"Found cached model: {onnx_opt_path}") + else: + print(f"Generating optimizing model: {onnx_opt_path}") + optimize_onnx( + onnx_path=onnx_path, + onnx_opt_path=onnx_opt_path, + model_data=self.model, + ) + self.model.min_latent_shape = min_image_resolution // 8 + self.model.max_latent_shape = max_image_resolution // 8 + if not force_engine_build and os.path.exists(engine_path): + print(f"Found cached engine: {engine_path}") + else: + build_engine( + engine_path=engine_path, + onnx_opt_path=onnx_opt_path, + model_data=self.model, + opt_image_height=opt_image_height, + opt_image_width=opt_image_width, + opt_batch_size=opt_batch_size, + build_static_batch=build_static_batch, + build_dynamic_shape=build_dynamic_shape, + build_all_tactics=build_all_tactics, + build_enable_refit=build_enable_refit, + ) + + gc.collect() + torch.cuda.empty_cache() diff --git a/live2diff/acceleration/tensorrt/engine.py b/live2diff/acceleration/tensorrt/engine.py new file mode 100644 index 0000000000000000000000000000000000000000..5b677d1305a7f60aa5441d6021ea8ed3db0f9f83 --- /dev/null +++ b/live2diff/acceleration/tensorrt/engine.py @@ -0,0 +1,239 @@ +from typing import * + +import torch +from polygraphy import cuda + +from live2diff.animatediff.models.unet_depth_streaming import UNet3DConditionStreamingOutput + +from .utilities import Engine + + +try: + from diffusers.models.autoencoder_tiny import AutoencoderTinyOutput +except ImportError: + from dataclasses import dataclass + + from diffusers.utils import BaseOutput + + @dataclass + class AutoencoderTinyOutput(BaseOutput): + """ + Output of AutoencoderTiny encoding method. + + Args: + latents (`torch.Tensor`): Encoded outputs of the `Encoder`. + + """ + + latents: torch.Tensor + + +try: + from diffusers.models.vae import DecoderOutput +except ImportError: + from dataclasses import dataclass + + from diffusers.utils import BaseOutput + + @dataclass + class DecoderOutput(BaseOutput): + r""" + Output of decoding method. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + The decoded output sample from the last layer of the model. + """ + + sample: torch.FloatTensor + + +class AutoencoderKLEngine: + def __init__( + self, + encoder_path: str, + decoder_path: str, + stream: cuda.Stream, + scaling_factor: int, + use_cuda_graph: bool = False, + ): + self.encoder = Engine(encoder_path) + self.decoder = Engine(decoder_path) + self.stream = stream + self.vae_scale_factor = scaling_factor + self.use_cuda_graph = use_cuda_graph + + self.encoder.load() + self.decoder.load() + self.encoder.activate() + self.decoder.activate() + + def encode(self, images: torch.Tensor, **kwargs): + self.encoder.allocate_buffers( + shape_dict={ + "images": images.shape, + "latent": ( + images.shape[0], + 4, + images.shape[2] // self.vae_scale_factor, + images.shape[3] // self.vae_scale_factor, + ), + }, + device=images.device, + ) + latents = self.encoder.infer( + {"images": images}, + self.stream, + use_cuda_graph=self.use_cuda_graph, + )["latent"] + return AutoencoderTinyOutput(latents=latents) + + def decode(self, latent: torch.Tensor, **kwargs): + self.decoder.allocate_buffers( + shape_dict={ + "latent": latent.shape, + "images": ( + latent.shape[0], + 3, + latent.shape[2] * self.vae_scale_factor, + latent.shape[3] * self.vae_scale_factor, + ), + }, + device=latent.device, + ) + images = self.decoder.infer( + {"latent": latent}, + self.stream, + use_cuda_graph=self.use_cuda_graph, + )["images"] + return DecoderOutput(sample=images) + + def to(self, *args, **kwargs): + pass + + def forward(self, *args, **kwargs): + pass + + +class UNet2DConditionModelDepthEngine: + def __init__(self, filepath: str, stream: cuda.Stream, use_cuda_graph: bool = False): + self.engine = Engine(filepath) + self.stream = stream + self.use_cuda_graph = use_cuda_graph + + self.init_profiler() + + self.engine.load() + self.engine.activate(profiler=self.profiler) + self.has_allocated = False + + def init_profiler(self): + import tensorrt + + class Profiler(tensorrt.IProfiler): + def __init__(self): + tensorrt.IProfiler.__init__(self) + + def report_layer_time(self, layer_name, ms): + print(f"{layer_name}: {ms} ms") + + self.profiler = Profiler() + + def __call__( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temporal_attention_mask: torch.Tensor, + depth_sample: torch.Tensor, + kv_cache: List[torch.Tensor], + pe_idx: torch.Tensor, + update_idx: torch.Tensor, + **kwargs, + ) -> Any: + if timestep.dtype != torch.float32: + timestep = timestep.float() + + feed_dict = { + "sample": latent_model_input, + "timestep": timestep, + "encoder_hidden_states": encoder_hidden_states, + "temporal_attention_mask": temporal_attention_mask, + "depth_sample": depth_sample, + "pe_idx": pe_idx, + "update_idx": update_idx, + } + for idx, cache in enumerate(kv_cache): + feed_dict[f"kv_cache_{idx}"] = cache + shape_dict = {k: v.shape for k, v in feed_dict.items()} + + if not self.has_allocated: + self.engine.allocate_buffers( + shape_dict=shape_dict, + device=latent_model_input.device, + ) + self.has_allocated = True + + output = self.engine.infer( + feed_dict, + self.stream, + use_cuda_graph=self.use_cuda_graph, + ) + + noise_pred = output["latent"] + kv_cache = [output[f"kv_cache_out_{idx}"] for idx in range(len(kv_cache))] + return UNet3DConditionStreamingOutput(sample=noise_pred, kv_cache=kv_cache) + + def to(self, *args, **kwargs): + pass + + def forward(self, *args, **kwargs): + pass + + +class MidasEngine: + def __init__(self, filepath: str, stream: cuda.Stream, use_cuda_graph: bool = False): + self.engine = Engine(filepath) + self.stream = stream + self.use_cuda_graph = use_cuda_graph + + self.engine.load() + self.engine.activate() + self.has_allocated = False + self.default_batch_size = 1 + + def __call__( + self, + images: torch.Tensor, + **kwargs, + ) -> Any: + if not self.has_allocated or images.shape[0] != self.default_batch_size: + bz = images.shape[0] + self.engine.allocate_buffers( + shape_dict={ + "images": (bz, 3, 384, 384), + "depth_map": (bz, 384, 384), + }, + device=images.device, + ) + self.has_allocated = True + self.default_batch_size = bz + + depth_map = self.engine.infer( + { + "images": images, + }, + self.stream, + use_cuda_graph=self.use_cuda_graph, + )["depth_map"] # (1, 384, 384) + + return depth_map + + def norm(self, x): + return (x - x.min()) / (x.max() - x.min()) + + def to(self, *args, **kwargs): + pass + + def forward(self, *args, **kwargs): + pass diff --git a/live2diff/acceleration/tensorrt/models.py b/live2diff/acceleration/tensorrt/models.py new file mode 100644 index 0000000000000000000000000000000000000000..e48bfb681af42ac600aa157148665cdcdb6beadc --- /dev/null +++ b/live2diff/acceleration/tensorrt/models.py @@ -0,0 +1,600 @@ +#! fork: https://github.com/NVIDIA/TensorRT/blob/main/demo/Diffusion/models.py + +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import gc + +import onnx +import onnx_graphsurgeon as gs +import torch +from onnx import shape_inference +from polygraphy.backend.onnx.loader import fold_constants + + +class Optimizer: + def __init__(self, onnx_path, verbose=False): + self.graph = gs.import_onnx(onnx.load(onnx_path)) + self.verbose = verbose + + def info(self, prefix): + if self.verbose: + print( + f"{prefix} .. {len(self.graph.nodes)} nodes, {len(self.graph.tensors().keys())} tensors, {len(self.graph.inputs)} inputs, {len(self.graph.outputs)} outputs" + ) + + def cleanup(self, return_onnx=False): + self.graph.cleanup().toposort() + if return_onnx: + return gs.export_onnx(self.graph) + + def select_outputs(self, keep, names=None): + self.graph.outputs = [self.graph.outputs[o] for o in keep] + if names: + for i, name in enumerate(names): + self.graph.outputs[i].name = name + + def fold_constants(self, return_onnx=False): + onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True) + self.graph = gs.import_onnx(onnx_graph) + if return_onnx: + return onnx_graph + + def infer_shapes(self, return_onnx=False): + onnx_graph = gs.export_onnx(self.graph) + if onnx_graph.ByteSize() > 2147483648: + raise TypeError(f"ERROR: model size exceeds supported 2GB limit, {onnx_graph.ByteSize() / 2147483648}") + else: + onnx_graph = shape_inference.infer_shapes(onnx_graph) + + self.graph = gs.import_onnx(onnx_graph) + if return_onnx: + return onnx_graph + + def infer_shapes_with_external(self, save_path, return_onnx=False): + # https://github.com/onnx/onnx/blob/main/docs/PythonAPIOverview.md#running-shape-inference-on-an-onnx-model + onnx_graph = gs.export_onnx(self.graph) + onnx.save_model( + onnx_graph, + save_path, + save_as_external_data=True, + all_tensors_to_one_file=False, + size_threshold=1024, + ) + shape_inference.infer_shapes_path(save_path, save_path) + self.graph = gs.import_onnx(onnx.load(save_path)) + if return_onnx: + return onnx.load(save_path) + + +class BaseModel: + def __init__( + self, + fp16=False, + device="cuda", + verbose=True, + max_batch_size=16, + min_batch_size=1, + embedding_dim=768, + text_maxlen=77, + ): + self.name = "SD Model" + self.fp16 = fp16 + self.device = device + self.verbose = verbose + + self.min_batch = min_batch_size + self.max_batch = max_batch_size + self.min_image_shape = 256 # min image resolution: 256x256 + self.max_image_shape = 1024 # max image resolution: 1024x1024 + self.min_latent_shape = self.min_image_shape // 8 + self.max_latent_shape = self.max_image_shape // 8 + + self.embedding_dim = embedding_dim + self.text_maxlen = text_maxlen + + def get_model(self): + pass + + def get_input_names(self): + pass + + def get_output_names(self): + pass + + def get_dynamic_axes(self): + return None + + def get_sample_input(self, batch_size, image_height, image_width): + pass + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): + return None + + def get_shape_dict(self, batch_size, image_height, image_width): + return None + + def optimize(self, onnx_path, onnx_opt_path): + opt = Optimizer(onnx_path, verbose=self.verbose) + opt.info(self.name + ": original") + opt.cleanup() + opt.info(self.name + ": cleanup") + opt.fold_constants() + opt.info(self.name + ": fold constants") + opt.infer_shapes() + opt.info(self.name + ": shape inference") + onnx_opt_graph = opt.cleanup(return_onnx=True) + opt.info(self.name + ": finished") + onnx.save(onnx_opt_graph, onnx_opt_path) + opt.info(self.name + f": saved to {onnx_opt_path}") + + del onnx_opt_graph + gc.collect() + torch.cuda.empty_cache() + + def check_dims(self, batch_size, image_height, image_width): + assert batch_size >= self.min_batch and batch_size <= self.max_batch + assert image_height % 8 == 0 or image_width % 8 == 0 + latent_height = image_height // 8 + latent_width = image_width // 8 + assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape + assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape + return (latent_height, latent_width) + + def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_shape): + min_batch = batch_size if static_batch else self.min_batch + max_batch = batch_size if static_batch else self.max_batch + latent_height = image_height // 8 + latent_width = image_width // 8 + min_image_height = image_height if static_shape else self.min_image_shape + max_image_height = image_height if static_shape else self.max_image_shape + min_image_width = image_width if static_shape else self.min_image_shape + max_image_width = image_width if static_shape else self.max_image_shape + min_latent_height = latent_height if static_shape else self.min_latent_shape + max_latent_height = latent_height if static_shape else self.max_latent_shape + min_latent_width = latent_width if static_shape else self.min_latent_shape + max_latent_width = latent_width if static_shape else self.max_latent_shape + return ( + min_batch, + max_batch, + min_image_height, + max_image_height, + min_image_width, + max_image_width, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) + + +class CLIP(BaseModel): + def __init__(self, device, max_batch_size, embedding_dim, min_batch_size=1): + super(CLIP, self).__init__( + device=device, + max_batch_size=max_batch_size, + min_batch_size=min_batch_size, + embedding_dim=embedding_dim, + ) + self.name = "CLIP" + + def get_input_names(self): + return ["input_ids"] + + def get_output_names(self): + return ["text_embeddings", "pooler_output"] + + def get_dynamic_axes(self): + return {"input_ids": {0: "B"}, "text_embeddings": {0: "B"}} + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): + self.check_dims(batch_size, image_height, image_width) + min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims( + batch_size, image_height, image_width, static_batch, static_shape + ) + return { + "input_ids": [ + (min_batch, self.text_maxlen), + (batch_size, self.text_maxlen), + (max_batch, self.text_maxlen), + ] + } + + def get_shape_dict(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + return { + "input_ids": (batch_size, self.text_maxlen), + "text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim), + } + + def get_sample_input(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device) + + def optimize(self, onnx_path, onnx_opt_path): + opt = Optimizer(onnx_path) + opt.info(self.name + ": original") + opt.select_outputs([0]) # delete graph output#1 + opt.cleanup() + opt.info(self.name + ": remove output[1]") + opt.fold_constants() + opt.info(self.name + ": fold constants") + opt.infer_shapes() + opt.info(self.name + ": shape inference") + opt.select_outputs([0], names=["text_embeddings"]) # rename network output + opt.info(self.name + ": remove output[0]") + onnx_opt_graph = opt.cleanup(return_onnx=True) + opt.info(self.name + ": finished") + onnx.save(onnx_opt_graph, onnx_opt_path) + opt.info(self.name + f": saved to {onnx_opt_path}") + + del onnx_opt_graph + gc.collect() + torch.cuda.empty_cache() + + +class InflatedUNetDepth(BaseModel): + def __init__( + self, + fp16=False, + device="cuda", + max_batch_size=16, + min_batch_size=1, + embedding_dim=768, + text_maxlen=77, + unet_dim=4, + kv_cache_list=None, + ): + super().__init__( + fp16=fp16, + device=device, + max_batch_size=max_batch_size, + min_batch_size=min_batch_size, + embedding_dim=embedding_dim, + text_maxlen=text_maxlen, + ) + + self.kv_cache_list = kv_cache_list + self.unet_dim = unet_dim + self.name = "UNet" + + self.streaming_length = 1 + self.window_size = 16 + + def get_input_names(self): + input_list = ["sample", "timestep", "encoder_hidden_states", "temporal_attention_mask", "depth_sample"] + input_list += [f"kv_cache_{i}" for i in range(len(self.kv_cache_list))] + input_list += ["pe_idx", "update_idx"] + return input_list + + def get_output_names(self): + output_list = ["latent"] + output_list += [f"kv_cache_out_{i}" for i in range(len(self.kv_cache_list))] + return output_list + + def get_dynamic_axes(self): + # NOTE: disable dynamic axes + return {} + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + ( + min_batch, + max_batch, + _, + _, + _, + _, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape) + + input_profile = { + "sample": [ + (min_batch, self.unet_dim, self.streaming_length, min_latent_height, min_latent_width), + (batch_size, self.unet_dim, self.streaming_length, latent_height, latent_width), + (max_batch, self.unet_dim, self.streaming_length, max_latent_height, max_latent_width), + ], + "timestep": [(min_batch,), (batch_size,), (max_batch,)], + "encoder_hidden_states": [ + (min_batch, self.text_maxlen, self.embedding_dim), + (batch_size, self.text_maxlen, self.embedding_dim), + (max_batch, self.text_maxlen, self.embedding_dim), + ], + "temporal_attention_mask": [ + (min_batch, self.window_size), + (batch_size, self.window_size), + (max_batch, self.window_size), + ], + "depth_sample": [ + (min_batch, self.unet_dim, self.streaming_length, min_latent_height, min_latent_width), + (batch_size, self.unet_dim, self.streaming_length, latent_height, latent_width), + (max_batch, self.unet_dim, self.streaming_length, max_latent_height, max_latent_width), + ], + } + for idx, tensor in enumerate(self.kv_cache_list): + input_profile[f"kv_cache_{idx}"] = [tuple(tensor.shape)] * 3 + + input_profile["pe_idx"] = [ + (min_batch, self.window_size), + (batch_size, self.window_size), + (max_batch, self.window_size), + ] + input_profile["update_idx"] = [ + (min_batch,), + (batch_size,), + (max_batch,), + ] + + return input_profile + + def get_sample_input(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + dtype = torch.float16 if self.fp16 else torch.float32 + attn_mask = torch.zeros((batch_size, self.window_size), dtype=torch.bool, device=self.device) + + attn_mask[:, :8] = True + attn_mask[0, -1] = True + attn_bias = torch.zeros_like(attn_mask, dtype=dtype, device=self.device) + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + + pe_idx = torch.arange(self.window_size).unsqueeze(0).repeat(batch_size, 1).cuda() + update_idx = torch.ones(batch_size, dtype=torch.int64).cuda() * 8 + update_idx[1] = 8 + 1 + + return ( + torch.randn( + batch_size, + self.unet_dim, + self.streaming_length, + latent_height, + latent_width, + dtype=dtype, + device=self.device, + ), + torch.ones((batch_size,), dtype=dtype, device=self.device), + torch.randn(batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), + attn_bias, + torch.randn( + batch_size, + self.unet_dim, + self.streaming_length, + latent_height, + latent_width, + dtype=dtype, + device=self.device, + ), + self.kv_cache_list, + pe_idx, + update_idx, + ) + + def optimize(self, onnx_path, onnx_opt_path): + """Onnx graph optimization function for model with external data.""" + opt = Optimizer(onnx_path, verbose=self.verbose) + opt.info(self.name + ": original") + opt.cleanup() + opt.info(self.name + ": cleanup") + opt.fold_constants() + opt.info(self.name + ": fold constants") + opt.infer_shapes_with_external(onnx_opt_path) + opt.info(self.name + ": shape inference") + onnx_opt_graph = opt.cleanup(return_onnx=True) + opt.info(self.name + ": finished") + onnx.save( + onnx_opt_graph, + onnx_opt_path, + save_as_external_data=True, + all_tensors_to_one_file=False, + size_threshold=1024, + ) + opt.info(self.name + f": saved to {onnx_opt_path}") + del onnx_opt_graph + gc.collect() + torch.cuda.empty_cache() + + +class Midas(BaseModel): + def __init__( + self, + fp16=False, + device="cuda", + max_batch_size=16, + min_batch_size=1, + embedding_dim=768, + text_maxlen=77, + ): + super().__init__( + fp16=fp16, + device=device, + max_batch_size=max_batch_size, + min_batch_size=min_batch_size, + embedding_dim=embedding_dim, + text_maxlen=text_maxlen, + ) + self.img_dim = 3 + self.name = "midas" + + def get_input_names(self): + return ["images"] + + def get_output_names(self): + return ["depth_map"] + + def get_dynamic_axes(self): + return { + "images": {0: "F"}, + "depth_map": {0: "F"}, + } + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + ( + min_batch, + max_batch, + _, + _, + _, + _, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape) + return { + "images": [ + (min_batch, self.img_dim, image_height, image_width), + (batch_size, self.img_dim, image_height, image_width), + (max_batch, self.img_dim, image_height, image_width), + ], + } + + def get_sample_input(self, batch_size, image_height, image_width): + dtype = torch.float16 if self.fp16 else torch.float32 + return torch.randn(batch_size, self.img_dim, image_height, image_width, dtype=dtype, device=self.device) + + +class VAE(BaseModel): + def __init__(self, device, max_batch_size, min_batch_size=1): + super(VAE, self).__init__( + device=device, + max_batch_size=max_batch_size, + min_batch_size=min_batch_size, + embedding_dim=None, + ) + self.name = "VAE decoder" + + def get_input_names(self): + return ["latent"] + + def get_output_names(self): + return ["images"] + + def get_dynamic_axes(self): + return { + "latent": {0: "B", 2: "H", 3: "W"}, + "images": {0: "B", 2: "8H", 3: "8W"}, + } + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + ( + min_batch, + max_batch, + _, + _, + _, + _, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape) + return { + "latent": [ + (min_batch, 4, min_latent_height, min_latent_width), + (batch_size, 4, latent_height, latent_width), + (max_batch, 4, max_latent_height, max_latent_width), + ] + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return { + "latent": (batch_size, 4, latent_height, latent_width), + "images": (batch_size, 3, image_height, image_width), + } + + def get_sample_input(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return torch.randn( + batch_size, + 4, + latent_height, + latent_width, + dtype=torch.float32, + device=self.device, + ) + + +class VAEEncoder(BaseModel): + def __init__(self, device, max_batch_size, min_batch_size=1): + super(VAEEncoder, self).__init__( + device=device, + max_batch_size=max_batch_size, + min_batch_size=min_batch_size, + embedding_dim=None, + ) + self.name = "VAE encoder" + + def get_input_names(self): + return ["images"] + + def get_output_names(self): + return ["latent"] + + def get_dynamic_axes(self): + return { + "images": {0: "B", 2: "8H", 3: "8W"}, + "latent": {0: "B", 2: "H", 3: "W"}, + } + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): + assert batch_size >= self.min_batch and batch_size <= self.max_batch + min_batch = batch_size if static_batch else self.min_batch + max_batch = batch_size if static_batch else self.max_batch + self.check_dims(batch_size, image_height, image_width) + ( + min_batch, + max_batch, + min_image_height, + max_image_height, + min_image_width, + max_image_width, + _, + _, + _, + _, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape) + + return { + "images": [ + (min_batch, 3, min_image_height, min_image_width), + (batch_size, 3, image_height, image_width), + (max_batch, 3, max_image_height, max_image_width), + ], + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return { + "images": (batch_size, 3, image_height, image_width), + "latent": (batch_size, 4, latent_height, latent_width), + } + + def get_sample_input(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + return torch.randn( + batch_size, + 3, + image_height, + image_width, + dtype=torch.float32, + device=self.device, + ) diff --git a/live2diff/acceleration/tensorrt/utilities.py b/live2diff/acceleration/tensorrt/utilities.py new file mode 100644 index 0000000000000000000000000000000000000000..a025cd0a1583a8b6643a63750853e2dad35fa061 --- /dev/null +++ b/live2diff/acceleration/tensorrt/utilities.py @@ -0,0 +1,434 @@ +#! fork: https://github.com/NVIDIA/TensorRT/blob/main/demo/Diffusion/utilities.py + +# +# Copyright 2022 The HuggingFace Inc. team. +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import gc +from collections import OrderedDict +from typing import * + +import numpy as np +import onnx +import onnx_graphsurgeon as gs +import tensorrt as trt +import torch +from cuda import cudart +from PIL import Image +from polygraphy import cuda +from polygraphy.backend.common import bytes_from_path +from polygraphy.backend.trt import ( + CreateConfig, + Profile, + engine_from_bytes, + engine_from_network, + network_from_onnx_path, + save_engine, +) + +from .models import BaseModel + + +TRT_LOGGER = trt.Logger(trt.Logger.ERROR) + +# Map of numpy dtype -> torch dtype +numpy_to_torch_dtype_dict = { + np.uint8: torch.uint8, + np.int8: torch.int8, + np.int16: torch.int16, + np.int32: torch.int32, + np.int64: torch.int64, + np.float16: torch.float16, + np.float32: torch.float32, + np.float64: torch.float64, + np.complex64: torch.complex64, + np.complex128: torch.complex128, +} +if np.version.full_version >= "1.24.0": + numpy_to_torch_dtype_dict[np.bool_] = torch.bool +else: + numpy_to_torch_dtype_dict[np.bool] = torch.bool + +# Map of torch dtype -> numpy dtype +torch_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()} + + +def CUASSERT(cuda_ret): + err = cuda_ret[0] + if err != cudart.cudaError_t.cudaSuccess: + raise RuntimeError( + f"CUDA ERROR: {err}, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t" + ) + if len(cuda_ret) > 1: + return cuda_ret[1] + return None + + +class Engine: + def __init__( + self, + engine_path, + ): + self.engine_path = engine_path + self.engine = None + self.context = None + self.buffers = OrderedDict() + self.tensors = OrderedDict() + self.cuda_graph_instance = None # cuda graph + + def __del__(self): + [buf.free() for buf in self.buffers.values() if isinstance(buf, cuda.DeviceArray)] + del self.engine + del self.context + del self.buffers + del self.tensors + + def refit(self, onnx_path, onnx_refit_path): + def convert_int64(arr): + # TODO: smarter conversion + if len(arr.shape) == 0: + return np.int32(arr) + return arr + + def add_to_map(refit_dict, name, values): + if name in refit_dict: + assert refit_dict[name] is None + if values.dtype == np.int64: + values = convert_int64(values) + refit_dict[name] = values + + print(f"Refitting TensorRT engine with {onnx_refit_path} weights") + refit_nodes = gs.import_onnx(onnx.load(onnx_refit_path)).toposort().nodes + + # Construct mapping from weight names in refit model -> original model + name_map = {} + for n, node in enumerate(gs.import_onnx(onnx.load(onnx_path)).toposort().nodes): + refit_node = refit_nodes[n] + assert node.op == refit_node.op + # Constant nodes in ONNX do not have inputs but have a constant output + if node.op == "Constant": + name_map[refit_node.outputs[0].name] = node.outputs[0].name + # Handle scale and bias weights + elif node.op == "Conv": + if node.inputs[1].__class__ == gs.Constant: + name_map[refit_node.name + "_TRTKERNEL"] = node.name + "_TRTKERNEL" + if node.inputs[2].__class__ == gs.Constant: + name_map[refit_node.name + "_TRTBIAS"] = node.name + "_TRTBIAS" + # For all other nodes: find node inputs that are initializers (gs.Constant) + else: + for i, inp in enumerate(node.inputs): + if inp.__class__ == gs.Constant: + name_map[refit_node.inputs[i].name] = inp.name + + def map_name(name): + if name in name_map: + return name_map[name] + return name + + # Construct refit dictionary + refit_dict = {} + refitter = trt.Refitter(self.engine, TRT_LOGGER) + all_weights = refitter.get_all() + for layer_name, role in zip(all_weights[0], all_weights[1]): + # for speciailized roles, use a unique name in the map: + if role == trt.WeightsRole.KERNEL: + name = layer_name + "_TRTKERNEL" + elif role == trt.WeightsRole.BIAS: + name = layer_name + "_TRTBIAS" + else: + name = layer_name + + assert name not in refit_dict, "Found duplicate layer: " + name + refit_dict[name] = None + + for n in refit_nodes: + # Constant nodes in ONNX do not have inputs but have a constant output + if n.op == "Constant": + name = map_name(n.outputs[0].name) + print(f"Add Constant {name}\n") + add_to_map(refit_dict, name, n.outputs[0].values) + + # Handle scale and bias weights + elif n.op == "Conv": + if n.inputs[1].__class__ == gs.Constant: + name = map_name(n.name + "_TRTKERNEL") + add_to_map(refit_dict, name, n.inputs[1].values) + + if n.inputs[2].__class__ == gs.Constant: + name = map_name(n.name + "_TRTBIAS") + add_to_map(refit_dict, name, n.inputs[2].values) + + # For all other nodes: find node inputs that are initializers (AKA gs.Constant) + else: + for inp in n.inputs: + name = map_name(inp.name) + if inp.__class__ == gs.Constant: + add_to_map(refit_dict, name, inp.values) + + for layer_name, weights_role in zip(all_weights[0], all_weights[1]): + if weights_role == trt.WeightsRole.KERNEL: + custom_name = layer_name + "_TRTKERNEL" + elif weights_role == trt.WeightsRole.BIAS: + custom_name = layer_name + "_TRTBIAS" + else: + custom_name = layer_name + + # Skip refitting Trilu for now; scalar weights of type int64 value 1 - for clip model + if layer_name.startswith("onnx::Trilu"): + continue + + if refit_dict[custom_name] is not None: + refitter.set_weights(layer_name, weights_role, refit_dict[custom_name]) + else: + print(f"[W] No refit weights for layer: {layer_name}") + + if not refitter.refit_cuda_engine(): + print("Failed to refit!") + exit(0) + + def build( + self, + onnx_path, + fp16, + input_profile=None, + enable_refit=False, + enable_all_tactics=False, + timing_cache=None, + workspace_size=0, + ): + print(f"Building TensorRT engine for {onnx_path}: {self.engine_path}") + p = Profile() + if input_profile: + for name, dims in input_profile.items(): + assert len(dims) == 3 + p.add(name, min=dims[0], opt=dims[1], max=dims[2]) + + config_kwargs = {} + + if workspace_size > 0: + config_kwargs["memory_pool_limits"] = {trt.MemoryPoolType.WORKSPACE: workspace_size} + if not enable_all_tactics: + config_kwargs["tactic_sources"] = [] + + engine = engine_from_network( + network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]), + config=CreateConfig( + fp16=fp16, refittable=enable_refit, profiles=[p], load_timing_cache=timing_cache, **config_kwargs + ), + save_timing_cache=timing_cache, + ) + save_engine(engine, path=self.engine_path) + + def load(self): + print(f"Loading TensorRT engine: {self.engine_path}") + self.engine = engine_from_bytes(bytes_from_path(self.engine_path)) + + def activate(self, reuse_device_memory=None, profiler=None): + if reuse_device_memory: + self.context = self.engine.create_execution_context_without_device_memory() + self.context.device_memory = reuse_device_memory + else: + self.context = self.engine.create_execution_context() + + def allocate_buffers(self, shape_dict=None, device="cuda"): + # NOTE: API for tensorrt 10.01 + from tensorrt import TensorIOMode + + for idx in range(self.engine.num_io_tensors): + binding = self.engine[idx] + if shape_dict and binding in shape_dict: + shape = shape_dict[binding] + else: + shape = self.engine.get_tensor_shape(binding) + dtype = trt.nptype(self.engine.get_tensor_dtype(binding)) + tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype], device=device) + self.tensors[binding] = tensor + + binding_mode = self.engine.get_tensor_mode(binding) + if binding_mode == TensorIOMode.INPUT: + self.context.set_input_shape(binding, shape) + self.has_allocated = True + + def infer(self, feed_dict, stream, use_cuda_graph=False): + for name, buf in feed_dict.items(): + self.tensors[name].copy_(buf) + + for name, tensor in self.tensors.items(): + self.context.set_tensor_address(name, tensor.data_ptr()) + + if use_cuda_graph: + if self.cuda_graph_instance is not None: + CUASSERT(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream.ptr)) + CUASSERT(cudart.cudaStreamSynchronize(stream.ptr)) + else: + # do inference before CUDA graph capture + noerror = self.context.execute_async_v3(stream.ptr) + if not noerror: + raise ValueError("ERROR: inference failed.") + # capture cuda graph + CUASSERT( + cudart.cudaStreamBeginCapture(stream.ptr, cudart.cudaStreamCaptureMode.cudaStreamCaptureModeGlobal) + ) + self.context.execute_async_v3(stream.ptr) + self.graph = CUASSERT(cudart.cudaStreamEndCapture(stream.ptr)) + self.cuda_graph_instance = CUASSERT(cudart.cudaGraphInstantiate(self.graph, 0)) + else: + noerror = self.context.execute_async_v3(stream.ptr) + if not noerror: + raise ValueError("ERROR: inference failed.") + + return self.tensors + + +def decode_images(images: torch.Tensor): + images = ( + ((images + 1) * 255 / 2).clamp(0, 255).detach().permute(0, 2, 3, 1).round().type(torch.uint8).cpu().numpy() + ) + return [Image.fromarray(x) for x in images] + + +def preprocess_image(image: Image.Image): + w, h = image.size + w, h = [x - x % 32 for x in (w, h)] # resize to integer multiple of 32 + image = image.resize((w, h)) + init_image = np.array(image).astype(np.float32) / 255.0 + init_image = init_image[None].transpose(0, 3, 1, 2) + init_image = torch.from_numpy(init_image).contiguous() + return 2.0 * init_image - 1.0 + + +def prepare_mask_and_masked_image(image: Image.Image, mask: Image.Image): + if isinstance(image, Image.Image): + image = np.array(image.convert("RGB")) + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32).contiguous() / 127.5 - 1.0 + if isinstance(mask, Image.Image): + mask = np.array(mask.convert("L")) + mask = mask.astype(np.float32) / 255.0 + mask = mask[None, None] + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask).to(dtype=torch.float32).contiguous() + + masked_image = image * (mask < 0.5) + + return mask, masked_image + + +def build_engine( + engine_path: str, + onnx_opt_path: str, + model_data: BaseModel, + opt_image_height: int, + opt_image_width: int, + opt_batch_size: int, + build_static_batch: bool = False, + build_dynamic_shape: bool = False, + build_all_tactics: bool = False, + build_enable_refit: bool = False, +): + _, free_mem, _ = cudart.cudaMemGetInfo() + GiB = 2**30 + if free_mem > 6 * GiB: + activation_carveout = 4 * GiB + max_workspace_size = free_mem - activation_carveout + else: + max_workspace_size = 0 + engine = Engine(engine_path) + input_profile = model_data.get_input_profile( + opt_batch_size, + opt_image_height, + opt_image_width, + static_batch=build_static_batch, + static_shape=not build_dynamic_shape, + ) + engine.build( + onnx_opt_path, + fp16=True, + input_profile=input_profile, + enable_refit=build_enable_refit, + enable_all_tactics=build_all_tactics, + workspace_size=max_workspace_size, + ) + + return engine + + +def export_onnx( + model, + onnx_path: str, + model_data: BaseModel, + opt_image_height: int, + opt_image_width: int, + opt_batch_size: int, + onnx_opset: int, + auto_cast: bool = True, +): + from contextlib import contextmanager + + @contextmanager + def auto_cast_manager(enabled): + if enabled: + with torch.inference_mode(), torch.autocast("cuda"): + yield + else: + yield + + with auto_cast_manager(auto_cast): + inputs = model_data.get_sample_input(opt_batch_size, opt_image_height, opt_image_width) + torch.onnx.export( + model, + inputs, + onnx_path, + export_params=True, + opset_version=onnx_opset, + do_constant_folding=True, + input_names=model_data.get_input_names(), + output_names=model_data.get_output_names(), + dynamic_axes=model_data.get_dynamic_axes(), + ) + del model + gc.collect() + torch.cuda.empty_cache() + + +def optimize_onnx( + onnx_path: str, + onnx_opt_path: str, + model_data: BaseModel, +): + model_data.optimize(onnx_path, onnx_opt_path) + # # onnx_opt_graph = model_data.optimize(onnx.load(onnx_path)) + # onnx_opt_graph = model_data.optimize(onnx_path) + # onnx.save(onnx_opt_graph, onnx_opt_path) + # del onnx_opt_graph + # gc.collect() + # torch.cuda.empty_cache() + + +def handle_onnx_batch_norm(onnx_path: str): + onnx_model = onnx.load(onnx_path) + for node in onnx_model.graph.node: + if node.op_type == "BatchNormalization": + for attribute in node.attribute: + if attribute.name == "training_mode": + if attribute.i == 1: + node.output.remove(node.output[1]) + node.output.remove(node.output[1]) + attribute.i = 0 + + onnx.save_model(onnx_model, onnx_path) diff --git a/live2diff/animatediff/__init__.py b/live2diff/animatediff/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/live2diff/animatediff/converter/__init__.py b/live2diff/animatediff/converter/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..84490b47685e7bf6f9671c75657df3de995b7955 --- /dev/null +++ b/live2diff/animatediff/converter/__init__.py @@ -0,0 +1,4 @@ +from .convert import load_third_party_checkpoints, load_third_party_unet + + +__all__ = ["load_third_party_checkpoints", "load_third_party_unet"] diff --git a/live2diff/animatediff/converter/convert.py b/live2diff/animatediff/converter/convert.py new file mode 100644 index 0000000000000000000000000000000000000000..a326fdd4782076ee10d179eb716633d79ea6d43e --- /dev/null +++ b/live2diff/animatediff/converter/convert.py @@ -0,0 +1,134 @@ +from typing import Optional + +import torch +from diffusers.pipelines import StableDiffusionPipeline +from safetensors import safe_open + +from .convert_from_ckpt import convert_ldm_clip_checkpoint, convert_ldm_unet_checkpoint, convert_ldm_vae_checkpoint +from .convert_lora_safetensor_to_diffusers import convert_lora_model_level + + +def load_third_party_checkpoints( + pipeline: StableDiffusionPipeline, + third_party_dict: dict, + dreambooth_path: Optional[str] = None, +): + """ + Modified from https://github.com/open-mmlab/PIA/blob/4b1ee136542e807a13c1adfe52f4e8e5fcc65cdb/animatediff/pipelines/i2v_pipeline.py#L165 + """ + vae = third_party_dict.get("vae", None) + lora_list = third_party_dict.get("lora_list", []) + + dreambooth = dreambooth_path or third_party_dict.get("dreambooth", None) + + text_embedding_dict = third_party_dict.get("text_embedding_dict", {}) + + if dreambooth is not None: + dreambooth_state_dict = {} + if dreambooth.endswith(".safetensors"): + with safe_open(dreambooth, framework="pt", device="cpu") as f: + for key in f.keys(): + dreambooth_state_dict[key] = f.get_tensor(key) + else: + dreambooth_state_dict = torch.load(dreambooth, map_location="cpu") + if "state_dict" in dreambooth_state_dict: + dreambooth_state_dict = dreambooth_state_dict["state_dict"] + # load unet + converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, pipeline.unet.config) + pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False) + + # load vae from dreambooth (if need) + converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, pipeline.vae.config) + # add prefix for compiled model + if "_orig_mod" in list(pipeline.vae.state_dict().keys())[0]: + converted_vae_checkpoint = {f"_orig_mod.{k}": v for k, v in converted_vae_checkpoint.items()} + pipeline.vae.load_state_dict(converted_vae_checkpoint, strict=True) + + # load text encoder (if need) + text_encoder_checkpoint = convert_ldm_clip_checkpoint(dreambooth_state_dict) + if text_encoder_checkpoint: + pipeline.text_encoder.load_state_dict(text_encoder_checkpoint, strict=False) + + if vae is not None: + vae_state_dict = {} + if vae.endswith("safetensors"): + with safe_open(vae, framework="pt", device="cpu") as f: + for key in f.keys(): + vae_state_dict[key] = f.get_tensor(key) + elif vae.endswith("ckpt") or vae.endswith("pt"): + vae_state_dict = torch.load(vae, map_location="cpu") + if "state_dict" in vae_state_dict: + vae_state_dict = vae_state_dict["state_dict"] + + vae_state_dict = {f"first_stage_model.{k}": v for k, v in vae_state_dict.items()} + + converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_state_dict, pipeline.vae.config) + # add prefix for compiled model + if "_orig_mod" in list(pipeline.vae.state_dict().keys())[0]: + converted_vae_checkpoint = {f"_orig_mod.{k}": v for k, v in converted_vae_checkpoint.items()} + pipeline.vae.load_state_dict(converted_vae_checkpoint, strict=True) + + if lora_list: + for lora_dict in lora_list: + lora, lora_alpha = lora_dict["lora"], lora_dict["lora_alpha"] + lora_state_dict = {} + with safe_open(lora, framework="pt", device="cpu") as file: + for k in file.keys(): + lora_state_dict[k] = file.get_tensor(k) + pipeline.unet, pipeline.text_encoder = convert_lora_model_level( + lora_state_dict, + pipeline.unet, + pipeline.text_encoder, + alpha=lora_alpha, + ) + print(f'Add LoRA "{lora}":{lora_alpha} to pipeline.') + + if text_embedding_dict is not None: + from diffusers.loaders import TextualInversionLoaderMixin + + assert isinstance( + pipeline, TextualInversionLoaderMixin + ), "Pipeline must inherit from TextualInversionLoaderMixin." + + for token, embedding_path in text_embedding_dict.items(): + pipeline.load_textual_inversion(embedding_path, token) + + return pipeline + + +def load_third_party_unet(unet, third_party_dict: dict, dreambooth_path: Optional[str] = None): + lora_list = third_party_dict.get("lora_list", []) + dreambooth = dreambooth_path or third_party_dict.get("dreambooth", None) + + if dreambooth is not None: + dreambooth_state_dict = {} + if dreambooth.endswith(".safetensors"): + with safe_open(dreambooth, framework="pt", device="cpu") as f: + for key in f.keys(): + dreambooth_state_dict[key] = f.get_tensor(key) + else: + dreambooth_state_dict = torch.load(dreambooth, map_location="cpu") + if "state_dict" in dreambooth_state_dict: + dreambooth_state_dict = dreambooth_state_dict["state_dict"] + # load unet + converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, unet.config) + unet.load_state_dict(converted_unet_checkpoint, strict=False) + + if lora_list: + for lora_dict in lora_list: + lora, lora_alpha = lora_dict["lora"], lora_dict["lora_alpha"] + lora_state_dict = {} + + with safe_open(lora, framework="pt", device="cpu") as file: + for k in file.keys(): + if "text" not in k: + lora_state_dict[k] = file.get_tensor(k) + unet, _ = convert_lora_model_level( + lora_state_dict, + unet, + None, + alpha=lora_alpha, + ) + print(f'Add LoRA "{lora}":{lora_alpha} to Warmup UNet.') + + return unet diff --git a/live2diff/animatediff/converter/convert_from_ckpt.py b/live2diff/animatediff/converter/convert_from_ckpt.py new file mode 100644 index 0000000000000000000000000000000000000000..9c1f512ba7937b7897574a18b92eb1e0916550a1 --- /dev/null +++ b/live2diff/animatediff/converter/convert_from_ckpt.py @@ -0,0 +1,599 @@ +# Modified from https://github.com/open-mmlab/PIA/blob/main/animatediff/utils/convert_from_ckpt.py and +# https://github.com/guoyww/AnimateDiff/blob/main/animatediff/utils/convert_from_ckpt.py +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Conversion script for the Stable Diffusion checkpoints.""" + + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + new_item = new_item.replace("q.weight", "to_q.weight") + new_item = new_item.replace("q.bias", "to_q.bias") + + new_item = new_item.replace("k.weight", "to_k.weight") + new_item = new_item.replace("k.bias", "to_k.bias") + + new_item = new_item.replace("v.weight", "to_v.weight") + new_item = new_item.replace("v.bias", "to_v.bias") + + new_item = new_item.replace("proj_out.weight", "to_out.0.weight") + new_item = new_item.replace("proj_out.bias", "to_out.0.bias") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + return mapping + + +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits + attention layers, and takes into account additional replacements that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + if "proj_attn.weight" in new_path: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + elif "to_out.0.weight" in new_path: + checkpoint[new_path] = old_checkpoint[path["old"]].squeeze() + elif any(qkv in new_path for qkv in ["to_q", "to_k", "to_v"]): + checkpoint[new_path] = old_checkpoint[path["old"]].squeeze() + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + + +def create_unet_diffusers_config(original_config, image_size: int, controlnet=False): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + if controlnet: + unet_params = original_config.model.params.control_stage_config.params + else: + unet_params = original_config.model.params.unet_config.params + + vae_params = original_config.model.params.first_stage_config.params.ddconfig + + block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) + + head_dim = unet_params.num_heads if "num_heads" in unet_params else None + use_linear_projection = ( + unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False + ) + if use_linear_projection: + # stable diffusion 2-base-512 and 2-768 + if head_dim is None: + head_dim = [5, 10, 20, 20] + + class_embed_type = None + projection_class_embeddings_input_dim = None + + if "num_classes" in unet_params: + if unet_params.num_classes == "sequential": + class_embed_type = "projection" + assert "adm_in_channels" in unet_params + projection_class_embeddings_input_dim = unet_params.adm_in_channels + else: + raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}") + + config = { + "sample_size": image_size // vae_scale_factor, + "in_channels": unet_params.in_channels, + "down_block_types": tuple(down_block_types), + "block_out_channels": tuple(block_out_channels), + "layers_per_block": unet_params.num_res_blocks, + "cross_attention_dim": unet_params.context_dim, + "attention_head_dim": head_dim, + "use_linear_projection": use_linear_projection, + "class_embed_type": class_embed_type, + "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, + } + + if not controlnet: + config["out_channels"] = unet_params.out_channels + config["up_block_types"] = tuple(up_block_types) + + return config + + +def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + + if controlnet: + unet_key = "control_model." + else: + unet_key = "model.diffusion_model." + + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: + print(f"Checkpoint {path} has both EMA and non-EMA weights.") + print( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + else: + if sum(k.startswith("model_ema") for k in keys) > 100: + print( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + if config["class_embed_type"] is None: + # No parameters to port + ... + elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": + new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + else: + raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + if not controlnet: + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias" + ) + + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + output_block_list = {k: sorted(v) for k, v in output_block_list.items()} + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + if controlnet: + # conditioning embedding + + orig_index = 0 + + new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + + orig_index += 2 + + diffusers_index = 0 + + while diffusers_index < 6: + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + diffusers_index += 1 + orig_index += 2 + + new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + + # down blocks + for i in range(num_input_blocks): + new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight") + new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias") + + # mid block + new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight") + new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias") + + return new_checkpoint + + +def convert_ldm_vae_checkpoint(checkpoint, config, only_decoder=False, only_encoder=False): + # extract state dict for VAE + vae_state_dict = {} + vae_key = "first_stage_model." + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + if only_decoder: + new_checkpoint = { + k: v for k, v in new_checkpoint.items() if k.startswith("decoder") or k.startswith("post_quant") + } + elif only_encoder: + new_checkpoint = {k: v for k, v in new_checkpoint.items() if k.startswith("encoder") or k.startswith("quant")} + + return new_checkpoint + + +def convert_ldm_clip_checkpoint(checkpoint): + keys = list(checkpoint.keys()) + + text_model_dict = {} + for key in keys: + if key.startswith("cond_stage_model.transformer"): + text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] + + return text_model_dict diff --git a/live2diff/animatediff/converter/convert_lora_safetensor_to_diffusers.py b/live2diff/animatediff/converter/convert_lora_safetensor_to_diffusers.py new file mode 100644 index 0000000000000000000000000000000000000000..9ebf5cae91816aad6c5991c565656e60ecea80d4 --- /dev/null +++ b/live2diff/animatediff/converter/convert_lora_safetensor_to_diffusers.py @@ -0,0 +1,101 @@ +# Modified from https://github.com/open-mmlab/PIA/blob/main/animatediff/utils/convert_lora_safetensor_to_diffusers.py and +# https://github.com/guoyww/AnimateDiff/blob/main/animatediff/utils/convert_lora_safetensor_to_diffusers.py +# Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Conversion script for the LoRA's safetensors checkpoints.""" + +import torch + + +def convert_lora_model_level( + state_dict, unet, text_encoder=None, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6 +): + """convert lora in model level instead of pipeline leval""" + + visited = [] + + # directly update weight in diffusers model + for key in state_dict: + # it is suggested to print out the key, it usually will be something like below + # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" + + # as we have set the alpha beforehand, so just skip + if ".alpha" in key or key in visited: + continue + + if "text" in key: + layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") + assert text_encoder is not None, "text_encoder must be passed since lora contains text encoder layers" + curr_layer = text_encoder + else: + layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_") + curr_layer = unet + + # find the target layer + temp_name = layer_infos.pop(0) + while len(layer_infos) > -1: + try: + curr_layer = curr_layer.__getattr__(temp_name) + if len(layer_infos) > 0: + temp_name = layer_infos.pop(0) + elif len(layer_infos) == 0: + break + except Exception: + if len(temp_name) > 0: + temp_name += "_" + layer_infos.pop(0) + else: + temp_name = layer_infos.pop(0) + + pair_keys = [] + if "lora_down" in key: + pair_keys.append(key.replace("lora_down", "lora_up")) + pair_keys.append(key) + else: + pair_keys.append(key) + pair_keys.append(key.replace("lora_up", "lora_down")) + + # update weight + # NOTE: load lycon, maybe have bugs :( + if "conv_in" in pair_keys[0]: + weight_up = state_dict[pair_keys[0]].to(torch.float32) + weight_down = state_dict[pair_keys[1]].to(torch.float32) + weight_up = weight_up.view(weight_up.size(0), -1) + weight_down = weight_down.view(weight_down.size(0), -1) + shape = list(curr_layer.weight.data.shape) + shape[1] = 4 + curr_layer.weight.data[:, :4, ...] += alpha * (weight_up @ weight_down).view(*shape) + elif "conv" in pair_keys[0]: + weight_up = state_dict[pair_keys[0]].to(torch.float32) + weight_down = state_dict[pair_keys[1]].to(torch.float32) + weight_up = weight_up.view(weight_up.size(0), -1) + weight_down = weight_down.view(weight_down.size(0), -1) + shape = list(curr_layer.weight.data.shape) + curr_layer.weight.data += alpha * (weight_up @ weight_down).view(*shape) + elif len(state_dict[pair_keys[0]].shape) == 4: + weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) + weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) + curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to( + curr_layer.weight.data.device + ) + else: + weight_up = state_dict[pair_keys[0]].to(torch.float32) + weight_down = state_dict[pair_keys[1]].to(torch.float32) + curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) + + # update visited list + for item in pair_keys: + visited.append(item) + + return unet, text_encoder diff --git a/live2diff/animatediff/models/__init__.py b/live2diff/animatediff/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/live2diff/animatediff/models/attention.py b/live2diff/animatediff/models/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..0a77dfdb18e83010ddff26fd396d6c33019e73d5 --- /dev/null +++ b/live2diff/animatediff/models/attention.py @@ -0,0 +1,648 @@ +# Adapted from https://github.com/guoyww/AnimateDiff + +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn.functional as F +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models import ModelMixin +from diffusers.models.attention import AdaLayerNorm, Attention, FeedForward +from diffusers.utils import BaseOutput +from diffusers.utils.import_utils import is_xformers_available +from einops import rearrange, repeat +from torch import nn + + +@dataclass +class Transformer3DModelOutput(BaseOutput): + sample: torch.FloatTensor + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +class Transformer3DModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + unet_use_cross_frame_attention=None, + unet_use_temporal_attention=None, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # Define input layers + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + + # Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + if use_linear_projection: + self.proj_out = nn.Linear(in_channels, inner_dim) + else: + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): + # Input + assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + if encoder_hidden_states is not None: + encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b f) n c", f=video_length) + + batch, channel, height, weight = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = self.proj_in(hidden_states) + + # Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + video_length=video_length, + ) + + # Output + if not self.use_linear_projection: + hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + + output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) + if not return_dict: + return (output,) + + return Transformer3DModelOutput(sample=output) + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + unet_use_cross_frame_attention=None, + unet_use_temporal_attention=None, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + self.use_ada_layer_norm = num_embeds_ada_norm is not None + self.unet_use_cross_frame_attention = unet_use_cross_frame_attention + self.unet_use_temporal_attention = unet_use_temporal_attention + + # SC-Attn + assert unet_use_cross_frame_attention is not None + if unet_use_cross_frame_attention: + self.attn1 = SparseCausalAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + else: + self.attn1 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + # Cross-Attn + if cross_attention_dim is not None: + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + else: + self.attn2 = None + + if cross_attention_dim is not None: + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + else: + self.norm2 = None + + # Feed-forward + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) + self.norm3 = nn.LayerNorm(dim) + + # Temp-Attn + assert unet_use_temporal_attention is not None + if unet_use_temporal_attention: + self.attn_temp = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + nn.init.zeros_(self.attn_temp.to_out[0].weight.data) + self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + def forward( + self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None + ): + # SparseCausal-Attention + norm_hidden_states = ( + self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) + ) + + # if self.only_cross_attention: + # hidden_states = ( + # self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states + # ) + # else: + # hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states + + # pdb.set_trace() + if self.unet_use_cross_frame_attention: + hidden_states = ( + self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + + hidden_states + ) + else: + hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states + + if self.attn2 is not None: + # Cross-Attention + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + hidden_states = ( + self.attn2( + norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask + ) + + hidden_states + ) + + # Feed-forward + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + + # Temporal-Attention + if self.unet_use_temporal_attention: + d = hidden_states.shape[1] + hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) + norm_hidden_states = ( + self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states) + ) + hidden_states = self.attn_temp(norm_hidden_states) + hidden_states + hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) + + return hidden_states + + +class CrossAttention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias=False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + *args, + **kwargs, + ): + super().__init__() + inner_dim = dim_head * heads + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + + self.scale = dim_head**-0.5 + + self.heads = heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + self._slice_size = None + self._use_memory_efficient_attention_xformers = False + self.added_kv_proj_dim = added_kv_proj_dim + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) + else: + self.group_norm = None + + self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(inner_dim, query_dim)) + self.to_out.append(nn.Dropout(dropout)) + + self.kv_channels = cross_attention_dim + + def set_info(self, h: int, w: int, *args, **kwargs): + """ + Useful function to pre-assign buffer for cacheable temporal-attn + """ + self.h = h + self.w = w + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def set_attention_slice(self, slice_size): + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + self._slice_size = slice_size + + @torch.no_grad() + def vis_attn_mask( + self, + query: Optional[torch.Tensor] = None, + key: Optional[torch.Tensor] = None, + attn_map: Optional[torch.Tensor] = None, + attn_bias: Optional[torch.Tensor] = None, + ): + # dtype = torch.float + dtype = torch.half + if attn_map is None: + attn_map = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=dtype, device=query.device), + query.to(dtype), + key.transpose(-1, -2).to(dtype), + beta=0, + alpha=self.scale, + ) + + if attn_bias is not None: + attn_map = attn_map + attn_bias.to(dtype) + + attn_map = attn_map.softmax(dim=-1) + + hw_head = self.h * self.w * self.heads + assert ( + attn_map.shape[0] % hw_head == 0 + ), "height-width-heads must be divisible by the first dimension of attn map. " + # NOTE: here we strict batch size is 1, + assert attn_map.shape[0] // hw_head in [1, 2], "input batch size must be 1 or 2 (for cfg)." + + if (attn_map.shape[0] // hw_head) == 2: + # NOTE: only visualize cond one + attn_map = attn_map[hw_head:] + + attn_map = attn_map.mean(0).cpu().numpy() + + # AttnMapVisualizer.visualize_attn_map(attn_map, 'f16-at-one-time-sink.png') + # exit() + + return attn_map + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) + dim = query.shape[-1] + query = self.reshape_heads_to_batch_dim(query) + + if self.added_kv_proj_dim is not None: + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj) + + key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) + else: + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + def _attention(self, query, key, value, attention_mask=None): + if self.upcast_attention: + query = query.float() + key = key.float() + + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + + # cast back to the original dtype + attention_probs = attention_probs.to(value.dtype) + + # compute attention output + hidden_states = torch.bmm(attention_probs, value) + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask): + batch_size_attention = query.shape[0] + hidden_states = torch.zeros( + (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype + ) + slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] + for i in range(hidden_states.shape[0] // slice_size): + start_idx = i * slice_size + end_idx = (i + 1) * slice_size + + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + + if self.upcast_attention: + query_slice = query_slice.float() + key_slice = key_slice.float() + + attn_slice = torch.baddbmm( + torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device), + query_slice, + key_slice.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + if attention_mask is not None: + attn_slice = attn_slice + attention_mask[start_idx:end_idx] + + if self.upcast_softmax: + attn_slice = attn_slice.float() + + attn_slice = attn_slice.softmax(dim=-1) + + # cast back to the original dtype + attn_slice = attn_slice.to(value.dtype) + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def set_use_memory_efficient_attention_xformers(self, *args, **kwargs): + print("Set Xformers for MotionModule's Attention.") + self._use_memory_efficient_attention_xformers = True + + def _memory_efficient_attention_xformers(self, query, key, value, attention_mask): + # TODO attention_mask + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def _memory_efficient_attention_pt20(self, query, key, value, attention_mask): + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + hidden_states = torch.nn.functional.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0, is_causal=False + ) + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + +class SparseCausalAttention(CrossAttention): + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) + dim = query.shape[-1] + query = self.reshape_heads_to_batch_dim(query) + + if self.added_kv_proj_dim is not None: + raise NotImplementedError + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + former_frame_index = torch.arange(video_length) - 1 + former_frame_index[0] = 0 + + key = rearrange(key, "(b f) d c -> b f d c", f=video_length) + # key = torch.cat([key[:, [0] * video_length], key[:, [0] * video_length]], dim=2) + key = key[:, [0] * video_length] + key = rearrange(key, "b f d c -> (b f) d c") + + value = rearrange(value, "(b f) d c -> b f d c", f=video_length) + # value = torch.cat([value[:, [0] * video_length], value[:, [0] * video_length]], dim=2) + # value = value[:, former_frame_index] + value = rearrange(value, "b f d c -> (b f) d c") + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + +class AttnMapVisualizer: + def __init__(self): + pass + + def set_visualizer(self, unet: nn.Module): + pass + + def add_attn_map(self): + pass + + @staticmethod + def visualize_attn_map(attn_map: torch.Tensor, save_path: str): + import numpy as np + from matplotlib import pyplot as plt + + plt.imshow(attn_map) + ax = plt.gca() + ax.set_xticks(np.arange(-0.5, attn_map.shape[0] - 1, 1)) + ax.set_yticks(np.arange(-0.5, attn_map.shape[1] - 1, 1)) + ax.set_xticklabels(np.arange(0, attn_map.shape[0], 1)) + ax.set_yticklabels(np.arange(0, attn_map.shape[1], 1)) + ax.grid(color="r", linestyle="-", linewidth=1) + plt.colorbar() + plt.savefig(save_path) + print(f"Saved to {save_path}") diff --git a/live2diff/animatediff/models/depth_utils.py b/live2diff/animatediff/models/depth_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..afa6c79fc170247573384cf10b07e6071da4c91d --- /dev/null +++ b/live2diff/animatediff/models/depth_utils.py @@ -0,0 +1,32 @@ +import torch +import torch.nn as nn + + +try: + from ...MiDaS.midas.dpt_depth import DPTDepthModel +except ImportError: + print('Please pull the MiDaS submodule via "git submodule update --init --recursive"!') + + +class MidasDetector(nn.Module): + def __init__(self, model_path="./models/dpt_hybrid-midas-501f0c75.pt"): + super().__init__() + + self.model = DPTDepthModel(path=model_path, backbone="vitb_rn50_384", non_negative=True) + self.model.requires_grad_(False) + self.model.eval() + + @property + def dtype(self): + return next(self.parameters()).dtype + + @property + def device(self): + return next(self.parameters()).device + + @torch.no_grad() + def forward(self, images: torch.Tensor): + """ + Input: [b, c, h, w] + """ + return self.model(images) diff --git a/live2diff/animatediff/models/motion_module.py b/live2diff/animatediff/models/motion_module.py new file mode 100644 index 0000000000000000000000000000000000000000..1d55825683554477d18dd32bc9d73f1cfe02e0d8 --- /dev/null +++ b/live2diff/animatediff/models/motion_module.py @@ -0,0 +1,530 @@ +# Adapted from https://github.com/guoyww/AnimateDiff +from dataclasses import dataclass + +import torch +import torch.nn.functional as F +from diffusers.models.attention import FeedForward +from diffusers.utils import BaseOutput +from diffusers.utils.import_utils import is_xformers_available +from einops import rearrange +from torch import nn + +from .attention import CrossAttention +from .positional_encoding import PositionalEncoding +from .resnet import zero_module +from .stream_motion_module import StreamTemporalAttention + + +def attn_mask_to_bias(attn_mask: torch.Tensor): + """ + Convert bool attention mask to float attention bias tensor. + """ + if attn_mask.dtype in [torch.float, torch.half]: + return attn_mask + elif attn_mask.dtype == torch.bool: + attn_bias = torch.zeros_like(attn_mask).float().masked_fill(attn_mask.logical_not(), float("-inf")) + return attn_bias + else: + raise TypeError("Only support float or bool tensor for attn_mask input. " f"But receive {type(attn_mask)}.") + + +@dataclass +class TemporalTransformer3DModelOutput(BaseOutput): + sample: torch.FloatTensor + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +def get_motion_module( + in_channels, + motion_module_type: str, + motion_module_kwargs: dict, +): + if motion_module_type == "Vanilla": + return VanillaTemporalModule( + in_channels=in_channels, + **motion_module_kwargs, + ) + elif motion_module_type == "Streaming": + return VanillaTemporalModule( + in_channels=in_channels, + enable_streaming=True, + **motion_module_kwargs, + ) + else: + raise ValueError + + +class VanillaTemporalModule(nn.Module): + def __init__( + self, + in_channels, + num_attention_heads=8, + num_transformer_block=2, + attention_block_types=("Temporal_Self", "Temporal_Self"), + cross_frame_attention_mode=None, + temporal_position_encoding=False, + temporal_position_encoding_max_len=32, + temporal_attention_dim_div=1, + # parameters for 3d conv + num_3d_conv_layers=0, + kernel_size=3, + down_up_sample=False, + zero_initialize=True, + attention_class_name="versatile", + attention_kwargs={}, + enable_streaming=False, + *args, + **kwargs, + ): + super().__init__() + + self.temporal_transformer = TemporalTransformer3DModel( + in_channels=in_channels, + num_attention_heads=num_attention_heads, + attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div, + num_layers=num_transformer_block, + attention_block_types=attention_block_types, + cross_frame_attention_mode=cross_frame_attention_mode, + temporal_position_encoding=temporal_position_encoding, + temporal_position_encoding_max_len=temporal_position_encoding_max_len, + attention_class_name=attention_class_name, + attention_kwargs=attention_kwargs, + enable_streaming=enable_streaming, + ) + + if zero_initialize: + self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out) + + self.enable_streaming = enable_streaming + + def forward(self, *args, **kwargs): + fwd_fn = self.forward_streaming if self.enable_streaming else self.forward_orig + return fwd_fn(*args, **kwargs) + + def forward_orig( + self, + input_tensor, + temb, + encoder_hidden_states, + attention_mask=None, + temporal_attention_mask=None, + kv_cache=None, + ): + hidden_states = input_tensor + hidden_states = self.temporal_transformer( + hidden_states, encoder_hidden_states, attention_mask, temporal_attention_mask, kv_cache=kv_cache + ) + + output = hidden_states + return output + + def forward_streaming( + self, + input_tensor, + temb, + encoder_hidden_states, + attention_mask=None, + temporal_attention_mask=None, + kv_cache=None, + pe_idx=None, + update_idx=None, + ): + hidden_states = input_tensor + hidden_states = self.temporal_transformer( + hidden_states, + encoder_hidden_states, + attention_mask, + temporal_attention_mask, + kv_cache=kv_cache, + pe_idx=pe_idx, + update_idx=update_idx, + ) + + output = hidden_states + return output + + +class TemporalTransformer3DModel(nn.Module): + def __init__( + self, + in_channels, + num_attention_heads, + attention_head_dim, + num_layers, + attention_block_types=( + "Temporal_Self", + "Temporal_Self", + ), + dropout=0.0, + norm_num_groups=32, + cross_attention_dim=1280, + activation_fn="geglu", + attention_bias=False, + upcast_attention=False, + cross_frame_attention_mode=None, + temporal_position_encoding=False, + temporal_position_encoding_max_len=32, + attention_class_name="versatile", + attention_kwargs={}, + enable_streaming=False, + ): + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + TemporalTransformerBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + attention_block_types=attention_block_types, + dropout=dropout, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + upcast_attention=upcast_attention, + cross_frame_attention_mode=cross_frame_attention_mode, + temporal_position_encoding=temporal_position_encoding, + temporal_position_encoding_max_len=temporal_position_encoding_max_len, + attention_class_name=attention_class_name, + attention_extra_args=attention_kwargs, + enable_streaming=enable_streaming, + ) + for d in range(num_layers) + ] + ) + self.proj_out = nn.Linear(inner_dim, in_channels) + + self.enable_streaming = enable_streaming + + def forward(self, *args, **kwargs): + fwd_fn = self.forward_streaming if self.enable_streaming else self.forward_orig + return fwd_fn(*args, **kwargs) + + def forward_orig( + self, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temporal_attention_mask=None, + kv_cache=None, + ): + assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + + batch, channel, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + hidden_states = self.proj_in(hidden_states) + + # Transformer Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + video_length=video_length, + height=height, + width=width, + temporal_attention_mask=temporal_attention_mask, + kv_cache=kv_cache, + ) + + # output + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) + + return output + + def forward_streaming( + self, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temporal_attention_mask=None, + kv_cache=None, + pe_idx=None, + update_idx=None, + ): + assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + + batch, channel, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + hidden_states = self.proj_in(hidden_states) + + # Transformer Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + video_length=video_length, + height=height, + width=width, + temporal_attention_mask=temporal_attention_mask, + kv_cache=kv_cache, + pe_idx=pe_idx, + update_idx=update_idx, + ) + + # output + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) + + return output + + +class TemporalTransformerBlock(nn.Module): + def __init__( + self, + dim, + num_attention_heads, + attention_head_dim, + attention_block_types=( + "Temporal_Self", + "Temporal_Self", + ), + dropout=0.0, + norm_num_groups=32, + cross_attention_dim=768, + activation_fn="geglu", + attention_bias=False, + upcast_attention=False, + cross_frame_attention_mode=None, + temporal_position_encoding=False, + temporal_position_encoding_max_len=32, + attention_class_name: str = "versatile", + attention_extra_args={}, + enable_streaming=False, + ): + super().__init__() + + attention_blocks = [] + norms = [] + + if attention_class_name == "versatile": + attention_cls = VersatileAttention + elif attention_class_name == "stream": + attention_cls = StreamTemporalAttention + assert enable_streaming, "StreamTemporalAttention can only used under streaming mode" + else: + raise ValueError(f"Do not support attention_cls: {attention_class_name}.") + + for block_name in attention_block_types: + attention_blocks.append( + attention_cls( + attention_mode=block_name.split("_")[0], + cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None, + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + cross_frame_attention_mode=cross_frame_attention_mode, + temporal_position_encoding=temporal_position_encoding, + temporal_position_encoding_max_len=temporal_position_encoding_max_len, + **attention_extra_args, + ) + ) + norms.append(nn.LayerNorm(dim)) + + self.attention_blocks = nn.ModuleList(attention_blocks) + self.norms = nn.ModuleList(norms) + + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) + self.ff_norm = nn.LayerNorm(dim) + + self.enable_streaming = enable_streaming + + def forward(self, *args, **kwargs): + fwd_func = self.forward_streaming if self.enable_streaming else self.forward_orig + return fwd_func(*args, **kwargs) + + def forward_orig( + self, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + video_length=None, + height=None, + width=None, + temporal_attention_mask=None, + kv_cache=None, + ): + for attention_block, norm in zip(self.attention_blocks, self.norms): + norm_hidden_states = norm(hidden_states) + kv_cache_ = kv_cache[attention_block.motion_module_idx] + hidden_states = ( + attention_block( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None, + video_length=video_length, + height=height, + width=width, + temporal_attention_mask=temporal_attention_mask, + kv_cache=kv_cache_, + ) + + hidden_states + ) + + hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states + + output = hidden_states + return output + + def forward_streaming( + self, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + video_length=None, + height=None, + width=None, + temporal_attention_mask=None, + kv_cache=None, + pe_idx=None, + update_idx=None, + ): + for attention_block, norm in zip(self.attention_blocks, self.norms): + norm_hidden_states = norm(hidden_states) + kv_cache_ = kv_cache[attention_block.motion_module_idx] + hidden_states = ( + attention_block( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None, + video_length=video_length, + height=height, + width=width, + temporal_attention_mask=temporal_attention_mask, + kv_cache=kv_cache_, + pe_idx=pe_idx, + update_idx=update_idx, + ) + + hidden_states + ) + + hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states + + output = hidden_states + return output + + +class VersatileAttention(CrossAttention): + def __init__( + self, + attention_mode=None, + cross_frame_attention_mode=None, + temporal_position_encoding=False, + temporal_position_encoding_max_len=32, + stream_cache_mode=None, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.stream_cache_mode = stream_cache_mode + self.timestep = None + + assert attention_mode in ["Temporal"] + + self.attention_mode = self._orig_attention_mode = attention_mode + self.is_cross_attention = kwargs.get("cross_attention_dim", None) is not None + + self.pos_encoder = PositionalEncoding( + kwargs["query_dim"], dropout=0.0, max_len=temporal_position_encoding_max_len + ) + + def extra_repr(self): + return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}" + + def set_index(self, idx): + self.motion_module_idx = idx + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + video_length=None, + kv_cache=None, + *args, + **kwargs, + ): + batch_size_frame, sequence_length, _ = hidden_states.shape + + d = hidden_states.shape[1] + hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + + query = self.to_q(hidden_states) + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + kv_cache[0, :, :video_length, :] = key.clone() + kv_cache[1, :, :video_length, :] = value.clone() + + pe = self.pos_encoder.pe[:, :video_length] + + pe_q = self.to_q(pe) + pe_k = self.to_k(pe) + pe_v = self.to_v(pe) + + query = query + pe_q + key = key + pe_k + value = value + pe_v + + query = self.reshape_heads_to_batch_dim(query) + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if attention_mask is not None: + attention_bias = attn_mask_to_bias(attention_mask) + if attention_bias.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_bias = F.pad(attention_mask, (0, target_length), value=float("-inf")) + attention_bias = attention_bias.repeat_interleave(self.heads, dim=0) + attention_bias = attention_bias.to(query) + else: + attention_bias = None + + hidden_states = self._memory_efficient_attention_pt20(query, key, value, attention_bias) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + + hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) + + return hidden_states diff --git a/live2diff/animatediff/models/positional_encoding.py b/live2diff/animatediff/models/positional_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..bee09b552f47f9f8ddae89c84001cf8cdc19f071 --- /dev/null +++ b/live2diff/animatediff/models/positional_encoding.py @@ -0,0 +1,41 @@ +import math +from typing import Optional + +import torch +import torch.nn as nn + + +class PositionalEncoding(nn.Module): + def __init__(self, d_model, dropout=0.0, max_len=32): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + position = torch.arange(max_len).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) + pe = torch.zeros(1, max_len, d_model) + pe[0, :, 0::2] = torch.sin(position * div_term) + pe[0, :, 1::2] = torch.cos(position * div_term) + self.register_buffer("pe", pe) + + def forward(self, x, roll: Optional[int] = None, full_video_length: Optional[int] = None): + """ + Support roll for positional encoding. + We select the first `full_video_length` elements and roll it by `roll`. + And then select the first `x.size(1)` elements and add them to `x`. + + Take full_video_length = 4, roll = 2, and x.size(1) = 1 as example. + + If the original positional encoding is: + [1, 2, 3, 4, 5, 6, 7, 8] + The rolled encoding is: + [3, 4, 1, 2] + And the selected encoding added to input is: + [3, 4] + + """ + if roll is None: + pe = self.pe[:, : x.size(1)] + else: + assert full_video_length is not None, "full_video_length must be passed when roll is not None." + pe = self.pe[:, :full_video_length].roll(shifts=roll, dims=1)[:, : x.size(1)] + x = x + pe + return self.dropout(x) diff --git a/live2diff/animatediff/models/resnet.py b/live2diff/animatediff/models/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..3dbeba4208322924d3347d193d582ef6d0938d68 --- /dev/null +++ b/live2diff/animatediff/models/resnet.py @@ -0,0 +1,264 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + + +def zero_module(module): + # Zero out the parameters of a module and return it. + for p in module.parameters(): + p.detach().zero_() + return module + + +class MappingNetwork(nn.Module): + """ + Modified from https://github.com/huggingface/diffusers/blob/196835695ed6fa3ec53b888088d9d5581e8f8e94/src/diffusers/models/controlnet.py#L66-L108 # noqa + """ + + def __init__( + self, + conditioning_embedding_channels: int, + conditioning_channels: int = 3, + block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), + ): + super().__init__() + + self.conv_in = InflatedConv3d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) + + self.blocks = nn.ModuleList([]) + + for i in range(len(block_out_channels) - 1): + channel_in = block_out_channels[i] + channel_out = block_out_channels[i + 1] + self.blocks.append(InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1)) + self.blocks.append(InflatedConv3d(channel_in, channel_out, kernel_size=3, padding=1)) + + self.conv_out = zero_module( + InflatedConv3d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) + ) + + def forward(self, conditioning): + embedding = self.conv_in(conditioning) + embedding = F.silu(embedding) + + for block in self.blocks: + embedding = block(embedding) + embedding = F.silu(embedding) + + embedding = self.conv_out(embedding) + + return embedding + + +class InflatedConv3d(nn.Conv2d): + def forward(self, x): + video_length = x.shape[2] + + x = rearrange(x, "b c f h w -> (b f) c h w") + x = super().forward(x) + x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) + + return x + + +class InflatedGroupNorm(nn.GroupNorm): + def forward(self, x): + video_length = x.shape[2] + + x = rearrange(x, "b c f h w -> (b f) c h w") + x = super().forward(x) + x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) + + return x + + +class Upsample3D(nn.Module): + def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + # conv = None + if use_conv_transpose: + raise NotImplementedError + elif use_conv: + self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) + + def forward(self, hidden_states, output_size=None): + assert hidden_states.shape[1] == self.channels + + if self.use_conv_transpose: + raise NotImplementedError + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.float32) + + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + # if `output_size` is passed we force the interpolation output + # size and do not make use of `scale_factor=2` + if output_size is None: + hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") + else: + hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(dtype) + + # if self.use_conv: + # if self.name == "conv": + # hidden_states = self.conv(hidden_states) + # else: + # hidden_states = self.Conv2d_0(hidden_states) + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class Downsample3D(nn.Module): + def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if use_conv: + self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) + else: + raise NotImplementedError + + def forward(self, hidden_states): + assert hidden_states.shape[1] == self.channels + if self.use_conv and self.padding == 0: + raise NotImplementedError + + assert hidden_states.shape[1] == self.channels + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class ResnetBlock3D(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + non_linearity="swish", + time_embedding_norm="default", + output_scale_factor=1.0, + use_in_shortcut=None, + use_inflated_groupnorm=False, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.time_embedding_norm = time_embedding_norm + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + assert use_inflated_groupnorm is not None + if use_inflated_groupnorm: + self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + else: + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + + self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if temb_channels is not None: + if self.time_embedding_norm == "default": + time_emb_proj_out_channels = out_channels + elif self.time_embedding_norm == "scale_shift": + time_emb_proj_out_channels = out_channels * 2 + else: + raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") + + self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) + else: + self.time_emb_proj = None + + if use_inflated_groupnorm: + self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + else: + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + + self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, input_tensor, temb): + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.conv1(hidden_states) + + if temb is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] + + if temb is not None and self.time_embedding_norm == "default": + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + + if temb is not None and self.time_embedding_norm == "scale_shift": + scale, shift = torch.chunk(temb, 2, dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + + +class Mish(torch.nn.Module): + def forward(self, hidden_states): + return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) diff --git a/live2diff/animatediff/models/stream_motion_module.py b/live2diff/animatediff/models/stream_motion_module.py new file mode 100644 index 0000000000000000000000000000000000000000..f91f990f89c10db9304129ee474035c0ecc3a917 --- /dev/null +++ b/live2diff/animatediff/models/stream_motion_module.py @@ -0,0 +1,213 @@ +import torch +import torch.nn.functional as F +from einops import rearrange + +from .attention import CrossAttention +from .positional_encoding import PositionalEncoding + + +class StreamTemporalAttention(CrossAttention): + """ + + * window_size: The max length of attention window. + * sink_size: The number sink token. + * positional_rule: absolute, relative + + Therefore, the seq length of temporal self-attention will be: + sink_length + cache_size + + """ + + def __init__( + self, + attention_mode=None, + cross_frame_attention_mode=None, + temporal_position_encoding=False, + temporal_position_encoding_max_len=32, + window_size=8, + sink_size=0, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.attention_mode = self._orig_attention_mode = attention_mode + self.is_cross_attention = kwargs["cross_attention_dim"] is not None + + self.pos_encoder = PositionalEncoding( + kwargs["query_dim"], + dropout=0.0, + max_len=temporal_position_encoding_max_len, + ) + + self.window_size = window_size + self.sink_size = sink_size + self.cache_size = self.window_size - self.sink_size + assert self.cache_size >= 0, ( + "cache_size must be greater or equal to 0. Please check your configuration. " + f"window_size: {window_size}, sink_size: {sink_size}, " + f"cache_size: {self.cache_size}" + ) + + self.motion_module_idx = None + + def set_index(self, idx): + self.motion_module_idx = idx + + @torch.no_grad() + def set_cache(self, denoising_steps_num: int): + """ + larger buffer index means cleaner latent + """ + device = next(self.parameters()).device + dtype = next(self.parameters()).dtype + + # [t, 2, hw, L, c], 2 means k and v + kv_cache = torch.zeros( + denoising_steps_num, + 2, + self.h * self.w, + self.window_size, + self.kv_channels, + device=device, + dtype=dtype, + ) + self.denoising_steps_num = denoising_steps_num + + return kv_cache + + @torch.no_grad() + def prepare_pe_buffer(self): + """In AnimateDiff, Temporal Self-attention use absolute positional encoding: + q = w_q * (x + pe) + bias + k = w_k * (x + pe) + bias + v = w_v * (x + pe) + bias + + If we want to conduct relative positional encoding with kv-cache, we should pre-calcute + `w_q/k/v * pe` and then cache `w_q/k/v * x + bias` + """ + + pe_list = self.pos_encoder.pe[:, : self.window_size] # [1, window_size, ch] + q_pe = F.linear(pe_list, self.to_q.weight) + k_pe = F.linear(pe_list, self.to_k.weight) + v_pe = F.linear(pe_list, self.to_v.weight) + + self.register_buffer("q_pe", q_pe) + self.register_buffer("k_pe", k_pe) + self.register_buffer("v_pe", v_pe) + + def prepare_qkv_full_and_cache(self, hidden_states, kv_cache, pe_idx, update_idx): + """ + hidden_states: [(N * bhw), F, c], + kv_cache: [2, N, hw, L, c] + + * for warmup case: `N` should be 1 and `F` should be warmup_size (`sink_size`) + * for streaming case: `N` should be `denoising_steps_num` and `F` should be `chunk_size` + + """ + q_layer = self.to_q(hidden_states) + k_layer = self.to_k(hidden_states) + v_layer = self.to_v(hidden_states) + + q_layer = rearrange(q_layer, "(n bhw) f c -> n bhw f c", n=self.denoising_steps_num) + k_layer = rearrange(k_layer, "(n bhw) f c -> n bhw f c", n=self.denoising_steps_num) + v_layer = rearrange(v_layer, "(n bhw) f c -> n bhw f c", n=self.denoising_steps_num) + + # onnx & trt friendly indexing + for idx in range(self.denoising_steps_num): + kv_cache[idx, 0, :, update_idx[idx]] = k_layer[idx, :, 0] + kv_cache[idx, 1, :, update_idx[idx]] = v_layer[idx, :, 0] + + k_full = kv_cache[:, 0] + v_full = kv_cache[:, 1] + + kv_idx = pe_idx + q_idx = torch.stack([kv_idx[idx, update_idx[idx]] for idx in range(self.denoising_steps_num)]).unsqueeze_( + 1 + ) # [timesteps, 1] + + pe_k = torch.cat( + [self.k_pe.index_select(1, kv_idx[idx]) for idx in range(self.denoising_steps_num)], dim=0 + ) # [n, window_size, c] + pe_v = torch.cat( + [self.v_pe.index_select(1, kv_idx[idx]) for idx in range(self.denoising_steps_num)], dim=0 + ) # [n, window_size, c] + pe_q = torch.cat( + [self.q_pe.index_select(1, q_idx[idx]) for idx in range(self.denoising_steps_num)], dim=0 + ) # [n, window_size, c] + + q_layer = q_layer + pe_q.unsqueeze(1) + k_full = k_full + pe_k.unsqueeze(1) + v_full = v_full + pe_v.unsqueeze(1) + + q_layer = rearrange(q_layer, "n bhw f c -> (n bhw) f c") + k_full = rearrange(k_full, "n bhw f c -> (n bhw) f c") + v_full = rearrange(v_full, "n bhw f c -> (n bhw) f c") + + return q_layer, k_full, v_full + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + video_length=None, + temporal_attention_mask=None, + kv_cache=None, + pe_idx=None, + update_idx=None, + *args, + **kwargs, + ): + """ + temporal_attention_mask: attention mask specific for the temporal self-attention. + """ + + d = hidden_states.shape[1] + hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query_layer, key_full, value_full = self.prepare_qkv_full_and_cache( + hidden_states, kv_cache, pe_idx, update_idx + ) + + # [(n * hw * b), f, c] -> [(n * hw * b * head), f, c // head] + query_layer = self.reshape_heads_to_batch_dim(query_layer) + key_full = self.reshape_heads_to_batch_dim(key_full) + value_full = self.reshape_heads_to_batch_dim(value_full) + + if temporal_attention_mask is not None: + q_size = query_layer.shape[1] + # [n, self.window_size] -> [n, hw, q_size, window_size] + temporal_attention_mask_ = temporal_attention_mask[:, None, None, :].repeat(1, self.h * self.w, q_size, 1) + temporal_attention_mask_ = rearrange(temporal_attention_mask_, "n hw Q KV -> (n hw) Q KV") + temporal_attention_mask_ = temporal_attention_mask_.repeat_interleave(self.heads, dim=0) + else: + temporal_attention_mask_ = None + + # attention, what we cannot get enough of + if hasattr(F, "scaled_dot_product_attention"): + hidden_states = self._memory_efficient_attention_pt20( + query_layer, key_full, value_full, attention_mask=temporal_attention_mask_ + ) + + elif self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers( + query_layer, key_full, value_full, attention_mask=temporal_attention_mask_ + ) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query_layer.dtype) + else: + hidden_states = self._attention(query_layer, key_full, value_full, temporal_attention_mask_) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + + hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) + + return hidden_states diff --git a/live2diff/animatediff/models/unet_blocks_streaming.py b/live2diff/animatediff/models/unet_blocks_streaming.py new file mode 100644 index 0000000000000000000000000000000000000000..6391f90ce9e1bb5cd7ee2e431419324c645b9394 --- /dev/null +++ b/live2diff/animatediff/models/unet_blocks_streaming.py @@ -0,0 +1,850 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py + +import torch +from torch import nn + +from .attention import Transformer3DModel +from .motion_module import get_motion_module +from .resnet import Downsample3D, ResnetBlock3D, Upsample3D + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + unet_use_cross_frame_attention=False, + unet_use_temporal_attention=False, + use_inflated_groupnorm=False, + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, +): + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlock3D": + return DownBlock3DStreaming( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + use_inflated_groupnorm=use_inflated_groupnorm, + use_motion_module=use_motion_module, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + elif down_block_type == "CrossAttnDownBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") + return CrossAttnDownBlock3DStreaming( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + use_inflated_groupnorm=use_inflated_groupnorm, + use_motion_module=use_motion_module, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + unet_use_cross_frame_attention=False, + unet_use_temporal_attention=False, + use_inflated_groupnorm=False, + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, +): + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlock3D": + return UpBlock3DStreaming( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + use_inflated_groupnorm=use_inflated_groupnorm, + use_motion_module=use_motion_module, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + elif up_block_type == "CrossAttnUpBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") + return CrossAttnUpBlock3DStreaming( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + use_inflated_groupnorm=use_inflated_groupnorm, + use_motion_module=use_motion_module, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock3DCrossAttnStreaming(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + upcast_attention=False, + unet_use_cross_frame_attention=False, + unet_use_temporal_attention=False, + use_inflated_groupnorm=False, + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_inflated_groupnorm=use_inflated_groupnorm, + ) + ] + attentions = [] + motion_modules = [] + + for _ in range(num_layers): + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=in_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + if use_motion_module + else None + ) + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_inflated_groupnorm=use_inflated_groupnorm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + temporal_attention_mask=None, + kv_cache=None, + pe_idx=None, + update_idx=None, + ): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet, motion_module in zip(self.attentions, self.resnets[1:], self.motion_modules): + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + if motion_module is not None: + hidden_states = motion_module( + hidden_states, + temb, + encoder_hidden_states=encoder_hidden_states, + temporal_attention_mask=temporal_attention_mask, + kv_cache=kv_cache, + pe_idx=pe_idx, + update_idx=update_idx, + ) + hidden_states = resnet(hidden_states, temb) + + # return hidden_states, kv_cache + return hidden_states + + +class CrossAttnDownBlock3DStreaming(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + unet_use_cross_frame_attention=False, + unet_use_temporal_attention=False, + use_inflated_groupnorm=False, + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + resnets = [] + attentions = [] + motion_modules = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_inflated_groupnorm=use_inflated_groupnorm, + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=out_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + if use_motion_module + else None + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + temporal_attention_mask=None, + kv_cache=None, + pe_idx=None, + update_idx=None, + ): + output_states = () + + for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + )[0] + if motion_module is not None: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(motion_module), + hidden_states.requires_grad_(), + temb, + encoder_hidden_states, + ) + + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + # add motion module + if motion_module is not None: + hidden_states = motion_module( + hidden_states, + temb, + encoder_hidden_states=encoder_hidden_states, + temporal_attention_mask=temporal_attention_mask, + kv_cache=kv_cache, + pe_idx=pe_idx, + update_idx=update_idx, + ) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + # return hidden_states, output_states, kv_cache + return hidden_states, output_states + + +class DownBlock3DStreaming(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + use_inflated_groupnorm=False, + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + resnets = [] + motion_modules = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_inflated_groupnorm=use_inflated_groupnorm, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=out_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + if use_motion_module + else None + ) + + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + temporal_attention_mask=None, + kv_cache=None, + pe_idx=None, + update_idx=None, + ): + output_states = () + + for resnet, motion_module in zip(self.resnets, self.motion_modules): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + if motion_module is not None: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(motion_module), + hidden_states.requires_grad_(), + temb, + encoder_hidden_states, + ) + else: + hidden_states = resnet(hidden_states, temb) + + # add motion module + if motion_module is not None: + hidden_states = motion_module( + hidden_states, + temb, + encoder_hidden_states=encoder_hidden_states, + temporal_attention_mask=temporal_attention_mask, + kv_cache=kv_cache, + pe_idx=pe_idx, + update_idx=update_idx, + ) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + # return hidden_states, output_states, kv_cache + return hidden_states, output_states + + +class CrossAttnUpBlock3DStreaming(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + unet_use_cross_frame_attention=False, + unet_use_temporal_attention=False, + use_inflated_groupnorm=False, + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + resnets = [] + attentions = [] + motion_modules = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_inflated_groupnorm=use_inflated_groupnorm, + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=out_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + if use_motion_module + else None + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + attention_mask=None, + temporal_attention_mask=None, + kv_cache=None, + pe_idx=None, + update_idx=None, + ): + for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + )[0] + if motion_module is not None: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(motion_module), + hidden_states.requires_grad_(), + temb, + encoder_hidden_states, + ) + + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + # add motion module + if motion_module is not None: + hidden_states = motion_module( + hidden_states, + temb, + encoder_hidden_states=encoder_hidden_states, + temporal_attention_mask=temporal_attention_mask, + kv_cache=kv_cache, + pe_idx=pe_idx, + update_idx=update_idx, + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + # return hidden_states, kv_cache + return hidden_states + + +class UpBlock3DStreaming(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + use_inflated_groupnorm=False, + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + resnets = [] + motion_modules = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_inflated_groupnorm=use_inflated_groupnorm, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=out_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + if use_motion_module + else None + ) + + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + upsample_size=None, + encoder_hidden_states=None, + temporal_attention_mask=None, + kv_cache=None, + pe_idx=None, + update_idx=None, + ): + for resnet, motion_module in zip(self.resnets, self.motion_modules): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + if motion_module is not None: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(motion_module), + hidden_states.requires_grad_(), + temb, + encoder_hidden_states, + ) + else: + hidden_states = resnet(hidden_states, temb) + if motion_module is not None: + hidden_states = motion_module( + hidden_states, + temb, + encoder_hidden_states=encoder_hidden_states, + temporal_attention_mask=temporal_attention_mask, + kv_cache=kv_cache, + pe_idx=pe_idx, + update_idx=update_idx, + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + # return hidden_states, kv_cache + return hidden_states diff --git a/live2diff/animatediff/models/unet_blocks_warmup.py b/live2diff/animatediff/models/unet_blocks_warmup.py new file mode 100644 index 0000000000000000000000000000000000000000..b031f805084e81a8d596a4a3d8f7173f303764e8 --- /dev/null +++ b/live2diff/animatediff/models/unet_blocks_warmup.py @@ -0,0 +1,833 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py + +import torch +from torch import nn + +from .attention import Transformer3DModel +from .motion_module import get_motion_module +from .resnet import Downsample3D, ResnetBlock3D, Upsample3D + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + unet_use_cross_frame_attention=False, + unet_use_temporal_attention=False, + use_inflated_groupnorm=False, + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, +): + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlock3D": + return DownBlock3DWarmup( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + use_inflated_groupnorm=use_inflated_groupnorm, + use_motion_module=use_motion_module, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + elif down_block_type == "CrossAttnDownBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") + return CrossAttnDownBlock3DWarmup( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + use_inflated_groupnorm=use_inflated_groupnorm, + use_motion_module=use_motion_module, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + unet_use_cross_frame_attention=False, + unet_use_temporal_attention=False, + use_inflated_groupnorm=False, + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, +): + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlock3D": + return UpBlock3DWarmup( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + use_inflated_groupnorm=use_inflated_groupnorm, + use_motion_module=use_motion_module, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + elif up_block_type == "CrossAttnUpBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") + return CrossAttnUpBlock3DWarmup( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + use_inflated_groupnorm=use_inflated_groupnorm, + use_motion_module=use_motion_module, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock3DCrossAttnWarmup(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + upcast_attention=False, + unet_use_cross_frame_attention=False, + unet_use_temporal_attention=False, + use_inflated_groupnorm=False, + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_inflated_groupnorm=use_inflated_groupnorm, + ) + ] + attentions = [] + motion_modules = [] + + for _ in range(num_layers): + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + # enable_cache=True, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=in_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + if use_motion_module + else None + ) + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_inflated_groupnorm=use_inflated_groupnorm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + temporal_attention_mask=None, + kv_cache=None, + ): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet, motion_module in zip(self.attentions, self.resnets[1:], self.motion_modules): + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + if motion_module is not None: + hidden_states = motion_module( + hidden_states, + temb, + encoder_hidden_states=encoder_hidden_states, + temporal_attention_mask=temporal_attention_mask, + kv_cache=kv_cache, + ) + hidden_states = resnet(hidden_states, temb) + + # return hidden_states, kv_cache + return hidden_states + + +class CrossAttnDownBlock3DWarmup(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + unet_use_cross_frame_attention=False, + unet_use_temporal_attention=False, + use_inflated_groupnorm=False, + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + resnets = [] + attentions = [] + motion_modules = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_inflated_groupnorm=use_inflated_groupnorm, + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + # enable_cache=True, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=out_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + if use_motion_module + else None + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + temporal_attention_mask=None, + kv_cache=None, + ): + output_states = () + + for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + )[0] + if motion_module is not None: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(motion_module), + hidden_states.requires_grad_(), + temb, + encoder_hidden_states, + ) + + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + # add motion module + if motion_module is not None: + hidden_states = motion_module( + hidden_states, + temb, + encoder_hidden_states=encoder_hidden_states, + temporal_attention_mask=temporal_attention_mask, + kv_cache=kv_cache, + ) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + # return hidden_states, output_states, kv_cache + return hidden_states, output_states + + +class DownBlock3DWarmup(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + use_inflated_groupnorm=False, + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + resnets = [] + motion_modules = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_inflated_groupnorm=use_inflated_groupnorm, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=out_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + if use_motion_module + else None + ) + + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + temporal_attention_mask=None, + kv_cache=None, + ): + output_states = () + + for resnet, motion_module in zip(self.resnets, self.motion_modules): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + if motion_module is not None: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(motion_module), + hidden_states.requires_grad_(), + temb, + encoder_hidden_states, + ) + else: + hidden_states = resnet(hidden_states, temb) + + # add motion module + if motion_module is not None: + hidden_states = motion_module( + hidden_states, + temb, + encoder_hidden_states=encoder_hidden_states, + temporal_attention_mask=temporal_attention_mask, + kv_cache=kv_cache, + ) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + # return hidden_states, output_states, kv_cache + return hidden_states, output_states + + +class CrossAttnUpBlock3DWarmup(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + unet_use_cross_frame_attention=False, + unet_use_temporal_attention=False, + use_inflated_groupnorm=False, + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + resnets = [] + attentions = [] + motion_modules = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_inflated_groupnorm=use_inflated_groupnorm, + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + # enable_cache=True, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=out_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + if use_motion_module + else None + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + attention_mask=None, + temporal_attention_mask=None, + kv_cache=None, + ): + for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + )[0] + if motion_module is not None: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(motion_module), + hidden_states.requires_grad_(), + temb, + encoder_hidden_states, + ) + + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + # add motion module + if motion_module is not None: + hidden_states = motion_module( + hidden_states, + temb, + encoder_hidden_states=encoder_hidden_states, + temporal_attention_mask=temporal_attention_mask, + kv_cache=kv_cache, + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + # return hidden_states, kv_cache + return hidden_states + + +class UpBlock3DWarmup(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + use_inflated_groupnorm=False, + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + resnets = [] + motion_modules = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_inflated_groupnorm=use_inflated_groupnorm, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=out_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + if use_motion_module + else None + ) + + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + upsample_size=None, + encoder_hidden_states=None, + temporal_attention_mask=None, + kv_cache=None, + ): + for resnet, motion_module in zip(self.resnets, self.motion_modules): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + if motion_module is not None: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(motion_module), + hidden_states.requires_grad_(), + temb, + encoder_hidden_states, + ) + else: + hidden_states = resnet(hidden_states, temb) + if motion_module is not None: + hidden_states = motion_module( + hidden_states, + temb, + encoder_hidden_states=encoder_hidden_states, + temporal_attention_mask=temporal_attention_mask, + kv_cache=kv_cache, + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + # return hidden_states, kv_cache + return hidden_states diff --git a/live2diff/animatediff/models/unet_depth_streaming.py b/live2diff/animatediff/models/unet_depth_streaming.py new file mode 100644 index 0000000000000000000000000000000000000000..e1c1bd5a96de069430c2227917d2dcf2dd1a9b84 --- /dev/null +++ b/live2diff/animatediff/models/unet_depth_streaming.py @@ -0,0 +1,663 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py + +import json +import os +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import UNet2DConditionLoadersMixin +from diffusers.models.attention_processor import AttentionProcessor +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import BaseOutput, logging + +from .resnet import InflatedConv3d, InflatedGroupNorm, MappingNetwork +from .unet_blocks_streaming import ( + UNetMidBlock3DCrossAttnStreaming, + get_down_block, + get_up_block, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNet3DConditionStreamingOutput(BaseOutput): + sample: torch.FloatTensor + kv_cache: List[torch.FloatTensor] + + +class UNet3DConditionStreamingModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + mid_block_type: str = "UNetMidBlock3DCrossAttn", + up_block_types: Tuple[str] = ( + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: int = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + attention_head_dim: Union[int, Tuple[int]] = 8, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + use_inflated_groupnorm=False, + # Additional + use_motion_module=False, + motion_module_resolutions=(1, 2, 4, 8), + motion_module_mid_block=False, + motion_module_decoder_only=False, + motion_module_type=None, + motion_module_kwargs={}, + unet_use_cross_frame_attention=False, + unet_use_temporal_attention=False, + cond_mapping: bool = False, + ): + super().__init__() + + self.sample_size = sample_size + time_embed_dim = block_out_channels[0] * 4 + + # input + self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) + if cond_mapping: + self.flow_conv_in = MappingNetwork( + conditioning_embedding_channels=block_out_channels[0], + conditioning_channels=in_channels, + ) + + # time + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + else: + self.class_embedding = None + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + res = 2**i + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + use_inflated_groupnorm=use_inflated_groupnorm, + use_motion_module=use_motion_module + and (res in motion_module_resolutions) + and (not motion_module_decoder_only), + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock3DCrossAttn": + self.mid_block = UNetMidBlock3DCrossAttnStreaming( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + use_inflated_groupnorm=use_inflated_groupnorm, + use_motion_module=use_motion_module and motion_module_mid_block, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the videos + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + only_cross_attention = list(reversed(only_cross_attention)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + res = 2 ** (3 - i) + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + use_inflated_groupnorm=use_inflated_groupnorm, + use_motion_module=use_motion_module and (res in motion_module_resolutions), + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if use_inflated_groupnorm: + self.conv_norm_out = InflatedGroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + else: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + self.conv_act = nn.SiLU() + self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1) + + def set_info_for_attn(self, height: int, width: int, *args, **kwargs): + """set height and width for each attention module.""" + motion_module_idx = 0 + + def _assign_info(module: nn.Module, height: int, width: int, *args, **kwargs): + nonlocal motion_module_idx + + for n, m in module.named_children(): + if hasattr(m, "set_info"): + m.set_info(height, width, *args, **kwargs) + if hasattr(m, "set_index"): + m.set_index(motion_module_idx) + motion_module_idx += 1 + else: + _assign_info(m, height, width, *args, **kwargs) + + h_scale, w_scale = height, width + for down_block in self.down_blocks: + _assign_info(down_block, h_scale, w_scale, *args, **kwargs) + if down_block.downsamplers is not None: + h_scale = h_scale // 2 + w_scale = w_scale // 2 + + _assign_info(self.mid_block, h_scale, w_scale, *args, **kwargs) + + for up_block in self.up_blocks: + _assign_info(up_block, h_scale, w_scale, *args, **kwargs) + if up_block.upsamplers is not None: + h_scale = h_scale * 2 + w_scale = w_scale * 2 + + def prepare_cache(self, denoising_steps_num: int): + """prepare cache for temporal self attention.""" + kv_cache_dict = {} # no non local, i think + + def _prepare_cache(module: nn.Module): + for n, m in module.named_children(): + if hasattr(m, "set_cache"): + kv_cache = m.set_cache(denoising_steps_num) + idx = m.motion_module_idx + kv_cache_dict[idx] = kv_cache + if hasattr(m, "prepare_pe_buffer"): + m.prepare_pe_buffer() + _prepare_cache(m) + + _prepare_cache(self) + + max_idx = max(list(kv_cache_dict.keys())) + kv_cache_list = [kv_cache_dict[idx] for idx in range(max_idx + 1)] + + return kv_cache_list + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) + + num_slicable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_slicable_layers * [1] + + slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False + ): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor, _remove_lora=_remove_lora) + else: + module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + temporal_attention_mask: torch.Tensor, + depth_sample: torch.Tensor, + kv_cache: List[torch.Tensor], + # support only update one element in kv-cache + pe_idx: torch.Tensor, + update_idx: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + # support ip-adapter + image_embeds: Optional[torch.Tensor] = None, + # support controlnet + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet3DConditionStreamingOutput, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, 1, height, width) noisy inputs tensor. + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + temporal_attention_maks: (batch, window_size, window_size) The attention mask for temporal self-attention. + depth_sample (`torch.FloatTensor`): (batch, channel, 1, height, width) depth inputs tensor. + kv_cache (`List[torch.FloatTensor]`): kv-cache for each temporal attention module. + pe_idx (`torch.FloatTensor`): The positional encoding of temporal attention module for current forward pass. + update_idx (`torch.LongTensor`): The index of kv-cache to update. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling ayears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # time + timesteps = timestep + if not torch.is_tensor(timesteps): + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + # prepare for ip-adapter + if image_embeds is not None: + image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype) + encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1) + + # pre-process + sample = self.conv_in(sample) + if depth_sample is not None: + depth_sample = self.flow_conv_in(depth_sample) + sample = depth_sample + sample + + # down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + temporal_attention_mask=temporal_attention_mask, + kv_cache=kv_cache, + pe_idx=pe_idx, + update_idx=update_idx, + ) + else: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + temporal_attention_mask=temporal_attention_mask, + kv_cache=kv_cache, + pe_idx=pe_idx, + update_idx=update_idx, + ) + + down_block_res_samples += res_samples + + # support controlnet + down_block_res_samples = list(down_block_res_samples) + if down_block_additional_residuals is not None: + for i, down_block_additional_residual in enumerate(down_block_additional_residuals): + if down_block_additional_residual.dim() == 4: # broadcast + down_block_additional_residual = down_block_additional_residual.unsqueeze(2) + down_block_res_samples[i] = down_block_res_samples[i] + down_block_additional_residual + + # mid + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + temporal_attention_mask=temporal_attention_mask, + kv_cache=kv_cache, + pe_idx=pe_idx, + update_idx=update_idx, + ) + + # support controlnet + if mid_block_additional_residual is not None: + if mid_block_additional_residual.dim() == 4: # broadcast + mid_block_additional_residual = mid_block_additional_residual.unsqueeze(2) + sample = sample + mid_block_additional_residual + + # up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + attention_mask=attention_mask, + temporal_attention_mask=temporal_attention_mask, + kv_cache=kv_cache, + pe_idx=pe_idx, + update_idx=update_idx, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + encoder_hidden_states=encoder_hidden_states, + temporal_attention_mask=temporal_attention_mask, + kv_cache=kv_cache, + pe_idx=pe_idx, + update_idx=update_idx, + ) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample, kv_cache) + + return UNet3DConditionStreamingOutput(sample=sample, kv_cache=kv_cache) + + @classmethod + def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None): + if subfolder is not None: + pretrained_model_path = os.path.join(pretrained_model_path, subfolder) + print(f"loaded 3D unet's pretrained weights from {pretrained_model_path} ...") + + config_file = os.path.join(pretrained_model_path, "config.json") + if not os.path.isfile(config_file): + raise RuntimeError(f"{config_file} does not exist") + with open(config_file, "r") as f: + config = json.load(f) + config["_class_name"] = cls.__name__ + config["down_block_types"] = [ + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ] + config["up_block_types"] = ["UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"] + + from diffusers.utils import WEIGHTS_NAME + + model = cls.from_config(config, **unet_additional_kwargs) + model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) + if not os.path.isfile(model_file): + raise RuntimeError(f"{model_file} does not exist") + state_dict = torch.load(model_file, map_location="cpu") + + m, u = model.load_state_dict(state_dict, strict=False) + print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") + + params = [p.numel() if "motion_modules." in n else 0 for n, p in model.named_parameters()] + print(f"### Motion Module Parameters: {sum(params) / 1e6} M") + + return model diff --git a/live2diff/animatediff/models/unet_depth_warmup.py b/live2diff/animatediff/models/unet_depth_warmup.py new file mode 100644 index 0000000000000000000000000000000000000000..2faa28f968e5f84cf82982b05c624012deafffc6 --- /dev/null +++ b/live2diff/animatediff/models/unet_depth_warmup.py @@ -0,0 +1,630 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py + +import json +import os +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import UNet2DConditionLoadersMixin +from diffusers.models.attention_processor import AttentionProcessor +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import BaseOutput, logging + +from .resnet import InflatedConv3d, InflatedGroupNorm, MappingNetwork +from .unet_blocks_warmup import ( + UNetMidBlock3DCrossAttnWarmup, + get_down_block, + get_up_block, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNet3DConditionOutput(BaseOutput): + sample: torch.FloatTensor + + +class UNet3DConditionWarmupModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + mid_block_type: str = "UNetMidBlock3DCrossAttn", + up_block_types: Tuple[str] = ( + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: int = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + attention_head_dim: Union[int, Tuple[int]] = 8, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + use_inflated_groupnorm=False, + # Additional + use_motion_module=False, + motion_module_resolutions=(1, 2, 4, 8), + motion_module_mid_block=False, + motion_module_decoder_only=False, + motion_module_type=None, + motion_module_kwargs={}, + unet_use_cross_frame_attention=False, + unet_use_temporal_attention=False, + cond_mapping: bool = False, + ): + super().__init__() + + self.sample_size = sample_size + time_embed_dim = block_out_channels[0] * 4 + + # input + self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) + if cond_mapping: + self.flow_conv_in = MappingNetwork( + conditioning_embedding_channels=block_out_channels[0], + conditioning_channels=in_channels, + ) + + # time + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + else: + self.class_embedding = None + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + res = 2**i + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + use_inflated_groupnorm=use_inflated_groupnorm, + use_motion_module=use_motion_module + and (res in motion_module_resolutions) + and (not motion_module_decoder_only), + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock3DCrossAttn": + self.mid_block = UNetMidBlock3DCrossAttnWarmup( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + use_inflated_groupnorm=use_inflated_groupnorm, + use_motion_module=use_motion_module and motion_module_mid_block, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the videos + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + only_cross_attention = list(reversed(only_cross_attention)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + res = 2 ** (3 - i) + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + use_inflated_groupnorm=use_inflated_groupnorm, + use_motion_module=use_motion_module and (res in motion_module_resolutions), + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if use_inflated_groupnorm: + self.conv_norm_out = InflatedGroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + else: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + self.conv_act = nn.SiLU() + self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1) + + def set_info_for_attn(self, height: int, width: int, *args, **kwargs): + """set height and width for each attention module.""" + motion_module_idx = 0 + + def _assign_info(module: nn.Module, height: int, width: int, *args, **kwargs): + nonlocal motion_module_idx + + for n, m in module.named_children(): + if hasattr(m, "set_info"): + m.set_info(height, width, *args, **kwargs) + if hasattr(m, "set_index"): + m.set_index(motion_module_idx) + motion_module_idx += 1 + else: + _assign_info(m, height, width, *args, **kwargs) + + h_scale, w_scale = height, width + for down_block in self.down_blocks: + _assign_info(down_block, h_scale, w_scale, *args, **kwargs) + if down_block.downsamplers is not None: + h_scale = h_scale // 2 + w_scale = w_scale // 2 + + _assign_info(self.mid_block, h_scale, w_scale, *args, **kwargs) + + for up_block in self.up_blocks: + _assign_info(up_block, h_scale, w_scale, *args, **kwargs) + if up_block.upsamplers is not None: + h_scale = h_scale * 2 + w_scale = w_scale * 2 + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) + + num_slicable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_slicable_layers * [1] + + slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False + ): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor, _remove_lora=_remove_lora) + else: + module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + temporal_attention_mask: torch.Tensor, + depth_sample: torch.Tensor, + kv_cache: List[torch.Tensor], + class_labels: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + # support ip-adapter + image_embeds: Optional[torch.Tensor] = None, + # support controlnet + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet3DConditionOutput, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, 1, height, width) noisy inputs tensor. + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + temporal_attention_maks: (batch, window_size, window_size) The attention mask for temporal self-attention. + depth_sample (`torch.FloatTensor`): (batch, channel, 1, height, width) depth inputs tensor. + kv_cache (`List[torch.FloatTensor]`): kv-cache for each temporal attention module. kv feature for warmup frames will be filled in warmup stage forward. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling ayears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # time + timesteps = timestep + if not torch.is_tensor(timesteps): + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + # prepare for ip-adapter + if image_embeds is not None: + image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype) + encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1) + + # pre-process + sample = self.conv_in(sample) + if depth_sample is not None: + depth_sample = self.flow_conv_in(depth_sample) + sample = depth_sample + sample + + # down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + temporal_attention_mask=temporal_attention_mask, + kv_cache=kv_cache, + ) + else: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + temporal_attention_mask=temporal_attention_mask, + kv_cache=kv_cache, + ) + + down_block_res_samples += res_samples + + # support controlnet + down_block_res_samples = list(down_block_res_samples) + if down_block_additional_residuals is not None: + for i, down_block_additional_residual in enumerate(down_block_additional_residuals): + if down_block_additional_residual.dim() == 4: # broadcast + down_block_additional_residual = down_block_additional_residual.unsqueeze(2) + down_block_res_samples[i] = down_block_res_samples[i] + down_block_additional_residual + + # mid + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + temporal_attention_mask=temporal_attention_mask, + kv_cache=kv_cache, + ) + + # support controlnet + if mid_block_additional_residual is not None: + if mid_block_additional_residual.dim() == 4: # broadcast + mid_block_additional_residual = mid_block_additional_residual.unsqueeze(2) + sample = sample + mid_block_additional_residual + + # up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + attention_mask=attention_mask, + temporal_attention_mask=temporal_attention_mask, + kv_cache=kv_cache, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + encoder_hidden_states=encoder_hidden_states, + temporal_attention_mask=temporal_attention_mask, + kv_cache=kv_cache, + ) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNet3DConditionOutput(sample=sample) + + @classmethod + def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None): + if subfolder is not None: + pretrained_model_path = os.path.join(pretrained_model_path, subfolder) + print(f"loaded 3D unet's pretrained weights from {pretrained_model_path} ...") + + config_file = os.path.join(pretrained_model_path, "config.json") + if not os.path.isfile(config_file): + raise RuntimeError(f"{config_file} does not exist") + with open(config_file, "r") as f: + config = json.load(f) + config["_class_name"] = cls.__name__ + config["down_block_types"] = [ + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ] + config["up_block_types"] = ["UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"] + # use normal temporal attention + unet_additional_kwargs["motion_module_type"] = "Vanilla" + unet_additional_kwargs["motion_module_kwargs"]["attention_class_name"] = "versatile" + unet_additional_kwargs["motion_module_kwargs"]["attention_kwargs"] = {} + + from diffusers.utils import WEIGHTS_NAME + + model = cls.from_config(config, **unet_additional_kwargs) + model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) + if not os.path.isfile(model_file): + raise RuntimeError(f"{model_file} does not exist") + state_dict = torch.load(model_file, map_location="cpu") + + m, u = model.load_state_dict(state_dict, strict=False) + print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") + + params = [p.numel() if "motion_modules." in n else 0 for n, p in model.named_parameters()] + print(f"### Motion Module Parameters: {sum(params) / 1e6} M") + + return model diff --git a/live2diff/animatediff/pipeline/__init__.py b/live2diff/animatediff/pipeline/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bd648e00bbbb73932e99ed86aab6413deb12f9b --- /dev/null +++ b/live2diff/animatediff/pipeline/__init__.py @@ -0,0 +1,4 @@ +from .pipeline_animatediff_depth import AnimationDepthPipeline + + +__all__ = ["AnimationDepthPipeline"] diff --git a/live2diff/animatediff/pipeline/loader.py b/live2diff/animatediff/pipeline/loader.py new file mode 100644 index 0000000000000000000000000000000000000000..d8b765067b60ed24628a86508be0c280f196c09b --- /dev/null +++ b/live2diff/animatediff/pipeline/loader.py @@ -0,0 +1,68 @@ +from typing import Dict, List, Optional, Union + +import torch +from diffusers.loaders.lora import LoraLoaderMixin +from diffusers.models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT +from diffusers.utils import USE_PEFT_BACKEND + + +class LoraLoaderWithWarmup(LoraLoaderMixin): + unet_warmup_name = "unet_warmup" + + def load_lora_weights( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name=None, + **kwargs, + ): + # load lora for text encoder and unet-streaming + super().load_lora_weights(pretrained_model_name_or_path_or_dict, adapter_name=adapter_name, **kwargs) + + # load lora for unet-warmup + state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + + self.load_lora_into_unet( + state_dict, + network_alphas=network_alphas, + unet=getattr(self, self.unet_warmup_name) if not hasattr(self, "unet_warmup") else self.unet_warmup, + low_cpu_mem_usage=low_cpu_mem_usage, + adapter_name=adapter_name, + _pipeline=self, + ) + + def fuse_lora( + self, + fuse_unet: bool = True, + fuse_text_encoder: bool = True, + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + ): + # fuse lora for text encoder and unet-streaming + super().fuse_lora(fuse_unet, fuse_text_encoder, lora_scale, safe_fusing, adapter_names) + + # fuse lora for unet-warmup + if fuse_unet: + unet_warmup = ( + getattr(self, self.unet_warmup_name) if not hasattr(self, "unet_warmup") else self.unet_warmup + ) + unet_warmup.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names) + + def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True): + # unfuse lora for text encoder and unet-streaming + super().unfuse_lora(unfuse_unet, unfuse_text_encoder) + + # unfuse lora for unet-warmup + if unfuse_unet: + unet_warmup = ( + getattr(self, self.unet_warmup_name) if not hasattr(self, "unet_warmup") else self.unet_warmup + ) + if not USE_PEFT_BACKEND: + unet_warmup.unfuse_lora() + else: + from peft.tuners.tuners_utils import BaseTunerLayer + + for module in unet_warmup.modules(): + if isinstance(module, BaseTunerLayer): + module.unmerge() diff --git a/live2diff/animatediff/pipeline/pipeline_animatediff_depth.py b/live2diff/animatediff/pipeline/pipeline_animatediff_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..b06c41cf768343468d1bbf756f7884170b901d3f --- /dev/null +++ b/live2diff/animatediff/pipeline/pipeline_animatediff_depth.py @@ -0,0 +1,350 @@ +# Adapted from https://github.com/open-mmlab/PIA/blob/main/animatediff/pipelines/i2v_pipeline.py + +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import torch +from diffusers.configuration_utils import FrozenDict +from diffusers.loaders import TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from diffusers.utils import BaseOutput, deprecate, is_accelerate_available, logging +from packaging import version +from transformers import CLIPTextModel, CLIPTokenizer + +from ..models.depth_utils import MidasDetector +from ..models.unet_depth_streaming import UNet3DConditionStreamingModel +from .loader import LoraLoaderWithWarmup + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class AnimationPipelineOutput(BaseOutput): + videos: Union[torch.Tensor, np.ndarray] + input_images: Optional[Union[torch.Tensor, np.ndarray]] = None + + +class AnimationDepthPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderWithWarmup): + _optional_components = [] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet3DConditionStreamingModel, + depth_model: MidasDetector, + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + depth_model=depth_model, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.log_denoising_mean = False + + def enable_vae_slicing(self): + self.vae.enable_slicing() + + def disable_vae_slicing(self): + self.vae.disable_slicing() + + def enable_sequential_cpu_offload(self, gpu_id=0): + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + @property + def _execution_device(self): + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt( + self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt, clip_skip=None + ): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = text_embeddings[0] + else: + # support ckip skip here, suitable for model based on NAI~ + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + output_hidden_states=True, + ) + text_embeddings = text_embeddings[-1][-(clip_skip + 1)] + text_embeddings = self.text_encoder.text_model.final_layer_norm(text_embeddings) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = uncond_embeddings[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + @classmethod + def build_pipeline(cls, config_path: str, dreambooth: Optional[str] = None): + """We build pipeline from config path""" + from omegaconf import OmegaConf + + from ...utils.config import load_config + from ..converter import load_third_party_checkpoints + from ..models.unet_depth_streaming import UNet3DConditionStreamingModel + + cfg = load_config(config_path) + pretrained_model_path = cfg.pretrained_model_path + unet_additional_kwargs = cfg.get("unet_additional_kwargs", {}) + noise_scheduler_kwargs = cfg.noise_scheduler_kwargs + third_party_dict = cfg.get("third_party_dict", {}) + + noise_scheduler = DDIMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs)) + + vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") + tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") + text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") + unet = UNet3DConditionStreamingModel.from_pretrained_2d( + pretrained_model_path, + subfolder="unet", + unet_additional_kwargs=OmegaConf.to_container(unet_additional_kwargs) if unet_additional_kwargs else {}, + ) + + motion_module_path = cfg.motion_module_path + # load motion module to unet + mm_checkpoint = torch.load(motion_module_path, map_location="cpu") + if "global_step" in mm_checkpoint: + print(f"global_step: {mm_checkpoint['global_step']}") + state_dict = mm_checkpoint["state_dict"] if "state_dict" in mm_checkpoint else mm_checkpoint + # NOTE: hard code here: remove `grid` from state_dict + state_dict = {k.replace("module.", ""): v for k, v in state_dict.items() if "grid" not in k} + + m, u = unet.load_state_dict(state_dict, strict=False) + assert len(u) == 0, f"Find unexpected keys ({len(u)}): {u}" + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + + unet = unet.to(device="cuda", dtype=torch.float16) + vae = vae.to(device="cuda", dtype=torch.bfloat16) + text_encoder = text_encoder.to(device="cuda", dtype=torch.float16) + depth_model = MidasDetector(cfg.depth_model_path).to(device="cuda", dtype=torch.float16) + + pipeline = cls( + unet=unet, + vae=vae, + tokenizer=tokenizer, + text_encoder=text_encoder, + depth_model=depth_model, + scheduler=noise_scheduler, + ) + pipeline = load_third_party_checkpoints(pipeline, third_party_dict, dreambooth) + + return pipeline + + @classmethod + def build_warmup_unet(cls, config_path: str, dreambooth: Optional[str] = None): + from omegaconf import OmegaConf + + from ...utils.config import load_config + from ..converter import load_third_party_unet + from ..models.unet_depth_warmup import UNet3DConditionWarmupModel + + cfg = load_config(config_path) + pretrained_model_path = cfg.pretrained_model_path + unet_additional_kwargs = cfg.get("unet_additional_kwargs", {}) + third_party_dict = cfg.get("third_party_dict", {}) + + unet = UNet3DConditionWarmupModel.from_pretrained_2d( + pretrained_model_path, + subfolder="unet", + unet_additional_kwargs=OmegaConf.to_container(unet_additional_kwargs) if unet_additional_kwargs else {}, + ) + motion_module_path = cfg.motion_module_path + # load motion module to unet + mm_checkpoint = torch.load(motion_module_path, map_location="cpu") + if "global_step" in mm_checkpoint: + print(f"global_step: {mm_checkpoint['global_step']}") + state_dict = mm_checkpoint["state_dict"] if "state_dict" in mm_checkpoint else mm_checkpoint + # NOTE: hard code here: remove `grid` from state_dict + state_dict = {k.replace("module.", ""): v for k, v in state_dict.items() if "grid" not in k} + + m, u = unet.load_state_dict(state_dict, strict=False) + assert len(u) == 0, f"Find unexpected keys ({len(u)}): {u}" + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + + unet = load_third_party_unet(unet, third_party_dict, dreambooth) + return unet + + def prepare_cache(self, height: int, width: int, denoising_steps_num: int): + vae = self.vae + scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) + self.unet.set_info_for_attn(height // scale_factor, width // scale_factor) + kv_cache_list = self.unet.prepare_cache(denoising_steps_num) + return kv_cache_list + + def prepare_warmup_unet(self, height: int, width: int, unet): + vae = self.vae + scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) + unet.set_info_for_attn(height // scale_factor, width // scale_factor) diff --git a/live2diff/image_filter.py b/live2diff/image_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..0a507474dee884a6ca7d0b1b3537d5ac00a88e30 --- /dev/null +++ b/live2diff/image_filter.py @@ -0,0 +1,45 @@ +import random +from typing import Optional + +import torch + + +class SimilarImageFilter: + def __init__(self, threshold: float = 0.98, max_skip_frame: float = 10) -> None: + self.threshold = threshold + self.prev_tensor = None + self.cos = torch.nn.CosineSimilarity(dim=0, eps=1e-6) + self.max_skip_frame = max_skip_frame + self.skip_count = 0 + + def __call__(self, x: torch.Tensor) -> Optional[torch.Tensor]: + if self.prev_tensor is None: + self.prev_tensor = x.detach().clone() + return x + else: + cos_sim = self.cos(self.prev_tensor.reshape(-1), x.reshape(-1)).item() + sample = random.uniform(0, 1) + if self.threshold >= 1: + skip_prob = 0 + else: + skip_prob = max(0, 1 - (1 - cos_sim) / (1 - self.threshold)) + + # not skip frame + if skip_prob < sample: + self.prev_tensor = x.detach().clone() + return x + # skip frame + else: + if self.skip_count > self.max_skip_frame: + self.skip_count = 0 + self.prev_tensor = x.detach().clone() + return x + else: + self.skip_count += 1 + return None + + def set_threshold(self, threshold: float) -> None: + self.threshold = threshold + + def set_max_skip_frame(self, max_skip_frame: float) -> None: + self.max_skip_frame = max_skip_frame diff --git a/live2diff/image_utils.py b/live2diff/image_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..77d7275c23e6c156985c41a8a4934a01d2bd0f65 --- /dev/null +++ b/live2diff/image_utils.py @@ -0,0 +1,89 @@ +from typing import List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torchvision + + +def denormalize(images: Union[torch.Tensor, np.ndarray]) -> torch.Tensor: + """ + Denormalize an image array to [0,1]. + """ + return (images / 2 + 0.5).clamp(0, 1) + + +def pt_to_numpy(images: torch.Tensor) -> np.ndarray: + """ + Convert a PyTorch tensor to a NumPy image. + """ + images = images.cpu().permute(0, 2, 3, 1).float().numpy() + return images + + +def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image: + """ + Convert a NumPy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + if images.shape[-1] == 1: + # special case for grayscale (single channel) images + pil_images = [PIL.Image.fromarray(image.squeeze(), mode="L") for image in images] + else: + pil_images = [PIL.Image.fromarray(image) for image in images] + + return pil_images + + +def postprocess_image( + image: torch.Tensor, + output_type: str = "pil", + do_denormalize: Optional[List[bool]] = None, +) -> Union[torch.Tensor, np.ndarray, PIL.Image.Image]: + if not isinstance(image, torch.Tensor): + raise ValueError( + f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor" + ) + + if output_type == "latent": + return image + + do_normalize_flg = True + if do_denormalize is None: + do_denormalize = [do_normalize_flg] * image.shape[0] + + image = torch.stack([denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]) + + if output_type == "pt": + return image + + image = pt_to_numpy(image) + + if output_type == "np": + return image + + if output_type == "pil": + return numpy_to_pil(image) + + +def process_image( + image_pil: PIL.Image.Image, range: Tuple[int, int] = (-1, 1) +) -> Tuple[torch.Tensor, PIL.Image.Image]: + image = torchvision.transforms.ToTensor()(image_pil) + r_min, r_max = range[0], range[1] + image = image * (r_max - r_min) + r_min + return image[None, ...], image_pil + + +def pil2tensor(image_pil: PIL.Image.Image) -> torch.Tensor: + height = image_pil.height + width = image_pil.width + imgs = [] + img, _ = process_image(image_pil) + imgs.append(img) + imgs = torch.vstack(imgs) + images = torch.nn.functional.interpolate(imgs, size=(height, width), mode="bilinear") + image_tensors = images.to(torch.float16) + return image_tensors diff --git a/live2diff/pipeline_stream_animation_depth.py b/live2diff/pipeline_stream_animation_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..57e62274e16afccd652142191446aff0df4a9972 --- /dev/null +++ b/live2diff/pipeline_stream_animation_depth.py @@ -0,0 +1,666 @@ +import time +from typing import Any, Dict, List, Literal, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from diffusers import LCMScheduler +from diffusers.image_processor import VaeImageProcessor +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import ( + retrieve_latents, +) +from einops import rearrange + +from live2diff.image_filter import SimilarImageFilter + +from .animatediff.pipeline import AnimationDepthPipeline + + +WARMUP_FRAMES = 8 +WINDOW_SIZE = 16 + + +class StreamAnimateDiffusionDepth: + def __init__( + self, + pipe: AnimationDepthPipeline, + num_inference_steps: int, + t_index_list: Optional[List[int]] = None, + strength: Optional[float] = None, + torch_dtype: torch.dtype = torch.float16, + width: int = 512, + height: int = 512, + do_add_noise: bool = True, + use_denoising_batch: bool = True, + frame_buffer_size: int = 1, + clip_skip: int = 1, + cfg_type: Literal["none", "full", "self", "initialize"] = "none", + ) -> None: + self.device = pipe.device + self.dtype = torch_dtype + self.generator = None + + self.height = height + self.width = width + + self.pipe = pipe + + self.latent_height = int(height // pipe.vae_scale_factor) + self.latent_width = int(width // pipe.vae_scale_factor) + + self.clip_skip = clip_skip + + self.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config) + self.scheduler.set_timesteps(num_inference_steps, self.device) + if strength is not None: + t_index_list, timesteps = self.get_timesteps(num_inference_steps, strength, self.device) + print( + f"Generate t_index_list: {t_index_list} via " + f"num_inference_steps: {num_inference_steps}, strength: {strength}" + ) + self.timesteps = timesteps + else: + print( + f"t_index_list is passed: {t_index_list}. " + f"Number Inference Steps: {num_inference_steps}, " + f"equivalents to strength {1 - t_index_list[0] / num_inference_steps}." + ) + self.timesteps = self.scheduler.timesteps.to(self.device) + + self.frame_bff_size = frame_buffer_size + self.denoising_steps_num = len(t_index_list) + self.strength = strength + + assert cfg_type == "none", f'cfg_type must be "none" for now, but got {cfg_type}.' + self.cfg_type = cfg_type + + if use_denoising_batch: + self.batch_size = self.denoising_steps_num * frame_buffer_size + if self.cfg_type == "initialize": + self.trt_unet_batch_size = (self.denoising_steps_num + 1) * self.frame_bff_size + elif self.cfg_type == "full": + self.trt_unet_batch_size = 2 * self.denoising_steps_num * self.frame_bff_size + else: + self.trt_unet_batch_size = self.denoising_steps_num * frame_buffer_size + else: + self.trt_unet_batch_size = self.frame_bff_size + self.batch_size = frame_buffer_size + + self.t_list = t_index_list + + self.do_add_noise = do_add_noise + self.use_denoising_batch = use_denoising_batch + + self.similar_image_filter = False + self.similar_filter = SimilarImageFilter() + self.prev_image_result = None + + self.image_processor = VaeImageProcessor(pipe.vae_scale_factor) + + self.text_encoder = pipe.text_encoder + self.unet = pipe.unet + self.vae = pipe.vae + + self.depth_detector = pipe.depth_model + self.inference_time_ema = 0 + self.depth_time_ema = 0 + self.inference_time_list = [] + self.depth_time_list = [] + self.mask_shift = 1 + + self.is_tensorrt = False + + def prepare_cache(self, height, width, denoising_steps_num): + kv_cache_list = self.pipe.prepare_cache( + height=height, + width=width, + denoising_steps_num=denoising_steps_num, + ) + self.pipe.prepare_warmup_unet(height=height, width=width, unet=self.unet_warmup) + self.kv_cache_list = kv_cache_list + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start:].to(device) + t_index = list(range(len(timesteps))) + + return t_index, timesteps + + def load_lora( + self, + pretrained_lora_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[Any] = None, + **kwargs, + ) -> None: + self.pipe.load_lora_weights( + pretrained_lora_model_name_or_path_or_dict, + adapter_name, + **kwargs, + ) + + def fuse_lora( + self, + fuse_unet: bool = True, + fuse_text_encoder: bool = True, + lora_scale: float = 1.0, + safe_fusing: bool = False, + ) -> None: + self.pipe.fuse_lora( + fuse_unet=fuse_unet, + fuse_text_encoder=fuse_text_encoder, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + ) + + def enable_similar_image_filter( + self, + threshold: float = 0.98, + max_skip_frame: float = 10, + ) -> None: + self.similar_image_filter = True + self.similar_filter.set_threshold(threshold) + self.similar_filter.set_max_skip_frame(max_skip_frame) + + def disable_similar_image_filter(self) -> None: + self.similar_image_filter = False + + @torch.no_grad() + def prepare( + self, + warmup_frames: torch.Tensor, + prompt: str, + negative_prompt: str = "", + guidance_scale: float = 1.2, + delta: float = 1.0, + generator: Optional[torch.Generator] = None, + seed: int = 2, + ) -> None: + """ + Forward warm-up frames and fill the buffer + images: [warmup_size, 3, h, w] in [0, 1] + """ + + if generator is None: + self.generator = torch.Generator(device=self.device) + self.generator.manual_seed(seed) + else: + self.generator = generator + # initialize x_t_latent (it can be any random tensor) + if self.denoising_steps_num > 1: + self.x_t_latent_buffer = torch.zeros( + ( + (self.denoising_steps_num - 1) * self.frame_bff_size, + 4, + 1, # for video + self.latent_height, + self.latent_width, + ), + dtype=self.dtype, + device=self.device, + ) + + self.depth_latent_buffer = torch.zeros_like(self.x_t_latent_buffer) + else: + self.x_t_latent_buffer = None + self.depth_latent_buffer = None + + self.attn_bias, self.pe_idx, self.update_idx = self.initialize_attn_bias_pe_and_update_idx() + + if self.cfg_type == "none": + self.guidance_scale = 1.0 + else: + self.guidance_scale = guidance_scale + self.delta = delta + + do_classifier_free_guidance = False + if self.guidance_scale > 1.0: + do_classifier_free_guidance = True + + encoder_output = self.pipe._encode_prompt( + prompt=prompt, + device=self.device, + num_videos_per_prompt=1, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + clip_skip=self.clip_skip, + ) + self.prompt_embeds = encoder_output[0].repeat(self.batch_size, 1, 1) + + if self.use_denoising_batch and self.cfg_type == "full": + uncond_prompt_embeds = encoder_output[1].repeat(self.batch_size, 1, 1) + elif self.cfg_type == "initialize": + uncond_prompt_embeds = encoder_output[1].repeat(self.frame_bff_size, 1, 1) + + if self.guidance_scale > 1.0 and (self.cfg_type == "initialize" or self.cfg_type == "full"): + self.prompt_embeds = torch.cat([uncond_prompt_embeds, self.prompt_embeds], dim=0) + + # make sub timesteps list based on the indices in the t_list list and the values in the timesteps list + self.sub_timesteps = [] + for t in self.t_list: + self.sub_timesteps.append(self.timesteps[t]) + + sub_timesteps_tensor = torch.tensor(self.sub_timesteps, dtype=torch.long, device=self.device) + self.sub_timesteps_tensor = torch.repeat_interleave( + sub_timesteps_tensor, + repeats=self.frame_bff_size if self.use_denoising_batch else 1, + dim=0, + ) + + self.init_noise = torch.randn( + (self.batch_size, 4, WARMUP_FRAMES, self.latent_height, self.latent_width), + generator=generator, + ).to(device=self.device, dtype=self.dtype) + + self.stock_noise = torch.zeros_like(self.init_noise) + + c_skip_list = [] + c_out_list = [] + for timestep in self.sub_timesteps: + c_skip, c_out = self.scheduler.get_scalings_for_boundary_condition_discrete(timestep) + c_skip_list.append(c_skip) + c_out_list.append(c_out) + + self.c_skip = ( + torch.stack(c_skip_list).view(len(self.t_list), 1, 1, 1, 1).to(dtype=self.dtype, device=self.device) + ) + self.c_out = ( + torch.stack(c_out_list).view(len(self.t_list), 1, 1, 1, 1).to(dtype=self.dtype, device=self.device) + ) + # print(self.c_skip) + + alpha_prod_t_sqrt_list = [] + beta_prod_t_sqrt_list = [] + for timestep in self.sub_timesteps: + alpha_prod_t_sqrt = self.scheduler.alphas_cumprod[timestep].sqrt() + beta_prod_t_sqrt = (1 - self.scheduler.alphas_cumprod[timestep]).sqrt() + alpha_prod_t_sqrt_list.append(alpha_prod_t_sqrt) + beta_prod_t_sqrt_list.append(beta_prod_t_sqrt) + alpha_prod_t_sqrt = ( + torch.stack(alpha_prod_t_sqrt_list) + .view(len(self.t_list), 1, 1, 1, 1) + .to(dtype=self.dtype, device=self.device) + ) + beta_prod_t_sqrt = ( + torch.stack(beta_prod_t_sqrt_list) + .view(len(self.t_list), 1, 1, 1, 1) + .to(dtype=self.dtype, device=self.device) + ) + self.alpha_prod_t_sqrt = torch.repeat_interleave( + alpha_prod_t_sqrt, + repeats=self.frame_bff_size if self.use_denoising_batch else 1, + dim=0, + ) + self.beta_prod_t_sqrt = torch.repeat_interleave( + beta_prod_t_sqrt, + repeats=self.frame_bff_size if self.use_denoising_batch else 1, + dim=0, + ) + # do warmup + # 1. encode images + warmup_x_list = [] + for f in warmup_frames: + x = self.image_processor.preprocess(f, self.height, self.width) + warmup_x_list.append(x.to(device=self.device, dtype=self.dtype)) + warmup_x = torch.cat(warmup_x_list, dim=0) # [warmup_size, c, h, w] + warmup_x_t = self.encode_image(warmup_x) + x_t_latent = rearrange(warmup_x_t, "f c h w -> c f h w")[None, ...] + depth_latent = self.encode_depth(warmup_x) + depth_latent = rearrange(depth_latent, "f c h w -> c f h w")[None, ...] + + # 2. run warmup denoising + self.unet_warmup = self.unet_warmup.to(device="cuda", dtype=self.dtype) + warmup_prompt = self.prompt_embeds[0:1] + for idx, t in enumerate(self.sub_timesteps_tensor): + t = t.view(1).repeat(x_t_latent.shape[0]) + + output_t = self.unet_warmup( + x_t_latent, + t, + temporal_attention_mask=None, + depth_sample=depth_latent, + encoder_hidden_states=warmup_prompt, + kv_cache=[cache[idx] for cache in self.kv_cache_list], + return_dict=True, + ) + model_pred = output_t["sample"] + x_0_pred = self.scheduler_step_batch(model_pred, x_t_latent, idx) + if idx < len(self.sub_timesteps_tensor) - 1: + # x_t_latent = self.alpha_prod_t_sqrt[idx + 1] * x_0_pred + + x_t_latent = self.alpha_prod_t_sqrt[idx + 1] * x_0_pred + self.beta_prod_t_sqrt[ + idx + 1 + ] * torch.randn_like(x_0_pred, device=self.device, dtype=self.dtype) + + self.unet_warmup = self.unet_warmup.to(device="cpu") + x_0_pred = rearrange(x_0_pred, "b c f h w -> b f c h w")[0] # [f, c, h, w] + denoisied_frame = self.decode_image(x_0_pred) + + self.warmup_engine() + + return denoisied_frame + + def warmup_engine(self): + """Warmup tensorrt engine.""" + + if not self.is_tensorrt: + return + + print("Warmup TensorRT engine.") + pseudo_latent = self.init_noise[:, :, 0:1, ...] + for _ in range(self.batch_size): + self.unet( + pseudo_latent, + self.sub_timesteps_tensor, + depth_sample=pseudo_latent, + encoder_hidden_states=self.prompt_embeds, + temporal_attention_mask=self.attn_bias, + kv_cache=self.kv_cache_list, + pe_idx=self.pe_idx, + update_idx=self.update_idx, + return_dict=True, + ) + print("Warmup TensorRT engine finished.") + + @torch.no_grad() + def update_prompt(self, prompt: str) -> None: + encoder_output = self.pipe._encode_prompt( + prompt=prompt, + device=self.device, + num_images_per_prompt=1, + do_classifier_free_guidance=False, + ) + self.prompt_embeds = encoder_output[0].repeat(self.batch_size, 1, 1) + + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + t_index: int, + ) -> torch.Tensor: + noisy_samples = self.alpha_prod_t_sqrt[t_index] * original_samples + self.beta_prod_t_sqrt[t_index] * noise + return noisy_samples + + def scheduler_step_batch( + self, + model_pred_batch: torch.Tensor, + x_t_latent_batch: torch.Tensor, + idx: Optional[int] = None, + ) -> torch.Tensor: + # TODO: use t_list to select beta_prod_t_sqrt + if idx is None: + F_theta = (x_t_latent_batch - self.beta_prod_t_sqrt * model_pred_batch) / self.alpha_prod_t_sqrt + denoised_batch = self.c_out * F_theta + self.c_skip * x_t_latent_batch + else: + F_theta = (x_t_latent_batch - self.beta_prod_t_sqrt[idx] * model_pred_batch) / self.alpha_prod_t_sqrt[idx] + denoised_batch = self.c_out[idx] * F_theta + self.c_skip[idx] * x_t_latent_batch + + return denoised_batch + + def initialize_attn_bias_pe_and_update_idx(self): + attn_mask = torch.zeros((self.denoising_steps_num, WINDOW_SIZE), dtype=torch.bool, device=self.device) + attn_mask[:, :WARMUP_FRAMES] = True + attn_mask[0, WARMUP_FRAMES] = True + attn_bias = torch.zeros_like(attn_mask, dtype=self.dtype) + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + + pe_idx = torch.arange(WINDOW_SIZE).unsqueeze(0).repeat(self.denoising_steps_num, 1).cuda() + update_idx = torch.ones(self.denoising_steps_num, dtype=torch.int64, device=self.device) * WARMUP_FRAMES + update_idx[1] = WARMUP_FRAMES + 1 + + return attn_bias, pe_idx, update_idx + + def update_attn_bias(self, attn_bias, pe_idx, update_idx): + """ + attn_bias: (timesteps, prev_len), init value: [[0, 0, 0, inf], [0, 0, inf, inf]] + pe_idx: (timesteps, prev_len), init value: [[0, 1, 2, 3], [0, 1, 2, 3]] + update_idx: (timesteps, ), init value: [2, 1] + """ + + for idx in range(self.denoising_steps_num): + # update pe_idx and update_idx based on attn_bias from last iteration + if torch.isinf(attn_bias[idx]).any(): + # some position not filled, do not change pe + # some position not filled, fill the last position + update_idx[idx] = (attn_bias[idx] == 0).sum() + else: + # all position are filled, roll pe + pe_idx[idx, WARMUP_FRAMES:] = pe_idx[idx, WARMUP_FRAMES:].roll(shifts=1, dims=0) + # all position are filled, fill the position with largest PE + update_idx[idx] = pe_idx[idx].argmax() + + num_unmask = (attn_bias[idx] == 0).sum() + attn_bias[idx, : min(num_unmask + 1, WINDOW_SIZE)] = 0 + + return attn_bias, pe_idx, update_idx + + def unet_step( + self, + x_t_latent: torch.Tensor, + depth_latent: torch.Tensor, + t_list: Union[torch.Tensor, list[int]], + idx: Optional[int] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.guidance_scale > 1.0 and (self.cfg_type == "initialize"): + x_t_latent_plus_uc = torch.concat([x_t_latent[0:1], x_t_latent], dim=0) + t_list = torch.concat([t_list[0:1], t_list], dim=0) + elif self.guidance_scale > 1.0 and (self.cfg_type == "full"): + x_t_latent_plus_uc = torch.concat([x_t_latent, x_t_latent], dim=0) + t_list = torch.concat([t_list, t_list], dim=0) + else: + x_t_latent_plus_uc = x_t_latent + + output = self.unet( + x_t_latent_plus_uc, + t_list, + depth_sample=depth_latent, + encoder_hidden_states=self.prompt_embeds, + temporal_attention_mask=self.attn_bias, + kv_cache=self.kv_cache_list, + pe_idx=self.pe_idx, + update_idx=self.update_idx, + return_dict=True, + ) + model_pred = output["sample"] + kv_cache_list = output["kv_cache"] + self.kv_cache_list = kv_cache_list + + if self.guidance_scale > 1.0 and (self.cfg_type == "initialize"): + noise_pred_text = model_pred[1:] + self.stock_noise = torch.concat( + [model_pred[0:1], self.stock_noise[1:]], dim=0 + ) # ここコメントアウトでself out cfg + elif self.guidance_scale > 1.0 and (self.cfg_type == "full"): + noise_pred_uncond, noise_pred_text = model_pred.chunk(2) + else: + noise_pred_text = model_pred + if self.guidance_scale > 1.0 and (self.cfg_type == "self" or self.cfg_type == "initialize"): + noise_pred_uncond = self.stock_noise * self.delta + if self.guidance_scale > 1.0 and self.cfg_type != "none": + model_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + else: + model_pred = noise_pred_text + + # compute the previous noisy sample x_t -> x_t-1 + if self.use_denoising_batch: + denoised_batch = self.scheduler_step_batch(model_pred, x_t_latent, idx) + if self.cfg_type == "self" or self.cfg_type == "initialize": + scaled_noise = self.beta_prod_t_sqrt * self.stock_noise + delta_x = self.scheduler_step_batch(model_pred, scaled_noise, idx) + alpha_next = torch.concat( + [ + self.alpha_prod_t_sqrt[1:], + torch.ones_like(self.alpha_prod_t_sqrt[0:1]), + ], + dim=0, + ) + delta_x = alpha_next * delta_x + beta_next = torch.concat( + [ + self.beta_prod_t_sqrt[1:], + torch.ones_like(self.beta_prod_t_sqrt[0:1]), + ], + dim=0, + ) + delta_x = delta_x / beta_next + init_noise = torch.concat([self.init_noise[1:], self.init_noise[0:1]], dim=0) + self.stock_noise = init_noise + delta_x + + else: + denoised_batch = self.scheduler_step_batch(model_pred, x_t_latent, idx) + + return denoised_batch, model_pred + + def encode_image(self, image_tensors: torch.Tensor) -> torch.Tensor: + """ + image_tensors: [f, c, h, w] + """ + # num_frames = image_tensors.shape[2] + image_tensors = image_tensors.to( + device=self.device, + dtype=self.vae.dtype, + ) + img_latent = retrieve_latents(self.vae.encode(image_tensors), self.generator) + img_latent = img_latent * self.vae.config.scaling_factor + noise = torch.randn( + img_latent.shape, + device=img_latent.device, + dtype=img_latent.dtype, + generator=self.generator, + ) + x_t_latent = self.add_noise(img_latent, noise, 0) + return x_t_latent + + def decode_image(self, x_0_pred_out: torch.Tensor) -> torch.Tensor: + """ + x_0_pred: [f, c, h, w] + """ + output_latent = self.vae.decode(x_0_pred_out / self.vae.config.scaling_factor, return_dict=False)[0] + return output_latent.clip(-1, 1) + + def encode_depth(self, image_tensors: torch.Tensor) -> Tuple[torch.Tensor]: + """ + image_tensor: [f, c, h, w], [-1, 1] + """ + image_tensors = image_tensors.to( + device=self.device, + dtype=self.depth_detector.dtype, + ) + # depth_map = self.depth_detector(image_tensors) + # depth_map_norm = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min()) + # depth_map_norm = depth_map_norm[:, None].repeat(1, 3, 1, 1) * 2 - 1 + # depth_latent = retrieve_latents(self.vae.encode(depth_map_norm.to(dtype=self.vae.dtype)), self.generator) + # depth_latent = depth_latent * self.vae.config.scaling_factor + # return depth_latent + + # preprocess + h, w = image_tensors.shape[2], image_tensors.shape[3] + images_input = F.interpolate(image_tensors, (384, 384), mode="bilinear", align_corners=False) + # forward + depth_map = self.depth_detector(images_input) + # postprocess + depth_map_norm = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min()) + depth_map_norm = depth_map_norm[:, None].repeat(1, 3, 1, 1) * 2 - 1 + depth_map_norm = F.interpolate(depth_map_norm, (h, w), mode="bilinear", align_corners=False) + # encode + depth_latent = retrieve_latents(self.vae.encode(depth_map_norm.to(dtype=self.vae.dtype)), self.generator) + depth_latent = depth_latent * self.vae.config.scaling_factor + return depth_latent + + def predict_x0_batch(self, x_t_latent: torch.Tensor, depth_latent: torch.Tensor) -> torch.Tensor: + prev_latent_batch = self.x_t_latent_buffer + prev_depth_latent_batch = self.depth_latent_buffer + + if self.use_denoising_batch: + t_list = self.sub_timesteps_tensor + if self.denoising_steps_num > 1: + x_t_latent = torch.cat((x_t_latent, prev_latent_batch), dim=0) + depth_latent = torch.cat((depth_latent, prev_depth_latent_batch), dim=0) + + self.stock_noise = torch.cat((self.init_noise[0:1], self.stock_noise[:-1]), dim=0) + x_0_pred_batch, model_pred = self.unet_step(x_t_latent, depth_latent, t_list) + self.attn_bias, self.pe_idx, self.update_idx = self.update_attn_bias( + self.attn_bias, self.pe_idx, self.update_idx + ) + + if self.denoising_steps_num > 1: + x_0_pred_out = x_0_pred_batch[-1].unsqueeze(0) + if self.do_add_noise: + # self.x_t_latent_buffer = ( + # self.alpha_prod_t_sqrt[1:] * x_0_pred_batch[:-1] + # + self.beta_prod_t_sqrt[1:] * self.init_noise[1:] + # ) + self.x_t_latent_buffer = self.alpha_prod_t_sqrt[1:] * x_0_pred_batch[:-1] + self.beta_prod_t_sqrt[ + 1: + ] * torch.randn_like(x_0_pred_batch[:-1]) + else: + self.x_t_latent_buffer = self.alpha_prod_t_sqrt[1:] * x_0_pred_batch[:-1] + self.depth_latent_buffer = depth_latent[:-1] + else: + x_0_pred_out = x_0_pred_batch + self.x_t_latent_buffer = None + else: + self.init_noise = x_t_latent + for idx, t in enumerate(self.sub_timesteps_tensor): + t = t.view( + 1, + ).repeat( + self.frame_bff_size, + ) + x_0_pred, model_pred = self.unet_step(x_t_latent, depth_latent, t, idx) + if idx < len(self.sub_timesteps_tensor) - 1: + if self.do_add_noise: + x_t_latent = self.alpha_prod_t_sqrt[idx + 1] * x_0_pred + self.beta_prod_t_sqrt[ + idx + 1 + ] * torch.randn_like(x_0_pred, device=self.device, dtype=self.dtype) + else: + x_t_latent = self.alpha_prod_t_sqrt[idx + 1] * x_0_pred + x_0_pred_out = x_0_pred + + return x_0_pred_out + + @torch.no_grad() + def __call__(self, x: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> torch.Tensor: + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + x = self.image_processor.preprocess(x, self.height, self.width).to(device=self.device, dtype=self.dtype) + if self.similar_image_filter: + x = self.similar_filter(x) + if x is None: + time.sleep(self.inference_time_ema) + return self.prev_image_result + x_t_latent = self.encode_image(x) + + start_depth = torch.cuda.Event(enable_timing=True) + end_depth = torch.cuda.Event(enable_timing=True) + start_depth.record() + depth_latent = self.encode_depth(x) + end_depth.record() + torch.cuda.synchronize() + depth_time = start_depth.elapsed_time(end_depth) / 1000 + + x_t_latent = x_t_latent.unsqueeze(2) + depth_latent = depth_latent.unsqueeze(2) + x_0_pred_out = self.predict_x0_batch(x_t_latent, depth_latent) # [1, c, 1, h, w] + x_0_pred_out = rearrange(x_0_pred_out, "b c f h w -> (b f) c h w") + x_output = self.decode_image(x_0_pred_out).detach().clone() + + self.prev_image_result = x_output + end.record() + torch.cuda.synchronize() + inference_time = start.elapsed_time(end) / 1000 + self.inference_time_ema = 0.9 * self.inference_time_ema + 0.1 * inference_time + self.depth_time_ema = 0.9 * self.depth_time_ema + 0.1 * depth_time + self.inference_time_list.append(inference_time) + self.depth_time_list.append(depth_time) + return x_output + + def load_warmup_unet(self, config): + unet_warmup = self.pipe.build_warmup_unet(config) + self.unet_warmup = unet_warmup + self.pipe.unet_warmup = unet_warmup + print("Load Warmup UNet.") diff --git a/live2diff/utils/__init__.py b/live2diff/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/live2diff/utils/config.py b/live2diff/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..60fd9072d480259c553ccb67fa7ec54bd2e4fd72 --- /dev/null +++ b/live2diff/utils/config.py @@ -0,0 +1,35 @@ +import os +import os.path as osp + +from omegaconf import OmegaConf + + +config_suffix = [".yaml"] + + +def load_config(config: str) -> OmegaConf: + config = OmegaConf.load(config) + base_config = config.pop("base", None) + + if base_config: + config = OmegaConf.merge(OmegaConf.load(base_config), config) + + return config + + +def dump_config(config: OmegaConf, save_path: str = None): + from omegaconf import Container + + if isinstance(config, Container): + if not save_path.endswith(".yaml"): + save_dir = save_path + save_path = osp.join(save_dir, "config.yaml") + else: + save_dir = osp.basename(config) + os.makedirs(save_dir, exist_ok=True) + OmegaConf.save(config, save_path) + + else: + raise TypeError("Only support saving `Config` from `OmegaConf`.") + + print(f"Dump Config to {save_path}.") diff --git a/live2diff/utils/io.py b/live2diff/utils/io.py new file mode 100644 index 0000000000000000000000000000000000000000..5917e0e26499a374014caec7cda490c39b4bd3cd --- /dev/null +++ b/live2diff/utils/io.py @@ -0,0 +1,48 @@ +import os +import os.path as osp + +import imageio +import numpy as np +import torch +import torchvision +from einops import rearrange +from PIL import Image + + +def read_video_frames(folder: str, height=None, width=None): + """ + Read video frames from the given folder. + + Output: + frames, in [0, 255], uint8, THWC + """ + _SUPPORTED_EXTENSIONS = [".png", ".jpg", ".jpeg"] + + frames = [f for f in os.listdir(folder) if osp.splitext(f)[1] in _SUPPORTED_EXTENSIONS] + # sort frames + sorted_frames = sorted(frames, key=lambda x: int(osp.splitext(x)[0])) + sorted_frames = [osp.join(folder, f) for f in sorted_frames] + + if height is not None and width is not None: + sorted_frames = [np.array(Image.open(f).resize((width, height))) for f in sorted_frames] + else: + sorted_frames = [np.array(Image.open(f)) for f in sorted_frames] + sorted_frames = torch.stack([torch.from_numpy(f) for f in sorted_frames], dim=0) + return sorted_frames + + +def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8): + videos = rearrange(videos, "b c t h w -> t b c h w") + outputs = [] + for x in videos: + x = torchvision.utils.make_grid(x, nrow=n_rows) + x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) + if rescale: + x = (x + 1.0) / 2.0 # -1,1 -> 0,1 + x = (x * 255).numpy().astype(np.uint8) + outputs.append(x) + + parent_dir = os.path.dirname(path) + if parent_dir != "": + os.makedirs(parent_dir, exist_ok=True) + imageio.mimsave(path, outputs, fps=fps, loop=0) diff --git a/live2diff/utils/wrapper.py b/live2diff/utils/wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..f890d8dbe1f1f8fbaaa48626978015edbf8c8f27 --- /dev/null +++ b/live2diff/utils/wrapper.py @@ -0,0 +1,640 @@ +import gc +import os +import traceback +from pathlib import Path +from typing import Dict, List, Literal, Optional, Union + +import numpy as np +import torch +from diffusers import AutoencoderTiny +from PIL import Image + +from live2diff import StreamAnimateDiffusionDepth +from live2diff.image_utils import postprocess_image +from live2diff.pipeline_stream_animation_depth import WARMUP_FRAMES + + +class StreamAnimateDiffusionDepthWrapper: + def __init__( + self, + config_path: str, + few_step_model_type: str, + num_inference_steps: int, + t_index_list: Optional[List[int]] = None, + strength: Optional[float] = None, + dreambooth_path: Optional[str] = None, + lora_dict: Optional[Dict[str, float]] = None, + output_type: Literal["pil", "pt", "np", "latent"] = "pil", + vae_id: Optional[str] = None, + device: Literal["cpu", "cuda"] = "cuda", + dtype: torch.dtype = torch.float16, + frame_buffer_size: int = 1, + width: int = 512, + height: int = 512, + acceleration: Literal["none", "xformers", "tensorrt"] = "tensorrt", + do_add_noise: bool = True, + device_ids: Optional[List[int]] = None, + use_tiny_vae: bool = True, + enable_similar_image_filter: bool = False, + similar_image_filter_threshold: float = 0.98, + similar_image_filter_max_skip_frame: int = 10, + use_denoising_batch: bool = True, + cfg_type: Literal["none", "full", "self", "initialize"] = "self", + seed: int = 42, + engine_dir: Optional[Union[str, Path]] = "engines", + opt_unet: bool = False, + ): + """ + Initializes the StreamAnimateDiffusionWrapper. + + Parameters + ---------- + config_path : str + The model id or path to load. + few_step_model_type : str + The few step model type to use. + num_inference_steps : int + The number of inference steps to perform. If `t_index_list` + is passed, `num_infernce_steps` will parsed as the number + of denoising steps before apply few-step lora. Otherwise, + `num_inference_steps` will be parsed as the number of + steps after applying few-step lora. + t_index_list : List[int] + The t_index_list to use for inference. + strength : Optional[float] + The strength to use for inference. + dreambooth_path : Optional[str] + The dreambooth path to use for inference. If not passed, + will use dreambooth from config. + lora_dict : Optional[Dict[str, float]], optional + The lora_dict to load, by default None. + Keys are the LoRA names and values are the LoRA scales. + Example: {'LoRA_1' : 0.5 , 'LoRA_2' : 0.7 ,...} + output_type : Literal["pil", "pt", "np", "latent"], optional + The output type of image, by default "pil". + vae_id : Optional[str], optional + The vae_id to load, by default None. + If None, the default TinyVAE + ("madebyollin/taesd") will be used. + device : Literal["cpu", "cuda"], optional + The device to use for inference, by default "cuda". + dtype : torch.dtype, optional + The dtype for inference, by default torch.float16. + frame_buffer_size : int, optional + The frame buffer size for denoising batch, by default 1. + width : int, optional + The width of the image, by default 512. + height : int, optional + The height of the image, by default 512. + acceleration : Literal["none", "xformers", "tensorrt"], optional + The acceleration method, by default "tensorrt". + do_add_noise : bool, optional + Whether to add noise for following denoising steps or not, + by default True. + device_ids : Optional[List[int]], optional + The device ids to use for DataParallel, by default None. + use_lcm_lora : bool, optional + Whether to use LCM-LoRA or not, by default True. + use_tiny_vae : bool, optional + Whether to use TinyVAE or not, by default True. + enable_similar_image_filter : bool, optional + Whether to enable similar image filter or not, + by default False. + similar_image_filter_threshold : float, optional + The threshold for similar image filter, by default 0.98. + similar_image_filter_max_skip_frame : int, optional + The max skip frame for similar image filter, by default 10. + use_denoising_batch : bool, optional + Whether to use denoising batch or not, by default True. + cfg_type : Literal["none", "full", "self", "initialize"], + optional + The cfg_type for img2img mode, by default "self". + You cannot use anything other than "none" for txt2img mode. + seed : int, optional + The seed, by default 42. + engine_dir : Optional[Union[str, Path]], optional + The directory to save TensorRT engines, by default "engines". + opt_unet : bool, optional + Whether to optimize UNet or not, by default False. + """ + self.sd_turbo = False + + self.device = device + self.dtype = dtype + self.width = width + self.height = height + self.output_type = output_type + self.frame_buffer_size = frame_buffer_size + + self.use_denoising_batch = use_denoising_batch + + self.stream: StreamAnimateDiffusionDepth = self._load_model( + config_path=config_path, + lora_dict=lora_dict, + dreambooth_path=dreambooth_path, + few_step_model_type=few_step_model_type, + vae_id=vae_id, + num_inference_steps=num_inference_steps, + t_index_list=t_index_list, + strength=strength, + height=height, + width=width, + acceleration=acceleration, + do_add_noise=do_add_noise, + use_tiny_vae=use_tiny_vae, + cfg_type=cfg_type, + seed=seed, + engine_dir=engine_dir, + opt_unet=opt_unet, + ) + self.batch_size = len(self.stream.t_list) * frame_buffer_size if use_denoising_batch else frame_buffer_size + + if device_ids is not None: + self.stream.unet = torch.nn.DataParallel(self.stream.unet, device_ids=device_ids) + + if enable_similar_image_filter: + self.stream.enable_similar_image_filter( + similar_image_filter_threshold, similar_image_filter_max_skip_frame + ) + + def prepare( + self, + warmup_frames: torch.Tensor, + prompt: str, + negative_prompt: str = "", + guidance_scale: float = 1.2, + delta: float = 1.0, + ) -> torch.Tensor: + """ + Prepares the model for inference. + + Parameters + ---------- + prompt : str + The prompt to generate images from. + num_inference_steps : int, optional + The number of inference steps to perform, by default 50. + guidance_scale : float, optional + The guidance scale to use, by default 1.2. + delta : float, optional + The delta multiplier of virtual residual noise, + by default 1.0. + + Returns + ---------- + warmup_frames : torch.Tensor + generated warmup-frames. + + """ + warmup_frames = self.stream.prepare( + warmup_frames=warmup_frames, + prompt=prompt, + negative_prompt=negative_prompt, + guidance_scale=guidance_scale, + delta=delta, + ) + + warmup_frames = warmup_frames.permute(0, 2, 3, 1) + warmup_frames = (warmup_frames.clip(-1, 1) + 1) / 2 + return warmup_frames + + def __call__( + self, + image: Optional[Union[str, Image.Image, torch.Tensor]] = None, + prompt: Optional[str] = None, + ) -> Union[Image.Image, List[Image.Image]]: + """ + Performs img2img or txt2img based on the mode. + + Parameters + ---------- + image : Optional[Union[str, Image.Image, torch.Tensor]] + The image to generate from. + prompt : Optional[str] + The prompt to generate images from. + + Returns + ------- + Union[Image.Image, List[Image.Image]] + The generated image. + """ + return self.img2img(image, prompt) + + def img2img( + self, image: Union[str, Image.Image, torch.Tensor], prompt: Optional[str] = None + ) -> Union[Image.Image, List[Image.Image], torch.Tensor, np.ndarray]: + """ + Performs img2img. + + Parameters + ---------- + image : Union[str, Image.Image, torch.Tensor] + The image to generate from. + + Returns + ------- + Image.Image + The generated image. + """ + if prompt is not None: + self.stream.update_prompt(prompt) + + if isinstance(image, str) or isinstance(image, Image.Image): + image = self.preprocess_image(image) + + image_tensor = self.stream(image) + image = self.postprocess_image(image_tensor, output_type=self.output_type) + + return image + + def preprocess_image(self, image: Union[str, Image.Image]) -> torch.Tensor: + """ + Preprocesses the image. + + Parameters + ---------- + image : Union[str, Image.Image, torch.Tensor] + The image to preprocess. + + Returns + ------- + torch.Tensor + The preprocessed image. + """ + if isinstance(image, str): + image = Image.open(image).convert("RGB").resize((self.width, self.height)) + if isinstance(image, Image.Image): + image = image.convert("RGB").resize((self.width, self.height)) + + return self.stream.image_processor.preprocess(image, self.height, self.width).to( + device=self.device, dtype=self.dtype + ) + + def postprocess_image( + self, image_tensor: torch.Tensor, output_type: str = "pil" + ) -> Union[Image.Image, List[Image.Image], torch.Tensor, np.ndarray]: + """ + Postprocesses the image. + + Parameters + ---------- + image_tensor : torch.Tensor + The image tensor to postprocess. + + Returns + ------- + Union[Image.Image, List[Image.Image]] + The postprocessed image. + """ + if self.frame_buffer_size > 1: + output = postprocess_image(image_tensor, output_type=output_type) + else: + output = postprocess_image(image_tensor, output_type=output_type)[0] + + if output_type not in ["pil", "np"]: + return output.cpu() + else: + return output + + @staticmethod + def get_model_prefix( + config_path: str, + few_step_model_type: str, + use_tiny_vae: bool, + num_denoising_steps: int, + height: int, + width: int, + dreambooth: Optional[str] = None, + lora_dict: Optional[dict] = None, + ) -> str: + from omegaconf import OmegaConf + + config = OmegaConf.load(config_path) + third_party = config.third_party_dict + dreambooth_path = dreambooth or third_party.dreambooth + if dreambooth_path is None: + dreambooth_name = "sd15" + else: + dreambooth_name = Path(dreambooth_path).stem + + base_lora_list = third_party.get("lora_list", []) + lora_dict = lora_dict or {} + for lora_alpha in base_lora_list: + lora_name = lora_alpha["lora"] + alpha = lora_alpha["lora_alpha"] + if lora_name not in lora_dict: + lora_dict[lora_name] = alpha + + prefix = f"{dreambooth_name}--{few_step_model_type}--step{num_denoising_steps}--" + for k, v in lora_dict.items(): + prefix += f"{Path(k).stem}-{v}--" + prefix += f"tiny_vae-{use_tiny_vae}--h-{height}--w-{width}" + return prefix + + def _load_model( + self, + config_path: str, + num_inference_steps: int, + height: int, + width: int, + t_index_list: Optional[List[int]] = None, + strength: Optional[float] = None, + dreambooth_path: Optional[str] = None, + lora_dict: Optional[Dict[str, float]] = None, + vae_id: Optional[str] = None, + acceleration: Literal["none", "xformers", "tensorrt"] = "tensorrt", + do_add_noise: bool = True, + few_step_model_type: Optional[str] = None, + use_tiny_vae: bool = True, + cfg_type: Literal["none", "full", "self", "initialize"] = "self", + seed: int = 2, + engine_dir: Optional[Union[str, Path]] = "engines", + opt_unet: bool = False, + ) -> StreamAnimateDiffusionDepth: + """ + Loads the model. + + This method does the following: + + 1. Loads the model from the model_id_or_path. + 3. Loads the VAE model from the vae_id if needed. + 4. Enables acceleration if needed. + 6. Load the safety checker if needed. + + Parameters + ---------- + config_path : str + The path to config, all needed checkpoints are list in config file. + t_index_list : List[int] + The t_index_list to use for inference. + dreambooth_path : Optional[str] + The dreambooth path to use for inference. If not passed, + will use dreambooth from config. + lora_dict : Optional[Dict[str, float]], optional + The lora_dict to load, by default None. + Keys are the LoRA names and values are the LoRA scales. + Example: {'LoRA_1' : 0.5 , 'LoRA_2' : 0.7 ,...} + vae_id : Optional[str], optional + The vae_id to load, by default None. + acceleration : Literal["none", "xfomers", "sfast", "tensorrt"], optional + The acceleration method, by default "tensorrt". + warmup : int, optional + The number of warmup steps to perform, by default 10. + do_add_noise : bool, optional + Whether to add noise for following denoising steps or not, + by default True. + use_lcm_lora : bool, optional + Whether to use LCM-LoRA or not, by default True. + use_tiny_vae : bool, optional + Whether to use TinyVAE or not, by default True. + cfg_type : Literal["none", "full", "self", "initialize"], + optional + The cfg_type for img2img mode, by default "self". + You cannot use anything other than "none" for txt2img mode. + seed : int, optional + The seed, by default 2. + opt_unet : bool, optional + Whether to optimize UNet or not, by default False. + + Returns + ------- + AnimatePipeline + The loaded pipeline. + """ + supported_few_step_model = ["LCM"] + assert ( + few_step_model_type.upper() in supported_few_step_model + ), f"Only support few_step_model: {supported_few_step_model}, but receive {few_step_model_type}." + + # NOTE: build animatediff pipeline + from live2diff.animatediff.pipeline import AnimationDepthPipeline + + try: + pipe = AnimationDepthPipeline.build_pipeline( + config_path, + ).to(device=self.device, dtype=self.dtype) + except Exception: # No model found + traceback.print_exc() + print("Model load has failed. Doesn't exist.") + exit() + + if few_step_model_type.upper() == "LCM": + few_step_lora = "latent-consistency/lcm-lora-sdv1-5" + stream_pipeline_cls = StreamAnimateDiffusionDepth + + print(f"Pipeline class: {stream_pipeline_cls}") + print(f"Few-step LoRA: {few_step_lora}") + + # parse clip skip from config + from .config import load_config + + cfg = load_config(config_path) + third_party_dict = cfg.third_party_dict + clip_skip = third_party_dict.get("clip_skip", 1) + + stream = stream_pipeline_cls( + pipe=pipe, + num_inference_steps=num_inference_steps, + t_index_list=t_index_list, + strength=strength, + torch_dtype=self.dtype, + width=self.width, + height=self.height, + do_add_noise=do_add_noise, + frame_buffer_size=self.frame_buffer_size, + use_denoising_batch=self.use_denoising_batch, + cfg_type=cfg_type, + clip_skip=clip_skip, + ) + + stream.load_warmup_unet(config_path) + stream.load_lora(few_step_lora) + stream.fuse_lora() + + denoising_steps_num = len(stream.t_list) + stream.prepare_cache( + height=height, + width=width, + denoising_steps_num=denoising_steps_num, + ) + kv_cache_list = stream.kv_cache_list + + if lora_dict is not None: + for lora_name, lora_scale in lora_dict.items(): + stream.load_lora(lora_name) + stream.fuse_lora(lora_scale=lora_scale) + print(f"Use LoRA: {lora_name} in weights {lora_scale}") + + if use_tiny_vae: + vae_id = "madebyollin/taesd" if vae_id is None else vae_id + stream.vae = AutoencoderTiny.from_pretrained(vae_id).to(device=pipe.device, dtype=pipe.dtype) + + try: + if acceleration == "none": + stream.pipe.unet = torch.compile(stream.pipe.unet, options={"triton.cudagraphs": True}, fullgraph=True) + stream.vae = torch.compile(stream.vae, options={"triton.cudagraphs": True}, fullgraph=True) + if acceleration == "xformers": + stream.pipe.enable_xformers_memory_efficient_attention() + if acceleration == "tensorrt": + from polygraphy import cuda + + from live2diff.acceleration.tensorrt import ( + TorchVAEEncoder, + compile_engine, + ) + from live2diff.acceleration.tensorrt.engine import ( + AutoencoderKLEngine, + MidasEngine, + UNet2DConditionModelDepthEngine, + ) + from live2diff.acceleration.tensorrt.models import ( + VAE, + InflatedUNetDepth, + Midas, + VAEEncoder, + ) + + prefix = self.get_model_prefix( + config_path=config_path, + few_step_model_type=few_step_model_type, + use_tiny_vae=use_tiny_vae, + num_denoising_steps=denoising_steps_num, + height=height, + width=width, + dreambooth=dreambooth_path, + lora_dict=lora_dict, + ) + + engine_dir = os.path.join(Path(engine_dir), prefix) + unet_path = os.path.join(engine_dir, "unet", "unet.engine") + unet_opt_path = os.path.join(engine_dir, "unet-opt", "unet.engine.opt") + midas_path = os.path.join(engine_dir, "depth", "midas.engine") + vae_encoder_path = os.path.join(engine_dir, "vae", "vae_encoder.engine") + vae_decoder_path = os.path.join(engine_dir, "vae", "vae_decoder.engine") + + if not os.path.exists(unet_path): + os.makedirs(os.path.dirname(unet_path), exist_ok=True) + os.makedirs(os.path.dirname(unet_opt_path), exist_ok=True) + unet_model = InflatedUNetDepth( + fp16=True, + device=stream.device, + max_batch_size=stream.trt_unet_batch_size, + min_batch_size=stream.trt_unet_batch_size, + embedding_dim=stream.text_encoder.config.hidden_size, + unet_dim=stream.unet.config.in_channels, + kv_cache_list=kv_cache_list, + ) + compile_engine( + torch_model=stream.unet, + model_data=unet_model, + onnx_path=unet_path + ".onnx", + onnx_opt_path=unet_opt_path, # use specific folder for external data + engine_path=unet_path, + opt_image_height=height, + opt_image_width=width, + opt_batch_size=stream.trt_unet_batch_size, + engine_build_options={"ignore_onnx_optimize": not opt_unet}, + ) + + if not os.path.exists(vae_decoder_path): + os.makedirs(os.path.dirname(vae_decoder_path), exist_ok=True) + stream.vae.forward = stream.vae.decode + max_bz = WARMUP_FRAMES + opt_bz = min_bz = 1 + vae_decoder_model = VAE( + device=stream.device, + max_batch_size=max_bz, + min_batch_size=min_bz, + ) + compile_engine( + torch_model=stream.vae, + model_data=vae_decoder_model, + onnx_path=vae_decoder_path + ".onnx", + onnx_opt_path=vae_decoder_path + ".opt.onnx", + engine_path=vae_decoder_path, + opt_image_height=height, + opt_image_width=width, + opt_batch_size=opt_bz, + ) + delattr(stream.vae, "forward") + + if not os.path.exists(midas_path): + os.makedirs(os.path.dirname(midas_path), exist_ok=True) + max_bz = WARMUP_FRAMES + opt_bz = min_bz = 1 + midas = Midas( + fp16=True, + device=stream.device, + max_batch_size=max_bz, + min_batch_size=min_bz, + ) + compile_engine( + torch_model=stream.depth_detector.half(), + model_data=midas, + onnx_path=midas_path + ".onnx", + onnx_opt_path=midas_path + ".opt.onnx", + engine_path=midas_path, + opt_batch_size=opt_bz, + opt_image_height=384, + opt_image_width=384, + engine_build_options={ + "auto_cast": False, + "handle_batch_norm": True, + "ignore_onnx_optimize": True, + }, + ) + + if not os.path.exists(vae_encoder_path): + os.makedirs(os.path.dirname(vae_encoder_path), exist_ok=True) + vae_encoder = TorchVAEEncoder(stream.vae).to(torch.device("cuda")) + max_bz = WARMUP_FRAMES + opt_bz = min_bz = 1 + vae_encoder_model = VAEEncoder( + device=stream.device, + max_batch_size=max_bz, + min_batch_size=min_bz, + ) + compile_engine( + torch_model=vae_encoder, + model_data=vae_encoder_model, + onnx_path=vae_encoder_path + ".onnx", + onnx_opt_path=vae_encoder_path + ".opt.onnx", + engine_path=vae_encoder_path, + opt_batch_size=opt_bz, + opt_image_height=height, + opt_image_width=width, + ) + cuda_stream = cuda.Stream() + + vae_config = stream.vae.config + vae_dtype = stream.vae.dtype + midas_dtype = stream.depth_detector.dtype + + stream.unet = UNet2DConditionModelDepthEngine(unet_path, cuda_stream, use_cuda_graph=False) + stream.depth_detector = MidasEngine(midas_path, cuda_stream, use_cuda_graph=False) + setattr(stream.depth_detector, "dtype", midas_dtype) + stream.vae = AutoencoderKLEngine( + vae_encoder_path, + vae_decoder_path, + cuda_stream, + stream.pipe.vae_scale_factor, + use_cuda_graph=False, + ) + setattr(stream.vae, "config", vae_config) + setattr(stream.vae, "dtype", vae_dtype) + + stream.is_tensorrt = True + + gc.collect() + torch.cuda.empty_cache() + + print("TensorRT acceleration enabled.") + + except Exception: + traceback.print_exc() + print("Acceleration has failed. Falling back to normal mode.") + + if seed < 0: # Random seed + seed = np.random.randint(0, 1000000) + + return stream diff --git a/scripts/download.sh b/scripts/download.sh new file mode 100644 index 0000000000000000000000000000000000000000..6369795541f14d7d3997045afcb897e2b9bc71d1 --- /dev/null +++ b/scripts/download.sh @@ -0,0 +1,91 @@ +#!/bin/bash +TOKEN=$2 + +download_disney() { + echo "Download checkpoint for Disney..." + wget https://civitai.com/api/download/models/69832\?token\=${TOKEN} -P ./models/Model --content-disposition --no-check-certificate +} + +download_moxin () { + echo "Download checkpoints for MoXin..." + wget https://civitai.com/api/download/models/106289\?token\=${TOKEN} -P ./models/Model --content-disposition --no-check-certificate + wget https://civitai.com/api/download/models/14856\?token\=${TOKEN} -P ./models/LoRA --content-disposition --no-check-certificate +} + +download_pixart () { + echo "Download checkpoint for PixArt..." + wget https://civitai.com/api/download/models/220049\?token\=${TOKEN} -P ./models/Model --content-disposition --no-check-certificate +} + +download_origami () { + echo "Download checkpoints for origami..." + wget https://civitai.com/api/download/models/270085\?token\=${TOKEN} -P ./models/Model --content-disposition --no-check-certificate + wget https://civitai.com/api/download/models/266928\?token\=${TOKEN} -P ./models/LoRA --content-disposition --no-check-certificate +} + +download_threeDelicacy () { + echo "Download checkpoints for threeDelicacy..." + wget https://civitai.com/api/download/models/36473\?token\=${TOKEN} -P ./models/Model --content-disposition --no-check-certificate +} + +download_toonyou () { + echo "Download checkpoint for Toonyou..." + wget https://civitai.com/api/download/models/125771\?token\=${TOKEN} -P ./models/Model --content-disposition --no-check-certificate +} + +download_zaum () { + echo "Download checkpoints for Zaum..." + wget https://civitai.com/api/download/models/428862\?token\=${TOKEN} -P ./models/Model --content-disposition --no-check-certificate + wget https://civitai.com/api/download/models/18989\?token\=${TOKEN} -P ./models/LoRA --content-disposition --no-check-certificate +} + +download_felted () { + echo "Download checkpoints for Felted..." + wget https://civitai.com/api/download/models/428862\?token\=${TOKEN} -P ./models/Model --content-disposition --no-check-certificate + wget https://civitai.com/api/download/models/86725\?token\=${TOKEN} -P ./models/LoRA --content-disposition --no-check-certificate +} + +if [ -z "$1" ]; then + echo "Please input the model you want to download." + echo "Supported model: all, disney, moxin, pixart, paperArt, threeDelicacy, toonyou, zaum." + exit 1 +fi + +declare -A download_func=( + ["disney"]="download_disney" + ["moxin"]="download_moxin" + ["pixart"]="download_pixart" + ["origami"]="download_origami" + ["threeDelicacy"]="download_threeDelicacy" + ["toonyou"]="download_toonyou" + ["zaum"]="download_zaum" + ["felted"]="download_felted" +) + +execute_function() { + local key="$1" + if [[ -n "${download_func[$key]}" ]]; then + ${download_func[$key]} + else + echo "Function not found for key: $key" + fi +} + + +for arg in "$@"; do + case "$arg" in + disney|moxin|pixart|origami|threeDelicacy|toonyou|zaum|felted) + model_name="$arg" + execute_function "$model_name" + ;; + all) + for model_name in "${!download_func[@]}"; do + execute_function "$model_name" + done + ;; + *) + echo "Invalid argument: $arg." + exit 1 + ;; + esac +done diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..f653b7fbee93eb6647e7c7863ad12b08aca608df --- /dev/null +++ b/setup.py @@ -0,0 +1,69 @@ +from typing import Literal, Optional + +from setuptools import find_packages, setup + + +deps = [ + "diffusers==0.25.0", + "transformers", + "accelerate", + "fire", + "einops", + "omegaconf", + "imageio", + "timm==0.6.7", + "lightning", + "peft", + "av", + "decord", + "pillow", + "pywin32;sys_platform == 'win32'", +] + + +def get_cuda_version_from_torch() -> Optional[Literal["11", "12"]]: + try: + import torch + + return torch.version.cuda.split(".")[0] + except ImportError: + raise ImportError("Please install PyTorch first. See https://pytorch.org/get-started/locally/.") + + +cu = get_cuda_version_from_torch() + +assert cu in ["11", "12"], f"Unsupported CUDA version: {cu}" + +deps_tensorrt = [ + "onnx==1.16.0", + "onnxruntime==1.16.3", + "protobuf==5.27.0", + "polygraphy", + "onnx-graphsurgeon", + "cuda-python", + f"tensorrt_cu{cu}_libs==10.0.1", + f"tensorrt_cu{cu}_bindings==10.0.1", + "tensorrt==10.0.1", + "colored", +] +extras = {"tensorrt": deps_tensorrt} + + +if __name__ == "__main__": + setup( + name="Live2Diff", + version="0.1", + description="real-time interactive video translation pipeline", + long_description=open("README.md", "r", encoding="utf-8").read(), + long_description_content_type="text/markdown", + keywords="deep learning diffusion pytorch stable diffusion streamdiffusion real-time next-frame prediction", + license="Apache 2.0 License", + author="leo", + author_email="xingzhening@pjlab.org.cn", + url="https://github.com/LeoXing1996/NextFramePredictionPreview", + package_dir={"": "live2diff"}, + packages=find_packages("live2diff"), + python_requires=">=3.10.0", + install_requires=deps, + extras_require=extras, + )