diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..bab002189355d42f5aba56499d615fa23e43e745
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,28 @@
+BSD 3-Clause License
+
+Copyright 2023 MagicAnimate Team All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+2. Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+3. Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
\ No newline at end of file
diff --git a/README.md b/README.md
index 689d44137e986820d38ecc07d357da5c865a53b2..851dd02426b3fa1a28bb50d1d644f75b222cac6d 100644
--- a/README.md
+++ b/README.md
@@ -1,13 +1,128 @@
----
-title: Magicanimate
-emoji: 📊
-colorFrom: yellow
-colorTo: pink
-sdk: gradio
-sdk_version: 4.7.1
-app_file: app.py
-pinned: false
-license: bsd-3-clause
----
-
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+
+
+
+
+
MagicAnimate: Temporally Consistent Human Image Animation using Diffusion Model
+
+ Zhongcong Xu
+ ·
+ Jianfeng Zhang
+ ·
+ Jun Hao Liew
+ ·
+ Hanshu Yan
+ ·
+ Jia-Wei Liu
+ ·
+ Chenxu Zhang
+ ·
+ Jiashi Feng
+ ·
+ Mike Zheng Shou
+
+
+
+
+
+
+ National University of Singapore | ByteDance
+
+
+
+
+
+
+ |
+
+
+ |
+
+
+
+
+ |
+
+
+ |
+
+
+
+## 📢 News
+* **[2023.12.4]** Release inference code and gradio demo. We are working to improve MagicAnimate, stay tuned!
+* **[2023.11.23]** Release MagicAnimate paper and project page.
+
+## 🏃♂️ Getting Started
+Please download the pretrained base models for [StableDiffusion V1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5) and [MSE-finetuned VAE](https://huggingface.co/stabilityai/sd-vae-ft-mse).
+
+Download our MagicAnimate [checkpints](https://huggingface.co/zcxu-eric/MagicAnimate).
+
+**Place them as following:**
+```bash
+magic-animate
+|----pretrained_models
+ |----MagicAnimate
+ |----appearance_encoder
+ |----diffusion_pytorch_model.safetensors
+ |----config.json
+ |----densepose_controlnet
+ |----diffusion_pytorch_model.safetensors
+ |----config.json
+ |----temporal_attention
+ |----temporal_attention.ckpt
+ |----sd-vae-ft-mse
+ |----...
+ |----stable-diffusion-v1-5
+ |----...
+|----...
+```
+
+## ⚒️ Installation
+prerequisites: `python>=3.8`, `CUDA>=11.3`, and `ffmpeg`.
+
+Install with `conda`:
+```bash
+conda env create -f environment.yml
+conda activate manimate
+```
+or `pip`:
+```bash
+pip3 install -r requirements.txt
+```
+
+## 💃 Inference
+Run inference on single GPU:
+```bash
+bash scripts/animate.sh
+```
+Run inference with multiple GPUs:
+```bash
+bash scripts/animate_dist.sh
+```
+
+## 🎨 Gradio Demo
+
+#### Online Gradio Demo:
+Try our [online gradio demo]() quickly.
+
+#### Local Gradio Demo:
+Launch local gradio demo on single GPU:
+```bash
+python3 -m demo.gradio_animate
+```
+Launch local gradio demo if you have multiple GPUs:
+```bash
+python3 -m demo.gradio_animate_dist
+```
+Then open gradio demo in local browser.
+
+## 🎓 Citation
+If you find this codebase useful for your research, please use the following entry.
+```BibTeX
+@inproceedings{xu2023magicanimate,
+ author = {Xu, Zhongcong and Zhang, Jianfeng and Liew, Jun Hao and Yan, Hanshu and Liu, Jia-Wei and Zhang, Chenxu and Feng, Jiashi and Shou, Mike Zheng},
+ title = {MagicAnimate: Temporally Consistent Human Image Animation using Diffusion Model},
+ booktitle = {arXiv},
+ year = {2023}
+}
+```
+
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..01b8fd5a9e113f6a78257a6d34501dfa4a266c6f
--- /dev/null
+++ b/app.py
@@ -0,0 +1,107 @@
+# Copyright 2023 ByteDance and/or its affiliates.
+#
+# Copyright (2023) MagicAnimate Authors
+#
+# ByteDance, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from ByteDance or
+# its affiliates is strictly prohibited.
+import argparse
+import imageio
+import numpy as np
+import gradio as gr
+from PIL import Image
+from subprocess import PIPE, run
+
+from demo.animate import MagicAnimate
+
+for command in [
+ 'mkdir ./pretrained_models && cd pretrained_models',
+ 'git lfs clone https://huggingface.co/zcxu-eric/MagicAnimate',
+ 'git lfs clone https://huggingface.co/runwayml/stable-diffusion-v1-5',
+ 'git lfs clone https://huggingface.co/stabilityai/sd-vae-ft-mse',
+ 'cd ..',
+]:
+ run(command, stdout=PIPE, stderr=PIPE, universal_newlines=True, shell=True)
+
+animator = MagicAnimate()
+
+def animate(reference_image, motion_sequence_state, seed, steps, guidance_scale):
+ return animator(reference_image, motion_sequence_state, seed, steps, guidance_scale)
+
+with gr.Blocks() as demo:
+
+ gr.HTML(
+ """
+
+
+ MagicAnimate: Temporally Consistent Human Image Animation
+
+
+
+
+ """)
+ animation = gr.Video(format="mp4", label="Animation Results", autoplay=True)
+
+ with gr.Row():
+ reference_image = gr.Image(label="Reference Image")
+ motion_sequence = gr.Video(format="mp4", label="Motion Sequence")
+
+ with gr.Column():
+ random_seed = gr.Textbox(label="Random seed", value=1, info="default: -1")
+ sampling_steps = gr.Textbox(label="Sampling steps", value=25, info="default: 25")
+ guidance_scale = gr.Textbox(label="Guidance scale", value=7.5, info="default: 7.5")
+ submit = gr.Button("Animate")
+
+ def read_video(video):
+ size = int(size)
+ reader = imageio.get_reader(video)
+ fps = reader.get_meta_data()['fps']
+ assert fps == 25.0, f'Expected video fps: 25, but {fps} fps found'
+ return video
+
+ def read_image(image, size=512):
+ return np.array(Image.fromarray(image).resize((size, size)))
+
+ # when user uploads a new video
+ motion_sequence.upload(
+ read_video,
+ motion_sequence,
+ motion_sequence
+ )
+ # when `first_frame` is updated
+ reference_image.upload(
+ read_image,
+ reference_image,
+ reference_image
+ )
+ # when the `submit` button is clicked
+ submit.click(
+ animate,
+ [reference_image, motion_sequence, random_seed, sampling_steps, guidance_scale],
+ animation
+ )
+
+ # Examples
+ gr.Markdown("## Examples")
+ gr.Examples(
+ examples=[
+ ["inputs/applications/source_image/monalisa.png", "inputs/applications/driving/densepose/running.mp4"],
+ ["inputs/applications/source_image/demo4.png", "inputs/applications/driving/densepose/demo4.mp4"],
+ ["inputs/applications/source_image/0002.png", "inputs/applications/driving/densepose/demo4.mp4"],
+ ["inputs/applications/source_image/dalle2.jpeg", "inputs/applications/driving/densepose/running2.mp4"],
+ ["inputs/applications/source_image/dalle8.jpeg", "inputs/applications/driving/densepose/dancing2.mp4"],
+ ["inputs/applications/source_image/multi1_source.png", "inputs/applications/driving/densepose/multi_dancing.mp4"],
+ ],
+ inputs=[reference_image, motion_sequence],
+ outputs=animation,
+ )
+
+
+demo.launch(share=True)
\ No newline at end of file
diff --git a/configs/inference/inference.yaml b/configs/inference/inference.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..86f377721c0637372229c024e2b5127bc6ca1604
--- /dev/null
+++ b/configs/inference/inference.yaml
@@ -0,0 +1,26 @@
+unet_additional_kwargs:
+ unet_use_cross_frame_attention: false
+ unet_use_temporal_attention: false
+ use_motion_module: true
+ motion_module_resolutions:
+ - 1
+ - 2
+ - 4
+ - 8
+ motion_module_mid_block: false
+ motion_module_decoder_only: false
+ motion_module_type: Vanilla
+ motion_module_kwargs:
+ num_attention_heads: 8
+ num_transformer_block: 1
+ attention_block_types:
+ - Temporal_Self
+ - Temporal_Self
+ temporal_position_encoding: true
+ temporal_position_encoding_max_len: 24
+ temporal_attention_dim_div: 1
+
+noise_scheduler_kwargs:
+ beta_start: 0.00085
+ beta_end: 0.012
+ beta_schedule: "linear"
diff --git a/configs/prompts/animation.yaml b/configs/prompts/animation.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d4c15903c4a396cc9e8e2fa59e62bd7d721539f3
--- /dev/null
+++ b/configs/prompts/animation.yaml
@@ -0,0 +1,42 @@
+pretrained_model_path: "pretrained_models/stable-diffusion-v1-5"
+pretrained_vae_path: "pretrained_models/sd-vae-ft-mse"
+pretrained_controlnet_path: "pretrained_models/MagicAnimate/densepose_controlnet"
+pretrained_appearance_encoder_path: "pretrained_models/MagicAnimate/appearance_encoder"
+pretrained_unet_path: ""
+
+motion_module: "pretrained_models/MagicAnimate/temporal_attention/temporal_attention.ckpt"
+
+savename: null
+
+fusion_blocks: "midup"
+
+seed: [1]
+steps: 25
+guidance_scale: 7.5
+
+source_image:
+ - "inputs/applications/source_image/monalisa.png"
+ - "inputs/applications/source_image/0002.png"
+ - "inputs/applications/source_image/demo4.png"
+ - "inputs/applications/source_image/dalle2.jpeg"
+ - "inputs/applications/source_image/dalle8.jpeg"
+ - "inputs/applications/source_image/multi1_source.png"
+video_path:
+ - "inputs/applications/driving/densepose/running.mp4"
+ - "inputs/applications/driving/densepose/demo4.mp4"
+ - "inputs/applications/driving/densepose/demo4.mp4"
+ - "inputs/applications/driving/densepose/running2.mp4"
+ - "inputs/applications/driving/densepose/dancing2.mp4"
+ - "inputs/applications/driving/densepose/multi_dancing.mp4"
+
+inference_config: "configs/inference/inference.yaml"
+size: 512
+L: 16
+S: 1
+I: 0
+clip: 0
+offset: 0
+max_length: null
+video_type: "condition"
+invert_video: false
+save_individual_videos: false
diff --git a/demo/animate.py b/demo/animate.py
new file mode 100644
index 0000000000000000000000000000000000000000..b71f1940212bcc4901828644be613a8d7f0525c8
--- /dev/null
+++ b/demo/animate.py
@@ -0,0 +1,195 @@
+# Copyright 2023 ByteDance and/or its affiliates.
+#
+# Copyright (2023) MagicAnimate Authors
+#
+# ByteDance, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from ByteDance or
+# its affiliates is strictly prohibited.
+import argparse
+import argparse
+import datetime
+import inspect
+import os
+import numpy as np
+from PIL import Image
+from omegaconf import OmegaConf
+from collections import OrderedDict
+
+import torch
+
+from diffusers import AutoencoderKL, DDIMScheduler, UniPCMultistepScheduler
+
+from tqdm import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from magicanimate.models.unet_controlnet import UNet3DConditionModel
+from magicanimate.models.controlnet import ControlNetModel
+from magicanimate.models.appearance_encoder import AppearanceEncoderModel
+from magicanimate.models.mutual_self_attention import ReferenceAttentionControl
+from magicanimate.pipelines.pipeline_animation import AnimationPipeline
+from magicanimate.utils.util import save_videos_grid
+from accelerate.utils import set_seed
+
+from magicanimate.utils.videoreader import VideoReader
+
+from einops import rearrange, repeat
+
+import csv, pdb, glob
+from safetensors import safe_open
+import math
+from pathlib import Path
+
+class MagicAnimate():
+ def __init__(self, config="configs/prompts/animation.yaml") -> None:
+ print("Initializing MagicAnimate Pipeline...")
+ *_, func_args = inspect.getargvalues(inspect.currentframe())
+ func_args = dict(func_args)
+
+ config = OmegaConf.load(config)
+
+ inference_config = OmegaConf.load(config.inference_config)
+
+ motion_module = config.motion_module
+
+ ### >>> create animation pipeline >>> ###
+ tokenizer = CLIPTokenizer.from_pretrained(config.pretrained_model_path, subfolder="tokenizer")
+ text_encoder = CLIPTextModel.from_pretrained(config.pretrained_model_path, subfolder="text_encoder")
+ if config.pretrained_unet_path:
+ unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_unet_path, unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
+ else:
+ unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
+ self.appearance_encoder = AppearanceEncoderModel.from_pretrained(config.pretrained_appearance_encoder_path, subfolder="appearance_encoder").cuda()
+ self.reference_control_writer = ReferenceAttentionControl(self.appearance_encoder, do_classifier_free_guidance=True, mode='write', fusion_blocks=config.fusion_blocks)
+ self.reference_control_reader = ReferenceAttentionControl(unet, do_classifier_free_guidance=True, mode='read', fusion_blocks=config.fusion_blocks)
+ if config.pretrained_vae_path is not None:
+ vae = AutoencoderKL.from_pretrained(config.pretrained_vae_path)
+ else:
+ vae = AutoencoderKL.from_pretrained(config.pretrained_model_path, subfolder="vae")
+
+ ### Load controlnet
+ controlnet = ControlNetModel.from_pretrained(config.pretrained_controlnet_path)
+
+ vae.to(torch.float16)
+ unet.to(torch.float16)
+ text_encoder.to(torch.float16)
+ controlnet.to(torch.float16)
+ self.appearance_encoder.to(torch.float16)
+
+ unet.enable_xformers_memory_efficient_attention()
+ self.appearance_encoder.enable_xformers_memory_efficient_attention()
+ controlnet.enable_xformers_memory_efficient_attention()
+
+ self.pipeline = AnimationPipeline(
+ vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet,
+ scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
+ # NOTE: UniPCMultistepScheduler
+ ).to("cuda")
+
+ # 1. unet ckpt
+ # 1.1 motion module
+ motion_module_state_dict = torch.load(motion_module, map_location="cpu")
+ if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]})
+ motion_module_state_dict = motion_module_state_dict['state_dict'] if 'state_dict' in motion_module_state_dict else motion_module_state_dict
+ try:
+ # extra steps for self-trained models
+ state_dict = OrderedDict()
+ for key in motion_module_state_dict.keys():
+ if key.startswith("module."):
+ _key = key.split("module.")[-1]
+ state_dict[_key] = motion_module_state_dict[key]
+ else:
+ state_dict[key] = motion_module_state_dict[key]
+ motion_module_state_dict = state_dict
+ del state_dict
+ missing, unexpected = self.pipeline.unet.load_state_dict(motion_module_state_dict, strict=False)
+ assert len(unexpected) == 0
+ except:
+ _tmp_ = OrderedDict()
+ for key in motion_module_state_dict.keys():
+ if "motion_modules" in key:
+ if key.startswith("unet."):
+ _key = key.split('unet.')[-1]
+ _tmp_[_key] = motion_module_state_dict[key]
+ else:
+ _tmp_[key] = motion_module_state_dict[key]
+ missing, unexpected = unet.load_state_dict(_tmp_, strict=False)
+ assert len(unexpected) == 0
+ del _tmp_
+ del motion_module_state_dict
+
+ self.pipeline.to("cuda")
+ self.L = config.L
+
+ print("Initialization Done!")
+
+ def __call__(self, source_image, motion_sequence, random_seed, step, guidance_scale, size=512):
+ prompt = n_prompt = ""
+ random_seed = int(random_seed)
+ step = int(step)
+ guidance_scale = float(guidance_scale)
+ samples_per_video = []
+ # manually set random seed for reproduction
+ if random_seed != -1:
+ torch.manual_seed(random_seed)
+ set_seed(random_seed)
+ else:
+ torch.seed()
+
+ if motion_sequence.endswith('.mp4'):
+ control = VideoReader(motion_sequence).read()
+ if control[0].shape[0] != size:
+ control = [np.array(Image.fromarray(c).resize((size, size))) for c in control]
+ control = np.array(control)
+
+ if source_image.shape[0] != size:
+ source_image = np.array(Image.fromarray(source_image).resize((size, size)))
+ H, W, C = source_image.shape
+
+ init_latents = None
+ original_length = control.shape[0]
+ if control.shape[0] % self.L > 0:
+ control = np.pad(control, ((0, self.L-control.shape[0] % self.L), (0, 0), (0, 0), (0, 0)), mode='edge')
+ generator = torch.Generator(device=torch.device("cuda:0"))
+ generator.manual_seed(torch.initial_seed())
+ sample = self.pipeline(
+ prompt,
+ negative_prompt = n_prompt,
+ num_inference_steps = step,
+ guidance_scale = guidance_scale,
+ width = W,
+ height = H,
+ video_length = len(control),
+ controlnet_condition = control,
+ init_latents = init_latents,
+ generator = generator,
+ appearance_encoder = self.appearance_encoder,
+ reference_control_writer = self.reference_control_writer,
+ reference_control_reader = self.reference_control_reader,
+ source_image = source_image,
+ ).videos
+
+ source_images = np.array([source_image] * original_length)
+ source_images = rearrange(torch.from_numpy(source_images), "t h w c -> 1 c t h w") / 255.0
+ samples_per_video.append(source_images)
+
+ control = control / 255.0
+ control = rearrange(control, "t h w c -> 1 c t h w")
+ control = torch.from_numpy(control)
+ samples_per_video.append(control[:, :, :original_length])
+
+ samples_per_video.append(sample[:, :, :original_length])
+
+ samples_per_video = torch.cat(samples_per_video)
+
+ time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
+ savedir = f"demo/outputs"
+ animation_path = f"{savedir}/{time_str}.mp4"
+
+ os.makedirs(savedir, exist_ok=True)
+ save_videos_grid(samples_per_video, animation_path)
+
+ return animation_path
+
\ No newline at end of file
diff --git a/inputs/applications/driving/densepose/.nfs006c000000039d6800000023 b/inputs/applications/driving/densepose/.nfs006c000000039d6800000023
new file mode 100644
index 0000000000000000000000000000000000000000..b8949299d2c18a8eece9e9687f3a2b9230fcedae
Binary files /dev/null and b/inputs/applications/driving/densepose/.nfs006c000000039d6800000023 differ
diff --git a/inputs/applications/driving/densepose/.nfs006c00000003a32d00000024 b/inputs/applications/driving/densepose/.nfs006c00000003a32d00000024
new file mode 100644
index 0000000000000000000000000000000000000000..4c83ce4f197a580ef873be6b54ef320f0f228fb6
Binary files /dev/null and b/inputs/applications/driving/densepose/.nfs006c00000003a32d00000024 differ
diff --git a/inputs/applications/driving/densepose/dancing2.mp4 b/inputs/applications/driving/densepose/dancing2.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..4af7ef94f3fa7791ff95a96c8850607cb070f978
Binary files /dev/null and b/inputs/applications/driving/densepose/dancing2.mp4 differ
diff --git a/inputs/applications/driving/densepose/demo4.mp4 b/inputs/applications/driving/densepose/demo4.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..740d9c08967e243fd7a93d213ffd3d7e328c90a9
Binary files /dev/null and b/inputs/applications/driving/densepose/demo4.mp4 differ
diff --git a/inputs/applications/driving/densepose/multi_dancing.mp4 b/inputs/applications/driving/densepose/multi_dancing.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..7a51a8337a8c1fbde56548b21c81ba94958891a5
Binary files /dev/null and b/inputs/applications/driving/densepose/multi_dancing.mp4 differ
diff --git a/inputs/applications/driving/densepose/running.mp4 b/inputs/applications/driving/densepose/running.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..1189d254f34f82a7ab8bd7f96afb3b713c7f3547
Binary files /dev/null and b/inputs/applications/driving/densepose/running.mp4 differ
diff --git a/inputs/applications/driving/densepose/running2.mp4 b/inputs/applications/driving/densepose/running2.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..2a1ccccf5101089eeab3e671a9a0fa5315c889c9
Binary files /dev/null and b/inputs/applications/driving/densepose/running2.mp4 differ
diff --git a/inputs/applications/source_image/0002.png b/inputs/applications/source_image/0002.png
new file mode 100644
index 0000000000000000000000000000000000000000..fd236a8ba11aa5698a264f7b1a96df2ca47ac2fb
Binary files /dev/null and b/inputs/applications/source_image/0002.png differ
diff --git a/inputs/applications/source_image/dalle2.jpeg b/inputs/applications/source_image/dalle2.jpeg
new file mode 100644
index 0000000000000000000000000000000000000000..9dd175f1739865343c7cc0f5e50f88079f4cb064
Binary files /dev/null and b/inputs/applications/source_image/dalle2.jpeg differ
diff --git a/inputs/applications/source_image/dalle8.jpeg b/inputs/applications/source_image/dalle8.jpeg
new file mode 100644
index 0000000000000000000000000000000000000000..04574c3fbc7956124c3639e092e3489089a2bd63
Binary files /dev/null and b/inputs/applications/source_image/dalle8.jpeg differ
diff --git a/inputs/applications/source_image/demo4.png b/inputs/applications/source_image/demo4.png
new file mode 100644
index 0000000000000000000000000000000000000000..110684f9baac2931306524f20439eb7e181d3dc7
Binary files /dev/null and b/inputs/applications/source_image/demo4.png differ
diff --git a/inputs/applications/source_image/monalisa.png b/inputs/applications/source_image/monalisa.png
new file mode 100644
index 0000000000000000000000000000000000000000..15eeb1c8b40499b44852bbdaca5ce7e8e7ac9db0
Binary files /dev/null and b/inputs/applications/source_image/monalisa.png differ
diff --git a/inputs/applications/source_image/multi1_source.png b/inputs/applications/source_image/multi1_source.png
new file mode 100644
index 0000000000000000000000000000000000000000..5791d81c7bf90ab56f187f65664cb2e1a023f996
Binary files /dev/null and b/inputs/applications/source_image/multi1_source.png differ
diff --git a/magicanimate/models/__pycache__/appearance_encoder.cpython-38.pyc b/magicanimate/models/__pycache__/appearance_encoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8912a67e911ee76135846f91566e2b0d69ba0810
Binary files /dev/null and b/magicanimate/models/__pycache__/appearance_encoder.cpython-38.pyc differ
diff --git a/magicanimate/models/__pycache__/attention.cpython-38.pyc b/magicanimate/models/__pycache__/attention.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..94f6f4c520a51b2c9ae95e33d482ddf2779c26ee
Binary files /dev/null and b/magicanimate/models/__pycache__/attention.cpython-38.pyc differ
diff --git a/magicanimate/models/__pycache__/controlnet.cpython-38.pyc b/magicanimate/models/__pycache__/controlnet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3b8fc075e92b8d4c2aa21632aff66abe7125aa4c
Binary files /dev/null and b/magicanimate/models/__pycache__/controlnet.cpython-38.pyc differ
diff --git a/magicanimate/models/__pycache__/embeddings.cpython-38.pyc b/magicanimate/models/__pycache__/embeddings.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9f81d7c99111b9f5d8898740a43a337ec90807f8
Binary files /dev/null and b/magicanimate/models/__pycache__/embeddings.cpython-38.pyc differ
diff --git a/magicanimate/models/__pycache__/motion_module.cpython-38.pyc b/magicanimate/models/__pycache__/motion_module.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..46b49ee55eece55056b428c26781fa7f9433b0eb
Binary files /dev/null and b/magicanimate/models/__pycache__/motion_module.cpython-38.pyc differ
diff --git a/magicanimate/models/__pycache__/mutual_self_attention.cpython-38.pyc b/magicanimate/models/__pycache__/mutual_self_attention.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9b632ff58ffa5eb480b07606db27c052da06ae38
Binary files /dev/null and b/magicanimate/models/__pycache__/mutual_self_attention.cpython-38.pyc differ
diff --git a/magicanimate/models/__pycache__/orig_attention.cpython-38.pyc b/magicanimate/models/__pycache__/orig_attention.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..736e3e2834636c6c9f4b82dc0fa563febbf6c610
Binary files /dev/null and b/magicanimate/models/__pycache__/orig_attention.cpython-38.pyc differ
diff --git a/magicanimate/models/__pycache__/resnet.cpython-38.pyc b/magicanimate/models/__pycache__/resnet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6379480913b4d142976095bba360b3963eff28b4
Binary files /dev/null and b/magicanimate/models/__pycache__/resnet.cpython-38.pyc differ
diff --git a/magicanimate/models/__pycache__/stable_diffusion_controlnet_reference.cpython-38.pyc b/magicanimate/models/__pycache__/stable_diffusion_controlnet_reference.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d17a09a13e7ce520932d7fbc369f219fc740a469
Binary files /dev/null and b/magicanimate/models/__pycache__/stable_diffusion_controlnet_reference.cpython-38.pyc differ
diff --git a/magicanimate/models/__pycache__/unet_3d_blocks.cpython-38.pyc b/magicanimate/models/__pycache__/unet_3d_blocks.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c5e37e70c10eb4ba8b69aa6ef899c73ab4e1d514
Binary files /dev/null and b/magicanimate/models/__pycache__/unet_3d_blocks.cpython-38.pyc differ
diff --git a/magicanimate/models/__pycache__/unet_controlnet.cpython-38.pyc b/magicanimate/models/__pycache__/unet_controlnet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e9f78a15cd972abe1a57ef7fea51af166f01420c
Binary files /dev/null and b/magicanimate/models/__pycache__/unet_controlnet.cpython-38.pyc differ
diff --git a/magicanimate/models/appearance_encoder.py b/magicanimate/models/appearance_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..81736db540c7ddd029b639e5978fc5b63762013a
--- /dev/null
+++ b/magicanimate/models/appearance_encoder.py
@@ -0,0 +1,1066 @@
+# *************************************************************************
+# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
+# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
+# ytedance Inc..
+# *************************************************************************
+
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.loaders import UNet2DConditionLoadersMixin
+from diffusers.utils import BaseOutput, logging
+from diffusers.models.activations import get_activation
+from diffusers.models.attention_processor import (
+ ADDED_KV_ATTENTION_PROCESSORS,
+ CROSS_ATTENTION_PROCESSORS,
+ AttentionProcessor,
+ AttnAddedKVProcessor,
+ AttnProcessor,
+)
+from diffusers.models.lora import LoRALinearLayer
+from diffusers.models.embeddings import (
+ GaussianFourierProjection,
+ ImageHintTimeEmbedding,
+ ImageProjection,
+ ImageTimeEmbedding,
+ PositionNet,
+ TextImageProjection,
+ TextImageTimeEmbedding,
+ TextTimeEmbedding,
+ TimestepEmbedding,
+ Timesteps,
+)
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.unet_2d_blocks import (
+ UNetMidBlock2DCrossAttn,
+ UNetMidBlock2DSimpleCrossAttn,
+ get_down_block,
+ get_up_block,
+)
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class Identity(torch.nn.Module):
+ r"""A placeholder identity operator that is argument-insensitive.
+
+ Args:
+ args: any argument (unused)
+ kwargs: any keyword argument (unused)
+
+ Shape:
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
+ - Output: :math:`(*)`, same shape as the input.
+
+ Examples::
+
+ >>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False)
+ >>> input = torch.randn(128, 20)
+ >>> output = m(input)
+ >>> print(output.size())
+ torch.Size([128, 20])
+
+ """
+ def __init__(self, scale=None, *args, **kwargs) -> None:
+ super(Identity, self).__init__()
+
+ def forward(self, input, *args, **kwargs):
+ return input
+
+
+
+class _LoRACompatibleLinear(nn.Module):
+ """
+ A Linear layer that can be used with LoRA.
+ """
+
+ def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.lora_layer = lora_layer
+
+ def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
+ self.lora_layer = lora_layer
+
+ def _fuse_lora(self):
+ pass
+
+ def _unfuse_lora(self):
+ pass
+
+ def forward(self, hidden_states, scale=None, lora_scale: int = 1):
+ return hidden_states
+
+
+@dataclass
+class UNet2DConditionOutput(BaseOutput):
+ """
+ The output of [`UNet2DConditionModel`].
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
+ """
+
+ sample: torch.FloatTensor = None
+
+
+class AppearanceEncoderModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
+ r"""
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
+ shaped output.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+
+ Parameters:
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
+ Height and width of input/output sample.
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
+ Whether to flip the sin to cos in the time embedding.
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
+ The tuple of downsample blocks to use.
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
+ Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
+ The tuple of upsample blocks to use.
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
+ Whether to include self-attention in the basic transformer blocks, see
+ [`~models.attention.BasicTransformerBlock`].
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
+ If `None`, normalization and activation layers is skipped in post-processing.
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
+ The dimension of the cross attention features.
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
+ encoder_hid_dim (`int`, *optional*, defaults to None):
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
+ dimension to `cross_attention_dim`.
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
+ num_attention_heads (`int`, *optional*):
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
+ class_embed_type (`str`, *optional*, defaults to `None`):
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
+ addition_embed_type (`str`, *optional*, defaults to `None`):
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
+ "text". "text" will use the `TextTimeEmbedding` layer.
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
+ Dimension for the timestep embeddings.
+ num_class_embeds (`int`, *optional*, defaults to `None`):
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
+ class conditioning with `class_embed_type` equal to `None`.
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
+ An optional override for the dimension of the projected time embedding.
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
+ timestep_post_act (`str`, *optional*, defaults to `None`):
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
+ The dimension of `cond_proj` layer in the timestep embedding.
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
+ embeddings with the class embeddings.
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
+ otherwise.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 4,
+ out_channels: int = 4,
+ center_input_sample: bool = False,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ layers_per_block: Union[int, Tuple[int]] = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: Optional[int] = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
+ encoder_hid_dim: Optional[int] = None,
+ encoder_hid_dim_type: Optional[str] = None,
+ attention_head_dim: Union[int, Tuple[int]] = 8,
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ class_embed_type: Optional[str] = None,
+ addition_embed_type: Optional[str] = None,
+ addition_time_embed_dim: Optional[int] = None,
+ num_class_embeds: Optional[int] = None,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ resnet_skip_time_act: bool = False,
+ resnet_out_scale_factor: int = 1.0,
+ time_embedding_type: str = "positional",
+ time_embedding_dim: Optional[int] = None,
+ time_embedding_act_fn: Optional[str] = None,
+ timestep_post_act: Optional[str] = None,
+ time_cond_proj_dim: Optional[int] = None,
+ conv_in_kernel: int = 3,
+ conv_out_kernel: int = 3,
+ projection_class_embeddings_input_dim: Optional[int] = None,
+ attention_type: str = "default",
+ class_embeddings_concat: bool = False,
+ mid_block_only_cross_attention: Optional[bool] = None,
+ cross_attention_norm: Optional[str] = None,
+ addition_embed_type_num_heads=64,
+ ):
+ super().__init__()
+
+ self.sample_size = sample_size
+
+ if num_attention_heads is not None:
+ raise ValueError(
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
+ )
+
+ # If `num_attention_heads` is not defined (which is the case for most models)
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
+ # which is why we correct for the naming here.
+ num_attention_heads = num_attention_heads or attention_head_dim
+
+ # Check inputs
+ if len(down_block_types) != len(up_block_types):
+ raise ValueError(
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
+ )
+
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
+ )
+
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
+ )
+
+ # input
+ conv_in_padding = (conv_in_kernel - 1) // 2
+ self.conv_in = nn.Conv2d(
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
+ )
+
+ # time
+ if time_embedding_type == "fourier":
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
+ if time_embed_dim % 2 != 0:
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
+ self.time_proj = GaussianFourierProjection(
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
+ )
+ timestep_input_dim = time_embed_dim
+ elif time_embedding_type == "positional":
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
+
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+ else:
+ raise ValueError(
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
+ )
+
+ self.time_embedding = TimestepEmbedding(
+ timestep_input_dim,
+ time_embed_dim,
+ act_fn=act_fn,
+ post_act_fn=timestep_post_act,
+ cond_proj_dim=time_cond_proj_dim,
+ )
+
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
+ encoder_hid_dim_type = "text_proj"
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
+
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
+ )
+
+ if encoder_hid_dim_type == "text_proj":
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
+ elif encoder_hid_dim_type == "text_image_proj":
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
+ self.encoder_hid_proj = TextImageProjection(
+ text_embed_dim=encoder_hid_dim,
+ image_embed_dim=cross_attention_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+ elif encoder_hid_dim_type == "image_proj":
+ # Kandinsky 2.2
+ self.encoder_hid_proj = ImageProjection(
+ image_embed_dim=encoder_hid_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+ elif encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
+ )
+ else:
+ self.encoder_hid_proj = None
+
+ # class embedding
+ if class_embed_type is None and num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+ elif class_embed_type == "timestep":
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
+ elif class_embed_type == "identity":
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
+ elif class_embed_type == "projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
+ # 2. it projects from an arbitrary input dimension.
+ #
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ elif class_embed_type == "simple_projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
+ else:
+ self.class_embedding = None
+
+ if addition_embed_type == "text":
+ if encoder_hid_dim is not None:
+ text_time_embedding_from_dim = encoder_hid_dim
+ else:
+ text_time_embedding_from_dim = cross_attention_dim
+
+ self.add_embedding = TextTimeEmbedding(
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
+ )
+ elif addition_embed_type == "text_image":
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
+ self.add_embedding = TextImageTimeEmbedding(
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
+ )
+ elif addition_embed_type == "text_time":
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ elif addition_embed_type == "image":
+ # Kandinsky 2.2
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
+ elif addition_embed_type == "image_hint":
+ # Kandinsky 2.2 ControlNet
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
+ elif addition_embed_type is not None:
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
+
+ if time_embedding_act_fn is None:
+ self.time_embed_act = None
+ else:
+ self.time_embed_act = get_activation(time_embedding_act_fn)
+
+ self.down_blocks = nn.ModuleList([])
+ self.up_blocks = nn.ModuleList([])
+
+ if isinstance(only_cross_attention, bool):
+ if mid_block_only_cross_attention is None:
+ mid_block_only_cross_attention = only_cross_attention
+
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
+
+ if mid_block_only_cross_attention is None:
+ mid_block_only_cross_attention = False
+
+ if isinstance(num_attention_heads, int):
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
+
+ if isinstance(attention_head_dim, int):
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
+
+ if isinstance(cross_attention_dim, int):
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
+
+ if isinstance(layers_per_block, int):
+ layers_per_block = [layers_per_block] * len(down_block_types)
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
+
+ if class_embeddings_concat:
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
+ # regular time embeddings
+ blocks_time_embed_dim = time_embed_dim * 2
+ else:
+ blocks_time_embed_dim = time_embed_dim
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block[i],
+ transformer_layers_per_block=transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim[i],
+ num_attention_heads=num_attention_heads[i],
+ downsample_padding=downsample_padding,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ attention_type=attention_type,
+ resnet_skip_time_act=resnet_skip_time_act,
+ resnet_out_scale_factor=resnet_out_scale_factor,
+ cross_attention_norm=cross_attention_norm,
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
+ self.mid_block = UNetMidBlock2DCrossAttn(
+ transformer_layers_per_block=transformer_layers_per_block[-1],
+ in_channels=block_out_channels[-1],
+ temb_channels=blocks_time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ cross_attention_dim=cross_attention_dim[-1],
+ num_attention_heads=num_attention_heads[-1],
+ resnet_groups=norm_num_groups,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ attention_type=attention_type,
+ )
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
+ in_channels=block_out_channels[-1],
+ temb_channels=blocks_time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ cross_attention_dim=cross_attention_dim[-1],
+ attention_head_dim=attention_head_dim[-1],
+ resnet_groups=norm_num_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ skip_time_act=resnet_skip_time_act,
+ only_cross_attention=mid_block_only_cross_attention,
+ cross_attention_norm=cross_attention_norm,
+ )
+ elif mid_block_type is None:
+ self.mid_block = None
+ else:
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
+
+ # count how many layers upsample the images
+ self.num_upsamplers = 0
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
+ reversed_layers_per_block = list(reversed(layers_per_block))
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
+ only_cross_attention = list(reversed(only_cross_attention))
+
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ is_final_block = i == len(block_out_channels) - 1
+
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ # add upsample block for all BUT final layer
+ if not is_final_block:
+ add_upsample = True
+ self.num_upsamplers += 1
+ else:
+ add_upsample = False
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=reversed_layers_per_block[i] + 1,
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_upsample=add_upsample,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=reversed_cross_attention_dim[i],
+ num_attention_heads=reversed_num_attention_heads[i],
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ attention_type=attention_type,
+ resnet_skip_time_act=resnet_skip_time_act,
+ resnet_out_scale_factor=resnet_out_scale_factor,
+ cross_attention_norm=cross_attention_norm,
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_q = _LoRACompatibleLinear()
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_k = _LoRACompatibleLinear()
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_v = _LoRACompatibleLinear()
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_out = nn.ModuleList([Identity(), Identity()])
+ self.up_blocks[3].attentions[2].transformer_blocks[0].norm2 = Identity()
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn2 = None
+ self.up_blocks[3].attentions[2].transformer_blocks[0].norm3 = Identity()
+ self.up_blocks[3].attentions[2].transformer_blocks[0].ff = Identity()
+ self.up_blocks[3].attentions[2].proj_out = Identity()
+
+ if attention_type in ["gated", "gated-text-image"]:
+ positive_len = 768
+ if isinstance(cross_attention_dim, int):
+ positive_len = cross_attention_dim
+ elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
+ positive_len = cross_attention_dim[0]
+
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
+ self.position_net = PositionNet(
+ positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
+ )
+
+ @property
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnAddedKVProcessor()
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnProcessor()
+ else:
+ raise ValueError(
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
+ )
+
+ self.set_attn_processor(processor)
+
+ def set_attention_slice(self, slice_size):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
+
+ Args:
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+ must be a multiple of `slice_size`.
+ """
+ sliceable_head_dims = []
+
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
+ if hasattr(module, "set_attention_slice"):
+ sliceable_head_dims.append(module.sliceable_head_dim)
+
+ for child in module.children():
+ fn_recursive_retrieve_sliceable_dims(child)
+
+ # retrieve number of attention layers
+ for module in self.children():
+ fn_recursive_retrieve_sliceable_dims(module)
+
+ num_sliceable_layers = len(sliceable_head_dims)
+
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
+ elif slice_size == "max":
+ # make smallest slice possible
+ slice_size = num_sliceable_layers * [1]
+
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
+
+ if len(slice_size) != len(sliceable_head_dims):
+ raise ValueError(
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
+ )
+
+ for i in range(len(slice_size)):
+ size = slice_size[i]
+ dim = sliceable_head_dims[i]
+ if size is not None and size > dim:
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
+
+ # Recursively walk through all the children.
+ # Any children which exposes the set_attention_slice method
+ # gets the message
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
+ if hasattr(module, "set_attention_slice"):
+ module.set_attention_slice(slice_size.pop())
+
+ for child in module.children():
+ fn_recursive_set_attention_slice(child, slice_size)
+
+ reversed_slice_size = list(reversed(slice_size))
+ for module in self.children():
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ class_labels: Optional[torch.Tensor] = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ) -> Union[UNet2DConditionOutput, Tuple]:
+ r"""
+ The [`UNet2DConditionModel`] forward method.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.FloatTensor`):
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
+ encoder_attention_mask (`torch.Tensor`):
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
+ tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
+ added_cond_kwargs: (`dict`, *optional*):
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
+ are passed along to the UNet blocks.
+
+ Returns:
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
+ a `tuple` is returned where the first element is the sample tensor.
+ """
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
+ # on the fly if necessary.
+ default_overall_up_factor = 2**self.num_upsamplers
+
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
+ forward_upsample_size = False
+ upsample_size = None
+
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
+ logger.info("Forward upsample size to force interpolation output size.")
+ forward_upsample_size = True
+
+ if attention_mask is not None:
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ # 0. center input if necessary
+ if self.config.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+
+ emb = self.time_embedding(t_emb, timestep_cond)
+ aug_emb = None
+
+ if self.class_embedding is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+
+ if self.config.class_embed_type == "timestep":
+ class_labels = self.time_proj(class_labels)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # there might be better ways to encapsulate this.
+ class_labels = class_labels.to(dtype=sample.dtype)
+
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
+
+ if self.config.class_embeddings_concat:
+ emb = torch.cat([emb, class_emb], dim=-1)
+ else:
+ emb = emb + class_emb
+
+ if self.config.addition_embed_type == "text":
+ aug_emb = self.add_embedding(encoder_hidden_states)
+ elif self.config.addition_embed_type == "text_image":
+ # Kandinsky 2.1 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
+ )
+
+ image_embs = added_cond_kwargs.get("image_embeds")
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
+ aug_emb = self.add_embedding(text_embs, image_embs)
+ elif self.config.addition_embed_type == "text_time":
+ # SDXL - style
+ if "text_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
+ )
+ text_embeds = added_cond_kwargs.get("text_embeds")
+ if "time_ids" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
+ )
+ time_ids = added_cond_kwargs.get("time_ids")
+ time_embeds = self.add_time_proj(time_ids.flatten())
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
+
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
+ add_embeds = add_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(add_embeds)
+ elif self.config.addition_embed_type == "image":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
+ )
+ image_embs = added_cond_kwargs.get("image_embeds")
+ aug_emb = self.add_embedding(image_embs)
+ elif self.config.addition_embed_type == "image_hint":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
+ )
+ image_embs = added_cond_kwargs.get("image_embeds")
+ hint = added_cond_kwargs.get("hint")
+ aug_emb, hint = self.add_embedding(image_embs, hint)
+ sample = torch.cat([sample, hint], dim=1)
+
+ emb = emb + aug_emb if aug_emb is not None else emb
+
+ if self.time_embed_act is not None:
+ emb = self.time_embed_act(emb)
+
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
+ # Kadinsky 2.1 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
+ # 2. pre-process
+ sample = self.conv_in(sample)
+
+ # 2.5 GLIGEN position net
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
+ cross_attention_kwargs = cross_attention_kwargs.copy()
+ gligen_args = cross_attention_kwargs.pop("gligen")
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
+
+ # 3. down
+
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
+
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ # For t2i-adapter CrossAttnDownBlock2D
+ additional_residuals = {}
+ if is_adapter and len(down_block_additional_residuals) > 0:
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
+
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ **additional_residuals,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+
+ if is_adapter and len(down_block_additional_residuals) > 0:
+ sample += down_block_additional_residuals.pop(0)
+
+ down_block_res_samples += res_samples
+
+ if is_controlnet:
+ new_down_block_res_samples = ()
+
+ for down_block_res_sample, down_block_additional_residual in zip(
+ down_block_res_samples, down_block_additional_residuals
+ ):
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
+
+ down_block_res_samples = new_down_block_res_samples
+
+ # 4. mid
+ if self.mid_block is not None:
+ sample = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ # To support T2I-Adapter-XL
+ if (
+ is_adapter
+ and len(down_block_additional_residuals) > 0
+ and sample.shape == down_block_additional_residuals[0].shape
+ ):
+ sample += down_block_additional_residuals.pop(0)
+
+ if is_controlnet:
+ sample = sample + mid_block_additional_residual
+
+ # 5. up
+ for i, upsample_block in enumerate(self.up_blocks):
+ is_final_block = i == len(self.up_blocks) - 1
+
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ upsample_size=upsample_size,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ else:
+ sample = upsample_block(
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
+ )
+
+ if not return_dict:
+ return (sample,)
+
+ return UNet2DConditionOutput(sample=sample)
\ No newline at end of file
diff --git a/magicanimate/models/attention.py b/magicanimate/models/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..796d861a6d6727aff179fcb452ca351d87148a7f
--- /dev/null
+++ b/magicanimate/models/attention.py
@@ -0,0 +1,320 @@
+# *************************************************************************
+# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
+# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
+# ytedance Inc..
+# *************************************************************************
+
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.utils import BaseOutput
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.models.attention import FeedForward, AdaLayerNorm
+from diffusers.models.attention import Attention as CrossAttention
+
+from einops import rearrange, repeat
+
+@dataclass
+class Transformer3DModelOutput(BaseOutput):
+ sample: torch.FloatTensor
+
+
+if is_xformers_available():
+ import xformers
+ import xformers.ops
+else:
+ xformers = None
+
+
+class Transformer3DModel(ModelMixin, ConfigMixin):
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 16,
+ attention_head_dim: int = 88,
+ in_channels: Optional[int] = None,
+ num_layers: int = 1,
+ dropout: float = 0.0,
+ norm_num_groups: int = 32,
+ cross_attention_dim: Optional[int] = None,
+ attention_bias: bool = False,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+
+ unet_use_cross_frame_attention=None,
+ unet_use_temporal_attention=None,
+ ):
+ super().__init__()
+ self.use_linear_projection = use_linear_projection
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_dim = attention_head_dim
+ inner_dim = num_attention_heads * attention_head_dim
+
+ # Define input layers
+ self.in_channels = in_channels
+
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+ if use_linear_projection:
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+ else:
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
+
+ # Define transformers blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ dropout=dropout,
+ cross_attention_dim=cross_attention_dim,
+ activation_fn=activation_fn,
+ num_embeds_ada_norm=num_embeds_ada_norm,
+ attention_bias=attention_bias,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
+ unet_use_temporal_attention=unet_use_temporal_attention,
+ )
+ for d in range(num_layers)
+ ]
+ )
+
+ # 4. Define output layers
+ if use_linear_projection:
+ self.proj_out = nn.Linear(in_channels, inner_dim)
+ else:
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
+ # Input
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
+ video_length = hidden_states.shape[2]
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
+ # JH: need not repeat when a list of prompts are given
+ if encoder_hidden_states.shape[0] != hidden_states.shape[0]:
+ encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
+
+ batch, channel, height, weight = hidden_states.shape
+ residual = hidden_states
+
+ hidden_states = self.norm(hidden_states)
+ if not self.use_linear_projection:
+ hidden_states = self.proj_in(hidden_states)
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
+ else:
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
+ hidden_states = self.proj_in(hidden_states)
+
+ # Blocks
+ for block in self.transformer_blocks:
+ hidden_states = block(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ timestep=timestep,
+ video_length=video_length
+ )
+
+ # Output
+ if not self.use_linear_projection:
+ hidden_states = (
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
+ )
+ hidden_states = self.proj_out(hidden_states)
+ else:
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = (
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
+ )
+
+ output = hidden_states + residual
+
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
+ if not return_dict:
+ return (output,)
+
+ return Transformer3DModelOutput(sample=output)
+
+
+class BasicTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout=0.0,
+ cross_attention_dim: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ attention_bias: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+
+ unet_use_cross_frame_attention = None,
+ unet_use_temporal_attention = None,
+ ):
+ super().__init__()
+ self.only_cross_attention = only_cross_attention
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
+ self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
+ self.unet_use_temporal_attention = unet_use_temporal_attention
+
+ # SC-Attn
+ assert unet_use_cross_frame_attention is not None
+ if unet_use_cross_frame_attention:
+ self.attn1 = SparseCausalAttention2D(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
+ upcast_attention=upcast_attention,
+ )
+ else:
+ self.attn1 = CrossAttention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ )
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
+
+ # Cross-Attn
+ if cross_attention_dim is not None:
+ self.attn2 = CrossAttention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ )
+ else:
+ self.attn2 = None
+
+ if cross_attention_dim is not None:
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
+ else:
+ self.norm2 = None
+
+ # Feed-forward
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
+ self.norm3 = nn.LayerNorm(dim)
+ self.use_ada_layer_norm_zero = False
+
+ # Temp-Attn
+ assert unet_use_temporal_attention is not None
+ if unet_use_temporal_attention:
+ self.attn_temp = CrossAttention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ )
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
+ self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
+
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, *args, **kwargs):
+ if not is_xformers_available():
+ print("Here is how to install it")
+ raise ModuleNotFoundError(
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
+ " xformers",
+ name="xformers",
+ )
+ elif not torch.cuda.is_available():
+ raise ValueError(
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
+ " available for GPU "
+ )
+ else:
+ try:
+ # Make sure we can run the memory efficient attention
+ _ = xformers.ops.memory_efficient_attention(
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ )
+ except Exception as e:
+ raise e
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
+ if self.attn2 is not None:
+ self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
+ # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
+
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None):
+ # SparseCausal-Attention
+ norm_hidden_states = (
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
+ )
+
+ # if self.only_cross_attention:
+ # hidden_states = (
+ # self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
+ # )
+ # else:
+ # hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
+
+ # pdb.set_trace()
+ if self.unet_use_cross_frame_attention:
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
+ else:
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
+
+ if self.attn2 is not None:
+ # Cross-Attention
+ norm_hidden_states = (
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
+ )
+ hidden_states = (
+ self.attn2(
+ norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
+ )
+ + hidden_states
+ )
+
+ # Feed-forward
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
+
+ # Temporal-Attention
+ if self.unet_use_temporal_attention:
+ d = hidden_states.shape[1]
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
+ norm_hidden_states = (
+ self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
+ )
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
+
+ return hidden_states
diff --git a/magicanimate/models/controlnet.py b/magicanimate/models/controlnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a74923e91e9211e59fa22d168983d5db95c0f7d
--- /dev/null
+++ b/magicanimate/models/controlnet.py
@@ -0,0 +1,578 @@
+# *************************************************************************
+# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
+# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
+# ytedance Inc..
+# *************************************************************************
+
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.utils import BaseOutput, logging
+from .embeddings import TimestepEmbedding, Timesteps
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.unet_2d_blocks import (
+ CrossAttnDownBlock2D,
+ DownBlock2D,
+ UNetMidBlock2DCrossAttn,
+ get_down_block,
+)
+from diffusers.models.unet_2d_condition import UNet2DConditionModel
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class ControlNetOutput(BaseOutput):
+ down_block_res_samples: Tuple[torch.Tensor]
+ mid_block_res_sample: torch.Tensor
+
+
+class ControlNetConditioningEmbedding(nn.Module):
+ """
+ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
+ [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
+ training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
+ convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
+ (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
+ model) to encode image-space conditions ... into feature maps ..."
+ """
+
+ def __init__(
+ self,
+ conditioning_embedding_channels: int,
+ conditioning_channels: int = 3,
+ block_out_channels: Tuple[int] = (16, 32, 96, 256),
+ ):
+ super().__init__()
+
+ self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
+
+ self.blocks = nn.ModuleList([])
+
+ for i in range(len(block_out_channels) - 1):
+ channel_in = block_out_channels[i]
+ channel_out = block_out_channels[i + 1]
+ self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
+ self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
+
+ self.conv_out = zero_module(
+ nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
+ )
+
+ def forward(self, conditioning):
+ embedding = self.conv_in(conditioning)
+ embedding = F.silu(embedding)
+
+ for block in self.blocks:
+ embedding = block(embedding)
+ embedding = F.silu(embedding)
+
+ embedding = self.conv_out(embedding)
+
+ return embedding
+
+
+class ControlNetModel(ModelMixin, ConfigMixin):
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 4,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ layers_per_block: int = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: Optional[int] = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: int = 1280,
+ attention_head_dim: Union[int, Tuple[int]] = 8,
+ use_linear_projection: bool = False,
+ class_embed_type: Optional[str] = None,
+ num_class_embeds: Optional[int] = None,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ projection_class_embeddings_input_dim: Optional[int] = None,
+ controlnet_conditioning_channel_order: str = "rgb",
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
+ ):
+ super().__init__()
+
+ # Check inputs
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
+ )
+
+ # input
+ conv_in_kernel = 3
+ conv_in_padding = (conv_in_kernel - 1) // 2
+ self.conv_in = nn.Conv2d(
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
+ )
+
+ # time
+ time_embed_dim = block_out_channels[0] * 4
+
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+
+ self.time_embedding = TimestepEmbedding(
+ timestep_input_dim,
+ time_embed_dim,
+ act_fn=act_fn,
+ )
+
+ # class embedding
+ if class_embed_type is None and num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+ elif class_embed_type == "timestep":
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+ elif class_embed_type == "identity":
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
+ elif class_embed_type == "projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
+ # 2. it projects from an arbitrary input dimension.
+ #
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ else:
+ self.class_embedding = None
+
+ # control net conditioning embedding
+ self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
+ conditioning_embedding_channels=block_out_channels[0],
+ block_out_channels=conditioning_embedding_out_channels,
+ )
+
+ self.down_blocks = nn.ModuleList([])
+ self.controlnet_down_blocks = nn.ModuleList([])
+
+ if isinstance(only_cross_attention, bool):
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
+
+ if isinstance(attention_head_dim, int):
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
+
+ # down
+ output_channel = block_out_channels[0]
+
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_down_blocks.append(controlnet_block)
+
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=attention_head_dim[i],
+ downsample_padding=downsample_padding,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ self.down_blocks.append(down_block)
+
+ for _ in range(layers_per_block):
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_down_blocks.append(controlnet_block)
+
+ if not is_final_block:
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_down_blocks.append(controlnet_block)
+
+ # mid
+ mid_block_channel = block_out_channels[-1]
+
+ controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_mid_block = controlnet_block
+
+ self.mid_block = UNetMidBlock2DCrossAttn(
+ in_channels=mid_block_channel,
+ temb_channels=time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=attention_head_dim[-1],
+ resnet_groups=norm_num_groups,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ )
+
+ @classmethod
+ def from_unet(
+ cls,
+ unet: UNet2DConditionModel,
+ controlnet_conditioning_channel_order: str = "rgb",
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
+ load_weights_from_unet: bool = True,
+ ):
+ r"""
+ Instantiate Controlnet class from UNet2DConditionModel.
+
+ Parameters:
+ unet (`UNet2DConditionModel`):
+ UNet model which weights are copied to the ControlNet. Note that all configuration options are also
+ copied where applicable.
+ """
+ controlnet = cls(
+ in_channels=unet.config.in_channels,
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
+ freq_shift=unet.config.freq_shift,
+ down_block_types=unet.config.down_block_types,
+ only_cross_attention=unet.config.only_cross_attention,
+ block_out_channels=unet.config.block_out_channels,
+ layers_per_block=unet.config.layers_per_block,
+ downsample_padding=unet.config.downsample_padding,
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
+ act_fn=unet.config.act_fn,
+ norm_num_groups=unet.config.norm_num_groups,
+ norm_eps=unet.config.norm_eps,
+ cross_attention_dim=unet.config.cross_attention_dim,
+ attention_head_dim=unet.config.attention_head_dim,
+ use_linear_projection=unet.config.use_linear_projection,
+ class_embed_type=unet.config.class_embed_type,
+ num_class_embeds=unet.config.num_class_embeds,
+ upcast_attention=unet.config.upcast_attention,
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
+ controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
+ )
+
+ if load_weights_from_unet:
+ controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
+ controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
+ controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
+
+ if controlnet.class_embedding:
+ controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
+
+ controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
+ controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
+
+ return controlnet
+
+ # @property
+ # # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
+ # def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ # r"""
+ # Returns:
+ # `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ # indexed by its weight name.
+ # """
+ # # set recursively
+ # processors = {}
+
+ # def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ # if hasattr(module, "set_processor"):
+ # processors[f"{name}.processor"] = module.processor
+
+ # for sub_name, child in module.named_children():
+ # fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ # return processors
+
+ # for name, module in self.named_children():
+ # fn_recursive_add_processors(name, module, processors)
+
+ # return processors
+
+ # # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ # def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ # r"""
+ # Parameters:
+ # `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
+ # The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ # of **all** `Attention` layers.
+ # In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
+
+ # """
+ # count = len(self.attn_processors.keys())
+
+ # if isinstance(processor, dict) and len(processor) != count:
+ # raise ValueError(
+ # f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ # f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ # )
+
+ # def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ # if hasattr(module, "set_processor"):
+ # if not isinstance(processor, dict):
+ # module.set_processor(processor)
+ # else:
+ # module.set_processor(processor.pop(f"{name}.processor"))
+
+ # for sub_name, child in module.named_children():
+ # fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ # for name, module in self.named_children():
+ # fn_recursive_attn_processor(name, module, processor)
+
+ # # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
+ # def set_default_attn_processor(self):
+ # """
+ # Disables custom attention processors and sets the default attention implementation.
+ # """
+ # self.set_attn_processor(AttnProcessor())
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
+ def set_attention_slice(self, slice_size):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+ must be a multiple of `slice_size`.
+ """
+ sliceable_head_dims = []
+
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
+ if hasattr(module, "set_attention_slice"):
+ sliceable_head_dims.append(module.sliceable_head_dim)
+
+ for child in module.children():
+ fn_recursive_retrieve_sliceable_dims(child)
+
+ # retrieve number of attention layers
+ for module in self.children():
+ fn_recursive_retrieve_sliceable_dims(module)
+
+ num_sliceable_layers = len(sliceable_head_dims)
+
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
+ elif slice_size == "max":
+ # make smallest slice possible
+ slice_size = num_sliceable_layers * [1]
+
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
+
+ if len(slice_size) != len(sliceable_head_dims):
+ raise ValueError(
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
+ )
+
+ for i in range(len(slice_size)):
+ size = slice_size[i]
+ dim = sliceable_head_dims[i]
+ if size is not None and size > dim:
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
+
+ # Recursively walk through all the children.
+ # Any children which exposes the set_attention_slice method
+ # gets the message
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
+ if hasattr(module, "set_attention_slice"):
+ module.set_attention_slice(slice_size.pop())
+
+ for child in module.children():
+ fn_recursive_set_attention_slice(child, slice_size)
+
+ reversed_slice_size = list(reversed(slice_size))
+ for module in self.children():
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
+ module.gradient_checkpointing = value
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ controlnet_cond: torch.FloatTensor,
+ conditioning_scale: float = 1.0,
+ class_labels: Optional[torch.Tensor] = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> Union[ControlNetOutput, Tuple]:
+ # check channel order
+ channel_order = self.config.controlnet_conditioning_channel_order
+
+ if channel_order == "rgb":
+ # in rgb order by default
+ ...
+ elif channel_order == "bgr":
+ controlnet_cond = torch.flip(controlnet_cond, dims=[1])
+ else:
+ raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
+
+ # prepare attention_mask
+ if attention_mask is not None:
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=self.dtype)
+
+ emb = self.time_embedding(t_emb, timestep_cond)
+
+ if self.class_embedding is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+
+ if self.config.class_embed_type == "timestep":
+ class_labels = self.time_proj(class_labels)
+
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
+ emb = emb + class_emb
+
+ # 2. pre-process
+ sample = self.conv_in(sample)
+
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
+
+ sample += controlnet_cond
+
+ # 3. down
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ # cross_attention_kwargs=cross_attention_kwargs,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+
+ down_block_res_samples += res_samples
+
+ # 4. mid
+ if self.mid_block is not None:
+ sample = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ # cross_attention_kwargs=cross_attention_kwargs,
+ )
+
+ # 5. Control net blocks
+
+ controlnet_down_block_res_samples = ()
+
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
+ down_block_res_sample = controlnet_block(down_block_res_sample)
+ controlnet_down_block_res_samples += (down_block_res_sample,)
+
+ down_block_res_samples = controlnet_down_block_res_samples
+
+ mid_block_res_sample = self.controlnet_mid_block(sample)
+
+ # 6. scaling
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
+ mid_block_res_sample *= conditioning_scale
+
+ if not return_dict:
+ return (down_block_res_samples, mid_block_res_sample)
+
+ return ControlNetOutput(
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
+ )
+
+
+def zero_module(module):
+ for p in module.parameters():
+ nn.init.zeros_(p)
+ return module
\ No newline at end of file
diff --git a/magicanimate/models/embeddings.py b/magicanimate/models/embeddings.py
new file mode 100644
index 0000000000000000000000000000000000000000..25fb3b2115a3559c2dc994820ee8297f96dbc146
--- /dev/null
+++ b/magicanimate/models/embeddings.py
@@ -0,0 +1,385 @@
+# *************************************************************************
+# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
+# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
+# ytedance Inc..
+# *************************************************************************
+
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+from typing import Optional
+
+import numpy as np
+import torch
+from torch import nn
+
+
+def get_timestep_embedding(
+ timesteps: torch.Tensor,
+ embedding_dim: int,
+ flip_sin_to_cos: bool = False,
+ downscale_freq_shift: float = 1,
+ scale: float = 1,
+ max_period: int = 10000,
+):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
+
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
+ """
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
+
+ half_dim = embedding_dim // 2
+ exponent = -math.log(max_period) * torch.arange(
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
+ )
+ exponent = exponent / (half_dim - downscale_freq_shift)
+
+ emb = torch.exp(exponent)
+ emb = timesteps[:, None].float() * emb[None, :]
+
+ # scale embeddings
+ emb = scale * emb
+
+ # concat sine and cosine embeddings
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
+
+ # flip sine and cosine embeddings
+ if flip_sin_to_cos:
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
+
+ # zero pad
+ if embedding_dim % 2 == 1:
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+ return emb
+
+
+def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
+ """
+ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
+ [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ grid_h = np.arange(grid_size, dtype=np.float32)
+ grid_w = np.arange(grid_size, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size, grid_size])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token and extra_tokens > 0:
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ if embed_dim % 2 != 0:
+ raise ValueError("embed_dim must be divisible by 2")
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
+ """
+ if embed_dim % 2 != 0:
+ raise ValueError("embed_dim must be divisible by 2")
+
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+
+class PatchEmbed(nn.Module):
+ """2D Image to Patch Embedding"""
+
+ def __init__(
+ self,
+ height=224,
+ width=224,
+ patch_size=16,
+ in_channels=3,
+ embed_dim=768,
+ layer_norm=False,
+ flatten=True,
+ bias=True,
+ ):
+ super().__init__()
+
+ num_patches = (height // patch_size) * (width // patch_size)
+ self.flatten = flatten
+ self.layer_norm = layer_norm
+
+ self.proj = nn.Conv2d(
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
+ )
+ if layer_norm:
+ self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
+ else:
+ self.norm = None
+
+ pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5))
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
+
+ def forward(self, latent):
+ latent = self.proj(latent)
+ if self.flatten:
+ latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
+ if self.layer_norm:
+ latent = self.norm(latent)
+ return latent + self.pos_embed
+
+
+class TimestepEmbedding(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ time_embed_dim: int,
+ act_fn: str = "silu",
+ out_dim: int = None,
+ post_act_fn: Optional[str] = None,
+ cond_proj_dim=None,
+ ):
+ super().__init__()
+
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim)
+
+ if cond_proj_dim is not None:
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
+ else:
+ self.cond_proj = None
+
+ if act_fn == "silu":
+ self.act = nn.SiLU()
+ elif act_fn == "mish":
+ self.act = nn.Mish()
+ elif act_fn == "gelu":
+ self.act = nn.GELU()
+ else:
+ raise ValueError(f"{act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'")
+
+ if out_dim is not None:
+ time_embed_dim_out = out_dim
+ else:
+ time_embed_dim_out = time_embed_dim
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
+
+ if post_act_fn is None:
+ self.post_act = None
+ elif post_act_fn == "silu":
+ self.post_act = nn.SiLU()
+ elif post_act_fn == "mish":
+ self.post_act = nn.Mish()
+ elif post_act_fn == "gelu":
+ self.post_act = nn.GELU()
+ else:
+ raise ValueError(f"{post_act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'")
+
+ def forward(self, sample, condition=None):
+ if condition is not None:
+ sample = sample + self.cond_proj(condition)
+ sample = self.linear_1(sample)
+
+ if self.act is not None:
+ sample = self.act(sample)
+
+ sample = self.linear_2(sample)
+
+ if self.post_act is not None:
+ sample = self.post_act(sample)
+ return sample
+
+
+class Timesteps(nn.Module):
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
+ super().__init__()
+ self.num_channels = num_channels
+ self.flip_sin_to_cos = flip_sin_to_cos
+ self.downscale_freq_shift = downscale_freq_shift
+
+ def forward(self, timesteps):
+ t_emb = get_timestep_embedding(
+ timesteps,
+ self.num_channels,
+ flip_sin_to_cos=self.flip_sin_to_cos,
+ downscale_freq_shift=self.downscale_freq_shift,
+ )
+ return t_emb
+
+
+class GaussianFourierProjection(nn.Module):
+ """Gaussian Fourier embeddings for noise levels."""
+
+ def __init__(
+ self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
+ ):
+ super().__init__()
+ self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
+ self.log = log
+ self.flip_sin_to_cos = flip_sin_to_cos
+
+ if set_W_to_weight:
+ # to delete later
+ self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
+
+ self.weight = self.W
+
+ def forward(self, x):
+ if self.log:
+ x = torch.log(x)
+
+ x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
+
+ if self.flip_sin_to_cos:
+ out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
+ else:
+ out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
+ return out
+
+
+class ImagePositionalEmbeddings(nn.Module):
+ """
+ Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
+ height and width of the latent space.
+
+ For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
+
+ For VQ-diffusion:
+
+ Output vector embeddings are used as input for the transformer.
+
+ Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
+
+ Args:
+ num_embed (`int`):
+ Number of embeddings for the latent pixels embeddings.
+ height (`int`):
+ Height of the latent image i.e. the number of height embeddings.
+ width (`int`):
+ Width of the latent image i.e. the number of width embeddings.
+ embed_dim (`int`):
+ Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
+ """
+
+ def __init__(
+ self,
+ num_embed: int,
+ height: int,
+ width: int,
+ embed_dim: int,
+ ):
+ super().__init__()
+
+ self.height = height
+ self.width = width
+ self.num_embed = num_embed
+ self.embed_dim = embed_dim
+
+ self.emb = nn.Embedding(self.num_embed, embed_dim)
+ self.height_emb = nn.Embedding(self.height, embed_dim)
+ self.width_emb = nn.Embedding(self.width, embed_dim)
+
+ def forward(self, index):
+ emb = self.emb(index)
+
+ height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))
+
+ # 1 x H x D -> 1 x H x 1 x D
+ height_emb = height_emb.unsqueeze(2)
+
+ width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))
+
+ # 1 x W x D -> 1 x 1 x W x D
+ width_emb = width_emb.unsqueeze(1)
+
+ pos_emb = height_emb + width_emb
+
+ # 1 x H x W x D -> 1 x L xD
+ pos_emb = pos_emb.view(1, self.height * self.width, -1)
+
+ emb = emb + pos_emb[:, : emb.shape[1], :]
+
+ return emb
+
+
+class LabelEmbedding(nn.Module):
+ """
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
+
+ Args:
+ num_classes (`int`): The number of classes.
+ hidden_size (`int`): The size of the vector embeddings.
+ dropout_prob (`float`): The probability of dropping a label.
+ """
+
+ def __init__(self, num_classes, hidden_size, dropout_prob):
+ super().__init__()
+ use_cfg_embedding = dropout_prob > 0
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
+ self.num_classes = num_classes
+ self.dropout_prob = dropout_prob
+
+ def token_drop(self, labels, force_drop_ids=None):
+ """
+ Drops labels to enable classifier-free guidance.
+ """
+ if force_drop_ids is None:
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
+ else:
+ drop_ids = torch.tensor(force_drop_ids == 1)
+ labels = torch.where(drop_ids, self.num_classes, labels)
+ return labels
+
+ def forward(self, labels, force_drop_ids=None):
+ use_dropout = self.dropout_prob > 0
+ if (self.training and use_dropout) or (force_drop_ids is not None):
+ labels = self.token_drop(labels, force_drop_ids)
+ embeddings = self.embedding_table(labels)
+ return embeddings
+
+
+class CombinedTimestepLabelEmbeddings(nn.Module):
+ def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
+ super().__init__()
+
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+ self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)
+
+ def forward(self, timestep, class_labels, hidden_dtype=None):
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
+
+ class_labels = self.class_embedder(class_labels) # (N, D)
+
+ conditioning = timesteps_emb + class_labels # (N, D)
+
+ return conditioning
\ No newline at end of file
diff --git a/magicanimate/models/motion_module.py b/magicanimate/models/motion_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..b272b69d4fe57c7baf18574a300bab51c7bebb83
--- /dev/null
+++ b/magicanimate/models/motion_module.py
@@ -0,0 +1,334 @@
+# *************************************************************************
+# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
+# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
+# ytedance Inc..
+# *************************************************************************
+
+# Adapted from https://github.com/guoyww/AnimateDiff
+from dataclasses import dataclass
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from diffusers.utils import BaseOutput
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.models.attention import FeedForward
+from magicanimate.models.orig_attention import CrossAttention
+
+from einops import rearrange, repeat
+import math
+
+
+def zero_module(module):
+ # Zero out the parameters of a module and return it.
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+@dataclass
+class TemporalTransformer3DModelOutput(BaseOutput):
+ sample: torch.FloatTensor
+
+
+if is_xformers_available():
+ import xformers
+ import xformers.ops
+else:
+ xformers = None
+
+
+def get_motion_module(
+ in_channels,
+ motion_module_type: str,
+ motion_module_kwargs: dict
+):
+ if motion_module_type == "Vanilla":
+ return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,)
+ else:
+ raise ValueError
+
+
+class VanillaTemporalModule(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ num_attention_heads = 8,
+ num_transformer_block = 2,
+ attention_block_types =( "Temporal_Self", "Temporal_Self" ),
+ cross_frame_attention_mode = None,
+ temporal_position_encoding = False,
+ temporal_position_encoding_max_len = 24,
+ temporal_attention_dim_div = 1,
+ zero_initialize = True,
+ ):
+ super().__init__()
+
+ self.temporal_transformer = TemporalTransformer3DModel(
+ in_channels=in_channels,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
+ num_layers=num_transformer_block,
+ attention_block_types=attention_block_types,
+ cross_frame_attention_mode=cross_frame_attention_mode,
+ temporal_position_encoding=temporal_position_encoding,
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
+ )
+
+ if zero_initialize:
+ self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
+
+ def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None):
+ hidden_states = input_tensor
+ hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
+
+ output = hidden_states
+ return output
+
+
+class TemporalTransformer3DModel(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ num_attention_heads,
+ attention_head_dim,
+
+ num_layers,
+ attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
+ dropout = 0.0,
+ norm_num_groups = 32,
+ cross_attention_dim = 768,
+ activation_fn = "geglu",
+ attention_bias = False,
+ upcast_attention = False,
+
+ cross_frame_attention_mode = None,
+ temporal_position_encoding = False,
+ temporal_position_encoding_max_len = 24,
+ ):
+ super().__init__()
+
+ inner_dim = num_attention_heads * attention_head_dim
+
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ TemporalTransformerBlock(
+ dim=inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ attention_block_types=attention_block_types,
+ dropout=dropout,
+ norm_num_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ activation_fn=activation_fn,
+ attention_bias=attention_bias,
+ upcast_attention=upcast_attention,
+ cross_frame_attention_mode=cross_frame_attention_mode,
+ temporal_position_encoding=temporal_position_encoding,
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
+ )
+ for d in range(num_layers)
+ ]
+ )
+ self.proj_out = nn.Linear(inner_dim, in_channels)
+
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
+ video_length = hidden_states.shape[2]
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
+
+ batch, channel, height, weight = hidden_states.shape
+ residual = hidden_states
+
+ hidden_states = self.norm(hidden_states)
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
+ hidden_states = self.proj_in(hidden_states)
+
+ # Transformer Blocks
+ for block in self.transformer_blocks:
+ hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length)
+
+ # output
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
+
+ output = hidden_states + residual
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
+
+ return output
+
+
+class TemporalTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_attention_heads,
+ attention_head_dim,
+ attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
+ dropout = 0.0,
+ norm_num_groups = 32,
+ cross_attention_dim = 768,
+ activation_fn = "geglu",
+ attention_bias = False,
+ upcast_attention = False,
+ cross_frame_attention_mode = None,
+ temporal_position_encoding = False,
+ temporal_position_encoding_max_len = 24,
+ ):
+ super().__init__()
+
+ attention_blocks = []
+ norms = []
+
+ for block_name in attention_block_types:
+ attention_blocks.append(
+ VersatileAttention(
+ attention_mode=block_name.split("_")[0],
+ cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
+
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+
+ cross_frame_attention_mode=cross_frame_attention_mode,
+ temporal_position_encoding=temporal_position_encoding,
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
+ )
+ )
+ norms.append(nn.LayerNorm(dim))
+
+ self.attention_blocks = nn.ModuleList(attention_blocks)
+ self.norms = nn.ModuleList(norms)
+
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
+ self.ff_norm = nn.LayerNorm(dim)
+
+
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
+ for attention_block, norm in zip(self.attention_blocks, self.norms):
+ norm_hidden_states = norm(hidden_states)
+ hidden_states = attention_block(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
+ video_length=video_length,
+ ) + hidden_states
+
+ hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
+
+ output = hidden_states
+ return output
+
+
+class PositionalEncoding(nn.Module):
+ def __init__(
+ self,
+ d_model,
+ dropout = 0.,
+ max_len = 24
+ ):
+ super().__init__()
+ self.dropout = nn.Dropout(p=dropout)
+ position = torch.arange(max_len).unsqueeze(1)
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
+ pe = torch.zeros(1, max_len, d_model)
+ pe[0, :, 0::2] = torch.sin(position * div_term)
+ pe[0, :, 1::2] = torch.cos(position * div_term)
+ self.register_buffer('pe', pe)
+
+ def forward(self, x):
+ x = x + self.pe[:, :x.size(1)]
+ return self.dropout(x)
+
+
+class VersatileAttention(CrossAttention):
+ def __init__(
+ self,
+ attention_mode = None,
+ cross_frame_attention_mode = None,
+ temporal_position_encoding = False,
+ temporal_position_encoding_max_len = 24,
+ *args, **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ assert attention_mode == "Temporal"
+
+ self.attention_mode = attention_mode
+ self.is_cross_attention = kwargs["cross_attention_dim"] is not None
+
+ self.pos_encoder = PositionalEncoding(
+ kwargs["query_dim"],
+ dropout=0.,
+ max_len=temporal_position_encoding_max_len
+ ) if (temporal_position_encoding and attention_mode == "Temporal") else None
+
+ def extra_repr(self):
+ return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
+
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ if self.attention_mode == "Temporal":
+ d = hidden_states.shape[1]
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
+
+ if self.pos_encoder is not None:
+ hidden_states = self.pos_encoder(hidden_states)
+
+ encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states
+ else:
+ raise NotImplementedError
+
+ encoder_hidden_states = encoder_hidden_states
+
+ if self.group_norm is not None:
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = self.to_q(hidden_states)
+ dim = query.shape[-1]
+ query = self.reshape_heads_to_batch_dim(query)
+
+ if self.added_kv_proj_dim is not None:
+ raise NotImplementedError
+
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
+ key = self.to_k(encoder_hidden_states)
+ value = self.to_v(encoder_hidden_states)
+
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+
+ if attention_mask is not None:
+ if attention_mask.shape[-1] != query.shape[1]:
+ target_length = query.shape[1]
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
+
+ # attention, what we cannot get enough of
+ if self._use_memory_efficient_attention_xformers:
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
+ hidden_states = hidden_states.to(query.dtype)
+ else:
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
+ hidden_states = self._attention(query, key, value, attention_mask)
+ else:
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
+
+ # linear proj
+ hidden_states = self.to_out[0](hidden_states)
+
+ # dropout
+ hidden_states = self.to_out[1](hidden_states)
+
+ if self.attention_mode == "Temporal":
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
+
+ return hidden_states
diff --git a/magicanimate/models/mutual_self_attention.py b/magicanimate/models/mutual_self_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..688332e6f6790d683d2fadbe544175d657cba65a
--- /dev/null
+++ b/magicanimate/models/mutual_self_attention.py
@@ -0,0 +1,642 @@
+# Copyright 2023 ByteDance and/or its affiliates.
+#
+# Copyright (2023) MagicAnimate Authors
+#
+# ByteDance, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from ByteDance or
+# its affiliates is strictly prohibited.
+
+import torch
+import torch.nn.functional as F
+
+from einops import rearrange
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+from diffusers.models.attention import BasicTransformerBlock
+from magicanimate.models.attention import BasicTransformerBlock as _BasicTransformerBlock
+from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
+from .stable_diffusion_controlnet_reference import torch_dfs
+
+
+class AttentionBase:
+ def __init__(self):
+ self.cur_step = 0
+ self.num_att_layers = -1
+ self.cur_att_layer = 0
+
+ def after_step(self):
+ pass
+
+ def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
+ out = self.forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
+ self.cur_att_layer += 1
+ if self.cur_att_layer == self.num_att_layers:
+ self.cur_att_layer = 0
+ self.cur_step += 1
+ # after step
+ self.after_step()
+ return out
+
+ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
+ out = torch.einsum('b i j, b j d -> b i d', attn, v)
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)
+ return out
+
+ def reset(self):
+ self.cur_step = 0
+ self.cur_att_layer = 0
+
+
+class MutualSelfAttentionControl(AttentionBase):
+
+ def __init__(self, total_steps=50, hijack_init_state=True, with_negative_guidance=False, appearance_control_alpha=0.5, mode='enqueue'):
+ """
+ Mutual self-attention control for Stable-Diffusion MODEl
+ Args:
+ total_steps: the total number of steps
+ """
+ super().__init__()
+ self.total_steps = total_steps
+ self.hijack = hijack_init_state
+ self.with_negative_guidance = with_negative_guidance
+
+ # alpha: mutual self attention intensity
+ # TODO: make alpha learnable
+ self.alpha = appearance_control_alpha
+ self.GLOBAL_ATTN_QUEUE = []
+ assert mode in ['enqueue', 'dequeue']
+ MODE = mode
+
+ def attn_batch(self, q, k, v, num_heads, **kwargs):
+ """
+ Performing attention for a batch of queries, keys, and values
+ """
+ b = q.shape[0] // num_heads
+ q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads)
+ k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads)
+ v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads)
+
+ sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale")
+ attn = sim.softmax(-1)
+ out = torch.einsum("h i j, h j d -> h i d", attn, v)
+ out = rearrange(out, "h (b n) d -> b n (h d)", b=b)
+ return out
+
+ def mutual_self_attn(self, q, k, v, num_heads, **kwargs):
+ q_tgt, q_src = q.chunk(2)
+ k_tgt, k_src = k.chunk(2)
+ v_tgt, v_src = v.chunk(2)
+
+ # out_tgt = self.attn_batch(q_tgt, k_src, v_src, num_heads, **kwargs) * self.alpha + \
+ # self.attn_batch(q_tgt, k_tgt, v_tgt, num_heads, **kwargs) * (1 - self.alpha)
+ out_tgt = self.attn_batch(q_tgt, torch.cat([k_tgt, k_src], dim=1), torch.cat([v_tgt, v_src], dim=1), num_heads, **kwargs)
+ out_src = self.attn_batch(q_src, k_src, v_src, num_heads, **kwargs)
+ out = torch.cat([out_tgt, out_src], dim=0)
+ return out
+
+ def mutual_self_attn_wq(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
+ if self.MODE == 'dequeue' and len(self.kv_queue) > 0:
+ k_src, v_src = self.kv_queue.pop(0)
+ out = self.attn_batch(q, torch.cat([k, k_src], dim=1), torch.cat([v, v_src], dim=1), num_heads, **kwargs)
+ return out
+ else:
+ self.kv_queue.append([k.clone(), v.clone()])
+ return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
+
+ def get_queue(self):
+ return self.GLOBAL_ATTN_QUEUE
+
+ def set_queue(self, attn_queue):
+ self.GLOBAL_ATTN_QUEUE = attn_queue
+
+ def clear_queue(self):
+ self.GLOBAL_ATTN_QUEUE = []
+
+ def to(self, dtype):
+ self.GLOBAL_ATTN_QUEUE = [p.to(dtype) for p in self.GLOBAL_ATTN_QUEUE]
+
+ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
+ """
+ Attention forward function
+ """
+ return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
+
+
+class ReferenceAttentionControl():
+
+ def __init__(self,
+ unet,
+ mode="write",
+ do_classifier_free_guidance=False,
+ attention_auto_machine_weight = float('inf'),
+ gn_auto_machine_weight = 1.0,
+ style_fidelity = 1.0,
+ reference_attn=True,
+ reference_adain=False,
+ fusion_blocks="midup",
+ batch_size=1,
+ ) -> None:
+ # 10. Modify self attention and group norm
+ self.unet = unet
+ assert mode in ["read", "write"]
+ assert fusion_blocks in ["midup", "full"]
+ self.reference_attn = reference_attn
+ self.reference_adain = reference_adain
+ self.fusion_blocks = fusion_blocks
+ self.register_reference_hooks(
+ mode,
+ do_classifier_free_guidance,
+ attention_auto_machine_weight,
+ gn_auto_machine_weight,
+ style_fidelity,
+ reference_attn,
+ reference_adain,
+ fusion_blocks,
+ batch_size=batch_size,
+ )
+
+ def register_reference_hooks(
+ self,
+ mode,
+ do_classifier_free_guidance,
+ attention_auto_machine_weight,
+ gn_auto_machine_weight,
+ style_fidelity,
+ reference_attn,
+ reference_adain,
+ dtype=torch.float16,
+ batch_size=1,
+ num_images_per_prompt=1,
+ device=torch.device("cpu"),
+ fusion_blocks='midup',
+ ):
+ MODE = mode
+ do_classifier_free_guidance = do_classifier_free_guidance
+ attention_auto_machine_weight = attention_auto_machine_weight
+ gn_auto_machine_weight = gn_auto_machine_weight
+ style_fidelity = style_fidelity
+ reference_attn = reference_attn
+ reference_adain = reference_adain
+ fusion_blocks = fusion_blocks
+ num_images_per_prompt = num_images_per_prompt
+ dtype=dtype
+ if do_classifier_free_guidance:
+ uc_mask = (
+ torch.Tensor([1] * batch_size * num_images_per_prompt * 16 + [0] * batch_size * num_images_per_prompt * 16)
+ .to(device)
+ .bool()
+ )
+ else:
+ uc_mask = (
+ torch.Tensor([0] * batch_size * num_images_per_prompt * 2)
+ .to(device)
+ .bool()
+ )
+
+ def hacked_basic_transformer_inner_forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ video_length=None,
+ ):
+ if self.use_ada_layer_norm:
+ norm_hidden_states = self.norm1(hidden_states, timestep)
+ elif self.use_ada_layer_norm_zero:
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ else:
+ norm_hidden_states = self.norm1(hidden_states)
+
+ # 1. Self-Attention
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+ if self.only_cross_attention:
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ else:
+ if MODE == "write":
+ self.bank.append(norm_hidden_states.clone())
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ if MODE == "read":
+ self.bank = [rearrange(d.unsqueeze(1).repeat(1, video_length, 1, 1), "b t l c -> (b t) l c")[:hidden_states.shape[0]] for d in self.bank]
+ hidden_states_uc = self.attn1(norm_hidden_states,
+ encoder_hidden_states=torch.cat([norm_hidden_states] + self.bank, dim=1),
+ attention_mask=attention_mask) + hidden_states
+ hidden_states_c = hidden_states_uc.clone()
+ _uc_mask = uc_mask.clone()
+ if do_classifier_free_guidance:
+ if hidden_states.shape[0] != _uc_mask.shape[0]:
+ _uc_mask = (
+ torch.Tensor([1] * (hidden_states.shape[0]//2) + [0] * (hidden_states.shape[0]//2))
+ .to(device)
+ .bool()
+ )
+ hidden_states_c[_uc_mask] = self.attn1(
+ norm_hidden_states[_uc_mask],
+ encoder_hidden_states=norm_hidden_states[_uc_mask],
+ attention_mask=attention_mask,
+ ) + hidden_states[_uc_mask]
+ hidden_states = hidden_states_c.clone()
+
+ self.bank.clear()
+ if self.attn2 is not None:
+ # Cross-Attention
+ norm_hidden_states = (
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
+ )
+ hidden_states = (
+ self.attn2(
+ norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
+ )
+ + hidden_states
+ )
+
+ # Feed-forward
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
+
+ # Temporal-Attention
+ if self.unet_use_temporal_attention:
+ d = hidden_states.shape[1]
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
+ norm_hidden_states = (
+ self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
+ )
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
+
+ return hidden_states
+
+ if self.use_ada_layer_norm_zero:
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = attn_output + hidden_states
+
+ if self.attn2 is not None:
+ norm_hidden_states = (
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
+ )
+
+ # 2. Cross-Attention
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ **cross_attention_kwargs,
+ )
+ hidden_states = attn_output + hidden_states
+
+ # 3. Feed-forward
+ norm_hidden_states = self.norm3(hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ ff_output = self.ff(norm_hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+
+ hidden_states = ff_output + hidden_states
+
+ return hidden_states
+
+ def hacked_mid_forward(self, *args, **kwargs):
+ eps = 1e-6
+ x = self.original_forward(*args, **kwargs)
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append(mean)
+ self.var_bank.append(var)
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank) / float(len(self.mean_bank))
+ var_acc = sum(self.var_bank) / float(len(self.var_bank))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ x_uc = (((x - mean) / std) * std_acc) + mean_acc
+ x_c = x_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ x_c[uc_mask] = x[uc_mask]
+ x = style_fidelity * x_c + (1.0 - style_fidelity) * x_uc
+ self.mean_bank = []
+ self.var_bank = []
+ return x
+
+ def hack_CrossAttnDownBlock2D_forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ):
+ eps = 1e-6
+
+ # TODO(Patrick, William) - attention mask is not used
+ output_states = ()
+
+ for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype)
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ output_states = output_states + (hidden_states,)
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+ def hacked_DownBlock2D_forward(self, hidden_states, temb=None):
+ eps = 1e-6
+
+ output_states = ()
+
+ for i, resnet in enumerate(self.resnets):
+ hidden_states = resnet(hidden_states, temb)
+
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype)
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ output_states = output_states + (hidden_states,)
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+ def hacked_CrossAttnUpBlock2D_forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ upsample_size: Optional[int] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ):
+ eps = 1e-6
+ # TODO(Patrick, William) - attention mask is not used
+ for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype)
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+ def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
+ eps = 1e-6
+ for i, resnet in enumerate(self.resnets):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+ hidden_states = resnet(hidden_states, temb)
+
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype)
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+ if self.reference_attn:
+ if self.fusion_blocks == "midup":
+ attn_modules = [module for module in (torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)]
+ elif self.fusion_blocks == "full":
+ attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)]
+ attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
+
+ for i, module in enumerate(attn_modules):
+ module._original_inner_forward = module.forward
+ module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)
+ module.bank = []
+ module.attn_weight = float(i) / float(len(attn_modules))
+
+ if self.reference_adain:
+ gn_modules = [self.unet.mid_block]
+ self.unet.mid_block.gn_weight = 0
+
+ down_blocks = self.unet.down_blocks
+ for w, module in enumerate(down_blocks):
+ module.gn_weight = 1.0 - float(w) / float(len(down_blocks))
+ gn_modules.append(module)
+
+ up_blocks = self.unet.up_blocks
+ for w, module in enumerate(up_blocks):
+ module.gn_weight = float(w) / float(len(up_blocks))
+ gn_modules.append(module)
+
+ for i, module in enumerate(gn_modules):
+ if getattr(module, "original_forward", None) is None:
+ module.original_forward = module.forward
+ if i == 0:
+ # mid_block
+ module.forward = hacked_mid_forward.__get__(module, torch.nn.Module)
+ elif isinstance(module, CrossAttnDownBlock2D):
+ module.forward = hack_CrossAttnDownBlock2D_forward.__get__(module, CrossAttnDownBlock2D)
+ elif isinstance(module, DownBlock2D):
+ module.forward = hacked_DownBlock2D_forward.__get__(module, DownBlock2D)
+ elif isinstance(module, CrossAttnUpBlock2D):
+ module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D)
+ elif isinstance(module, UpBlock2D):
+ module.forward = hacked_UpBlock2D_forward.__get__(module, UpBlock2D)
+ module.mean_bank = []
+ module.var_bank = []
+ module.gn_weight *= 2
+
+ def update(self, writer, dtype=torch.float16):
+ if self.reference_attn:
+ if self.fusion_blocks == "midup":
+ reader_attn_modules = [module for module in (torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, _BasicTransformerBlock)]
+ writer_attn_modules = [module for module in (torch_dfs(writer.unet.mid_block)+torch_dfs(writer.unet.up_blocks)) if isinstance(module, BasicTransformerBlock)]
+ elif self.fusion_blocks == "full":
+ reader_attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, _BasicTransformerBlock)]
+ writer_attn_modules = [module for module in torch_dfs(writer.unet) if isinstance(module, BasicTransformerBlock)]
+ reader_attn_modules = sorted(reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
+ writer_attn_modules = sorted(writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
+ for r, w in zip(reader_attn_modules, writer_attn_modules):
+ r.bank = [v.clone().to(dtype) for v in w.bank]
+ # w.bank.clear()
+ if self.reference_adain:
+ reader_gn_modules = [self.unet.mid_block]
+
+ down_blocks = self.unet.down_blocks
+ for w, module in enumerate(down_blocks):
+ reader_gn_modules.append(module)
+
+ up_blocks = self.unet.up_blocks
+ for w, module in enumerate(up_blocks):
+ reader_gn_modules.append(module)
+
+ writer_gn_modules = [writer.unet.mid_block]
+
+ down_blocks = writer.unet.down_blocks
+ for w, module in enumerate(down_blocks):
+ writer_gn_modules.append(module)
+
+ up_blocks = writer.unet.up_blocks
+ for w, module in enumerate(up_blocks):
+ writer_gn_modules.append(module)
+
+ for r, w in zip(reader_gn_modules, writer_gn_modules):
+ if len(w.mean_bank) > 0 and isinstance(w.mean_bank[0], list):
+ r.mean_bank = [[v.clone().to(dtype) for v in vl] for vl in w.mean_bank]
+ r.var_bank = [[v.clone().to(dtype) for v in vl] for vl in w.var_bank]
+ else:
+ r.mean_bank = [v.clone().to(dtype) for v in w.mean_bank]
+ r.var_bank = [v.clone().to(dtype) for v in w.var_bank]
+
+ def clear(self):
+ if self.reference_attn:
+ if self.fusion_blocks == "midup":
+ reader_attn_modules = [module for module in (torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)]
+ elif self.fusion_blocks == "full":
+ reader_attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)]
+ reader_attn_modules = sorted(reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
+ for r in reader_attn_modules:
+ r.bank.clear()
+ if self.reference_adain:
+ reader_gn_modules = [self.unet.mid_block]
+
+ down_blocks = self.unet.down_blocks
+ for w, module in enumerate(down_blocks):
+ reader_gn_modules.append(module)
+
+ up_blocks = self.unet.up_blocks
+ for w, module in enumerate(up_blocks):
+ reader_gn_modules.append(module)
+
+ for r in reader_gn_modules:
+ r.mean_bank.clear()
+ r.var_bank.clear()
+
\ No newline at end of file
diff --git a/magicanimate/models/orig_attention.py b/magicanimate/models/orig_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c3eba094cd37665839911535d5fb7bb76a0cb18
--- /dev/null
+++ b/magicanimate/models/orig_attention.py
@@ -0,0 +1,988 @@
+# *************************************************************************
+# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
+# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
+# ytedance Inc..
+# *************************************************************************
+
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+from dataclasses import dataclass
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.embeddings import ImagePositionalEmbeddings
+from diffusers.utils import BaseOutput
+from diffusers.utils.import_utils import is_xformers_available
+
+
+@dataclass
+class Transformer2DModelOutput(BaseOutput):
+ """
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
+ Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions
+ for the unnoised latent pixels.
+ """
+
+ sample: torch.FloatTensor
+
+
+if is_xformers_available():
+ import xformers
+ import xformers.ops
+else:
+ xformers = None
+
+
+class Transformer2DModel(ModelMixin, ConfigMixin):
+ """
+ Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual
+ embeddings) inputs.
+
+ When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
+ transformer action. Finally, reshape to image.
+
+ When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
+ embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
+ classes of unnoised image.
+
+ Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
+ image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
+
+ Parameters:
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
+ in_channels (`int`, *optional*):
+ Pass if the input is continuous. The number of channels in the input and output.
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
+ sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
+ Note that this is fixed at training time as it is used for learning a number of position embeddings. See
+ `ImagePositionalEmbeddings`.
+ num_vector_embeds (`int`, *optional*):
+ Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
+ Includes the class for the masked latent pixel.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
+ The number of diffusion steps used during training. Note that this is fixed at training time as it is used
+ to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
+ up to but not more than steps than `num_embeds_ada_norm`.
+ attention_bias (`bool`, *optional*):
+ Configure if the TransformerBlocks' attention should contain a bias parameter.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 16,
+ attention_head_dim: int = 88,
+ in_channels: Optional[int] = None,
+ num_layers: int = 1,
+ dropout: float = 0.0,
+ norm_num_groups: int = 32,
+ cross_attention_dim: Optional[int] = None,
+ attention_bias: bool = False,
+ sample_size: Optional[int] = None,
+ num_vector_embeds: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ ):
+ super().__init__()
+ self.use_linear_projection = use_linear_projection
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_dim = attention_head_dim
+ inner_dim = num_attention_heads * attention_head_dim
+
+ # 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
+ # Define whether input is continuous or discrete depending on configuration
+ self.is_input_continuous = in_channels is not None
+ self.is_input_vectorized = num_vector_embeds is not None
+
+ if self.is_input_continuous and self.is_input_vectorized:
+ raise ValueError(
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
+ " sure that either `in_channels` or `num_vector_embeds` is None."
+ )
+ elif not self.is_input_continuous and not self.is_input_vectorized:
+ raise ValueError(
+ f"Has to define either `in_channels`: {in_channels} or `num_vector_embeds`: {num_vector_embeds}. Make"
+ " sure that either `in_channels` or `num_vector_embeds` is not None."
+ )
+
+ # 2. Define input layers
+ if self.is_input_continuous:
+ self.in_channels = in_channels
+
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+ if use_linear_projection:
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+ else:
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
+ elif self.is_input_vectorized:
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
+
+ self.height = sample_size
+ self.width = sample_size
+ self.num_vector_embeds = num_vector_embeds
+ self.num_latent_pixels = self.height * self.width
+
+ self.latent_image_embedding = ImagePositionalEmbeddings(
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
+ )
+
+ # 3. Define transformers blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ dropout=dropout,
+ cross_attention_dim=cross_attention_dim,
+ activation_fn=activation_fn,
+ num_embeds_ada_norm=num_embeds_ada_norm,
+ attention_bias=attention_bias,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ )
+ for d in range(num_layers)
+ ]
+ )
+
+ # 4. Define output layers
+ if self.is_input_continuous:
+ if use_linear_projection:
+ self.proj_out = nn.Linear(in_channels, inner_dim)
+ else:
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+ elif self.is_input_vectorized:
+ self.norm_out = nn.LayerNorm(inner_dim)
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
+
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
+ """
+ Args:
+ hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
+ When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
+ hidden_states
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
+ self-attention.
+ timestep ( `torch.long`, *optional*):
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
+ if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
+ tensor.
+ """
+ # 1. Input
+ if self.is_input_continuous:
+ batch, channel, height, weight = hidden_states.shape
+ residual = hidden_states
+
+ hidden_states = self.norm(hidden_states)
+ if not self.use_linear_projection:
+ hidden_states = self.proj_in(hidden_states)
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
+ else:
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
+ hidden_states = self.proj_in(hidden_states)
+ elif self.is_input_vectorized:
+ hidden_states = self.latent_image_embedding(hidden_states)
+
+ # 2. Blocks
+ for block in self.transformer_blocks:
+ hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep)
+
+ # 3. Output
+ if self.is_input_continuous:
+ if not self.use_linear_projection:
+ hidden_states = (
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
+ )
+ hidden_states = self.proj_out(hidden_states)
+ else:
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = (
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
+ )
+
+ output = hidden_states + residual
+ elif self.is_input_vectorized:
+ hidden_states = self.norm_out(hidden_states)
+ logits = self.out(hidden_states)
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
+ logits = logits.permute(0, 2, 1)
+
+ # log(p(x_0))
+ output = F.log_softmax(logits.double(), dim=1).float()
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
+
+
+class AttentionBlock(nn.Module):
+ """
+ An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
+ to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ Uses three q, k, v linear layers to compute attention.
+
+ Parameters:
+ channels (`int`): The number of channels in the input and output.
+ num_head_channels (`int`, *optional*):
+ The number of channels in each head. If None, then `num_heads` = 1.
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm.
+ rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
+ eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
+ """
+
+ # IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore
+
+ def __init__(
+ self,
+ channels: int,
+ num_head_channels: Optional[int] = None,
+ norm_num_groups: int = 32,
+ rescale_output_factor: float = 1.0,
+ eps: float = 1e-5,
+ ):
+ super().__init__()
+ self.channels = channels
+
+ self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
+ self.num_head_size = num_head_channels
+ self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)
+
+ # define q,k,v as linear layers
+ self.query = nn.Linear(channels, channels)
+ self.key = nn.Linear(channels, channels)
+ self.value = nn.Linear(channels, channels)
+
+ self.rescale_output_factor = rescale_output_factor
+ self.proj_attn = nn.Linear(channels, channels, 1)
+
+ self._use_memory_efficient_attention_xformers = False
+
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, *args, **kwargs):
+ if not is_xformers_available():
+ raise ModuleNotFoundError(
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
+ " xformers",
+ name="xformers",
+ )
+ elif not torch.cuda.is_available():
+ raise ValueError(
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
+ " available for GPU "
+ )
+ else:
+ try:
+ # Make sure we can run the memory efficient attention
+ _ = xformers.ops.memory_efficient_attention(
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ )
+ except Exception as e:
+ raise e
+ self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
+
+ def reshape_heads_to_batch_dim(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.num_heads
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
+ return tensor
+
+ def reshape_batch_dim_to_heads(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.num_heads
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
+ return tensor
+
+ def forward(self, hidden_states):
+ residual = hidden_states
+ batch, channel, height, width = hidden_states.shape
+
+ # norm
+ hidden_states = self.group_norm(hidden_states)
+
+ hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
+
+ # proj to q, k, v
+ query_proj = self.query(hidden_states)
+ key_proj = self.key(hidden_states)
+ value_proj = self.value(hidden_states)
+
+ scale = 1 / math.sqrt(self.channels / self.num_heads)
+
+ query_proj = self.reshape_heads_to_batch_dim(query_proj)
+ key_proj = self.reshape_heads_to_batch_dim(key_proj)
+ value_proj = self.reshape_heads_to_batch_dim(value_proj)
+
+ if self._use_memory_efficient_attention_xformers:
+ # Memory efficient attention
+ hidden_states = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None)
+ hidden_states = hidden_states.to(query_proj.dtype)
+ else:
+ attention_scores = torch.baddbmm(
+ torch.empty(
+ query_proj.shape[0],
+ query_proj.shape[1],
+ key_proj.shape[1],
+ dtype=query_proj.dtype,
+ device=query_proj.device,
+ ),
+ query_proj,
+ key_proj.transpose(-1, -2),
+ beta=0,
+ alpha=scale,
+ )
+ attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
+ hidden_states = torch.bmm(attention_probs, value_proj)
+
+ # reshape hidden_states
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+
+ # compute next hidden_states
+ hidden_states = self.proj_attn(hidden_states)
+
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
+
+ # res connect and rescale
+ hidden_states = (hidden_states + residual) / self.rescale_output_factor
+ return hidden_states
+
+
+class BasicTransformerBlock(nn.Module):
+ r"""
+ A basic Transformer block.
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ num_embeds_ada_norm (:
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
+ attention_bias (:
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout=0.0,
+ cross_attention_dim: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ attention_bias: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ ):
+ super().__init__()
+ self.only_cross_attention = only_cross_attention
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
+
+ # 1. Self-Attn
+ self.attn1 = CrossAttention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
+ upcast_attention=upcast_attention,
+ ) # is a self-attention
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
+
+ # 2. Cross-Attn
+ if cross_attention_dim is not None:
+ self.attn2 = CrossAttention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ ) # is self-attn if encoder_hidden_states is none
+ else:
+ self.attn2 = None
+
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
+
+ if cross_attention_dim is not None:
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
+ else:
+ self.norm2 = None
+
+ # 3. Feed-forward
+ self.norm3 = nn.LayerNorm(dim)
+
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, *args, **kwargs):
+ if not is_xformers_available():
+ print("Here is how to install it")
+ raise ModuleNotFoundError(
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
+ " xformers",
+ name="xformers",
+ )
+ elif not torch.cuda.is_available():
+ raise ValueError(
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
+ " available for GPU "
+ )
+ else:
+ try:
+ # Make sure we can run the memory efficient attention
+ _ = xformers.ops.memory_efficient_attention(
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ )
+ except Exception as e:
+ raise e
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
+ if self.attn2 is not None:
+ self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
+
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None):
+ # 1. Self-Attention
+ norm_hidden_states = (
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
+ )
+
+ if self.only_cross_attention:
+ hidden_states = (
+ self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
+ )
+ else:
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
+
+ if self.attn2 is not None:
+ # 2. Cross-Attention
+ norm_hidden_states = (
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
+ )
+ hidden_states = (
+ self.attn2(
+ norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
+ )
+ + hidden_states
+ )
+
+ # 3. Feed-forward
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
+
+ return hidden_states
+
+
+class CrossAttention(nn.Module):
+ r"""
+ A cross attention layer.
+
+ Parameters:
+ query_dim (`int`): The number of channels in the query.
+ cross_attention_dim (`int`, *optional*):
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ bias (`bool`, *optional*, defaults to False):
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
+ """
+
+ def __init__(
+ self,
+ query_dim: int,
+ cross_attention_dim: Optional[int] = None,
+ heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias=False,
+ upcast_attention: bool = False,
+ upcast_softmax: bool = False,
+ added_kv_proj_dim: Optional[int] = None,
+ norm_num_groups: Optional[int] = None,
+ ):
+ super().__init__()
+ inner_dim = dim_head * heads
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
+ self.upcast_attention = upcast_attention
+ self.upcast_softmax = upcast_softmax
+
+ self.scale = dim_head**-0.5
+
+ self.heads = heads
+ # for slice_size > 0 the attention score computation
+ # is split across the batch axis to save memory
+ # You can set slice_size with `set_attention_slice`
+ self.sliceable_head_dim = heads
+ self._slice_size = None
+ self._use_memory_efficient_attention_xformers = False
+ self.added_kv_proj_dim = added_kv_proj_dim
+
+ if norm_num_groups is not None:
+ self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
+ else:
+ self.group_norm = None
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
+
+ if self.added_kv_proj_dim is not None:
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
+
+ self.to_out = nn.ModuleList([])
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
+ self.to_out.append(nn.Dropout(dropout))
+
+ def reshape_heads_to_batch_dim(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
+ return tensor
+
+ def reshape_batch_dim_to_heads(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
+ return tensor
+
+ def set_attention_slice(self, slice_size):
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
+
+ self._slice_size = slice_size
+
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ encoder_hidden_states = encoder_hidden_states
+
+ if self.group_norm is not None:
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = self.to_q(hidden_states)
+ dim = query.shape[-1]
+ query = self.reshape_heads_to_batch_dim(query)
+
+ if self.added_kv_proj_dim is not None:
+ key = self.to_k(hidden_states)
+ value = self.to_v(hidden_states)
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
+
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
+
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
+ else:
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
+ key = self.to_k(encoder_hidden_states)
+ value = self.to_v(encoder_hidden_states)
+
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+
+ if attention_mask is not None:
+ if attention_mask.shape[-1] != query.shape[1]:
+ target_length = query.shape[1]
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
+
+ # attention, what we cannot get enough of
+ if self._use_memory_efficient_attention_xformers:
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
+ hidden_states = hidden_states.to(query.dtype)
+ else:
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
+ hidden_states = self._attention(query, key, value, attention_mask)
+ else:
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
+
+ # linear proj
+ hidden_states = self.to_out[0](hidden_states)
+
+ # dropout
+ hidden_states = self.to_out[1](hidden_states)
+ return hidden_states
+
+ def _attention(self, query, key, value, attention_mask=None):
+ if self.upcast_attention:
+ query = query.float()
+ key = key.float()
+
+ attention_scores = torch.baddbmm(
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
+ query,
+ key.transpose(-1, -2),
+ beta=0,
+ alpha=self.scale,
+ )
+
+ if attention_mask is not None:
+ attention_scores = attention_scores + attention_mask
+
+ if self.upcast_softmax:
+ attention_scores = attention_scores.float()
+
+ attention_probs = attention_scores.softmax(dim=-1)
+
+ # cast back to the original dtype
+ attention_probs = attention_probs.to(value.dtype)
+
+ # compute attention output
+ hidden_states = torch.bmm(attention_probs, value)
+
+ # reshape hidden_states
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ return hidden_states
+
+ def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
+ batch_size_attention = query.shape[0]
+ hidden_states = torch.zeros(
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
+ )
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
+ for i in range(hidden_states.shape[0] // slice_size):
+ start_idx = i * slice_size
+ end_idx = (i + 1) * slice_size
+
+ query_slice = query[start_idx:end_idx]
+ key_slice = key[start_idx:end_idx]
+
+ if self.upcast_attention:
+ query_slice = query_slice.float()
+ key_slice = key_slice.float()
+
+ attn_slice = torch.baddbmm(
+ torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
+ query_slice,
+ key_slice.transpose(-1, -2),
+ beta=0,
+ alpha=self.scale,
+ )
+
+ if attention_mask is not None:
+ attn_slice = attn_slice + attention_mask[start_idx:end_idx]
+
+ if self.upcast_softmax:
+ attn_slice = attn_slice.float()
+
+ attn_slice = attn_slice.softmax(dim=-1)
+
+ # cast back to the original dtype
+ attn_slice = attn_slice.to(value.dtype)
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
+
+ hidden_states[start_idx:end_idx] = attn_slice
+
+ # reshape hidden_states
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ return hidden_states
+
+ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
+ # TODO attention_mask
+ query = query.contiguous()
+ key = key.contiguous()
+ value = value.contiguous()
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ return hidden_states
+
+
+class FeedForward(nn.Module):
+ r"""
+ A feed-forward layer.
+
+ Parameters:
+ dim (`int`): The number of channels in the input.
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ dim_out: Optional[int] = None,
+ mult: int = 4,
+ dropout: float = 0.0,
+ activation_fn: str = "geglu",
+ ):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = dim_out if dim_out is not None else dim
+
+ if activation_fn == "gelu":
+ act_fn = GELU(dim, inner_dim)
+ elif activation_fn == "geglu":
+ act_fn = GEGLU(dim, inner_dim)
+ elif activation_fn == "geglu-approximate":
+ act_fn = ApproximateGELU(dim, inner_dim)
+
+ self.net = nn.ModuleList([])
+ # project in
+ self.net.append(act_fn)
+ # project dropout
+ self.net.append(nn.Dropout(dropout))
+ # project out
+ self.net.append(nn.Linear(inner_dim, dim_out))
+
+ def forward(self, hidden_states):
+ for module in self.net:
+ hidden_states = module(hidden_states)
+ return hidden_states
+
+
+class GELU(nn.Module):
+ r"""
+ GELU activation function
+ """
+
+ def __init__(self, dim_in: int, dim_out: int):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out)
+
+ def gelu(self, gate):
+ if gate.device.type != "mps":
+ return F.gelu(gate)
+ # mps: gelu is not implemented for float16
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
+
+ def forward(self, hidden_states):
+ hidden_states = self.proj(hidden_states)
+ hidden_states = self.gelu(hidden_states)
+ return hidden_states
+
+
+# feedforward
+class GEGLU(nn.Module):
+ r"""
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
+
+ Parameters:
+ dim_in (`int`): The number of channels in the input.
+ dim_out (`int`): The number of channels in the output.
+ """
+
+ def __init__(self, dim_in: int, dim_out: int):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def gelu(self, gate):
+ if gate.device.type != "mps":
+ return F.gelu(gate)
+ # mps: gelu is not implemented for float16
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
+
+ def forward(self, hidden_states):
+ hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
+ return hidden_states * self.gelu(gate)
+
+
+class ApproximateGELU(nn.Module):
+ """
+ The approximate form of Gaussian Error Linear Unit (GELU)
+
+ For more details, see section 2: https://arxiv.org/abs/1606.08415
+ """
+
+ def __init__(self, dim_in: int, dim_out: int):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out)
+
+ def forward(self, x):
+ x = self.proj(x)
+ return x * torch.sigmoid(1.702 * x)
+
+
+class AdaLayerNorm(nn.Module):
+ """
+ Norm layer modified to incorporate timestep embeddings.
+ """
+
+ def __init__(self, embedding_dim, num_embeddings):
+ super().__init__()
+ self.emb = nn.Embedding(num_embeddings, embedding_dim)
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
+
+ def forward(self, x, timestep):
+ emb = self.linear(self.silu(self.emb(timestep)))
+ scale, shift = torch.chunk(emb, 2)
+ x = self.norm(x) * (1 + scale) + shift
+ return x
+
+
+class DualTransformer2DModel(nn.Module):
+ """
+ Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
+
+ Parameters:
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
+ in_channels (`int`, *optional*):
+ Pass if the input is continuous. The number of channels in the input and output.
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
+ dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
+ sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
+ Note that this is fixed at training time as it is used for learning a number of position embeddings. See
+ `ImagePositionalEmbeddings`.
+ num_vector_embeds (`int`, *optional*):
+ Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
+ Includes the class for the masked latent pixel.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
+ The number of diffusion steps used during training. Note that this is fixed at training time as it is used
+ to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
+ up to but not more than steps than `num_embeds_ada_norm`.
+ attention_bias (`bool`, *optional*):
+ Configure if the TransformerBlocks' attention should contain a bias parameter.
+ """
+
+ def __init__(
+ self,
+ num_attention_heads: int = 16,
+ attention_head_dim: int = 88,
+ in_channels: Optional[int] = None,
+ num_layers: int = 1,
+ dropout: float = 0.0,
+ norm_num_groups: int = 32,
+ cross_attention_dim: Optional[int] = None,
+ attention_bias: bool = False,
+ sample_size: Optional[int] = None,
+ num_vector_embeds: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ ):
+ super().__init__()
+ self.transformers = nn.ModuleList(
+ [
+ Transformer2DModel(
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ in_channels=in_channels,
+ num_layers=num_layers,
+ dropout=dropout,
+ norm_num_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ attention_bias=attention_bias,
+ sample_size=sample_size,
+ num_vector_embeds=num_vector_embeds,
+ activation_fn=activation_fn,
+ num_embeds_ada_norm=num_embeds_ada_norm,
+ )
+ for _ in range(2)
+ ]
+ )
+
+ # Variables that can be set by a pipeline:
+
+ # The ratio of transformer1 to transformer2's output states to be combined during inference
+ self.mix_ratio = 0.5
+
+ # The shape of `encoder_hidden_states` is expected to be
+ # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
+ self.condition_lengths = [77, 257]
+
+ # Which transformer to use to encode which condition.
+ # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
+ self.transformer_index_for_condition = [1, 0]
+
+ def forward(
+ self, hidden_states, encoder_hidden_states, timestep=None, attention_mask=None, return_dict: bool = True
+ ):
+ """
+ Args:
+ hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
+ When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
+ hidden_states
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
+ self-attention.
+ timestep ( `torch.long`, *optional*):
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
+ attention_mask (`torch.FloatTensor`, *optional*):
+ Optional attention mask to be applied in CrossAttention
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
+ if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
+ tensor.
+ """
+ input_states = hidden_states
+
+ encoded_states = []
+ tokens_start = 0
+ # attention_mask is not used yet
+ for i in range(2):
+ # for each of the two transformers, pass the corresponding condition tokens
+ condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
+ transformer_index = self.transformer_index_for_condition[i]
+ encoded_state = self.transformers[transformer_index](
+ input_states,
+ encoder_hidden_states=condition_state,
+ timestep=timestep,
+ return_dict=False,
+ )[0]
+ encoded_states.append(encoded_state - input_states)
+ tokens_start += self.condition_lengths[i]
+
+ output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
+ output_states = output_states + input_states
+
+ if not return_dict:
+ return (output_states,)
+
+ return Transformer2DModelOutput(sample=output_states)
\ No newline at end of file
diff --git a/magicanimate/models/resnet.py b/magicanimate/models/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1231ee0668f1989ec9075b8f3c741cdc34816f5
--- /dev/null
+++ b/magicanimate/models/resnet.py
@@ -0,0 +1,212 @@
+# *************************************************************************
+# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
+# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
+# ytedance Inc..
+# *************************************************************************
+
+# Adapted from https://github.com/guoyww/AnimateDiff
+
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+# `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from einops import rearrange
+
+
+class InflatedConv3d(nn.Conv2d):
+ def forward(self, x):
+ video_length = x.shape[2]
+
+ x = rearrange(x, "b c f h w -> (b f) c h w")
+ x = super().forward(x)
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
+
+ return x
+
+
+class Upsample3D(nn.Module):
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_conv_transpose = use_conv_transpose
+ self.name = name
+
+ conv = None
+ if use_conv_transpose:
+ raise NotImplementedError
+ elif use_conv:
+ self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
+
+ def forward(self, hidden_states, output_size=None):
+ assert hidden_states.shape[1] == self.channels
+
+ if self.use_conv_transpose:
+ raise NotImplementedError
+
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
+ dtype = hidden_states.dtype
+ if dtype == torch.bfloat16:
+ hidden_states = hidden_states.to(torch.float32)
+
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
+ if hidden_states.shape[0] >= 64:
+ hidden_states = hidden_states.contiguous()
+
+ # if `output_size` is passed we force the interpolation output
+ # size and do not make use of `scale_factor=2`
+ if output_size is None:
+ hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
+ else:
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
+
+ # If the input is bfloat16, we cast back to bfloat16
+ if dtype == torch.bfloat16:
+ hidden_states = hidden_states.to(dtype)
+
+ hidden_states = self.conv(hidden_states)
+
+ return hidden_states
+
+
+class Downsample3D(nn.Module):
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.padding = padding
+ stride = 2
+ self.name = name
+
+ if use_conv:
+ self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
+ else:
+ raise NotImplementedError
+
+ def forward(self, hidden_states):
+ assert hidden_states.shape[1] == self.channels
+ if self.use_conv and self.padding == 0:
+ raise NotImplementedError
+
+ assert hidden_states.shape[1] == self.channels
+ hidden_states = self.conv(hidden_states)
+
+ return hidden_states
+
+
+class ResnetBlock3D(nn.Module):
+ def __init__(
+ self,
+ *,
+ in_channels,
+ out_channels=None,
+ conv_shortcut=False,
+ dropout=0.0,
+ temb_channels=512,
+ groups=32,
+ groups_out=None,
+ pre_norm=True,
+ eps=1e-6,
+ non_linearity="swish",
+ time_embedding_norm="default",
+ output_scale_factor=1.0,
+ use_in_shortcut=None,
+ ):
+ super().__init__()
+ self.pre_norm = pre_norm
+ self.pre_norm = True
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+ self.time_embedding_norm = time_embedding_norm
+ self.output_scale_factor = output_scale_factor
+
+ if groups_out is None:
+ groups_out = groups
+
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
+
+ self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ if temb_channels is not None:
+ if self.time_embedding_norm == "default":
+ time_emb_proj_out_channels = out_channels
+ elif self.time_embedding_norm == "scale_shift":
+ time_emb_proj_out_channels = out_channels * 2
+ else:
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
+
+ self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
+ else:
+ self.time_emb_proj = None
+
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ if non_linearity == "swish":
+ self.nonlinearity = lambda x: F.silu(x)
+ elif non_linearity == "mish":
+ self.nonlinearity = Mish()
+ elif non_linearity == "silu":
+ self.nonlinearity = nn.SiLU()
+
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
+
+ self.conv_shortcut = None
+ if self.use_in_shortcut:
+ self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, input_tensor, temb):
+ hidden_states = input_tensor
+
+ hidden_states = self.norm1(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+
+ hidden_states = self.conv1(hidden_states)
+
+ if temb is not None:
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
+
+ if temb is not None and self.time_embedding_norm == "default":
+ hidden_states = hidden_states + temb
+
+ hidden_states = self.norm2(hidden_states)
+
+ if temb is not None and self.time_embedding_norm == "scale_shift":
+ scale, shift = torch.chunk(temb, 2, dim=1)
+ hidden_states = hidden_states * (1 + scale) + shift
+
+ hidden_states = self.nonlinearity(hidden_states)
+
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.conv_shortcut is not None:
+ input_tensor = self.conv_shortcut(input_tensor)
+
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
+
+ return output_tensor
+
+
+class Mish(torch.nn.Module):
+ def forward(self, hidden_states):
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
\ No newline at end of file
diff --git a/magicanimate/models/stable_diffusion_controlnet_reference.py b/magicanimate/models/stable_diffusion_controlnet_reference.py
new file mode 100644
index 0000000000000000000000000000000000000000..436e74cce0b7705f72db463376a384c895bcc0ac
--- /dev/null
+++ b/magicanimate/models/stable_diffusion_controlnet_reference.py
@@ -0,0 +1,840 @@
+# *************************************************************************
+# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
+# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
+# ytedance Inc..
+# *************************************************************************
+
+# Inspired by: https://github.com/Mikubill/sd-webui-controlnet/discussions/1236 and https://github.com/Mikubill/sd-webui-controlnet/discussions/1280
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+
+from diffusers import StableDiffusionControlNetPipeline
+from diffusers.models import ControlNetModel
+from diffusers.models.attention import BasicTransformerBlock
+from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
+from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.utils import logging
+from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import cv2
+ >>> import torch
+ >>> import numpy as np
+ >>> from PIL import Image
+ >>> from diffusers import UniPCMultistepScheduler
+ >>> from diffusers.utils import load_image
+
+ >>> input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")
+
+ >>> # get canny image
+ >>> image = cv2.Canny(np.array(input_image), 100, 200)
+ >>> image = image[:, :, None]
+ >>> image = np.concatenate([image, image, image], axis=2)
+ >>> canny_image = Image.fromarray(image)
+
+ >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
+ >>> pipe = StableDiffusionControlNetReferencePipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+ controlnet=controlnet,
+ safety_checker=None,
+ torch_dtype=torch.float16
+ ).to('cuda:0')
+
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe_controlnet.scheduler.config)
+
+ >>> result_img = pipe(ref_image=input_image,
+ prompt="1girl",
+ image=canny_image,
+ num_inference_steps=20,
+ reference_attn=True,
+ reference_adain=True).images[0]
+
+ >>> result_img.show()
+ ```
+"""
+
+
+def torch_dfs(model: torch.nn.Module):
+ result = [model]
+ for child in model.children():
+ result += torch_dfs(child)
+ return result
+
+
+class StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeline):
+ def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance):
+ refimage = refimage.to(device=device, dtype=dtype)
+
+ # encode the mask image into latents space so we can concatenate it to the latents
+ if isinstance(generator, list):
+ ref_image_latents = [
+ self.vae.encode(refimage[i : i + 1]).latent_dist.sample(generator=generator[i])
+ for i in range(batch_size)
+ ]
+ ref_image_latents = torch.cat(ref_image_latents, dim=0)
+ else:
+ ref_image_latents = self.vae.encode(refimage).latent_dist.sample(generator=generator)
+ ref_image_latents = self.vae.config.scaling_factor * ref_image_latents
+
+ # duplicate mask and ref_image_latents for each generation per prompt, using mps friendly method
+ if ref_image_latents.shape[0] < batch_size:
+ if not batch_size % ref_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {ref_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ ref_image_latents = ref_image_latents.repeat(batch_size // ref_image_latents.shape[0], 1, 1, 1)
+
+ ref_image_latents = torch.cat([ref_image_latents] * 2) if do_classifier_free_guidance else ref_image_latents
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
+ return ref_image_latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Union[
+ torch.FloatTensor,
+ PIL.Image.Image,
+ np.ndarray,
+ List[torch.FloatTensor],
+ List[PIL.Image.Image],
+ List[np.ndarray],
+ ] = None,
+ ref_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
+ guess_mode: bool = False,
+ attention_auto_machine_weight: float = 1.0,
+ gn_auto_machine_weight: float = 1.0,
+ style_fidelity: float = 0.5,
+ reference_attn: bool = True,
+ reference_adain: bool = True,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
+ `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
+ The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
+ the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
+ also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
+ height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
+ specified in init, images must be passed as a list such that each element of the list can be correctly
+ batched for input to a single controlnet.
+ ref_image (`torch.FloatTensor`, `PIL.Image.Image`):
+ The Reference Control input condition. Reference Control uses this input condition to generate guidance to Unet. If
+ the type is specified as `Torch.FloatTensor`, it is passed to Reference Control as is. `PIL.Image.Image` can
+ also be accepted as an image.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
+ to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
+ corresponding scale as a list.
+ guess_mode (`bool`, *optional*, defaults to `False`):
+ In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
+ you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
+ attention_auto_machine_weight (`float`):
+ Weight of using reference query for self attention's context.
+ If attention_auto_machine_weight=1.0, use reference query for all self attention's context.
+ gn_auto_machine_weight (`float`):
+ Weight of using reference adain. If gn_auto_machine_weight=2.0, use all reference adain plugins.
+ style_fidelity (`float`):
+ style fidelity of ref_uncond_xt. If style_fidelity=1.0, control more important,
+ elif style_fidelity=0.0, prompt more important, else balanced.
+ reference_attn (`bool`):
+ Whether to use reference query for self attention's context.
+ reference_adain (`bool`):
+ Whether to use reference adain.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ assert reference_attn or reference_adain, "`reference_attn` or `reference_adain` must be True."
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ image,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ controlnet_conditioning_scale,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
+
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
+
+ global_pool_conditions = (
+ controlnet.config.global_pool_conditions
+ if isinstance(controlnet, ControlNetModel)
+ else controlnet.nets[0].config.global_pool_conditions
+ )
+ guess_mode = guess_mode or global_pool_conditions
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. Prepare image
+ if isinstance(controlnet, ControlNetModel):
+ image = self.prepare_image(
+ image=image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+ height, width = image.shape[-2:]
+ elif isinstance(controlnet, MultiControlNetModel):
+ images = []
+
+ for image_ in image:
+ image_ = self.prepare_image(
+ image=image_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+
+ images.append(image_)
+
+ image = images
+ height, width = image[0].shape[-2:]
+ else:
+ assert False
+
+ # 5. Preprocess reference image
+ ref_image = self.prepare_image(
+ image=ref_image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=prompt_embeds.dtype,
+ )
+
+ # 6. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 7. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 8. Prepare reference latent variables
+ ref_image_latents = self.prepare_ref_latents(
+ ref_image,
+ batch_size * num_images_per_prompt,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ do_classifier_free_guidance,
+ )
+
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 10. Modify self attention and group norm
+ MODE = "write"
+ uc_mask = (
+ torch.Tensor([1] * batch_size * num_images_per_prompt + [0] * batch_size * num_images_per_prompt)
+ .type_as(ref_image_latents)
+ .bool()
+ )
+
+ def hacked_basic_transformer_inner_forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ ):
+ if self.use_ada_layer_norm:
+ norm_hidden_states = self.norm1(hidden_states, timestep)
+ elif self.use_ada_layer_norm_zero:
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ else:
+ norm_hidden_states = self.norm1(hidden_states)
+
+ # 1. Self-Attention
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+ if self.only_cross_attention:
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ else:
+ if MODE == "write":
+ self.bank.append(norm_hidden_states.detach().clone())
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ if MODE == "read":
+ if attention_auto_machine_weight > self.attn_weight:
+ attn_output_uc = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=torch.cat([norm_hidden_states] + self.bank, dim=1),
+ # attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ attn_output_c = attn_output_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ attn_output_c[uc_mask] = self.attn1(
+ norm_hidden_states[uc_mask],
+ encoder_hidden_states=norm_hidden_states[uc_mask],
+ **cross_attention_kwargs,
+ )
+ attn_output = style_fidelity * attn_output_c + (1.0 - style_fidelity) * attn_output_uc
+ self.bank.clear()
+ else:
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ if self.use_ada_layer_norm_zero:
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = attn_output + hidden_states
+
+ if self.attn2 is not None:
+ norm_hidden_states = (
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
+ )
+
+ # 2. Cross-Attention
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ **cross_attention_kwargs,
+ )
+ hidden_states = attn_output + hidden_states
+
+ # 3. Feed-forward
+ norm_hidden_states = self.norm3(hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ ff_output = self.ff(norm_hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+
+ hidden_states = ff_output + hidden_states
+
+ return hidden_states
+
+ def hacked_mid_forward(self, *args, **kwargs):
+ eps = 1e-6
+ x = self.original_forward(*args, **kwargs)
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append(mean)
+ self.var_bank.append(var)
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank) / float(len(self.mean_bank))
+ var_acc = sum(self.var_bank) / float(len(self.var_bank))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ x_uc = (((x - mean) / std) * std_acc) + mean_acc
+ x_c = x_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ x_c[uc_mask] = x[uc_mask]
+ x = style_fidelity * x_c + (1.0 - style_fidelity) * x_uc
+ self.mean_bank = []
+ self.var_bank = []
+ return x
+
+ def hack_CrossAttnDownBlock2D_forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ):
+ eps = 1e-6
+
+ # TODO(Patrick, William) - attention mask is not used
+ output_states = ()
+
+ for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ output_states = output_states + (hidden_states,)
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+ def hacked_DownBlock2D_forward(self, hidden_states, temb=None):
+ eps = 1e-6
+
+ output_states = ()
+
+ for i, resnet in enumerate(self.resnets):
+ hidden_states = resnet(hidden_states, temb)
+
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ output_states = output_states + (hidden_states,)
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+ def hacked_CrossAttnUpBlock2D_forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ upsample_size: Optional[int] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ):
+ eps = 1e-6
+ # TODO(Patrick, William) - attention mask is not used
+ for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+ def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
+ eps = 1e-6
+ for i, resnet in enumerate(self.resnets):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+ hidden_states = resnet(hidden_states, temb)
+
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+ if reference_attn:
+ attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock)]
+ attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
+
+ for i, module in enumerate(attn_modules):
+ module._original_inner_forward = module.forward
+ module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)
+ module.bank = []
+ module.attn_weight = float(i) / float(len(attn_modules))
+
+ if reference_adain:
+ gn_modules = [self.unet.mid_block]
+ self.unet.mid_block.gn_weight = 0
+
+ down_blocks = self.unet.down_blocks
+ for w, module in enumerate(down_blocks):
+ module.gn_weight = 1.0 - float(w) / float(len(down_blocks))
+ gn_modules.append(module)
+
+ up_blocks = self.unet.up_blocks
+ for w, module in enumerate(up_blocks):
+ module.gn_weight = float(w) / float(len(up_blocks))
+ gn_modules.append(module)
+
+ for i, module in enumerate(gn_modules):
+ if getattr(module, "original_forward", None) is None:
+ module.original_forward = module.forward
+ if i == 0:
+ # mid_block
+ module.forward = hacked_mid_forward.__get__(module, torch.nn.Module)
+ elif isinstance(module, CrossAttnDownBlock2D):
+ module.forward = hack_CrossAttnDownBlock2D_forward.__get__(module, CrossAttnDownBlock2D)
+ elif isinstance(module, DownBlock2D):
+ module.forward = hacked_DownBlock2D_forward.__get__(module, DownBlock2D)
+ elif isinstance(module, CrossAttnUpBlock2D):
+ module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D)
+ elif isinstance(module, UpBlock2D):
+ module.forward = hacked_UpBlock2D_forward.__get__(module, UpBlock2D)
+ module.mean_bank = []
+ module.var_bank = []
+ module.gn_weight *= 2
+
+ # 11. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # controlnet(s) inference
+ if guess_mode and do_classifier_free_guidance:
+ # Infer ControlNet only for the conditional batch.
+ control_model_input = latents
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
+ else:
+ control_model_input = latent_model_input
+ controlnet_prompt_embeds = prompt_embeds
+
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
+ control_model_input,
+ t,
+ encoder_hidden_states=controlnet_prompt_embeds,
+ controlnet_cond=image,
+ conditioning_scale=controlnet_conditioning_scale,
+ guess_mode=guess_mode,
+ return_dict=False,
+ )
+
+ if guess_mode and do_classifier_free_guidance:
+ # Infered ControlNet only for the conditional batch.
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
+ # add 0 to the unconditional batch to keep it unchanged.
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
+
+ # ref only part
+ noise = randn_tensor(
+ ref_image_latents.shape, generator=generator, device=device, dtype=ref_image_latents.dtype
+ )
+ ref_xt = self.scheduler.add_noise(
+ ref_image_latents,
+ noise,
+ t.reshape(
+ 1,
+ ),
+ )
+ ref_xt = self.scheduler.scale_model_input(ref_xt, t)
+
+ MODE = "write"
+ self.unet(
+ ref_xt,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )
+
+ # predict the noise residual
+ MODE = "read"
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # If we do sequential model offloading, let's offload unet and controlnet
+ # manually for max memory savings
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.unet.to("cpu")
+ self.controlnet.to("cpu")
+ torch.cuda.empty_cache()
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/magicanimate/models/unet.py b/magicanimate/models/unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..09e5e11fdb9c985fd8627ce90d56791d0d8f469c
--- /dev/null
+++ b/magicanimate/models/unet.py
@@ -0,0 +1,508 @@
+# *************************************************************************
+# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
+# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
+# ytedance Inc..
+# *************************************************************************
+
+# Adapted from https://github.com/guoyww/AnimateDiff
+
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import os
+import json
+import pdb
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.utils import BaseOutput, logging
+from diffusers.models.embeddings import TimestepEmbedding, Timesteps
+from .unet_3d_blocks import (
+ CrossAttnDownBlock3D,
+ CrossAttnUpBlock3D,
+ DownBlock3D,
+ UNetMidBlock3DCrossAttn,
+ UpBlock3D,
+ get_down_block,
+ get_up_block,
+)
+from .resnet import InflatedConv3d
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class UNet3DConditionOutput(BaseOutput):
+ sample: torch.FloatTensor
+
+
+class UNet3DConditionModel(ModelMixin, ConfigMixin):
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 4,
+ out_channels: int = 4,
+ center_input_sample: bool = False,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "DownBlock3D",
+ ),
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
+ up_block_types: Tuple[str] = (
+ "UpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D"
+ ),
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ layers_per_block: int = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: int = 1280,
+ attention_head_dim: Union[int, Tuple[int]] = 8,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ class_embed_type: Optional[str] = None,
+ num_class_embeds: Optional[int] = None,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+
+ # Additional
+ use_motion_module = False,
+ motion_module_resolutions = ( 1,2,4,8 ),
+ motion_module_mid_block = False,
+ motion_module_decoder_only = False,
+ motion_module_type = None,
+ motion_module_kwargs = {},
+ unet_use_cross_frame_attention = None,
+ unet_use_temporal_attention = None,
+ ):
+ super().__init__()
+
+ self.sample_size = sample_size
+ time_embed_dim = block_out_channels[0] * 4
+
+ # input
+ self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
+
+ # time
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+
+ # class embedding
+ if class_embed_type is None and num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+ elif class_embed_type == "timestep":
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+ elif class_embed_type == "identity":
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
+ else:
+ self.class_embedding = None
+
+ self.down_blocks = nn.ModuleList([])
+ self.mid_block = None
+ self.up_blocks = nn.ModuleList([])
+
+ if isinstance(only_cross_attention, bool):
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
+
+ if isinstance(attention_head_dim, int):
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ res = 2 ** i
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attention_head_dim[i],
+ downsample_padding=downsample_padding,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
+ unet_use_temporal_attention=unet_use_temporal_attention,
+
+ use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only),
+ motion_module_type=motion_module_type,
+ motion_module_kwargs=motion_module_kwargs,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
+ self.mid_block = UNetMidBlock3DCrossAttn(
+ in_channels=block_out_channels[-1],
+ temb_channels=time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attention_head_dim[-1],
+ resnet_groups=norm_num_groups,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
+ unet_use_temporal_attention=unet_use_temporal_attention,
+
+ use_motion_module=use_motion_module and motion_module_mid_block,
+ motion_module_type=motion_module_type,
+ motion_module_kwargs=motion_module_kwargs,
+ )
+ else:
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
+
+ # count how many layers upsample the videos
+ self.num_upsamplers = 0
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
+ only_cross_attention = list(reversed(only_cross_attention))
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ res = 2 ** (3 - i)
+ is_final_block = i == len(block_out_channels) - 1
+
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ # add upsample block for all BUT final layer
+ if not is_final_block:
+ add_upsample = True
+ self.num_upsamplers += 1
+ else:
+ add_upsample = False
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=layers_per_block + 1,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=time_embed_dim,
+ add_upsample=add_upsample,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=reversed_attention_head_dim[i],
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
+ unet_use_temporal_attention=unet_use_temporal_attention,
+
+ use_motion_module=use_motion_module and (res in motion_module_resolutions),
+ motion_module_type=motion_module_type,
+ motion_module_kwargs=motion_module_kwargs,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
+ self.conv_act = nn.SiLU()
+ self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
+
+ def set_attention_slice(self, slice_size):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+ must be a multiple of `slice_size`.
+ """
+ sliceable_head_dims = []
+
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
+ if hasattr(module, "set_attention_slice"):
+ sliceable_head_dims.append(module.sliceable_head_dim)
+
+ for child in module.children():
+ fn_recursive_retrieve_slicable_dims(child)
+
+ # retrieve number of attention layers
+ for module in self.children():
+ fn_recursive_retrieve_slicable_dims(module)
+
+ num_slicable_layers = len(sliceable_head_dims)
+
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
+ elif slice_size == "max":
+ # make smallest slice possible
+ slice_size = num_slicable_layers * [1]
+
+ slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
+
+ if len(slice_size) != len(sliceable_head_dims):
+ raise ValueError(
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
+ )
+
+ for i in range(len(slice_size)):
+ size = slice_size[i]
+ dim = sliceable_head_dims[i]
+ if size is not None and size > dim:
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
+
+ # Recursively walk through all the children.
+ # Any children which exposes the set_attention_slice method
+ # gets the message
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
+ if hasattr(module, "set_attention_slice"):
+ module.set_attention_slice(slice_size.pop())
+
+ for child in module.children():
+ fn_recursive_set_attention_slice(child, slice_size)
+
+ reversed_slice_size = list(reversed(slice_size))
+ for module in self.children():
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
+ module.gradient_checkpointing = value
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ class_labels: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ) -> Union[UNet3DConditionOutput, Tuple]:
+ r"""
+ Args:
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+ """
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
+ # on the fly if necessary.
+ default_overall_up_factor = 2**self.num_upsamplers
+
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
+ forward_upsample_size = False
+ upsample_size = None
+
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
+ logger.info("Forward upsample size to force interpolation output size.")
+ forward_upsample_size = True
+
+ # prepare attention_mask
+ if attention_mask is not None:
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # center input if necessary
+ if self.config.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=self.dtype)
+ emb = self.time_embedding(t_emb)
+
+ if self.class_embedding is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+
+ if self.config.class_embed_type == "timestep":
+ class_labels = self.time_proj(class_labels)
+
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
+ emb = emb + class_emb
+
+ # pre-process
+ sample = self.conv_in(sample)
+
+ # down
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states)
+
+ down_block_res_samples += res_samples
+
+ # mid
+ sample = self.mid_block(
+ sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
+ )
+
+ # up
+ for i, upsample_block in enumerate(self.up_blocks):
+ is_final_block = i == len(self.up_blocks) - 1
+
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ upsample_size=upsample_size,
+ attention_mask=attention_mask,
+ )
+ else:
+ sample = upsample_block(
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states,
+ )
+
+ # post-process
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ if not return_dict:
+ return (sample,)
+
+ return UNet3DConditionOutput(sample=sample)
+
+ @classmethod
+ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None):
+ if subfolder is not None:
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
+ print(f"loaded temporal unet's pretrained weights from {pretrained_model_path} ...")
+
+ config_file = os.path.join(pretrained_model_path, 'config.json')
+ if not os.path.isfile(config_file):
+ raise RuntimeError(f"{config_file} does not exist")
+ with open(config_file, "r") as f:
+ config = json.load(f)
+ config["_class_name"] = cls.__name__
+ config["down_block_types"] = [
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "DownBlock3D"
+ ]
+ config["up_block_types"] = [
+ "UpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D"
+ ]
+
+ from diffusers.utils import WEIGHTS_NAME
+ model = cls.from_config(config, **unet_additional_kwargs)
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
+ if not os.path.isfile(model_file):
+ raise RuntimeError(f"{model_file} does not exist")
+ state_dict = torch.load(model_file, map_location="cpu")
+
+ m, u = model.load_state_dict(state_dict, strict=False)
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
+ # print(f"### missing keys:\n{m}\n### unexpected keys:\n{u}\n")
+
+ params = [p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()]
+ print(f"### Temporal Module Parameters: {sum(params) / 1e6} M")
+
+ return model
diff --git a/magicanimate/models/unet_3d_blocks.py b/magicanimate/models/unet_3d_blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..55cf1559598a34dcfcc48756a71549f5b0fd9d7b
--- /dev/null
+++ b/magicanimate/models/unet_3d_blocks.py
@@ -0,0 +1,751 @@
+# *************************************************************************
+# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
+# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
+# ytedance Inc..
+# *************************************************************************
+
+# Adapted from https://github.com/guoyww/AnimateDiff
+
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+from torch import nn
+
+from .attention import Transformer3DModel
+from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
+from .motion_module import get_motion_module
+
+
+def get_down_block(
+ down_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ temb_channels,
+ add_downsample,
+ resnet_eps,
+ resnet_act_fn,
+ attn_num_head_channels,
+ resnet_groups=None,
+ cross_attention_dim=None,
+ downsample_padding=None,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ resnet_time_scale_shift="default",
+
+ unet_use_cross_frame_attention=None,
+ unet_use_temporal_attention=None,
+
+ use_motion_module=None,
+
+ motion_module_type=None,
+ motion_module_kwargs=None,
+):
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
+ if down_block_type == "DownBlock3D":
+ return DownBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+
+ use_motion_module=use_motion_module,
+ motion_module_type=motion_module_type,
+ motion_module_kwargs=motion_module_kwargs,
+ )
+ elif down_block_type == "CrossAttnDownBlock3D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
+ return CrossAttnDownBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attn_num_head_channels,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
+ unet_use_temporal_attention=unet_use_temporal_attention,
+
+ use_motion_module=use_motion_module,
+ motion_module_type=motion_module_type,
+ motion_module_kwargs=motion_module_kwargs,
+ )
+ raise ValueError(f"{down_block_type} does not exist.")
+
+
+def get_up_block(
+ up_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ prev_output_channel,
+ temb_channels,
+ add_upsample,
+ resnet_eps,
+ resnet_act_fn,
+ attn_num_head_channels,
+ resnet_groups=None,
+ cross_attention_dim=None,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ resnet_time_scale_shift="default",
+
+ unet_use_cross_frame_attention=None,
+ unet_use_temporal_attention=None,
+
+ use_motion_module=None,
+ motion_module_type=None,
+ motion_module_kwargs=None,
+):
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
+ if up_block_type == "UpBlock3D":
+ return UpBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+
+ use_motion_module=use_motion_module,
+ motion_module_type=motion_module_type,
+ motion_module_kwargs=motion_module_kwargs,
+ )
+ elif up_block_type == "CrossAttnUpBlock3D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
+ return CrossAttnUpBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attn_num_head_channels,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
+ unet_use_temporal_attention=unet_use_temporal_attention,
+
+ use_motion_module=use_motion_module,
+ motion_module_type=motion_module_type,
+ motion_module_kwargs=motion_module_kwargs,
+ )
+ raise ValueError(f"{up_block_type} does not exist.")
+
+
+class UNetMidBlock3DCrossAttn(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ output_scale_factor=1.0,
+ cross_attention_dim=1280,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ upcast_attention=False,
+
+ unet_use_cross_frame_attention=None,
+ unet_use_temporal_attention=None,
+
+ use_motion_module=None,
+
+ motion_module_type=None,
+ motion_module_kwargs=None,
+ ):
+ super().__init__()
+
+ self.has_cross_attention = True
+ self.attn_num_head_channels = attn_num_head_channels
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ ]
+ attentions = []
+ motion_modules = []
+
+ for _ in range(num_layers):
+ if dual_cross_attention:
+ raise NotImplementedError
+ attentions.append(
+ Transformer3DModel(
+ attn_num_head_channels,
+ in_channels // attn_num_head_channels,
+ in_channels=in_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
+ unet_use_temporal_attention=unet_use_temporal_attention,
+ )
+ )
+ motion_modules.append(
+ get_motion_module(
+ in_channels=in_channels,
+ motion_module_type=motion_module_type,
+ motion_module_kwargs=motion_module_kwargs,
+ ) if use_motion_module else None
+ )
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+ self.motion_modules = nn.ModuleList(motion_modules)
+
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for attn, resnet, motion_module in zip(self.attentions, self.resnets[1:], self.motion_modules):
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
+ hidden_states = resnet(hidden_states, temb)
+
+ return hidden_states
+
+
+class CrossAttnDownBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ cross_attention_dim=1280,
+ output_scale_factor=1.0,
+ downsample_padding=1,
+ add_downsample=True,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+
+ unet_use_cross_frame_attention=None,
+ unet_use_temporal_attention=None,
+
+ use_motion_module=None,
+
+ motion_module_type=None,
+ motion_module_kwargs=None,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+ motion_modules = []
+
+ self.has_cross_attention = True
+ self.attn_num_head_channels = attn_num_head_channels
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ if dual_cross_attention:
+ raise NotImplementedError
+ attentions.append(
+ Transformer3DModel(
+ attn_num_head_channels,
+ out_channels // attn_num_head_channels,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
+ unet_use_temporal_attention=unet_use_temporal_attention,
+ )
+ )
+ motion_modules.append(
+ get_motion_module(
+ in_channels=out_channels,
+ motion_module_type=motion_module_type,
+ motion_module_kwargs=motion_module_kwargs,
+ ) if use_motion_module else None
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+ self.motion_modules = nn.ModuleList(motion_modules)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample3D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
+ output_states = ()
+
+ for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(attn, return_dict=False),
+ hidden_states,
+ encoder_hidden_states,
+ )[0]
+ if motion_module is not None:
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
+
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
+
+ # add motion module
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
+
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class DownBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ downsample_padding=1,
+
+ use_motion_module=None,
+ motion_module_type=None,
+ motion_module_kwargs=None,
+ ):
+ super().__init__()
+ resnets = []
+ motion_modules = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ motion_modules.append(
+ get_motion_module(
+ in_channels=out_channels,
+ motion_module_type=motion_module_type,
+ motion_module_kwargs=motion_module_kwargs,
+ ) if use_motion_module else None
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+ self.motion_modules = nn.ModuleList(motion_modules)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample3D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
+ output_states = ()
+
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
+ if self.training and self.gradient_checkpointing:
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ if motion_module is not None:
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ # add motion module
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
+
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class CrossAttnUpBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ cross_attention_dim=1280,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+
+ unet_use_cross_frame_attention=None,
+ unet_use_temporal_attention=None,
+
+ use_motion_module=None,
+
+ motion_module_type=None,
+ motion_module_kwargs=None,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+ motion_modules = []
+
+ self.has_cross_attention = True
+ self.attn_num_head_channels = attn_num_head_channels
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ if dual_cross_attention:
+ raise NotImplementedError
+ attentions.append(
+ Transformer3DModel(
+ attn_num_head_channels,
+ out_channels // attn_num_head_channels,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
+ unet_use_temporal_attention=unet_use_temporal_attention,
+ )
+ )
+ motion_modules.append(
+ get_motion_module(
+ in_channels=out_channels,
+ motion_module_type=motion_module_type,
+ motion_module_kwargs=motion_module_kwargs,
+ ) if use_motion_module else None
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+ self.motion_modules = nn.ModuleList(motion_modules)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ res_hidden_states_tuple,
+ temb=None,
+ encoder_hidden_states=None,
+ upsample_size=None,
+ attention_mask=None,
+ ):
+ for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(attn, return_dict=False),
+ hidden_states,
+ encoder_hidden_states,
+ )[0]
+ if motion_module is not None:
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
+
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
+
+ # add motion module
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+
+class UpBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_upsample=True,
+
+ use_motion_module=None,
+ motion_module_type=None,
+ motion_module_kwargs=None,
+ ):
+ super().__init__()
+ resnets = []
+ motion_modules = []
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ motion_modules.append(
+ get_motion_module(
+ in_channels=out_channels,
+ motion_module_type=motion_module_type,
+ motion_module_kwargs=motion_module_kwargs,
+ ) if use_motion_module else None
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+ self.motion_modules = nn.ModuleList(motion_modules)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, encoder_hidden_states=None,):
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ if motion_module is not None:
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
\ No newline at end of file
diff --git a/magicanimate/models/unet_controlnet.py b/magicanimate/models/unet_controlnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ccd9cad7d1f9e02aed3592c07a064d2a04dbd69
--- /dev/null
+++ b/magicanimate/models/unet_controlnet.py
@@ -0,0 +1,525 @@
+# *************************************************************************
+# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
+# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
+# ytedance Inc..
+# *************************************************************************
+
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import os
+import json
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.utils import BaseOutput, logging
+from diffusers.models.embeddings import TimestepEmbedding, Timesteps
+from magicanimate.models.unet_3d_blocks import (
+ CrossAttnDownBlock3D,
+ CrossAttnUpBlock3D,
+ DownBlock3D,
+ UNetMidBlock3DCrossAttn,
+ UpBlock3D,
+ get_down_block,
+ get_up_block,
+)
+from .resnet import InflatedConv3d
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class UNet3DConditionOutput(BaseOutput):
+ sample: torch.FloatTensor
+
+
+class UNet3DConditionModel(ModelMixin, ConfigMixin):
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 4,
+ out_channels: int = 4,
+ center_input_sample: bool = False,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "DownBlock3D",
+ ),
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
+ up_block_types: Tuple[str] = (
+ "UpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D"
+ ),
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ layers_per_block: int = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: int = 1280,
+ attention_head_dim: Union[int, Tuple[int]] = 8,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ class_embed_type: Optional[str] = None,
+ num_class_embeds: Optional[int] = None,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+
+ # Additional
+ use_motion_module = False,
+ motion_module_resolutions = ( 1,2,4,8 ),
+ motion_module_mid_block = False,
+ motion_module_decoder_only = False,
+ motion_module_type = None,
+ motion_module_kwargs = {},
+ unet_use_cross_frame_attention = None,
+ unet_use_temporal_attention = None,
+ ):
+ super().__init__()
+
+ self.sample_size = sample_size
+ time_embed_dim = block_out_channels[0] * 4
+
+ # input
+ self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
+
+ # time
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+
+ # class embedding
+ if class_embed_type is None and num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+ elif class_embed_type == "timestep":
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+ elif class_embed_type == "identity":
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
+ else:
+ self.class_embedding = None
+
+ self.down_blocks = nn.ModuleList([])
+ self.mid_block = None
+ self.up_blocks = nn.ModuleList([])
+
+ if isinstance(only_cross_attention, bool):
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
+
+ if isinstance(attention_head_dim, int):
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ res = 2 ** i
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attention_head_dim[i],
+ downsample_padding=downsample_padding,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
+ unet_use_temporal_attention=unet_use_temporal_attention,
+
+ use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only),
+ motion_module_type=motion_module_type,
+ motion_module_kwargs=motion_module_kwargs,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
+ self.mid_block = UNetMidBlock3DCrossAttn(
+ in_channels=block_out_channels[-1],
+ temb_channels=time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attention_head_dim[-1],
+ resnet_groups=norm_num_groups,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
+ unet_use_temporal_attention=unet_use_temporal_attention,
+
+ use_motion_module=use_motion_module and motion_module_mid_block,
+ motion_module_type=motion_module_type,
+ motion_module_kwargs=motion_module_kwargs,
+ )
+ else:
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
+
+ # count how many layers upsample the videos
+ self.num_upsamplers = 0
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
+ only_cross_attention = list(reversed(only_cross_attention))
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ res = 2 ** (3 - i)
+ is_final_block = i == len(block_out_channels) - 1
+
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ # add upsample block for all BUT final layer
+ if not is_final_block:
+ add_upsample = True
+ self.num_upsamplers += 1
+ else:
+ add_upsample = False
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=layers_per_block + 1,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=time_embed_dim,
+ add_upsample=add_upsample,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=reversed_attention_head_dim[i],
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
+ unet_use_temporal_attention=unet_use_temporal_attention,
+
+ use_motion_module=use_motion_module and (res in motion_module_resolutions),
+ motion_module_type=motion_module_type,
+ motion_module_kwargs=motion_module_kwargs,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
+ self.conv_act = nn.SiLU()
+ self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
+
+ def set_attention_slice(self, slice_size):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+ must be a multiple of `slice_size`.
+ """
+ sliceable_head_dims = []
+
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
+ if hasattr(module, "set_attention_slice"):
+ sliceable_head_dims.append(module.sliceable_head_dim)
+
+ for child in module.children():
+ fn_recursive_retrieve_slicable_dims(child)
+
+ # retrieve number of attention layers
+ for module in self.children():
+ fn_recursive_retrieve_slicable_dims(module)
+
+ num_slicable_layers = len(sliceable_head_dims)
+
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
+ elif slice_size == "max":
+ # make smallest slice possible
+ slice_size = num_slicable_layers * [1]
+
+ slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
+
+ if len(slice_size) != len(sliceable_head_dims):
+ raise ValueError(
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
+ )
+
+ for i in range(len(slice_size)):
+ size = slice_size[i]
+ dim = sliceable_head_dims[i]
+ if size is not None and size > dim:
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
+
+ # Recursively walk through all the children.
+ # Any children which exposes the set_attention_slice method
+ # gets the message
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
+ if hasattr(module, "set_attention_slice"):
+ module.set_attention_slice(slice_size.pop())
+
+ for child in module.children():
+ fn_recursive_set_attention_slice(child, slice_size)
+
+ reversed_slice_size = list(reversed(slice_size))
+ for module in self.children():
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
+ module.gradient_checkpointing = value
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ class_labels: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ # for controlnet
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ) -> Union[UNet3DConditionOutput, Tuple]:
+ r"""
+ Args:
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+ """
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
+ # on the fly if necessary.
+ default_overall_up_factor = 2**self.num_upsamplers
+
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
+ forward_upsample_size = False
+ upsample_size = None
+
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
+ logger.info("Forward upsample size to force interpolation output size.")
+ forward_upsample_size = True
+
+ # prepare attention_mask
+ if attention_mask is not None:
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # center input if necessary
+ if self.config.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=self.dtype)
+ emb = self.time_embedding(t_emb)
+
+ if self.class_embedding is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+
+ if self.config.class_embed_type == "timestep":
+ class_labels = self.time_proj(class_labels)
+
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
+ emb = emb + class_emb
+
+ # pre-process
+ sample = self.conv_in(sample)
+
+ # down
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
+
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states)
+
+ down_block_res_samples += res_samples
+
+ if is_controlnet:
+ new_down_block_res_samples = ()
+
+ for down_block_res_sample, down_block_additional_residual in zip(
+ down_block_res_samples, down_block_additional_residuals
+ ):
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
+
+ down_block_res_samples = new_down_block_res_samples
+
+ # mid
+ sample = self.mid_block(
+ sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
+ )
+
+ if is_controlnet:
+ sample = sample + mid_block_additional_residual
+
+ # up
+ for i, upsample_block in enumerate(self.up_blocks):
+ is_final_block = i == len(self.up_blocks) - 1
+
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ upsample_size=upsample_size,
+ attention_mask=attention_mask,
+ )
+ else:
+ sample = upsample_block(
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states,
+ )
+
+ # post-process
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ if not return_dict:
+ return (sample,)
+
+ return UNet3DConditionOutput(sample=sample)
+
+ @classmethod
+ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None):
+ if subfolder is not None:
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
+ print(f"loaded temporal unet's pretrained weights from {pretrained_model_path} ...")
+
+ config_file = os.path.join(pretrained_model_path, 'config.json')
+ if not os.path.isfile(config_file):
+ raise RuntimeError(f"{config_file} does not exist")
+ with open(config_file, "r") as f:
+ config = json.load(f)
+ config["_class_name"] = cls.__name__
+ config["down_block_types"] = [
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "DownBlock3D"
+ ]
+ config["up_block_types"] = [
+ "UpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D"
+ ]
+ # config["mid_block_type"] = "UNetMidBlock3DCrossAttn"
+
+ from diffusers.utils import WEIGHTS_NAME
+ model = cls.from_config(config, **unet_additional_kwargs)
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
+ if not os.path.isfile(model_file):
+ raise RuntimeError(f"{model_file} does not exist")
+ state_dict = torch.load(model_file, map_location="cpu")
+
+ m, u = model.load_state_dict(state_dict, strict=False)
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
+ # print(f"### missing keys:\n{m}\n### unexpected keys:\n{u}\n")
+
+ params = [p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()]
+ print(f"### Temporal Module Parameters: {sum(params) / 1e6} M")
+
+ return model
diff --git a/magicanimate/pipelines/__pycache__/animation.cpython-37.pyc b/magicanimate/pipelines/__pycache__/animation.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c9739b1d9a8220f44aadf90537dc32314d648738
Binary files /dev/null and b/magicanimate/pipelines/__pycache__/animation.cpython-37.pyc differ
diff --git a/magicanimate/pipelines/__pycache__/animation.cpython-38.pyc b/magicanimate/pipelines/__pycache__/animation.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..08d058d7ca8b2a8532ad5e97e109d7190c16127e
Binary files /dev/null and b/magicanimate/pipelines/__pycache__/animation.cpython-38.pyc differ
diff --git a/magicanimate/pipelines/__pycache__/context.cpython-38.pyc b/magicanimate/pipelines/__pycache__/context.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b37b304d6644940cda2985ef107a6f6aece61bf3
Binary files /dev/null and b/magicanimate/pipelines/__pycache__/context.cpython-38.pyc differ
diff --git a/magicanimate/pipelines/__pycache__/dist_animation.cpython-37.pyc b/magicanimate/pipelines/__pycache__/dist_animation.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ae41088706b703811c27a47becd7d34519045571
Binary files /dev/null and b/magicanimate/pipelines/__pycache__/dist_animation.cpython-37.pyc differ
diff --git a/magicanimate/pipelines/__pycache__/dist_animation.cpython-38.pyc b/magicanimate/pipelines/__pycache__/dist_animation.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2b79af89fea8aa02000c36379ff57c831f174390
Binary files /dev/null and b/magicanimate/pipelines/__pycache__/dist_animation.cpython-38.pyc differ
diff --git a/magicanimate/pipelines/__pycache__/pipeline_animation.cpython-38.pyc b/magicanimate/pipelines/__pycache__/pipeline_animation.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..016270e040ee9631dbbc932680ce7b5c5c868e91
Binary files /dev/null and b/magicanimate/pipelines/__pycache__/pipeline_animation.cpython-38.pyc differ
diff --git a/magicanimate/pipelines/animation.py b/magicanimate/pipelines/animation.py
new file mode 100644
index 0000000000000000000000000000000000000000..899583ed4ac71dac05e1971553696d3ccc8cde81
--- /dev/null
+++ b/magicanimate/pipelines/animation.py
@@ -0,0 +1,282 @@
+# Copyright 2023 ByteDance and/or its affiliates.
+#
+# Copyright (2023) MagicAnimate Authors
+#
+# ByteDance, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from ByteDance or
+# its affiliates is strictly prohibited.
+import argparse
+import datetime
+import inspect
+import os
+import random
+import numpy as np
+
+from PIL import Image
+from omegaconf import OmegaConf
+from collections import OrderedDict
+
+import torch
+import torch.distributed as dist
+
+from diffusers import AutoencoderKL, DDIMScheduler, UniPCMultistepScheduler
+
+from tqdm import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from magicanimate.models.unet_controlnet import UNet3DConditionModel
+from magicanimate.models.controlnet import ControlNetModel
+from magicanimate.models.appearance_encoder import AppearanceEncoderModel
+from magicanimate.models.mutual_self_attention import ReferenceAttentionControl
+from magicanimate.pipelines.pipeline_animation import AnimationPipeline
+from magicanimate.utils.util import save_videos_grid
+from magicanimate.utils.dist_tools import distributed_init
+from accelerate.utils import set_seed
+
+from magicanimate.utils.videoreader import VideoReader
+
+from einops import rearrange
+
+from pathlib import Path
+
+
+def main(args):
+
+ *_, func_args = inspect.getargvalues(inspect.currentframe())
+ func_args = dict(func_args)
+
+ config = OmegaConf.load(args.config)
+
+ # Initialize distributed training
+ device = torch.device(f"cuda:{args.rank}")
+ dist_kwargs = {"rank":args.rank, "world_size":args.world_size, "dist":args.dist}
+
+ if config.savename is None:
+ time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
+ savedir = f"samples/{Path(args.config).stem}-{time_str}"
+ else:
+ savedir = f"samples/{config.savename}"
+
+ if args.dist:
+ dist.broadcast_object_list([savedir], 0)
+ dist.barrier()
+
+ if args.rank == 0:
+ os.makedirs(savedir, exist_ok=True)
+
+ inference_config = OmegaConf.load(config.inference_config)
+
+ motion_module = config.motion_module
+
+ ### >>> create animation pipeline >>> ###
+ tokenizer = CLIPTokenizer.from_pretrained(config.pretrained_model_path, subfolder="tokenizer")
+ text_encoder = CLIPTextModel.from_pretrained(config.pretrained_model_path, subfolder="text_encoder")
+ if config.pretrained_unet_path:
+ unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_unet_path, unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
+ else:
+ unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
+ appearance_encoder = AppearanceEncoderModel.from_pretrained(config.pretrained_appearance_encoder_path, subfolder="appearance_encoder").to(device)
+ reference_control_writer = ReferenceAttentionControl(appearance_encoder, do_classifier_free_guidance=True, mode='write', fusion_blocks=config.fusion_blocks)
+ reference_control_reader = ReferenceAttentionControl(unet, do_classifier_free_guidance=True, mode='read', fusion_blocks=config.fusion_blocks)
+ if config.pretrained_vae_path is not None:
+ vae = AutoencoderKL.from_pretrained(config.pretrained_vae_path)
+ else:
+ vae = AutoencoderKL.from_pretrained(config.pretrained_model_path, subfolder="vae")
+
+ ### Load controlnet
+ controlnet = ControlNetModel.from_pretrained(config.pretrained_controlnet_path)
+
+ unet.enable_xformers_memory_efficient_attention()
+ appearance_encoder.enable_xformers_memory_efficient_attention()
+ controlnet.enable_xformers_memory_efficient_attention()
+
+ vae.to(torch.float16)
+ unet.to(torch.float16)
+ text_encoder.to(torch.float16)
+ appearance_encoder.to(torch.float16)
+ controlnet.to(torch.float16)
+
+ pipeline = AnimationPipeline(
+ vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet,
+ scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
+ # NOTE: UniPCMultistepScheduler
+ )
+
+ # 1. unet ckpt
+ # 1.1 motion module
+ motion_module_state_dict = torch.load(motion_module, map_location="cpu")
+ if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]})
+ motion_module_state_dict = motion_module_state_dict['state_dict'] if 'state_dict' in motion_module_state_dict else motion_module_state_dict
+ try:
+ # extra steps for self-trained models
+ state_dict = OrderedDict()
+ for key in motion_module_state_dict.keys():
+ if key.startswith("module."):
+ _key = key.split("module.")[-1]
+ state_dict[_key] = motion_module_state_dict[key]
+ else:
+ state_dict[key] = motion_module_state_dict[key]
+ motion_module_state_dict = state_dict
+ del state_dict
+ missing, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False)
+ assert len(unexpected) == 0
+ except:
+ _tmp_ = OrderedDict()
+ for key in motion_module_state_dict.keys():
+ if "motion_modules" in key:
+ if key.startswith("unet."):
+ _key = key.split('unet.')[-1]
+ _tmp_[_key] = motion_module_state_dict[key]
+ else:
+ _tmp_[key] = motion_module_state_dict[key]
+ missing, unexpected = unet.load_state_dict(_tmp_, strict=False)
+ assert len(unexpected) == 0
+ del _tmp_
+ del motion_module_state_dict
+
+ pipeline.to(device)
+ ### <<< create validation pipeline <<< ###
+
+ random_seeds = config.get("seed", [-1])
+ random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds)
+ random_seeds = random_seeds * len(config.source_image) if len(random_seeds) == 1 else random_seeds
+
+ # input test videos (either source video/ conditions)
+
+ test_videos = config.video_path
+ source_images = config.source_image
+ num_actual_inference_steps = config.get("num_actual_inference_steps", config.steps)
+
+ # read size, step from yaml file
+ sizes = [config.size] * len(test_videos)
+ steps = [config.S] * len(test_videos)
+
+ config.random_seed = []
+ prompt = n_prompt = ""
+ for idx, (source_image, test_video, random_seed, size, step) in tqdm(
+ enumerate(zip(source_images, test_videos, random_seeds, sizes, steps)),
+ total=len(test_videos),
+ disable=(args.rank!=0)
+ ):
+ samples_per_video = []
+ samples_per_clip = []
+ # manually set random seed for reproduction
+ if random_seed != -1:
+ torch.manual_seed(random_seed)
+ set_seed(random_seed)
+ else:
+ torch.seed()
+ config.random_seed.append(torch.initial_seed())
+
+ if test_video.endswith('.mp4'):
+ control = VideoReader(test_video).read()
+ if control[0].shape[0] != size:
+ control = [np.array(Image.fromarray(c).resize((size, size))) for c in control]
+ if config.max_length is not None:
+ control = control[config.offset: (config.offset+config.max_length)]
+ control = np.array(control)
+
+ if source_image.endswith(".mp4"):
+ source_image = np.array(Image.fromarray(VideoReader(source_image).read()[0]).resize((size, size)))
+ else:
+ source_image = np.array(Image.open(source_image).resize((size, size)))
+ H, W, C = source_image.shape
+
+ print(f"current seed: {torch.initial_seed()}")
+ init_latents = None
+
+ # print(f"sampling {prompt} ...")
+ original_length = control.shape[0]
+ if control.shape[0] % config.L > 0:
+ control = np.pad(control, ((0, config.L-control.shape[0] % config.L), (0, 0), (0, 0), (0, 0)), mode='edge')
+ generator = torch.Generator(device=torch.device("cuda:0"))
+ generator.manual_seed(torch.initial_seed())
+ sample = pipeline(
+ prompt,
+ negative_prompt = n_prompt,
+ num_inference_steps = config.steps,
+ guidance_scale = config.guidance_scale,
+ width = W,
+ height = H,
+ video_length = len(control),
+ controlnet_condition = control,
+ init_latents = init_latents,
+ generator = generator,
+ num_actual_inference_steps = num_actual_inference_steps,
+ appearance_encoder = appearance_encoder,
+ reference_control_writer = reference_control_writer,
+ reference_control_reader = reference_control_reader,
+ source_image = source_image,
+ **dist_kwargs,
+ ).videos
+
+ if args.rank == 0:
+ source_images = np.array([source_image] * original_length)
+ source_images = rearrange(torch.from_numpy(source_images), "t h w c -> 1 c t h w") / 255.0
+ samples_per_video.append(source_images)
+
+ control = control / 255.0
+ control = rearrange(control, "t h w c -> 1 c t h w")
+ control = torch.from_numpy(control)
+ samples_per_video.append(control[:, :, :original_length])
+
+ samples_per_video.append(sample[:, :, :original_length])
+
+ samples_per_video = torch.cat(samples_per_video)
+
+ video_name = os.path.basename(test_video)[:-4]
+ source_name = os.path.basename(config.source_image[idx]).split(".")[0]
+ save_videos_grid(samples_per_video[-1:], f"{savedir}/videos/{source_name}_{video_name}.mp4")
+ save_videos_grid(samples_per_video, f"{savedir}/videos/{source_name}_{video_name}/grid.mp4")
+
+ if config.save_individual_videos:
+ save_videos_grid(samples_per_video[1:2], f"{savedir}/videos/{source_name}_{video_name}/ctrl.mp4")
+ save_videos_grid(samples_per_video[0:1], f"{savedir}/videos/{source_name}_{video_name}/orig.mp4")
+
+ if args.dist:
+ dist.barrier()
+
+ if args.rank == 0:
+ OmegaConf.save(config, f"{savedir}/config.yaml")
+
+
+def distributed_main(device_id, args):
+ args.rank = device_id
+ args.device_id = device_id
+ if torch.cuda.is_available():
+ torch.cuda.set_device(args.device_id)
+ torch.cuda.init()
+ distributed_init(args)
+ main(args)
+
+
+def run(args):
+
+ if args.dist:
+ args.world_size = max(1, torch.cuda.device_count())
+ assert args.world_size <= torch.cuda.device_count()
+
+ if args.world_size > 0 and torch.cuda.device_count() > 1:
+ port = random.randint(10000, 20000)
+ args.init_method = f"tcp://localhost:{port}"
+ torch.multiprocessing.spawn(
+ fn=distributed_main,
+ args=(args,),
+ nprocs=args.world_size,
+ )
+ else:
+ main(args)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--config", type=str, required=True)
+ parser.add_argument("--dist", action="store_true", required=False)
+ parser.add_argument("--rank", type=int, default=0, required=False)
+ parser.add_argument("--world_size", type=int, default=1, required=False)
+
+ args = parser.parse_args()
+ run(args)
diff --git a/magicanimate/pipelines/context.py b/magicanimate/pipelines/context.py
new file mode 100644
index 0000000000000000000000000000000000000000..59fb8864ad14844e7c018234f9030f12e801707b
--- /dev/null
+++ b/magicanimate/pipelines/context.py
@@ -0,0 +1,76 @@
+# *************************************************************************
+# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
+# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
+# ytedance Inc..
+# *************************************************************************
+
+# Adapted from https://github.com/s9roll7/animatediff-cli-prompt-travel/tree/main
+import numpy as np
+from typing import Callable, Optional, List
+
+
+def ordered_halving(val):
+ bin_str = f"{val:064b}"
+ bin_flip = bin_str[::-1]
+ as_int = int(bin_flip, 2)
+
+ return as_int / (1 << 64)
+
+
+def uniform(
+ step: int = ...,
+ num_steps: Optional[int] = None,
+ num_frames: int = ...,
+ context_size: Optional[int] = None,
+ context_stride: int = 3,
+ context_overlap: int = 4,
+ closed_loop: bool = True,
+):
+ if num_frames <= context_size:
+ yield list(range(num_frames))
+ return
+
+ context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1)
+
+ for context_step in 1 << np.arange(context_stride):
+ pad = int(round(num_frames * ordered_halving(step)))
+ for j in range(
+ int(ordered_halving(step) * context_step) + pad,
+ num_frames + pad + (0 if closed_loop else -context_overlap),
+ (context_size * context_step - context_overlap),
+ ):
+ yield [e % num_frames for e in range(j, j + context_size * context_step, context_step)]
+
+
+def get_context_scheduler(name: str) -> Callable:
+ if name == "uniform":
+ return uniform
+ else:
+ raise ValueError(f"Unknown context_overlap policy {name}")
+
+
+def get_total_steps(
+ scheduler,
+ timesteps: List[int],
+ num_steps: Optional[int] = None,
+ num_frames: int = ...,
+ context_size: Optional[int] = None,
+ context_stride: int = 3,
+ context_overlap: int = 4,
+ closed_loop: bool = True,
+):
+ return sum(
+ len(
+ list(
+ scheduler(
+ i,
+ num_steps,
+ num_frames,
+ context_size,
+ context_stride,
+ context_overlap,
+ )
+ )
+ )
+ for i in range(len(timesteps))
+ )
diff --git a/magicanimate/pipelines/pipeline_animation.py b/magicanimate/pipelines/pipeline_animation.py
new file mode 100644
index 0000000000000000000000000000000000000000..08a77bacc4f11f66bee1703e76cb1df21368a234
--- /dev/null
+++ b/magicanimate/pipelines/pipeline_animation.py
@@ -0,0 +1,799 @@
+# *************************************************************************
+# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
+# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
+# ytedance Inc..
+# *************************************************************************
+
+# Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py
+
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+TODO:
+1. support multi-controlnet
+2. [DONE] support DDIM inversion
+3. support Prompt-to-prompt
+"""
+
+import inspect, math
+from typing import Callable, List, Optional, Union
+from dataclasses import dataclass
+from PIL import Image
+import numpy as np
+import torch
+import torch.distributed as dist
+from tqdm import tqdm
+from diffusers.utils import is_accelerate_available
+from packaging import version
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from diffusers.configuration_utils import FrozenDict
+from diffusers.models import AutoencoderKL
+from diffusers.pipeline_utils import DiffusionPipeline
+from diffusers.schedulers import (
+ DDIMScheduler,
+ DPMSolverMultistepScheduler,
+ EulerAncestralDiscreteScheduler,
+ EulerDiscreteScheduler,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+)
+from diffusers.utils import deprecate, logging, BaseOutput
+
+from einops import rearrange
+
+from magicanimate.models.unet_controlnet import UNet3DConditionModel
+from magicanimate.models.controlnet import ControlNetModel
+from magicanimate.models.mutual_self_attention import ReferenceAttentionControl
+from magicanimate.pipelines.context import (
+ get_context_scheduler,
+ get_total_steps
+)
+from magicanimate.utils.util import get_tensor_interpolation_method
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class AnimationPipelineOutput(BaseOutput):
+ videos: Union[torch.Tensor, np.ndarray]
+
+
+class AnimationPipeline(DiffusionPipeline):
+ _optional_components = []
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet3DConditionModel,
+ controlnet: ControlNetModel,
+ scheduler: Union[
+ DDIMScheduler,
+ PNDMScheduler,
+ LMSDiscreteScheduler,
+ EulerDiscreteScheduler,
+ EulerAncestralDiscreteScheduler,
+ DPMSolverMultistepScheduler,
+ ],
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ controlnet=controlnet,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+
+ def enable_vae_slicing(self):
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ self.vae.disable_slicing()
+
+ def enable_sequential_cpu_offload(self, gpu_id=0):
+ if is_accelerate_available():
+ from accelerate import cpu_offload
+ else:
+ raise ImportError("Please install accelerate via `pip install accelerate`")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
+ if cpu_offloaded_model is not None:
+ cpu_offload(cpu_offloaded_model, device)
+
+
+ @property
+ def _execution_device(self):
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
+ return self.device
+ for module in self.unet.modules():
+ if (
+ hasattr(module, "_hf_hook")
+ and hasattr(module._hf_hook, "execution_device")
+ and module._hf_hook.execution_device is not None
+ ):
+ return torch.device(module._hf_hook.execution_device)
+ return self.device
+
+ def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt):
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ text_embeddings = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ text_embeddings = text_embeddings[0]
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = text_embeddings.shape
+ text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
+ text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ uncond_embeddings = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ uncond_embeddings = uncond_embeddings[0]
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ return text_embeddings
+
+ def decode_latents(self, latents, rank, decoder_consistency=None):
+ video_length = latents.shape[2]
+ latents = 1 / 0.18215 * latents
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
+ # video = self.vae.decode(latents).sample
+ video = []
+ for frame_idx in tqdm(range(latents.shape[0]), disable=(rank!=0)):
+ if decoder_consistency is not None:
+ video.append(decoder_consistency(latents[frame_idx:frame_idx+1]))
+ else:
+ video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
+ video = torch.cat(video)
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
+ video = (video / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ video = video.cpu().float().numpy()
+ return video
+
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(self, prompt, height, width, callback_steps):
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None, clip_length=16):
+ shape = (batch_size, num_channels_latents, clip_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+ if latents is None:
+ rand_device = "cpu" if device.type == "mps" else device
+
+ if isinstance(generator, list):
+ latents = [
+ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
+ for i in range(batch_size)
+ ]
+ latents = torch.cat(latents, dim=0).to(device)
+ else:
+ latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
+
+ latents = latents.repeat(1, 1, video_length//clip_length, 1, 1)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def prepare_condition(self, condition, num_videos_per_prompt, device, dtype, do_classifier_free_guidance):
+ # prepare conditions for controlnet
+ condition = torch.from_numpy(condition.copy()).to(device=device, dtype=dtype) / 255.0
+ condition = torch.stack([condition for _ in range(num_videos_per_prompt)], dim=0)
+ condition = rearrange(condition, 'b f h w c -> (b f) c h w').clone()
+ if do_classifier_free_guidance:
+ condition = torch.cat([condition] * 2)
+ return condition
+
+ def next_step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ x: torch.FloatTensor,
+ eta=0.,
+ verbose=False
+ ):
+ """
+ Inverse sampling for DDIM Inversion
+ """
+ if verbose:
+ print("timestep: ", timestep)
+ next_step = timestep
+ timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999)
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
+ alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step]
+ beta_prod_t = 1 - alpha_prod_t
+ pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
+ pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output
+ x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir
+ return x_next, pred_x0
+
+ @torch.no_grad()
+ def images2latents(self, images, dtype):
+ """
+ Convert RGB image to VAE latents
+ """
+ device = self._execution_device
+ images = torch.from_numpy(images).float().to(dtype) / 127.5 - 1
+ images = rearrange(images, "f h w c -> f c h w").to(device)
+ latents = []
+ for frame_idx in range(images.shape[0]):
+ latents.append(self.vae.encode(images[frame_idx:frame_idx+1])['latent_dist'].mean * 0.18215)
+ latents = torch.cat(latents)
+ return latents
+
+ @torch.no_grad()
+ def invert(
+ self,
+ image: torch.Tensor,
+ prompt,
+ num_inference_steps=20,
+ num_actual_inference_steps=10,
+ eta=0.0,
+ return_intermediates=False,
+ **kwargs):
+ """
+ Adapted from: https://github.com/Yujun-Shi/DragDiffusion/blob/main/drag_pipeline.py#L440
+ invert a real image into noise map with determinisc DDIM inversion
+ """
+ device = self._execution_device
+ batch_size = image.shape[0]
+ if isinstance(prompt, list):
+ if batch_size == 1:
+ image = image.expand(len(prompt), -1, -1, -1)
+ elif isinstance(prompt, str):
+ if batch_size > 1:
+ prompt = [prompt] * batch_size
+
+ # text embeddings
+ text_input = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=77,
+ return_tensors="pt"
+ )
+ text_embeddings = self.text_encoder(text_input.input_ids.to(device))[0]
+ print("input text embeddings :", text_embeddings.shape)
+ # define initial latents
+ latents = self.images2latents(image)
+
+ print("latents shape: ", latents.shape)
+ # interative sampling
+ self.scheduler.set_timesteps(num_inference_steps)
+ print("Valid timesteps: ", reversed(self.scheduler.timesteps))
+ latents_list = [latents]
+ pred_x0_list = [latents]
+ for i, t in enumerate(tqdm(reversed(self.scheduler.timesteps), desc="DDIM Inversion")):
+
+ if num_actual_inference_steps is not None and i >= num_actual_inference_steps:
+ continue
+ model_inputs = latents
+
+ # predict the noise
+ # NOTE: the u-net here is UNet3D, therefore the model_inputs need to be of shape (b c f h w)
+ model_inputs = rearrange(model_inputs, "f c h w -> 1 c f h w")
+ noise_pred = self.unet(model_inputs, t, encoder_hidden_states=text_embeddings).sample
+ noise_pred = rearrange(noise_pred, "b c f h w -> (b f) c h w")
+
+ # compute the previous noise sample x_t-1 -> x_t
+ latents, pred_x0 = self.next_step(noise_pred, t, latents)
+ latents_list.append(latents)
+ pred_x0_list.append(pred_x0)
+
+ if return_intermediates:
+ # return the intermediate laters during inversion
+ return latents, latents_list
+ return latents
+
+ def interpolate_latents(self, latents: torch.Tensor, interpolation_factor:int, device ):
+ if interpolation_factor < 2:
+ return latents
+
+ new_latents = torch.zeros(
+ (latents.shape[0],latents.shape[1],((latents.shape[2]-1) * interpolation_factor)+1, latents.shape[3],latents.shape[4]),
+ device=latents.device,
+ dtype=latents.dtype,
+ )
+
+ org_video_length = latents.shape[2]
+ rate = [i/interpolation_factor for i in range(interpolation_factor)][1:]
+
+ new_index = 0
+
+ v0 = None
+ v1 = None
+
+ for i0,i1 in zip( range( org_video_length ),range( org_video_length )[1:] ):
+ v0 = latents[:,:,i0,:,:]
+ v1 = latents[:,:,i1,:,:]
+
+ new_latents[:,:,new_index,:,:] = v0
+ new_index += 1
+
+ for f in rate:
+ v = get_tensor_interpolation_method()(v0.to(device=device),v1.to(device=device),f)
+ new_latents[:,:,new_index,:,:] = v.to(latents.device)
+ new_index += 1
+
+ new_latents[:,:,new_index,:,:] = v1
+ new_index += 1
+
+ return new_latents
+
+ def select_controlnet_res_samples(self, controlnet_res_samples_cache_dict, context, do_classifier_free_guidance, b, f):
+ _down_block_res_samples = []
+ _mid_block_res_sample = []
+ for i in np.concatenate(np.array(context)):
+ _down_block_res_samples.append(controlnet_res_samples_cache_dict[i][0])
+ _mid_block_res_sample.append(controlnet_res_samples_cache_dict[i][1])
+ down_block_res_samples = [[] for _ in range(len(controlnet_res_samples_cache_dict[i][0]))]
+ for res_t in _down_block_res_samples:
+ for i, res in enumerate(res_t):
+ down_block_res_samples[i].append(res)
+ down_block_res_samples = [torch.cat(res) for res in down_block_res_samples]
+ mid_block_res_sample = torch.cat(_mid_block_res_sample)
+
+ # reshape controlnet output to match the unet3d inputs
+ b = b // 2 if do_classifier_free_guidance else b
+ _down_block_res_samples = []
+ for sample in down_block_res_samples:
+ sample = rearrange(sample, '(b f) c h w -> b c f h w', b=b, f=f)
+ if do_classifier_free_guidance:
+ sample = sample.repeat(2, 1, 1, 1, 1)
+ _down_block_res_samples.append(sample)
+ down_block_res_samples = _down_block_res_samples
+ mid_block_res_sample = rearrange(mid_block_res_sample, '(b f) c h w -> b c f h w', b=b, f=f)
+ if do_classifier_free_guidance:
+ mid_block_res_sample = mid_block_res_sample.repeat(2, 1, 1, 1, 1)
+
+ return down_block_res_samples, mid_block_res_sample
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ video_length: Optional[int],
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_videos_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "tensor",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ controlnet_condition: list = None,
+ controlnet_conditioning_scale: float = 1.0,
+ context_frames: int = 16,
+ context_stride: int = 1,
+ context_overlap: int = 4,
+ context_batch_size: int = 1,
+ context_schedule: str = "uniform",
+ init_latents: Optional[torch.FloatTensor] = None,
+ num_actual_inference_steps: Optional[int] = None,
+ appearance_encoder = None,
+ reference_control_writer = None,
+ reference_control_reader = None,
+ source_image: str = None,
+ decoder_consistency = None,
+ **kwargs,
+ ):
+ """
+ New args:
+ - controlnet_condition : condition map (e.g., depth, canny, keypoints) for controlnet
+ - controlnet_conditioning_scale : conditioning scale for controlnet
+ - init_latents : initial latents to begin with (used along with invert())
+ - num_actual_inference_steps : number of actual inference steps (while total steps is num_inference_steps)
+ """
+ controlnet = self.controlnet
+
+ # Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # Check inputs. Raise error if not correct
+ self.check_inputs(prompt, height, width, callback_steps)
+
+ # Define call parameters
+ # batch_size = 1 if isinstance(prompt, str) else len(prompt)
+ batch_size = 1
+ if latents is not None:
+ batch_size = latents.shape[0]
+ if isinstance(prompt, list):
+ batch_size = len(prompt)
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # Encode input prompt
+ prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size
+ if negative_prompt is not None:
+ negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size
+ text_embeddings = self._encode_prompt(
+ prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
+ )
+ text_embeddings = torch.cat([text_embeddings] * context_batch_size)
+
+ reference_control_writer = ReferenceAttentionControl(appearance_encoder, do_classifier_free_guidance=True, mode='write', batch_size=context_batch_size)
+ reference_control_reader = ReferenceAttentionControl(self.unet, do_classifier_free_guidance=True, mode='read', batch_size=context_batch_size)
+
+ is_dist_initialized = kwargs.get("dist", False)
+ rank = kwargs.get("rank", 0)
+ world_size = kwargs.get("world_size", 1)
+
+ # Prepare video
+ assert num_videos_per_prompt == 1 # FIXME: verify if num_videos_per_prompt > 1 works
+ assert batch_size == 1 # FIXME: verify if batch_size > 1 works
+ control = self.prepare_condition(
+ condition=controlnet_condition,
+ device=device,
+ dtype=controlnet.dtype,
+ num_videos_per_prompt=num_videos_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ )
+ controlnet_uncond_images, controlnet_cond_images = control.chunk(2)
+
+ # Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # Prepare latent variables
+ if init_latents is not None:
+ latents = rearrange(init_latents, "(b f) c h w -> b c f h w", f=video_length)
+ else:
+ num_channels_latents = self.unet.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ video_length,
+ height,
+ width,
+ text_embeddings.dtype,
+ device,
+ generator,
+ latents,
+ )
+ latents_dtype = latents.dtype
+
+ # Prepare extra step kwargs.
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # Prepare text embeddings for controlnet
+ controlnet_text_embeddings = text_embeddings.repeat_interleave(video_length, 0)
+ _, controlnet_text_embeddings_c = controlnet_text_embeddings.chunk(2)
+
+ controlnet_res_samples_cache_dict = {i:None for i in range(video_length)}
+
+ # For img2img setting
+ if num_actual_inference_steps is None:
+ num_actual_inference_steps = num_inference_steps
+
+ if isinstance(source_image, str):
+ ref_image_latents = self.images2latents(np.array(Image.open(source_image).resize((width, height)))[None, :], latents_dtype).cuda()
+ elif isinstance(source_image, np.ndarray):
+ ref_image_latents = self.images2latents(source_image[None, :], latents_dtype).cuda()
+
+ context_scheduler = get_context_scheduler(context_schedule)
+
+ # Denoising loop
+ for i, t in tqdm(enumerate(timesteps), total=len(timesteps), disable=(rank!=0)):
+ if num_actual_inference_steps is not None and i < num_inference_steps - num_actual_inference_steps:
+ continue
+
+ noise_pred = torch.zeros(
+ (latents.shape[0] * (2 if do_classifier_free_guidance else 1), *latents.shape[1:]),
+ device=latents.device,
+ dtype=latents.dtype,
+ )
+ counter = torch.zeros(
+ (1, 1, latents.shape[2], 1, 1), device=latents.device, dtype=latents.dtype
+ )
+
+ appearance_encoder(
+ ref_image_latents.repeat(context_batch_size * (2 if do_classifier_free_guidance else 1), 1, 1, 1),
+ t,
+ encoder_hidden_states=text_embeddings,
+ return_dict=False,
+ )
+
+ context_queue = list(context_scheduler(
+ 0, num_inference_steps, latents.shape[2], context_frames, context_stride, 0
+ ))
+ num_context_batches = math.ceil(len(context_queue) / context_batch_size)
+ for i in range(num_context_batches):
+ context = context_queue[i*context_batch_size: (i+1)*context_batch_size]
+ # expand the latents if we are doing classifier free guidance
+ controlnet_latent_input = (
+ torch.cat([latents[:, :, c] for c in context])
+ .to(device)
+ )
+ controlnet_latent_input = self.scheduler.scale_model_input(controlnet_latent_input, t)
+
+ # prepare inputs for controlnet
+ b, c, f, h, w = controlnet_latent_input.shape
+ controlnet_latent_input = rearrange(controlnet_latent_input, "b c f h w -> (b f) c h w")
+
+ # controlnet inference
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
+ controlnet_latent_input,
+ t,
+ encoder_hidden_states=torch.cat([controlnet_text_embeddings_c[c] for c in context]),
+ controlnet_cond=torch.cat([controlnet_cond_images[c] for c in context]),
+ conditioning_scale=controlnet_conditioning_scale,
+ return_dict=False,
+ )
+
+ for j, k in enumerate(np.concatenate(np.array(context))):
+ controlnet_res_samples_cache_dict[k] = ([sample[j:j+1] for sample in down_block_res_samples], mid_block_res_sample[j:j+1])
+
+ context_queue = list(context_scheduler(
+ 0, num_inference_steps, latents.shape[2], context_frames, context_stride, context_overlap
+ ))
+
+ num_context_batches = math.ceil(len(context_queue) / context_batch_size)
+ global_context = []
+ for i in range(num_context_batches):
+ global_context.append(context_queue[i*context_batch_size: (i+1)*context_batch_size])
+
+ for context in global_context[rank::world_size]:
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = (
+ torch.cat([latents[:, :, c] for c in context])
+ .to(device)
+ .repeat(2 if do_classifier_free_guidance else 1, 1, 1, 1, 1)
+ )
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ b, c, f, h, w = latent_model_input.shape
+ down_block_res_samples, mid_block_res_sample = self.select_controlnet_res_samples(
+ controlnet_res_samples_cache_dict,
+ context,
+ do_classifier_free_guidance,
+ b, f
+ )
+
+ reference_control_reader.update(reference_control_writer)
+
+ # predict the noise residual
+ pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=text_embeddings[:b],
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ return_dict=False,
+ )[0]
+
+ reference_control_reader.clear()
+
+ pred_uc, pred_c = pred.chunk(2)
+ pred = torch.cat([pred_uc.unsqueeze(0), pred_c.unsqueeze(0)])
+ for j, c in enumerate(context):
+ noise_pred[:, :, c] = noise_pred[:, :, c] + pred[:, j]
+ counter[:, :, c] = counter[:, :, c] + 1
+
+ if is_dist_initialized:
+ noise_pred_gathered = [torch.zeros_like(noise_pred) for _ in range(world_size)]
+ if rank == 0:
+ dist.gather(tensor=noise_pred, gather_list=noise_pred_gathered, dst=0)
+ else:
+ dist.gather(tensor=noise_pred, gather_list=[], dst=0)
+ dist.barrier()
+
+ if rank == 0:
+ for k in range(1, world_size):
+ for context in global_context[k::world_size]:
+ for j, c in enumerate(context):
+ noise_pred[:, :, c] = noise_pred[:, :, c] + noise_pred_gathered[k][:, :, c]
+ counter[:, :, c] = counter[:, :, c] + 1
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = (noise_pred / counter).chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ if is_dist_initialized:
+ dist.broadcast(latents, 0)
+ dist.barrier()
+
+ reference_control_writer.clear()
+
+ interpolation_factor = 1
+ latents = self.interpolate_latents(latents, interpolation_factor, device)
+ # Post-processing
+ video = self.decode_latents(latents, rank, decoder_consistency=decoder_consistency)
+
+ if is_dist_initialized:
+ dist.barrier()
+
+ # Convert to tensor
+ if output_type == "tensor":
+ video = torch.from_numpy(video)
+
+ if not return_dict:
+ return video
+
+ return AnimationPipelineOutput(videos=video)
diff --git a/magicanimate/utils/__pycache__/dist_tools.cpython-38.pyc b/magicanimate/utils/__pycache__/dist_tools.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fc362cbaa73edc4be0ab32379c983a089477b803
Binary files /dev/null and b/magicanimate/utils/__pycache__/dist_tools.cpython-38.pyc differ
diff --git a/magicanimate/utils/__pycache__/util.cpython-38.pyc b/magicanimate/utils/__pycache__/util.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..203fd0d0bf60eef6d6657a785b6208b3476b376c
Binary files /dev/null and b/magicanimate/utils/__pycache__/util.cpython-38.pyc differ
diff --git a/magicanimate/utils/__pycache__/videoreader.cpython-38.pyc b/magicanimate/utils/__pycache__/videoreader.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a5b9d98e5208b4bee650f5fa487628acd82be67e
Binary files /dev/null and b/magicanimate/utils/__pycache__/videoreader.cpython-38.pyc differ
diff --git a/magicanimate/utils/dist_tools.py b/magicanimate/utils/dist_tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..0761c6d197a14b0d77d167d0e27bde3f4cc2ec6b
--- /dev/null
+++ b/magicanimate/utils/dist_tools.py
@@ -0,0 +1,105 @@
+# Copyright 2023 ByteDance and/or its affiliates.
+#
+# Copyright (2023) MagicAnimate Authors
+#
+# ByteDance, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from ByteDance or
+# its affiliates is strictly prohibited.
+import os
+import socket
+import warnings
+import torch
+from torch import distributed as dist
+
+
+def distributed_init(args):
+
+ if dist.is_initialized():
+ warnings.warn("Distributed is already initialized, cannot initialize twice!")
+ args.rank = dist.get_rank()
+ else:
+ print(
+ f"Distributed Init (Rank {args.rank}): "
+ f"{args.init_method}"
+ )
+ dist.init_process_group(
+ backend='nccl',
+ init_method=args.init_method,
+ world_size=args.world_size,
+ rank=args.rank,
+ )
+ print(
+ f"Initialized Host {socket.gethostname()} as Rank "
+ f"{args.rank}"
+ )
+
+ if "MASTER_ADDR" not in os.environ or "MASTER_PORT" not in os.environ:
+ # Set for onboxdataloader support
+ split = args.init_method.split("//")
+ assert len(split) == 2, (
+ "host url for distributed should be split by '//' "
+ + "into exactly two elements"
+ )
+
+ split = split[1].split(":")
+ assert (
+ len(split) == 2
+ ), "host url should be of the form :"
+ os.environ["MASTER_ADDR"] = split[0]
+ os.environ["MASTER_PORT"] = split[1]
+
+ # perform a dummy all-reduce to initialize the NCCL communicator
+ dist.all_reduce(torch.zeros(1).cuda())
+
+ suppress_output(is_master())
+ args.rank = dist.get_rank()
+ return args.rank
+
+
+def get_rank():
+ if not dist.is_available():
+ return 0
+ if not dist.is_nccl_available():
+ return 0
+ if not dist.is_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def is_master():
+ return get_rank() == 0
+
+
+def synchronize():
+ if dist.is_initialized():
+ dist.barrier()
+
+
+def suppress_output(is_master):
+ """Suppress printing on the current device. Force printing with `force=True`."""
+ import builtins as __builtin__
+
+ builtin_print = __builtin__.print
+
+ def print(*args, **kwargs):
+ force = kwargs.pop("force", False)
+ if is_master or force:
+ builtin_print(*args, **kwargs)
+
+ __builtin__.print = print
+
+ import warnings
+
+ builtin_warn = warnings.warn
+
+ def warn(*args, **kwargs):
+ force = kwargs.pop("force", False)
+ if is_master or force:
+ builtin_warn(*args, **kwargs)
+
+ # Log warnings only once
+ warnings.warn = warn
+ warnings.simplefilter("once", UserWarning)
\ No newline at end of file
diff --git a/magicanimate/utils/util.py b/magicanimate/utils/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..76f2245cf13367694ad00faf3be721de70ff57bd
--- /dev/null
+++ b/magicanimate/utils/util.py
@@ -0,0 +1,138 @@
+# *************************************************************************
+# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
+# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
+# ytedance Inc..
+# *************************************************************************
+
+# Adapted from https://github.com/guoyww/AnimateDiff
+import os
+import imageio
+import numpy as np
+
+import torch
+import torchvision
+
+from PIL import Image
+from typing import Union
+from tqdm import tqdm
+from einops import rearrange
+
+
+def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=25):
+ videos = rearrange(videos, "b c t h w -> t b c h w")
+ outputs = []
+ for x in videos:
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
+ if rescale:
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
+ x = (x * 255).numpy().astype(np.uint8)
+ outputs.append(x)
+
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ imageio.mimsave(path, outputs, fps=fps)
+
+def save_images_grid(images: torch.Tensor, path: str):
+ assert images.shape[2] == 1 # no time dimension
+ images = images.squeeze(2)
+ grid = torchvision.utils.make_grid(images)
+ grid = (grid * 255).numpy().transpose(1, 2, 0).astype(np.uint8)
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ Image.fromarray(grid).save(path)
+
+# DDIM Inversion
+@torch.no_grad()
+def init_prompt(prompt, pipeline):
+ uncond_input = pipeline.tokenizer(
+ [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
+ return_tensors="pt"
+ )
+ uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
+ text_input = pipeline.tokenizer(
+ [prompt],
+ padding="max_length",
+ max_length=pipeline.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
+ context = torch.cat([uncond_embeddings, text_embeddings])
+
+ return context
+
+
+def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
+ sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
+ timestep, next_timestep = min(
+ timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
+ alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
+ alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
+ beta_prod_t = 1 - alpha_prod_t
+ next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
+ next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
+ next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
+ return next_sample
+
+
+def get_noise_pred_single(latents, t, context, unet):
+ noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
+ return noise_pred
+
+
+@torch.no_grad()
+def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
+ context = init_prompt(prompt, pipeline)
+ uncond_embeddings, cond_embeddings = context.chunk(2)
+ all_latent = [latent]
+ latent = latent.clone().detach()
+ for i in tqdm(range(num_inv_steps)):
+ t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
+ noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
+ latent = next_step(noise_pred, t, latent, ddim_scheduler)
+ all_latent.append(latent)
+ return all_latent
+
+
+@torch.no_grad()
+def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
+ ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
+ return ddim_latents
+
+
+def video2images(path, step=4, length=16, start=0):
+ reader = imageio.get_reader(path)
+ frames = []
+ for frame in reader:
+ frames.append(np.array(frame))
+ frames = frames[start::step][:length]
+ return frames
+
+
+def images2video(video, path, fps=8):
+ imageio.mimsave(path, video, fps=fps)
+ return
+
+
+tensor_interpolation = None
+
+def get_tensor_interpolation_method():
+ return tensor_interpolation
+
+def set_tensor_interpolation_method(is_slerp):
+ global tensor_interpolation
+ tensor_interpolation = slerp if is_slerp else linear
+
+def linear(v1, v2, t):
+ return (1.0 - t) * v1 + t * v2
+
+def slerp(
+ v0: torch.Tensor, v1: torch.Tensor, t: float, DOT_THRESHOLD: float = 0.9995
+) -> torch.Tensor:
+ u0 = v0 / v0.norm()
+ u1 = v1 / v1.norm()
+ dot = (u0 * u1).sum()
+ if dot.abs() > DOT_THRESHOLD:
+ #logger.info(f'warning: v0 and v1 close to parallel, using linear interpolation instead.')
+ return (1.0 - t) * v0 + t * v1
+ omega = dot.acos()
+ return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin()
\ No newline at end of file
diff --git a/magicanimate/utils/videoreader.py b/magicanimate/utils/videoreader.py
new file mode 100644
index 0000000000000000000000000000000000000000..e26a37e01a1deb223e0ca8d24beb306d1703d01e
--- /dev/null
+++ b/magicanimate/utils/videoreader.py
@@ -0,0 +1,157 @@
+# *************************************************************************
+# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
+# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
+# ytedance Inc..
+# *************************************************************************
+
+# Copyright 2022 ByteDance and/or its affiliates.
+#
+# Copyright (2022) PV3D Authors
+#
+# ByteDance, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from ByteDance or
+# its affiliates is strictly prohibited.
+import av, gc
+import torch
+import warnings
+import numpy as np
+
+
+_CALLED_TIMES = 0
+_GC_COLLECTION_INTERVAL = 20
+
+
+# remove warnings
+av.logging.set_level(av.logging.ERROR)
+
+
+class VideoReader():
+ """
+ Simple wrapper around PyAV that exposes a few useful functions for
+ dealing with video reading. PyAV is a pythonic binding for the ffmpeg libraries.
+ Acknowledgement: Codes are borrowed from Bruno Korbar
+ """
+ def __init__(self, video, num_frames=float("inf"), decode_lossy=False, audio_resample_rate=None, bi_frame=False):
+ """
+ Arguments:
+ video_path (str): path or byte of the video to be loaded
+ """
+ self.container = av.open(video)
+ self.num_frames = num_frames
+ self.bi_frame = bi_frame
+
+ self.resampler = None
+ if audio_resample_rate is not None:
+ self.resampler = av.AudioResampler(rate=audio_resample_rate)
+
+ if self.container.streams.video:
+ # enable multi-threaded video decoding
+ if decode_lossy:
+ warnings.warn('VideoReader| thread_type==AUTO can yield potential frame dropping!', RuntimeWarning)
+ self.container.streams.video[0].thread_type = 'AUTO'
+ self.video_stream = self.container.streams.video[0]
+ else:
+ self.video_stream = None
+
+ self.fps = self._get_video_frame_rate()
+
+ def seek(self, pts, backward=True, any_frame=False):
+ stream = self.video_stream
+ self.container.seek(pts, any_frame=any_frame, backward=backward, stream=stream)
+
+ def _occasional_gc(self):
+ # there are a lot of reference cycles in PyAV, so need to manually call
+ # the garbage collector from time to time
+ global _CALLED_TIMES, _GC_COLLECTION_INTERVAL
+ _CALLED_TIMES += 1
+ if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1:
+ gc.collect()
+
+ def _read_video(self, offset):
+ self._occasional_gc()
+
+ pts = self.container.duration * offset
+ time_ = pts / float(av.time_base)
+ self.container.seek(int(pts))
+
+ video_frames = []
+ count = 0
+ for _, frame in enumerate(self._iter_frames()):
+ if frame.pts * frame.time_base >= time_:
+ video_frames.append(frame)
+ if count >= self.num_frames - 1:
+ break
+ count += 1
+ return video_frames
+
+ def _iter_frames(self):
+ for packet in self.container.demux(self.video_stream):
+ for frame in packet.decode():
+ yield frame
+
+ def _compute_video_stats(self):
+ if self.video_stream is None or self.container is None:
+ return 0
+ num_of_frames = self.container.streams.video[0].frames
+ if num_of_frames == 0:
+ num_of_frames = self.fps * float(self.container.streams.video[0].duration*self.video_stream.time_base)
+ self.seek(0, backward=False)
+ count = 0
+ time_base = 512
+ for p in self.container.decode(video=0):
+ count = count + 1
+ if count == 1:
+ start_pts = p.pts
+ elif count == 2:
+ time_base = p.pts - start_pts
+ break
+ return start_pts, time_base, num_of_frames
+
+ def _get_video_frame_rate(self):
+ return float(self.container.streams.video[0].guessed_rate)
+
+ def sample(self, debug=False):
+
+ if self.container is None:
+ raise RuntimeError('video stream not found')
+ sample = dict()
+ _, _, total_num_frames = self._compute_video_stats()
+ offset = torch.randint(max(1, total_num_frames-self.num_frames-1), [1]).item()
+ video_frames = self._read_video(offset/total_num_frames)
+ video_frames = np.array([np.uint8(f.to_rgb().to_ndarray()) for f in video_frames])
+ sample["frames"] = video_frames
+ sample["frame_idx"] = [offset]
+
+ if self.bi_frame:
+ frames = [np.random.beta(2, 1, size=1), np.random.beta(1, 2, size=1)]
+ frames = [int(frames[0] * self.num_frames), int(frames[1] * self.num_frames)]
+ frames.sort()
+ video_frames = np.array([video_frames[min(frames)], video_frames[max(frames)]])
+ Ts= [min(frames) / (self.num_frames - 1), max(frames) / (self.num_frames - 1)]
+ sample["frames"] = video_frames
+ sample["real_t"] = torch.tensor(Ts, dtype=torch.float32)
+ sample["frame_idx"] = [offset+min(frames), offset+max(frames)]
+ return sample
+
+ return sample
+
+ def read_frames(self, frame_indices):
+ self.num_frames = frame_indices[1] - frame_indices[0]
+ video_frames = self._read_video(frame_indices[0]/self.get_num_frames())
+ video_frames = np.array([
+ np.uint8(video_frames[0].to_rgb().to_ndarray()),
+ np.uint8(video_frames[-1].to_rgb().to_ndarray())
+ ])
+ return video_frames
+
+ def read(self):
+ video_frames = self._read_video(0)
+ video_frames = np.array([np.uint8(f.to_rgb().to_ndarray()) for f in video_frames])
+ return video_frames
+
+ def get_num_frames(self):
+ _, _, total_num_frames = self._compute_video_stats()
+ return total_num_frames
\ No newline at end of file
diff --git a/packages.txt b/packages.txt
new file mode 100644
index 0000000000000000000000000000000000000000..20645e641240cb419f5fc66c14c1447e91daf669
--- /dev/null
+++ b/packages.txt
@@ -0,0 +1 @@
+ffmpeg
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..82f59c87c395ea9454a2578f8b5e84cd361f4342
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,127 @@
+absl-py==1.4.0
+accelerate==0.22.0
+aiofiles==23.2.1
+aiohttp==3.8.5
+aiosignal==1.3.1
+altair==5.0.1
+annotated-types==0.5.0
+antlr4-python3-runtime==4.9.3
+anyio==3.7.1
+async-timeout==4.0.3
+attrs==23.1.0
+cachetools==5.3.1
+certifi==2023.7.22
+charset-normalizer==3.2.0
+click==8.1.7
+cmake==3.27.2
+contourpy==1.1.0
+cycler==0.11.0
+datasets==2.14.4
+dill==0.3.7
+einops==0.6.1
+exceptiongroup==1.1.3
+fastapi==0.103.0
+ffmpy==0.3.1
+filelock==3.12.2
+fonttools==4.42.1
+frozenlist==1.4.0
+fsspec==2023.6.0
+google-auth==2.22.0
+google-auth-oauthlib==1.0.0
+gradio==3.41.2
+gradio-client==0.5.0
+grpcio==1.57.0
+h11==0.14.0
+httpcore==0.17.3
+httpx==0.24.1
+huggingface-hub==0.16.4
+idna==3.4
+importlib-metadata==6.8.0
+importlib-resources==6.0.1
+jinja2==3.1.2
+joblib==1.3.2
+jsonschema==4.19.0
+jsonschema-specifications==2023.7.1
+kiwisolver==1.4.5
+lightning-utilities==0.9.0
+lit==16.0.6
+markdown==3.4.4
+markupsafe==2.1.3
+matplotlib==3.7.2
+mpmath==1.3.0
+multidict==6.0.4
+multiprocess==0.70.15
+networkx==3.1
+numpy==1.24.4
+nvidia-cublas-cu11==11.10.3.66
+nvidia-cuda-cupti-cu11==11.7.101
+nvidia-cuda-nvrtc-cu11==11.7.99
+nvidia-cuda-runtime-cu11==11.7.99
+nvidia-cudnn-cu11==8.5.0.96
+nvidia-cufft-cu11==10.9.0.58
+nvidia-curand-cu11==10.2.10.91
+nvidia-cusolver-cu11==11.4.0.1
+nvidia-cusparse-cu11==11.7.4.91
+nvidia-nccl-cu11==2.14.3
+nvidia-nvtx-cu11==11.7.91
+oauthlib==3.2.2
+omegaconf==2.3.0
+opencv-python==4.8.0.76
+orjson==3.9.5
+pandas==2.0.3
+pillow==9.5.0
+pkgutil-resolve-name==1.3.10
+protobuf==4.24.2
+psutil==5.9.5
+pyarrow==13.0.0
+pyasn1==0.5.0
+pyasn1-modules==0.3.0
+pydantic==2.3.0
+pydantic-core==2.6.3
+pydub==0.25.1
+pyparsing==3.0.9
+python-multipart==0.0.6
+pytorch-lightning==2.0.7
+pytz==2023.3
+pyyaml==6.0.1
+referencing==0.30.2
+regex==2023.8.8
+requests==2.31.0
+requests-oauthlib==1.3.1
+rpds-py==0.9.2
+rsa==4.9
+safetensors==0.3.3
+semantic-version==2.10.0
+sniffio==1.3.0
+starlette==0.27.0
+sympy==1.12
+tensorboard==2.14.0
+tensorboard-data-server==0.7.1
+tokenizers==0.13.3
+toolz==0.12.0
+torchmetrics==1.1.0
+tqdm==4.66.1
+transformers==4.32.0
+triton==2.0.0
+tzdata==2023.3
+urllib3==1.26.16
+uvicorn==0.23.2
+websockets==11.0.3
+werkzeug==2.3.7
+xxhash==3.3.0
+yarl==1.9.2
+zipp==3.16.2
+decord
+imageio==2.9.0
+imageio-ffmpeg==0.4.3
+timm
+scipy
+scikit-image
+av
+imgaug
+lpips
+ffmpeg-python
+torch==2.0.1
+torchvision==0.15.2
+xformers==0.0.22
+diffusers==0.21.4
diff --git a/samples/animation-2023-12-05T00-24-12/videos/0002_demo4.mp4 b/samples/animation-2023-12-05T00-24-12/videos/0002_demo4.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..a9f0602367e1a8732fb0be34b8ab48ce22f0b769
Binary files /dev/null and b/samples/animation-2023-12-05T00-24-12/videos/0002_demo4.mp4 differ
diff --git a/samples/animation-2023-12-05T00-24-12/videos/0002_demo4/grid.mp4 b/samples/animation-2023-12-05T00-24-12/videos/0002_demo4/grid.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..8c03c1eb2db19845476cc0ea99502a6ea5a582f5
Binary files /dev/null and b/samples/animation-2023-12-05T00-24-12/videos/0002_demo4/grid.mp4 differ
diff --git a/samples/animation-2023-12-05T00-24-12/videos/demo4_demo4.mp4 b/samples/animation-2023-12-05T00-24-12/videos/demo4_demo4.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..8f84ada3cbf1efa43b23138505387260ad642c83
Binary files /dev/null and b/samples/animation-2023-12-05T00-24-12/videos/demo4_demo4.mp4 differ
diff --git a/samples/animation-2023-12-05T00-24-12/videos/demo4_demo4/grid.mp4 b/samples/animation-2023-12-05T00-24-12/videos/demo4_demo4/grid.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..462f7a7a5675834ed62e590b5102b3e78b563bee
Binary files /dev/null and b/samples/animation-2023-12-05T00-24-12/videos/demo4_demo4/grid.mp4 differ
diff --git a/samples/animation-2023-12-05T00-24-12/videos/monalisa_running.mp4 b/samples/animation-2023-12-05T00-24-12/videos/monalisa_running.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..4080a0155411670dc3839a7b8fd76517e3212e49
Binary files /dev/null and b/samples/animation-2023-12-05T00-24-12/videos/monalisa_running.mp4 differ
diff --git a/samples/animation-2023-12-05T00-24-12/videos/monalisa_running/grid.mp4 b/samples/animation-2023-12-05T00-24-12/videos/monalisa_running/grid.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..665b46f9551b4f4c8cb6253662ae7b5546d11f51
Binary files /dev/null and b/samples/animation-2023-12-05T00-24-12/videos/monalisa_running/grid.mp4 differ
diff --git a/samples/animation-2023-12-05T00-37-05/videos/monalisa_running.mp4 b/samples/animation-2023-12-05T00-37-05/videos/monalisa_running.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..98e8cf53ecfefda176bb6ce86bab59fa484f141f
Binary files /dev/null and b/samples/animation-2023-12-05T00-37-05/videos/monalisa_running.mp4 differ
diff --git a/samples/animation-2023-12-05T00-37-05/videos/monalisa_running/grid.mp4 b/samples/animation-2023-12-05T00-37-05/videos/monalisa_running/grid.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..5029c554a8aaa1e92d2a41143ae6bd2b2288dbc0
Binary files /dev/null and b/samples/animation-2023-12-05T00-37-05/videos/monalisa_running/grid.mp4 differ