Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +22 -0
- README.md +168 -12
- __assets__/feature_visualization.png +3 -0
- __assets__/pipeline.png +3 -0
- __assets__/teaser.gif +3 -0
- __assets__/teaser.mp4 +3 -0
- condition_images/rgb/dog_on_grass.png +3 -0
- condition_images/scribble/lion_forest.png +0 -0
- configs/i2v_rgb.jsonl +1 -0
- configs/i2v_rgb.yaml +20 -0
- configs/i2v_sketch.jsonl +1 -0
- configs/i2v_sketch.yaml +20 -0
- configs/model_config/inference-v1.yaml +25 -0
- configs/model_config/inference-v2.yaml +24 -0
- configs/model_config/inference-v3.yaml +22 -0
- configs/model_config/model_config copy.yaml +22 -0
- configs/model_config/model_config.yaml +21 -0
- configs/model_config/model_config_public.yaml +25 -0
- configs/sparsectrl/image_condition.yaml +17 -0
- configs/sparsectrl/latent_condition.yaml +17 -0
- configs/t2v_camera.jsonl +12 -0
- configs/t2v_camera.yaml +19 -0
- configs/t2v_object.jsonl +6 -0
- configs/t2v_object.yaml +19 -0
- environment.yaml +25 -0
- generated_videos/camera_zoom_out_Dog,_lying_on_the_grass76739_76739.mp4 +3 -0
- generated_videos/inference_config.json +21 -0
- generated_videos/sample_white_tiger_Lion,_walks_in_the_forest76739_76739.mp4 +3 -0
- i2v_video_sample.py +157 -0
- models/Motion_Module/Put motion module checkpoints here.txt +0 -0
- motionclone/models/__pycache__/attention.cpython-310.pyc +0 -0
- motionclone/models/__pycache__/attention.cpython-38.pyc +0 -0
- motionclone/models/__pycache__/motion_module.cpython-310.pyc +0 -0
- motionclone/models/__pycache__/motion_module.cpython-38.pyc +0 -0
- motionclone/models/__pycache__/resnet.cpython-310.pyc +0 -0
- motionclone/models/__pycache__/resnet.cpython-38.pyc +0 -0
- motionclone/models/__pycache__/sparse_controlnet.cpython-38.pyc +0 -0
- motionclone/models/__pycache__/unet.cpython-310.pyc +0 -0
- motionclone/models/__pycache__/unet.cpython-38.pyc +0 -0
- motionclone/models/__pycache__/unet_blocks.cpython-310.pyc +0 -0
- motionclone/models/__pycache__/unet_blocks.cpython-38.pyc +0 -0
- motionclone/models/attention.py +611 -0
- motionclone/models/motion_module.py +347 -0
- motionclone/models/resnet.py +218 -0
- motionclone/models/scheduler.py +155 -0
- motionclone/models/sparse_controlnet.py +593 -0
- motionclone/models/unet.py +515 -0
- motionclone/models/unet_blocks.py +760 -0
- motionclone/pipelines/__pycache__/pipeline_animation.cpython-310.pyc +0 -0
- motionclone/pipelines/__pycache__/pipeline_animation.cpython-38.pyc +0 -0
.gitattributes
CHANGED
@@ -33,3 +33,25 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
__assets__/feature_visualization.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
__assets__/pipeline.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
__assets__/teaser.gif filter=lfs diff=lfs merge=lfs -text
|
39 |
+
__assets__/teaser.mp4 filter=lfs diff=lfs merge=lfs -text
|
40 |
+
condition_images/rgb/dog_on_grass.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
generated_videos/camera_zoom_out_Dog,_lying_on_the_grass76739_76739.mp4 filter=lfs diff=lfs merge=lfs -text
|
42 |
+
generated_videos/sample_white_tiger_Lion,_walks_in_the_forest76739_76739.mp4 filter=lfs diff=lfs merge=lfs -text
|
43 |
+
reference_videos/camera_1.mp4 filter=lfs diff=lfs merge=lfs -text
|
44 |
+
reference_videos/camera_pan_down.mp4 filter=lfs diff=lfs merge=lfs -text
|
45 |
+
reference_videos/camera_pan_up.mp4 filter=lfs diff=lfs merge=lfs -text
|
46 |
+
reference_videos/camera_translation_1.mp4 filter=lfs diff=lfs merge=lfs -text
|
47 |
+
reference_videos/camera_translation_2.mp4 filter=lfs diff=lfs merge=lfs -text
|
48 |
+
reference_videos/camera_zoom_in.mp4 filter=lfs diff=lfs merge=lfs -text
|
49 |
+
reference_videos/camera_zoom_out.mp4 filter=lfs diff=lfs merge=lfs -text
|
50 |
+
reference_videos/sample_astronaut.mp4 filter=lfs diff=lfs merge=lfs -text
|
51 |
+
reference_videos/sample_blackswan.mp4 filter=lfs diff=lfs merge=lfs -text
|
52 |
+
reference_videos/sample_cat.mp4 filter=lfs diff=lfs merge=lfs -text
|
53 |
+
reference_videos/sample_cow.mp4 filter=lfs diff=lfs merge=lfs -text
|
54 |
+
reference_videos/sample_fox.mp4 filter=lfs diff=lfs merge=lfs -text
|
55 |
+
reference_videos/sample_leaves.mp4 filter=lfs diff=lfs merge=lfs -text
|
56 |
+
reference_videos/sample_white_tiger.mp4 filter=lfs diff=lfs merge=lfs -text
|
57 |
+
reference_videos/sample_wolf.mp4 filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,12 +1,168 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MotionClone
|
2 |
+
This repository is the official implementation of [MotionClone](https://arxiv.org/abs/2406.05338). It is a **training-free framework** that enables motion cloning from a reference video for controllable video generation, **without cumbersome video inversion processes**.
|
3 |
+
<details><summary>Click for the full abstract of MotionClone</summary>
|
4 |
+
|
5 |
+
> Motion-based controllable video generation offers the potential for creating captivating visual content. Existing methods typically necessitate model training to encode particular motion cues or incorporate fine-tuning to inject certain motion patterns, resulting in limited flexibility and generalization.
|
6 |
+
In this work, we propose **MotionClone** a training-free framework that enables motion cloning from reference videos to versatile motion-controlled video generation, including text-to-video and image-to-video. Based on the observation that the dominant components in temporal-attention maps drive motion synthesis, while the rest mainly capture noisy or very subtle motions, MotionClone utilizes sparse temporal attention weights as motion representations for motion guidance, facilitating diverse motion transfer across varying scenarios. Meanwhile, MotionClone allows for the direct extraction of motion representation through a single denoising step, bypassing the cumbersome inversion processes and thus promoting both efficiency and flexibility.
|
7 |
+
Extensive experiments demonstrate that MotionClone exhibits proficiency in both global camera motion and local object motion, with notable superiority in terms of motion fidelity, textual alignment, and temporal consistency.
|
8 |
+
</details>
|
9 |
+
|
10 |
+
**[MotionClone: Training-Free Motion Cloning for Controllable Video Generation](https://arxiv.org/abs/2406.05338)**
|
11 |
+
</br>
|
12 |
+
[Pengyang Ling*](https://github.com/LPengYang/),
|
13 |
+
[Jiazi Bu*](https://github.com/Bujiazi/),
|
14 |
+
[Pan Zhang<sup>†</sup>](https://panzhang0212.github.io/),
|
15 |
+
[Xiaoyi Dong](https://scholar.google.com/citations?user=FscToE0AAAAJ&hl=en/),
|
16 |
+
[Yuhang Zang](https://yuhangzang.github.io/),
|
17 |
+
[Tong Wu](https://wutong16.github.io/),
|
18 |
+
[Huaian Chen](https://scholar.google.com.hk/citations?hl=zh-CN&user=D6ol9XkAAAAJ),
|
19 |
+
[Jiaqi Wang](https://myownskyw7.github.io/),
|
20 |
+
[Yi Jin<sup>†</sup>](https://scholar.google.ca/citations?hl=en&user=mAJ1dCYAAAAJ)
|
21 |
+
(*Equal Contribution)(<sup>†</sup>Corresponding Author)
|
22 |
+
|
23 |
+
<!-- [Arxiv Report](https://arxiv.org/abs/2307.04725) | [Project Page](https://animatediff.github.io/) -->
|
24 |
+
[](https://arxiv.org/abs/2406.05338)
|
25 |
+
[](https://bujiazi.github.io/motionclone.github.io/)
|
26 |
+

|
27 |
+
<!-- [](https://bujiazi.github.io/motionclone.github.io/) -->
|
28 |
+
<!-- [](https://bujiazi.github.io/motionclone.github.io/) -->
|
29 |
+
|
30 |
+
## Demo
|
31 |
+
[![]](https://github.com/user-attachments/assets/d1f1c753-f192-455b-9779-94c925e51aaa)
|
32 |
+
|
33 |
+
```bash
|
34 |
+
sudo apt-get update && sudo apt-get install git-lfs ffmpeg cbm
|
35 |
+
|
36 |
+
conda create --name py310 python=3.10
|
37 |
+
conda activate py310
|
38 |
+
pip install ipykernel
|
39 |
+
python -m ipykernel install --user --name py310 --display-name "py310"
|
40 |
+
|
41 |
+
git clone https://github.com/svjack/MotionClone && cd MotionClone
|
42 |
+
pip install -r requirements.txt
|
43 |
+
|
44 |
+
mkdir -p models
|
45 |
+
git clone https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5 models/StableDiffusion/
|
46 |
+
|
47 |
+
mkdir -p models/DreamBooth_LoRA
|
48 |
+
wget https://huggingface.co/svjack/Realistic-Vision-V6.0-B1/resolve/main/realisticVisionV60B1_v51VAE.safetensors -O models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors
|
49 |
+
|
50 |
+
mkdir -p models/Motion_Module
|
51 |
+
wget https://huggingface.co/guoyww/animatediff/resolve/main/v3_sd15_mm.ckpt -O models/Motion_Module/v3_sd15_mm.ckpt
|
52 |
+
wget https://huggingface.co/guoyww/animatediff/resolve/main/v3_sd15_adapter.ckpt -O models/Motion_Module/v3_sd15_adapter.ckpt
|
53 |
+
|
54 |
+
mkdir -p models/SparseCtrl
|
55 |
+
wget https://huggingface.co/guoyww/animatediff/resolve/main/v3_sd15_sparsectrl_rgb.ckpt -O models/SparseCtrl/v3_sd15_sparsectrl_rgb.ckpt
|
56 |
+
wget https://huggingface.co/guoyww/animatediff/resolve/main/v3_sd15_sparsectrl_scribble.ckpt -O models/SparseCtrl/v3_sd15_sparsectrl_scribble.ckpt
|
57 |
+
```
|
58 |
+
|
59 |
+
## 🖋 News
|
60 |
+
- The latest version of our paper (**v4**) is available on arXiv! (10.08)
|
61 |
+
- The latest version of our paper (**v3**) is available on arXiv! (7.2)
|
62 |
+
- Code released! (6.29)
|
63 |
+
|
64 |
+
## 🏗️ Todo
|
65 |
+
- [x] We have updated the latest version of MotionCloning, which performs motion transfer **without video inversion** and supports **image-to-video and sketch-to-video**.
|
66 |
+
- [x] Release the MotionClone code (We have released **the first version** of our code and will continue to optimize it. We welcome any questions or issues you may have and will address them promptly.)
|
67 |
+
- [x] Release paper
|
68 |
+
|
69 |
+
## 📚 Gallery
|
70 |
+
We show more results in the [Project Page](https://bujiazi.github.io/motionclone.github.io/).
|
71 |
+
|
72 |
+
## 🚀 Method Overview
|
73 |
+
### Feature visualization
|
74 |
+
<div align="center">
|
75 |
+
<img src='__assets__/feature_visualization.png'/>
|
76 |
+
</div>
|
77 |
+
|
78 |
+
### Pipeline
|
79 |
+
<div align="center">
|
80 |
+
<img src='__assets__/pipeline.png'/>
|
81 |
+
</div>
|
82 |
+
|
83 |
+
MotionClone utilizes sparse temporal attention weights as motion representations for motion guidance, facilitating diverse motion transfer across varying scenarios. Meanwhile, MotionClone allows for the direct extraction of motion representation through a single denoising step, bypassing the cumbersome inversion processes and thus promoting both efficiency and flexibility.
|
84 |
+
|
85 |
+
## 🔧 Installations (python==3.11.3 recommended)
|
86 |
+
|
87 |
+
### Setup repository and conda environment
|
88 |
+
|
89 |
+
```
|
90 |
+
git clone https://github.com/Bujiazi/MotionClone.git
|
91 |
+
cd MotionClone
|
92 |
+
|
93 |
+
conda env create -f environment.yaml
|
94 |
+
conda activate motionclone
|
95 |
+
```
|
96 |
+
|
97 |
+
## 🔑 Pretrained Model Preparations
|
98 |
+
|
99 |
+
### Download Stable Diffusion V1.5
|
100 |
+
|
101 |
+
```
|
102 |
+
git lfs install
|
103 |
+
git clone https://huggingface.co/runwayml/stable-diffusion-v1-5 models/StableDiffusion/
|
104 |
+
```
|
105 |
+
|
106 |
+
After downloading Stable Diffusion, save them to `models/StableDiffusion`.
|
107 |
+
|
108 |
+
### Prepare Community Models
|
109 |
+
|
110 |
+
Manually download the community `.safetensors` models from [RealisticVision V5.1](https://civitai.com/models/4201?modelVersionId=130072) and save them to `models/DreamBooth_LoRA`.
|
111 |
+
|
112 |
+
### Prepare AnimateDiff Motion Modules
|
113 |
+
|
114 |
+
Manually download the AnimateDiff modules from [AnimateDiff](https://github.com/guoyww/AnimateDiff), we recommend [`v3_adapter_sd_v15.ckpt`](https://huggingface.co/guoyww/animatediff/blob/main/v3_sd15_adapter.ckpt) and [`v3_sd15_mm.ckpt.ckpt`](https://huggingface.co/guoyww/animatediff/blob/main/v3_sd15_mm.ckpt). Save the modules to `models/Motion_Module`.
|
115 |
+
|
116 |
+
### Prepare SparseCtrl for image-to-video and sketch-to-video
|
117 |
+
Manually download "v3_sd15_sparsectrl_rgb.ckpt" and "v3_sd15_sparsectrl_scribble.ckpt" from [AnimateDiff](https://huggingface.co/guoyww/animatediff/tree/main). Save the modules to `models/SparseCtrl`.
|
118 |
+
|
119 |
+
## 🎈 Quick Start
|
120 |
+
|
121 |
+
### Perform Text-to-video generation with customized camera motion
|
122 |
+
```
|
123 |
+
python t2v_video_sample.py --inference_config "configs/t2v_camera.yaml" --examples "configs/t2v_camera.jsonl"
|
124 |
+
```
|
125 |
+
|
126 |
+
https://github.com/user-attachments/assets/2656a49a-c57d-4f89-bc65-5ec09ac037ea
|
127 |
+
|
128 |
+
|
129 |
+
|
130 |
+
|
131 |
+
|
132 |
+
### Perform Text-to-video generation with customized object motion
|
133 |
+
```
|
134 |
+
python t2v_video_sample.py --inference_config "configs/t2v_object.yaml" --examples "configs/t2v_object.jsonl"
|
135 |
+
```
|
136 |
+
### Combine motion cloning with sketch-to-video
|
137 |
+
```
|
138 |
+
python i2v_video_sample.py --inference_config "configs/i2v_sketch.yaml" --examples "configs/i2v_sketch.jsonl"
|
139 |
+
```
|
140 |
+
### Combine motion cloning with image-to-video
|
141 |
+
```
|
142 |
+
python i2v_video_sample.py --inference_config "configs/i2v_rgb.yaml" --examples "configs/i2v_rgb.jsonl"
|
143 |
+
```
|
144 |
+
|
145 |
+
|
146 |
+
## 📎 Citation
|
147 |
+
|
148 |
+
If you find this work helpful, please cite the following paper:
|
149 |
+
|
150 |
+
```
|
151 |
+
@article{ling2024motionclone,
|
152 |
+
title={MotionClone: Training-Free Motion Cloning for Controllable Video Generation},
|
153 |
+
author={Ling, Pengyang and Bu, Jiazi and Zhang, Pan and Dong, Xiaoyi and Zang, Yuhang and Wu, Tong and Chen, Huaian and Wang, Jiaqi and Jin, Yi},
|
154 |
+
journal={arXiv preprint arXiv:2406.05338},
|
155 |
+
year={2024}
|
156 |
+
}
|
157 |
+
```
|
158 |
+
|
159 |
+
## 📣 Disclaimer
|
160 |
+
|
161 |
+
This is official code of MotionClone.
|
162 |
+
All the copyrights of the demo images and audio are from community users.
|
163 |
+
Feel free to contact us if you would like remove them.
|
164 |
+
|
165 |
+
## 💞 Acknowledgements
|
166 |
+
The code is built upon the below repositories, we thank all the contributors for open-sourcing.
|
167 |
+
* [AnimateDiff](https://github.com/guoyww/AnimateDiff)
|
168 |
+
* [FreeControl](https://github.com/genforce/freecontrol)
|
__assets__/feature_visualization.png
ADDED
![]() |
Git LFS Details
|
__assets__/pipeline.png
ADDED
![]() |
Git LFS Details
|
__assets__/teaser.gif
ADDED
![]() |
Git LFS Details
|
__assets__/teaser.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:201747f42691e708b9efe48ea054961fd82cf54b83ac43e0d97a43f81779c00b
|
3 |
+
size 4957080
|
condition_images/rgb/dog_on_grass.png
ADDED
![]() |
Git LFS Details
|
condition_images/scribble/lion_forest.png
ADDED
![]() |
configs/i2v_rgb.jsonl
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"video_path":"reference_videos/camera_zoom_out.mp4", "condition_image_paths":["condition_images/rgb/dog_on_grass.png"], "new_prompt": "Dog, lying on the grass"}
|
configs/i2v_rgb.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
motion_module: "models/Motion_Module/v3_sd15_mm.ckpt"
|
2 |
+
dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors"
|
3 |
+
model_config: "configs/model_config/model_config.yaml"
|
4 |
+
controlnet_path: "models/SparseCtrl/v3_sd15_sparsectrl_rgb.ckpt"
|
5 |
+
controlnet_config: "configs/sparsectrl/latent_condition.yaml"
|
6 |
+
adapter_lora_path: "models/Motion_Module/v3_sd15_adapter.ckpt"
|
7 |
+
|
8 |
+
cfg_scale: 7.5 # in default realistic classifer-free guidance
|
9 |
+
negative_prompt: "ugly, deformed, noisy, blurry, distorted, out of focus, bad anatomy, extra limbs, poorly drawn face, poorly drawn hands, missing fingers"
|
10 |
+
|
11 |
+
inference_steps: 100 # the total denosing step for inference
|
12 |
+
guidance_scale: 0.3 # which scale of time step to end guidance
|
13 |
+
guidance_steps: 40 # the step for guidance in inference, no more than 1000*guidance_scale, the remaining steps (inference_steps-guidance_steps) is performed without gudiance
|
14 |
+
warm_up_steps: 10
|
15 |
+
cool_up_steps: 10
|
16 |
+
|
17 |
+
motion_guidance_weight: 2000
|
18 |
+
motion_guidance_blocks: ['up_blocks.1']
|
19 |
+
|
20 |
+
add_noise_step: 400
|
configs/i2v_sketch.jsonl
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"video_path":"reference_videos/sample_white_tiger.mp4", "condition_image_paths":["condition_images/scribble/lion_forest.png"], "new_prompt": "Lion, walks in the forest"}
|
configs/i2v_sketch.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
motion_module: "models/Motion_Module/v3_sd15_mm.ckpt"
|
2 |
+
dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors"
|
3 |
+
model_config: "configs/model_config/model_config.yaml"
|
4 |
+
controlnet_config: "configs/sparsectrl/image_condition.yaml"
|
5 |
+
controlnet_path: "models/SparseCtrl/v3_sd15_sparsectrl_scribble.ckpt"
|
6 |
+
adapter_lora_path: "models/Motion_Module/v3_sd15_adapter.ckpt"
|
7 |
+
|
8 |
+
cfg_scale: 7.5 # in default realistic classifer-free guidance
|
9 |
+
negative_prompt: "ugly, deformed, noisy, blurry, distorted, out of focus, bad anatomy, extra limbs, poorly drawn face, poorly drawn hands, missing fingers"
|
10 |
+
|
11 |
+
inference_steps: 200 # the total denosing step for inference
|
12 |
+
guidance_scale: 0.4 # which scale of time step to end guidance
|
13 |
+
guidance_steps: 120 # the step for guidance in inference, no more than 1000*guidance_scale, the remaining steps (inference_steps-guidance_steps) is performed without gudiance
|
14 |
+
warm_up_steps: 10
|
15 |
+
cool_up_steps: 10
|
16 |
+
|
17 |
+
motion_guidance_weight: 2000
|
18 |
+
motion_guidance_blocks: ['up_blocks.1']
|
19 |
+
|
20 |
+
add_noise_step: 400
|
configs/model_config/inference-v1.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
unet_additional_kwargs:
|
2 |
+
use_inflated_groupnorm: true # from config v3
|
3 |
+
|
4 |
+
|
5 |
+
use_motion_module: true
|
6 |
+
motion_module_resolutions: [1,2,4,8]
|
7 |
+
motion_module_mid_block: false
|
8 |
+
motion_module_decoder_only: false
|
9 |
+
motion_module_type: "Vanilla"
|
10 |
+
|
11 |
+
motion_module_kwargs:
|
12 |
+
num_attention_heads: 8
|
13 |
+
num_transformer_block: 1
|
14 |
+
attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
|
15 |
+
temporal_position_encoding: true
|
16 |
+
temporal_position_encoding_max_len: 32
|
17 |
+
temporal_attention_dim_div: 1
|
18 |
+
zero_initialize: true # from config v3
|
19 |
+
|
20 |
+
noise_scheduler_kwargs:
|
21 |
+
beta_start: 0.00085
|
22 |
+
beta_end: 0.012
|
23 |
+
beta_schedule: "linear"
|
24 |
+
steps_offset: 1
|
25 |
+
clip_sample: False
|
configs/model_config/inference-v2.yaml
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
unet_additional_kwargs:
|
2 |
+
use_inflated_groupnorm: true
|
3 |
+
unet_use_cross_frame_attention: false
|
4 |
+
unet_use_temporal_attention: false
|
5 |
+
use_motion_module: true
|
6 |
+
motion_module_resolutions: [1,2,4,8]
|
7 |
+
motion_module_mid_block: true
|
8 |
+
motion_module_decoder_only: false
|
9 |
+
motion_module_type: "Vanilla"
|
10 |
+
|
11 |
+
motion_module_kwargs:
|
12 |
+
num_attention_heads: 8
|
13 |
+
num_transformer_block: 1
|
14 |
+
attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
|
15 |
+
temporal_position_encoding: true
|
16 |
+
temporal_position_encoding_max_len: 32
|
17 |
+
temporal_attention_dim_div: 1
|
18 |
+
|
19 |
+
noise_scheduler_kwargs:
|
20 |
+
beta_start: 0.00085
|
21 |
+
beta_end: 0.012
|
22 |
+
beta_schedule: "linear"
|
23 |
+
steps_offset: 1
|
24 |
+
clip_sample: False
|
configs/model_config/inference-v3.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
unet_additional_kwargs:
|
2 |
+
use_inflated_groupnorm: true
|
3 |
+
use_motion_module: true
|
4 |
+
motion_module_resolutions: [1,2,4,8]
|
5 |
+
motion_module_mid_block: false
|
6 |
+
motion_module_type: Vanilla
|
7 |
+
|
8 |
+
motion_module_kwargs:
|
9 |
+
num_attention_heads: 8
|
10 |
+
num_transformer_block: 1
|
11 |
+
attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
|
12 |
+
temporal_position_encoding: true
|
13 |
+
temporal_position_encoding_max_len: 32
|
14 |
+
temporal_attention_dim_div: 1
|
15 |
+
zero_initialize: true
|
16 |
+
|
17 |
+
noise_scheduler_kwargs:
|
18 |
+
beta_start: 0.00085
|
19 |
+
beta_end: 0.012
|
20 |
+
beta_schedule: "linear"
|
21 |
+
steps_offset: 1
|
22 |
+
clip_sample: False
|
configs/model_config/model_config copy.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
unet_additional_kwargs:
|
2 |
+
use_inflated_groupnorm: true # from config v3
|
3 |
+
use_motion_module: true
|
4 |
+
motion_module_resolutions: [1,2,4,8]
|
5 |
+
motion_module_mid_block: false
|
6 |
+
motion_module_type: "Vanilla"
|
7 |
+
|
8 |
+
motion_module_kwargs:
|
9 |
+
num_attention_heads: 8
|
10 |
+
num_transformer_block: 1
|
11 |
+
attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
|
12 |
+
temporal_position_encoding: true
|
13 |
+
temporal_position_encoding_max_len: 32
|
14 |
+
temporal_attention_dim_div: 1
|
15 |
+
zero_initialize: true # from config v3
|
16 |
+
|
17 |
+
noise_scheduler_kwargs:
|
18 |
+
beta_start: 0.00085
|
19 |
+
beta_end: 0.012
|
20 |
+
beta_schedule: "linear"
|
21 |
+
steps_offset: 1
|
22 |
+
clip_sample: False
|
configs/model_config/model_config.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
unet_additional_kwargs:
|
2 |
+
use_inflated_groupnorm: true
|
3 |
+
use_motion_module: true
|
4 |
+
motion_module_resolutions: [1,2,4,8]
|
5 |
+
motion_module_mid_block: false
|
6 |
+
motion_module_type: "Vanilla"
|
7 |
+
|
8 |
+
motion_module_kwargs:
|
9 |
+
num_attention_heads: 8
|
10 |
+
num_transformer_block: 1
|
11 |
+
attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
|
12 |
+
temporal_position_encoding: true
|
13 |
+
temporal_attention_dim_div: 1
|
14 |
+
zero_initialize: true
|
15 |
+
|
16 |
+
noise_scheduler_kwargs:
|
17 |
+
beta_start: 0.00085
|
18 |
+
beta_end: 0.012
|
19 |
+
beta_schedule: "linear"
|
20 |
+
steps_offset: 1
|
21 |
+
clip_sample: false
|
configs/model_config/model_config_public.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
unet_additional_kwargs:
|
2 |
+
use_inflated_groupnorm: true # from config v3
|
3 |
+
unet_use_cross_frame_attention: false
|
4 |
+
unet_use_temporal_attention: false
|
5 |
+
use_motion_module: true
|
6 |
+
motion_module_resolutions: [1,2,4,8]
|
7 |
+
motion_module_mid_block: false
|
8 |
+
motion_module_decoder_only: false
|
9 |
+
motion_module_type: "Vanilla"
|
10 |
+
|
11 |
+
motion_module_kwargs:
|
12 |
+
num_attention_heads: 8
|
13 |
+
num_transformer_block: 1
|
14 |
+
attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
|
15 |
+
temporal_position_encoding: true
|
16 |
+
temporal_position_encoding_max_len: 32
|
17 |
+
temporal_attention_dim_div: 1
|
18 |
+
zero_initialize: true # from config v3
|
19 |
+
|
20 |
+
noise_scheduler_kwargs:
|
21 |
+
beta_start: 0.00085
|
22 |
+
beta_end: 0.012
|
23 |
+
beta_schedule: "linear"
|
24 |
+
steps_offset: 1
|
25 |
+
clip_sample: False
|
configs/sparsectrl/image_condition.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
controlnet_additional_kwargs:
|
2 |
+
set_noisy_sample_input_to_zero: true
|
3 |
+
use_simplified_condition_embedding: false
|
4 |
+
conditioning_channels: 3
|
5 |
+
|
6 |
+
use_motion_module: true
|
7 |
+
motion_module_resolutions: [1,2,4,8]
|
8 |
+
motion_module_mid_block: false
|
9 |
+
motion_module_type: "Vanilla"
|
10 |
+
|
11 |
+
motion_module_kwargs:
|
12 |
+
num_attention_heads: 8
|
13 |
+
num_transformer_block: 1
|
14 |
+
attention_block_types: [ "Temporal_Self" ]
|
15 |
+
temporal_position_encoding: true
|
16 |
+
temporal_position_encoding_max_len: 32
|
17 |
+
temporal_attention_dim_div: 1
|
configs/sparsectrl/latent_condition.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
controlnet_additional_kwargs:
|
2 |
+
set_noisy_sample_input_to_zero: true
|
3 |
+
use_simplified_condition_embedding: true
|
4 |
+
conditioning_channels: 4
|
5 |
+
|
6 |
+
use_motion_module: true
|
7 |
+
motion_module_resolutions: [1,2,4,8]
|
8 |
+
motion_module_mid_block: false
|
9 |
+
motion_module_type: "Vanilla"
|
10 |
+
|
11 |
+
motion_module_kwargs:
|
12 |
+
num_attention_heads: 8
|
13 |
+
num_transformer_block: 1
|
14 |
+
attention_block_types: [ "Temporal_Self" ]
|
15 |
+
temporal_position_encoding: true
|
16 |
+
temporal_position_encoding_max_len: 32
|
17 |
+
temporal_attention_dim_div: 1
|
configs/t2v_camera.jsonl
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{"video_path":"reference_videos/camera_zoom_in.mp4", "new_prompt": "Relics on the seabed", "seed": 42}
|
2 |
+
{"video_path":"reference_videos/camera_zoom_in.mp4", "new_prompt": "A road in the mountain", "seed": 42}
|
3 |
+
{"video_path":"reference_videos/camera_zoom_in.mp4", "new_prompt": "Caves, a path for exploration", "seed": 2026}
|
4 |
+
{"video_path":"reference_videos/camera_zoom_in.mp4", "new_prompt": "Railway for train"}
|
5 |
+
{"video_path":"reference_videos/camera_zoom_out.mp4", "new_prompt": "Tree, in the mountain", "seed": 2026}
|
6 |
+
{"video_path":"reference_videos/camera_zoom_out.mp4", "new_prompt": "Red car on the track", "seed": 2026}
|
7 |
+
{"video_path":"reference_videos/camera_zoom_out.mp4", "new_prompt": "Man, standing in his garden.", "seed": 2026}
|
8 |
+
{"video_path":"reference_videos/camera_1.mp4", "new_prompt": "A island, on the ocean, sunny day"}
|
9 |
+
{"video_path":"reference_videos/camera_1.mp4", "new_prompt": "A tower, with fireworks"}
|
10 |
+
{"video_path":"reference_videos/camera_pan_up.mp4", "new_prompt": "Beautiful house, around with flowers", "seed": 42}
|
11 |
+
{"video_path":"reference_videos/camera_translation_2.mp4", "new_prompt": "Forest, in winter", "seed": 2028}
|
12 |
+
{"video_path":"reference_videos/camera_pan_down.mp4", "new_prompt": "Eagle, standing in the tree", "seed": 2026}
|
configs/t2v_camera.yaml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
motion_module: "models/Motion_Module/v3_sd15_mm.ckpt"
|
3 |
+
dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors"
|
4 |
+
model_config: "configs/model_config/model_config.yaml"
|
5 |
+
|
6 |
+
cfg_scale: 7.5 # in default realistic classifer-free guidance
|
7 |
+
negative_prompt: "bad anatomy, extra limbs, ugly, deformed, noisy, blurry, distorted, out of focus, poorly drawn face, poorly drawn hands, missing fingers"
|
8 |
+
postive_prompt: " 8k, high detailed, best quality, film grain, Fujifilm XT3"
|
9 |
+
|
10 |
+
inference_steps: 100 # the total denosing step for inference
|
11 |
+
guidance_scale: 0.3 # which scale of time step to end guidance 0.2/40
|
12 |
+
guidance_steps: 50 # the step for guidance in inference, no more than 1000*guidance_scale, the remaining steps (inference_steps-guidance_steps) is performed without gudiance
|
13 |
+
warm_up_steps: 10
|
14 |
+
cool_up_steps: 10
|
15 |
+
|
16 |
+
motion_guidance_weight: 2000
|
17 |
+
motion_guidance_blocks: ['up_blocks.1']
|
18 |
+
|
19 |
+
add_noise_step: 400
|
configs/t2v_object.jsonl
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{"video_path":"reference_videos/sample_astronaut.mp4", "new_prompt": "Robot, walks in the street.","seed":59}
|
2 |
+
{"video_path":"reference_videos/sample_cat.mp4", "new_prompt": "Tiger, raises its head.", "seed": 2025}
|
3 |
+
{"video_path":"reference_videos/sample_leaves.mp4", "new_prompt": "Petals falling in the wind.","seed":3407}
|
4 |
+
{"video_path":"reference_videos/sample_fox.mp4", "new_prompt": "Cat, turns its head in the living room."}
|
5 |
+
{"video_path":"reference_videos/sample_blackswan.mp4", "new_prompt": "Duck, swims in the river.","seed":3407}
|
6 |
+
{"video_path":"reference_videos/sample_cow.mp4", "new_prompt": "Pig, drinks water on beach.","seed":3407}
|
configs/t2v_object.yaml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
motion_module: "models/Motion_Module/v3_sd15_mm.ckpt"
|
3 |
+
dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors"
|
4 |
+
model_config: "configs/model_config/model_config.yaml"
|
5 |
+
|
6 |
+
cfg_scale: 7.5 # in default realistic classifer-free guidance
|
7 |
+
negative_prompt: "bad anatomy, extra limbs, ugly, deformed, noisy, blurry, distorted, out of focus, poorly drawn face, poorly drawn hands, missing fingers"
|
8 |
+
postive_prompt: "8k, high detailed, best quality, film grain, Fujifilm XT3"
|
9 |
+
|
10 |
+
inference_steps: 300 # the total denosing step for inference
|
11 |
+
guidance_scale: 0.4 # which scale of time step to end guidance
|
12 |
+
guidance_steps: 180 # the step for guidance in inference, no more than 1000*guidance_scale, the remaining steps (inference_steps-guidance_steps) is performed without gudiance
|
13 |
+
warm_up_steps: 10
|
14 |
+
cool_up_steps: 10
|
15 |
+
|
16 |
+
motion_guidance_weight: 2000
|
17 |
+
motion_guidance_blocks: ['up_blocks.1',]
|
18 |
+
|
19 |
+
add_noise_step: 400
|
environment.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: motionclone
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- nvidia
|
5 |
+
dependencies:
|
6 |
+
- python=3.11.3
|
7 |
+
- pytorch=2.0.1
|
8 |
+
- torchvision=0.15.2
|
9 |
+
- pytorch-cuda=11.8
|
10 |
+
- pip
|
11 |
+
- pip:
|
12 |
+
- accelerate
|
13 |
+
- diffusers==0.16.0
|
14 |
+
- transformers==4.28.1
|
15 |
+
- xformers==0.0.20
|
16 |
+
- imageio[ffmpeg]
|
17 |
+
- decord==0.6.0
|
18 |
+
- gdown
|
19 |
+
- einops
|
20 |
+
- omegaconf
|
21 |
+
- safetensors
|
22 |
+
- gradio
|
23 |
+
- wandb
|
24 |
+
- triton
|
25 |
+
- opencv-python
|
generated_videos/camera_zoom_out_Dog,_lying_on_the_grass76739_76739.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:63ecf6f1250b83d71b50352a020c97eb60223ee33813219b2bd8d7588f1ecfec
|
3 |
+
size 285735
|
generated_videos/inference_config.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
motion_module: models/Motion_Module/v3_sd15_mm.ckpt
|
2 |
+
dreambooth_path: models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors
|
3 |
+
model_config: configs/model_config/model_config.yaml
|
4 |
+
controlnet_config: configs/sparsectrl/image_condition.yaml
|
5 |
+
controlnet_path: models/SparseCtrl/v3_sd15_sparsectrl_scribble.ckpt
|
6 |
+
adapter_lora_path: models/Motion_Module/v3_sd15_adapter.ckpt
|
7 |
+
cfg_scale: 7.5
|
8 |
+
negative_prompt: ugly, deformed, noisy, blurry, distorted, out of focus, bad anatomy,
|
9 |
+
extra limbs, poorly drawn face, poorly drawn hands, missing fingers
|
10 |
+
inference_steps: 200
|
11 |
+
guidance_scale: 0.4
|
12 |
+
guidance_steps: 120
|
13 |
+
warm_up_steps: 10
|
14 |
+
cool_up_steps: 10
|
15 |
+
motion_guidance_weight: 2000
|
16 |
+
motion_guidance_blocks:
|
17 |
+
- up_blocks.1
|
18 |
+
add_noise_step: 400
|
19 |
+
width: 512
|
20 |
+
height: 512
|
21 |
+
video_length: 16
|
generated_videos/sample_white_tiger_Lion,_walks_in_the_forest76739_76739.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3ae68b549f1c6541417009d1cdd35d01286876bada07fb53a3354ad9225856cf
|
3 |
+
size 538343
|
i2v_video_sample.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from omegaconf import OmegaConf
|
3 |
+
import torch
|
4 |
+
from diffusers import AutoencoderKL, DDIMScheduler
|
5 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
6 |
+
from motionclone.models.unet import UNet3DConditionModel
|
7 |
+
from motionclone.models.sparse_controlnet import SparseControlNetModel
|
8 |
+
from motionclone.pipelines.pipeline_animation import AnimationPipeline
|
9 |
+
from motionclone.utils.util import load_weights, auto_download
|
10 |
+
from diffusers.utils.import_utils import is_xformers_available
|
11 |
+
from motionclone.utils.motionclone_functions import *
|
12 |
+
import json
|
13 |
+
from motionclone.utils.xformer_attention import *
|
14 |
+
|
15 |
+
|
16 |
+
def main(args):
|
17 |
+
|
18 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.visible_gpu or str(os.getenv('CUDA_VISIBLE_DEVICES', 0))
|
19 |
+
|
20 |
+
config = OmegaConf.load(args.inference_config)
|
21 |
+
adopted_dtype = torch.float16
|
22 |
+
device = "cuda"
|
23 |
+
set_all_seed(42)
|
24 |
+
|
25 |
+
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
|
26 |
+
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder").to(device).to(dtype=adopted_dtype)
|
27 |
+
vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae").to(device).to(dtype=adopted_dtype)
|
28 |
+
|
29 |
+
config.width = config.get("W", args.W)
|
30 |
+
config.height = config.get("H", args.H)
|
31 |
+
config.video_length = config.get("L", args.L)
|
32 |
+
|
33 |
+
if not os.path.exists(args.generated_videos_save_dir):
|
34 |
+
os.makedirs(args.generated_videos_save_dir)
|
35 |
+
OmegaConf.save(config, os.path.join(args.generated_videos_save_dir,"inference_config.json"))
|
36 |
+
|
37 |
+
model_config = OmegaConf.load(config.get("model_config", ""))
|
38 |
+
unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(model_config.unet_additional_kwargs),).to(device).to(dtype=adopted_dtype)
|
39 |
+
|
40 |
+
# load controlnet model
|
41 |
+
controlnet = None
|
42 |
+
if config.get("controlnet_path", "") != "":
|
43 |
+
# assert model_config.get("controlnet_images", "") != ""
|
44 |
+
assert config.get("controlnet_config", "") != ""
|
45 |
+
|
46 |
+
unet.config.num_attention_heads = 8
|
47 |
+
unet.config.projection_class_embeddings_input_dim = None
|
48 |
+
|
49 |
+
controlnet_config = OmegaConf.load(config.controlnet_config)
|
50 |
+
controlnet = SparseControlNetModel.from_unet(unet, controlnet_additional_kwargs=controlnet_config.get("controlnet_additional_kwargs", {})).to(device).to(dtype=adopted_dtype)
|
51 |
+
|
52 |
+
auto_download(config.controlnet_path, is_dreambooth_lora=False)
|
53 |
+
print(f"loading controlnet checkpoint from {config.controlnet_path} ...")
|
54 |
+
controlnet_state_dict = torch.load(config.controlnet_path, map_location="cpu")
|
55 |
+
controlnet_state_dict = controlnet_state_dict["controlnet"] if "controlnet" in controlnet_state_dict else controlnet_state_dict
|
56 |
+
controlnet_state_dict = {name: param for name, param in controlnet_state_dict.items() if "pos_encoder.pe" not in name}
|
57 |
+
controlnet_state_dict.pop("animatediff_config", "")
|
58 |
+
controlnet.load_state_dict(controlnet_state_dict)
|
59 |
+
del controlnet_state_dict
|
60 |
+
|
61 |
+
# set xformers
|
62 |
+
if is_xformers_available() and (not args.without_xformers):
|
63 |
+
unet.enable_xformers_memory_efficient_attention()
|
64 |
+
|
65 |
+
pipeline = AnimationPipeline(
|
66 |
+
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
|
67 |
+
controlnet=controlnet,
|
68 |
+
scheduler=DDIMScheduler(**OmegaConf.to_container(model_config.noise_scheduler_kwargs)),
|
69 |
+
).to(device)
|
70 |
+
|
71 |
+
pipeline = load_weights(
|
72 |
+
pipeline,
|
73 |
+
# motion module
|
74 |
+
motion_module_path = config.get("motion_module", ""),
|
75 |
+
# domain adapter
|
76 |
+
adapter_lora_path = config.get("adapter_lora_path", ""),
|
77 |
+
adapter_lora_scale = config.get("adapter_lora_scale", 1.0),
|
78 |
+
# image layer
|
79 |
+
dreambooth_model_path = config.get("dreambooth_path", ""),
|
80 |
+
).to(device)
|
81 |
+
pipeline.text_encoder.to(dtype=adopted_dtype)
|
82 |
+
|
83 |
+
# customized functions in motionclone_functions
|
84 |
+
pipeline.scheduler.customized_step = schedule_customized_step.__get__(pipeline.scheduler)
|
85 |
+
pipeline.scheduler.customized_set_timesteps = schedule_set_timesteps.__get__(pipeline.scheduler)
|
86 |
+
pipeline.unet.forward = unet_customized_forward.__get__(pipeline.unet)
|
87 |
+
pipeline.sample_video = sample_video.__get__(pipeline)
|
88 |
+
pipeline.single_step_video = single_step_video.__get__(pipeline)
|
89 |
+
pipeline.get_temp_attn_prob = get_temp_attn_prob.__get__(pipeline)
|
90 |
+
pipeline.add_noise = add_noise.__get__(pipeline)
|
91 |
+
pipeline.compute_temp_loss = compute_temp_loss.__get__(pipeline)
|
92 |
+
pipeline.obtain_motion_representation = obtain_motion_representation.__get__(pipeline)
|
93 |
+
|
94 |
+
for param in pipeline.unet.parameters():
|
95 |
+
param.requires_grad = False
|
96 |
+
for param in pipeline.controlnet.parameters():
|
97 |
+
param.requires_grad = False
|
98 |
+
|
99 |
+
pipeline.input_config, pipeline.unet.input_config = config, config
|
100 |
+
pipeline.unet = prep_unet_attention(pipeline.unet,pipeline.input_config.motion_guidance_blocks)
|
101 |
+
pipeline.unet = prep_unet_conv(pipeline.unet)
|
102 |
+
pipeline.scheduler.customized_set_timesteps(config.inference_steps, config.guidance_steps,config.guidance_scale,device=device,timestep_spacing_type = "uneven")
|
103 |
+
|
104 |
+
with open(args.examples, 'r') as files:
|
105 |
+
for line in files:
|
106 |
+
# prepare infor of each case
|
107 |
+
example_infor = json.loads(line)
|
108 |
+
config.video_path = example_infor["video_path"]
|
109 |
+
config.condition_image_path_list = example_infor["condition_image_paths"]
|
110 |
+
config.image_index = example_infor.get("image_index",[0])
|
111 |
+
assert len(config.image_index) == len(config.condition_image_path_list)
|
112 |
+
config.new_prompt = example_infor["new_prompt"] + config.get("positive_prompt", "")
|
113 |
+
config.controlnet_scale = example_infor.get("controlnet_scale", 1.0)
|
114 |
+
pipeline.input_config, pipeline.unet.input_config = config, config # update config
|
115 |
+
|
116 |
+
# perform motion representation extraction
|
117 |
+
seed_motion = seed_motion = example_infor.get("seed", args.default_seed)
|
118 |
+
generator = torch.Generator(device=pipeline.device)
|
119 |
+
generator.manual_seed(seed_motion)
|
120 |
+
if not os.path.exists(args.motion_representation_save_dir):
|
121 |
+
os.makedirs(args.motion_representation_save_dir)
|
122 |
+
motion_representation_path = os.path.join(args.motion_representation_save_dir, os.path.splitext(os.path.basename(config.video_path))[0] + '.pt')
|
123 |
+
pipeline.obtain_motion_representation(generator= generator, motion_representation_path = motion_representation_path, use_controlnet=True,)
|
124 |
+
|
125 |
+
# perform video generation
|
126 |
+
seed = seed_motion # can assign other seed here
|
127 |
+
generator = torch.Generator(device=pipeline.device)
|
128 |
+
generator.manual_seed(seed)
|
129 |
+
pipeline.input_config.seed = seed
|
130 |
+
videos = pipeline.sample_video(generator = generator, add_controlnet=True,)
|
131 |
+
|
132 |
+
videos = rearrange(videos, "b c f h w -> b f h w c")
|
133 |
+
save_path = os.path.join(args.generated_videos_save_dir, os.path.splitext(os.path.basename(config.video_path))[0]
|
134 |
+
+ "_" + config.new_prompt.strip().replace(' ', '_') + str(seed_motion) + "_" +str(seed)+'.mp4')
|
135 |
+
videos_uint8 = (videos[0] * 255).astype(np.uint8)
|
136 |
+
imageio.mimwrite(save_path, videos_uint8, fps=8)
|
137 |
+
print(save_path,"is done")
|
138 |
+
|
139 |
+
if __name__ == "__main__":
|
140 |
+
parser = argparse.ArgumentParser()
|
141 |
+
parser.add_argument("--pretrained-model-path", type=str, default="models/StableDiffusion",)
|
142 |
+
|
143 |
+
parser.add_argument("--inference_config", type=str, default="configs/i2v_sketch.yaml")
|
144 |
+
parser.add_argument("--examples", type=str, default="configs/i2v_sketch.jsonl")
|
145 |
+
parser.add_argument("--motion-representation-save-dir", type=str, default="motion_representation/")
|
146 |
+
parser.add_argument("--generated-videos-save-dir", type=str, default="generated_videos/")
|
147 |
+
|
148 |
+
parser.add_argument("--visible_gpu", type=str, default=None)
|
149 |
+
parser.add_argument("--default-seed", type=int, default=76739)
|
150 |
+
parser.add_argument("--L", type=int, default=16)
|
151 |
+
parser.add_argument("--W", type=int, default=512)
|
152 |
+
parser.add_argument("--H", type=int, default=512)
|
153 |
+
|
154 |
+
parser.add_argument("--without-xformers", action="store_true")
|
155 |
+
|
156 |
+
args = parser.parse_args()
|
157 |
+
main(args)
|
models/Motion_Module/Put motion module checkpoints here.txt
ADDED
File without changes
|
motionclone/models/__pycache__/attention.cpython-310.pyc
ADDED
Binary file (13.7 kB). View file
|
|
motionclone/models/__pycache__/attention.cpython-38.pyc
ADDED
Binary file (13.6 kB). View file
|
|
motionclone/models/__pycache__/motion_module.cpython-310.pyc
ADDED
Binary file (8.71 kB). View file
|
|
motionclone/models/__pycache__/motion_module.cpython-38.pyc
ADDED
Binary file (8.67 kB). View file
|
|
motionclone/models/__pycache__/resnet.cpython-310.pyc
ADDED
Binary file (5.31 kB). View file
|
|
motionclone/models/__pycache__/resnet.cpython-38.pyc
ADDED
Binary file (5.41 kB). View file
|
|
motionclone/models/__pycache__/sparse_controlnet.cpython-38.pyc
ADDED
Binary file (14 kB). View file
|
|
motionclone/models/__pycache__/unet.cpython-310.pyc
ADDED
Binary file (12.7 kB). View file
|
|
motionclone/models/__pycache__/unet.cpython-38.pyc
ADDED
Binary file (12.4 kB). View file
|
|
motionclone/models/__pycache__/unet_blocks.cpython-310.pyc
ADDED
Binary file (12.8 kB). View file
|
|
motionclone/models/__pycache__/unet_blocks.cpython-38.pyc
ADDED
Binary file (12.1 kB). View file
|
|
motionclone/models/attention.py
ADDED
@@ -0,0 +1,611 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
|
2 |
+
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch import nn
|
9 |
+
|
10 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
11 |
+
from diffusers.models.modeling_utils import ModelMixin
|
12 |
+
from diffusers.utils import BaseOutput
|
13 |
+
from diffusers.utils.import_utils import is_xformers_available
|
14 |
+
from diffusers.models.attention import FeedForward, AdaLayerNorm
|
15 |
+
|
16 |
+
from einops import rearrange, repeat
|
17 |
+
import pdb
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class Transformer3DModelOutput(BaseOutput):
|
21 |
+
sample: torch.FloatTensor
|
22 |
+
|
23 |
+
|
24 |
+
if is_xformers_available():
|
25 |
+
import xformers
|
26 |
+
import xformers.ops
|
27 |
+
else:
|
28 |
+
xformers = None
|
29 |
+
|
30 |
+
|
31 |
+
class Transformer3DModel(ModelMixin, ConfigMixin):
|
32 |
+
@register_to_config
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
num_attention_heads: int = 16,
|
36 |
+
attention_head_dim: int = 88,
|
37 |
+
in_channels: Optional[int] = None,
|
38 |
+
num_layers: int = 1,
|
39 |
+
dropout: float = 0.0,
|
40 |
+
norm_num_groups: int = 32,
|
41 |
+
cross_attention_dim: Optional[int] = None,
|
42 |
+
attention_bias: bool = False,
|
43 |
+
activation_fn: str = "geglu",
|
44 |
+
num_embeds_ada_norm: Optional[int] = None,
|
45 |
+
use_linear_projection: bool = False,
|
46 |
+
only_cross_attention: bool = False,
|
47 |
+
upcast_attention: bool = False,
|
48 |
+
|
49 |
+
unet_use_cross_frame_attention=None,
|
50 |
+
unet_use_temporal_attention=None,
|
51 |
+
):
|
52 |
+
super().__init__()
|
53 |
+
self.use_linear_projection = use_linear_projection
|
54 |
+
self.num_attention_heads = num_attention_heads
|
55 |
+
self.attention_head_dim = attention_head_dim
|
56 |
+
inner_dim = num_attention_heads * attention_head_dim
|
57 |
+
|
58 |
+
# Define input layers
|
59 |
+
self.in_channels = in_channels
|
60 |
+
|
61 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
62 |
+
if use_linear_projection:
|
63 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
64 |
+
else:
|
65 |
+
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
66 |
+
|
67 |
+
# Define transformers blocks
|
68 |
+
self.transformer_blocks = nn.ModuleList(
|
69 |
+
[
|
70 |
+
BasicTransformerBlock(
|
71 |
+
inner_dim,
|
72 |
+
num_attention_heads,
|
73 |
+
attention_head_dim,
|
74 |
+
dropout=dropout,
|
75 |
+
cross_attention_dim=cross_attention_dim,
|
76 |
+
activation_fn=activation_fn,
|
77 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
78 |
+
attention_bias=attention_bias,
|
79 |
+
only_cross_attention=only_cross_attention,
|
80 |
+
upcast_attention=upcast_attention,
|
81 |
+
|
82 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
83 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
84 |
+
)
|
85 |
+
for d in range(num_layers)
|
86 |
+
]
|
87 |
+
)
|
88 |
+
|
89 |
+
# 4. Define output layers
|
90 |
+
if use_linear_projection:
|
91 |
+
self.proj_out = nn.Linear(in_channels, inner_dim)
|
92 |
+
else:
|
93 |
+
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
94 |
+
|
95 |
+
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
|
96 |
+
# Input
|
97 |
+
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
|
98 |
+
video_length = hidden_states.shape[2]
|
99 |
+
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
|
100 |
+
encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
|
101 |
+
|
102 |
+
batch, channel, height, weight = hidden_states.shape
|
103 |
+
residual = hidden_states
|
104 |
+
|
105 |
+
hidden_states = self.norm(hidden_states)
|
106 |
+
if not self.use_linear_projection:
|
107 |
+
hidden_states = self.proj_in(hidden_states)
|
108 |
+
inner_dim = hidden_states.shape[1]
|
109 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
110 |
+
else:
|
111 |
+
inner_dim = hidden_states.shape[1]
|
112 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
113 |
+
hidden_states = self.proj_in(hidden_states)
|
114 |
+
|
115 |
+
# Blocks
|
116 |
+
for block in self.transformer_blocks:
|
117 |
+
hidden_states = block(
|
118 |
+
hidden_states,
|
119 |
+
encoder_hidden_states=encoder_hidden_states,
|
120 |
+
timestep=timestep,
|
121 |
+
video_length=video_length
|
122 |
+
)
|
123 |
+
|
124 |
+
# Output
|
125 |
+
if not self.use_linear_projection:
|
126 |
+
hidden_states = (
|
127 |
+
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
128 |
+
)
|
129 |
+
hidden_states = self.proj_out(hidden_states)
|
130 |
+
else:
|
131 |
+
hidden_states = self.proj_out(hidden_states)
|
132 |
+
hidden_states = (
|
133 |
+
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
134 |
+
)
|
135 |
+
|
136 |
+
output = hidden_states + residual
|
137 |
+
|
138 |
+
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
|
139 |
+
if not return_dict:
|
140 |
+
return (output,)
|
141 |
+
|
142 |
+
return Transformer3DModelOutput(sample=output)
|
143 |
+
|
144 |
+
|
145 |
+
class BasicTransformerBlock(nn.Module):
|
146 |
+
def __init__(
|
147 |
+
self,
|
148 |
+
dim: int,
|
149 |
+
num_attention_heads: int,
|
150 |
+
attention_head_dim: int,
|
151 |
+
dropout=0.0,
|
152 |
+
cross_attention_dim: Optional[int] = None,
|
153 |
+
activation_fn: str = "geglu",
|
154 |
+
num_embeds_ada_norm: Optional[int] = None,
|
155 |
+
attention_bias: bool = False,
|
156 |
+
only_cross_attention: bool = False,
|
157 |
+
upcast_attention: bool = False,
|
158 |
+
|
159 |
+
unet_use_cross_frame_attention = None,
|
160 |
+
unet_use_temporal_attention = None,
|
161 |
+
):
|
162 |
+
super().__init__()
|
163 |
+
self.only_cross_attention = only_cross_attention
|
164 |
+
self.use_ada_layer_norm = num_embeds_ada_norm is not None
|
165 |
+
self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
|
166 |
+
self.unet_use_temporal_attention = unet_use_temporal_attention
|
167 |
+
|
168 |
+
# SC-Attn
|
169 |
+
assert unet_use_cross_frame_attention is not None
|
170 |
+
if unet_use_cross_frame_attention:
|
171 |
+
self.attn1 = SparseCausalAttention2D(
|
172 |
+
query_dim=dim,
|
173 |
+
heads=num_attention_heads,
|
174 |
+
dim_head=attention_head_dim,
|
175 |
+
dropout=dropout,
|
176 |
+
bias=attention_bias,
|
177 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
178 |
+
upcast_attention=upcast_attention,
|
179 |
+
)
|
180 |
+
else:
|
181 |
+
self.attn1 = CrossAttention(
|
182 |
+
query_dim=dim,
|
183 |
+
heads=num_attention_heads,
|
184 |
+
dim_head=attention_head_dim,
|
185 |
+
dropout=dropout,
|
186 |
+
bias=attention_bias,
|
187 |
+
upcast_attention=upcast_attention,
|
188 |
+
)
|
189 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
190 |
+
|
191 |
+
# Cross-Attn
|
192 |
+
if cross_attention_dim is not None:
|
193 |
+
self.attn2 = CrossAttention(
|
194 |
+
query_dim=dim,
|
195 |
+
cross_attention_dim=cross_attention_dim,
|
196 |
+
heads=num_attention_heads,
|
197 |
+
dim_head=attention_head_dim,
|
198 |
+
dropout=dropout,
|
199 |
+
bias=attention_bias,
|
200 |
+
upcast_attention=upcast_attention,
|
201 |
+
)
|
202 |
+
else:
|
203 |
+
self.attn2 = None
|
204 |
+
|
205 |
+
if cross_attention_dim is not None:
|
206 |
+
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
207 |
+
else:
|
208 |
+
self.norm2 = None
|
209 |
+
|
210 |
+
# Feed-forward
|
211 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
212 |
+
self.norm3 = nn.LayerNorm(dim)
|
213 |
+
|
214 |
+
# Temp-Attn
|
215 |
+
assert unet_use_temporal_attention is not None
|
216 |
+
if unet_use_temporal_attention:
|
217 |
+
self.attn_temp = CrossAttention(
|
218 |
+
query_dim=dim,
|
219 |
+
heads=num_attention_heads,
|
220 |
+
dim_head=attention_head_dim,
|
221 |
+
dropout=dropout,
|
222 |
+
bias=attention_bias,
|
223 |
+
upcast_attention=upcast_attention,
|
224 |
+
)
|
225 |
+
nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
|
226 |
+
self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
227 |
+
|
228 |
+
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, op=None):
|
229 |
+
if not is_xformers_available():
|
230 |
+
print("Here is how to install it")
|
231 |
+
raise ModuleNotFoundError(
|
232 |
+
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
233 |
+
" xformers",
|
234 |
+
name="xformers",
|
235 |
+
)
|
236 |
+
elif not torch.cuda.is_available():
|
237 |
+
raise ValueError(
|
238 |
+
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
|
239 |
+
" available for GPU "
|
240 |
+
)
|
241 |
+
else:
|
242 |
+
try:
|
243 |
+
# Make sure we can run the memory efficient attention
|
244 |
+
_ = xformers.ops.memory_efficient_attention(
|
245 |
+
torch.randn((1, 2, 40), device="cuda"),
|
246 |
+
torch.randn((1, 2, 40), device="cuda"),
|
247 |
+
torch.randn((1, 2, 40), device="cuda"),
|
248 |
+
)
|
249 |
+
except Exception as e:
|
250 |
+
raise e
|
251 |
+
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
252 |
+
if self.attn2 is not None:
|
253 |
+
self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
254 |
+
# self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
255 |
+
|
256 |
+
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None):
|
257 |
+
# SparseCausal-Attention
|
258 |
+
norm_hidden_states = (
|
259 |
+
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
|
260 |
+
)
|
261 |
+
|
262 |
+
# if self.only_cross_attention:
|
263 |
+
# hidden_states = (
|
264 |
+
# self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
|
265 |
+
# )
|
266 |
+
# else:
|
267 |
+
# hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
|
268 |
+
|
269 |
+
# pdb.set_trace()
|
270 |
+
if self.unet_use_cross_frame_attention:
|
271 |
+
hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
|
272 |
+
else:
|
273 |
+
hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
|
274 |
+
|
275 |
+
if self.attn2 is not None:
|
276 |
+
# Cross-Attention
|
277 |
+
norm_hidden_states = (
|
278 |
+
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
279 |
+
)
|
280 |
+
hidden_states = (
|
281 |
+
self.attn2(
|
282 |
+
norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
283 |
+
)
|
284 |
+
+ hidden_states
|
285 |
+
)
|
286 |
+
|
287 |
+
# Feed-forward
|
288 |
+
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
289 |
+
|
290 |
+
# Temporal-Attention
|
291 |
+
if self.unet_use_temporal_attention:
|
292 |
+
d = hidden_states.shape[1]
|
293 |
+
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
|
294 |
+
norm_hidden_states = (
|
295 |
+
self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
|
296 |
+
)
|
297 |
+
hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
|
298 |
+
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
|
299 |
+
|
300 |
+
return hidden_states
|
301 |
+
|
302 |
+
class CrossAttention(nn.Module):
|
303 |
+
r"""
|
304 |
+
A cross attention layer.
|
305 |
+
|
306 |
+
Parameters:
|
307 |
+
query_dim (`int`): The number of channels in the query.
|
308 |
+
cross_attention_dim (`int`, *optional*):
|
309 |
+
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
|
310 |
+
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
|
311 |
+
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
|
312 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
313 |
+
bias (`bool`, *optional*, defaults to False):
|
314 |
+
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
315 |
+
"""
|
316 |
+
|
317 |
+
def __init__(
|
318 |
+
self,
|
319 |
+
query_dim: int,
|
320 |
+
cross_attention_dim: Optional[int] = None,
|
321 |
+
heads: int = 8,
|
322 |
+
dim_head: int = 64,
|
323 |
+
dropout: float = 0.0,
|
324 |
+
bias=False,
|
325 |
+
upcast_attention: bool = False,
|
326 |
+
upcast_softmax: bool = False,
|
327 |
+
added_kv_proj_dim: Optional[int] = None,
|
328 |
+
norm_num_groups: Optional[int] = None,
|
329 |
+
):
|
330 |
+
super().__init__()
|
331 |
+
inner_dim = dim_head * heads
|
332 |
+
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
333 |
+
self.upcast_attention = upcast_attention
|
334 |
+
self.upcast_softmax = upcast_softmax
|
335 |
+
|
336 |
+
self.scale = dim_head**-0.5
|
337 |
+
|
338 |
+
self.heads = heads
|
339 |
+
# for slice_size > 0 the attention score computation
|
340 |
+
# is split across the batch axis to save memory
|
341 |
+
# You can set slice_size with `set_attention_slice`
|
342 |
+
self.sliceable_head_dim = heads
|
343 |
+
self._slice_size = None
|
344 |
+
self._use_memory_efficient_attention_xformers = False
|
345 |
+
self.added_kv_proj_dim = added_kv_proj_dim
|
346 |
+
|
347 |
+
#### add processer
|
348 |
+
self.processor = None
|
349 |
+
|
350 |
+
if norm_num_groups is not None:
|
351 |
+
self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
|
352 |
+
else:
|
353 |
+
self.group_norm = None
|
354 |
+
|
355 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
|
356 |
+
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
357 |
+
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
358 |
+
|
359 |
+
if self.added_kv_proj_dim is not None:
|
360 |
+
self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
|
361 |
+
self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
|
362 |
+
|
363 |
+
self.to_out = nn.ModuleList([])
|
364 |
+
self.to_out.append(nn.Linear(inner_dim, query_dim))
|
365 |
+
self.to_out.append(nn.Dropout(dropout))
|
366 |
+
|
367 |
+
def reshape_heads_to_batch_dim(self, tensor):
|
368 |
+
batch_size, seq_len, dim = tensor.shape
|
369 |
+
head_size = self.heads
|
370 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
371 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
|
372 |
+
return tensor
|
373 |
+
|
374 |
+
def reshape_batch_dim_to_heads(self, tensor):
|
375 |
+
batch_size, seq_len, dim = tensor.shape
|
376 |
+
head_size = self.heads
|
377 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
378 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
379 |
+
return tensor
|
380 |
+
|
381 |
+
def set_attention_slice(self, slice_size):
|
382 |
+
if slice_size is not None and slice_size > self.sliceable_head_dim:
|
383 |
+
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
|
384 |
+
|
385 |
+
self._slice_size = slice_size
|
386 |
+
|
387 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
388 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
389 |
+
|
390 |
+
encoder_hidden_states = encoder_hidden_states
|
391 |
+
|
392 |
+
if self.group_norm is not None:
|
393 |
+
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
394 |
+
|
395 |
+
query = self.to_q(hidden_states)
|
396 |
+
dim = query.shape[-1]
|
397 |
+
# query = self.reshape_heads_to_batch_dim(query) # move backwards
|
398 |
+
|
399 |
+
if self.added_kv_proj_dim is not None:
|
400 |
+
key = self.to_k(hidden_states)
|
401 |
+
value = self.to_v(hidden_states)
|
402 |
+
encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
|
403 |
+
encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
|
404 |
+
|
405 |
+
######record###### record before reshape heads to batch dim
|
406 |
+
if self.processor is not None:
|
407 |
+
self.processor.record_qkv(self, hidden_states, query, key, value, attention_mask)
|
408 |
+
##################
|
409 |
+
|
410 |
+
key = self.reshape_heads_to_batch_dim(key)
|
411 |
+
value = self.reshape_heads_to_batch_dim(value)
|
412 |
+
encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
|
413 |
+
encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
|
414 |
+
|
415 |
+
key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
|
416 |
+
value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
|
417 |
+
else:
|
418 |
+
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
419 |
+
key = self.to_k(encoder_hidden_states)
|
420 |
+
value = self.to_v(encoder_hidden_states)
|
421 |
+
|
422 |
+
######record######
|
423 |
+
if self.processor is not None:
|
424 |
+
self.processor.record_qkv(self, hidden_states, query, key, value, attention_mask)
|
425 |
+
##################
|
426 |
+
|
427 |
+
key = self.reshape_heads_to_batch_dim(key)
|
428 |
+
value = self.reshape_heads_to_batch_dim(value)
|
429 |
+
|
430 |
+
query = self.reshape_heads_to_batch_dim(query) # reshape query
|
431 |
+
|
432 |
+
if attention_mask is not None:
|
433 |
+
if attention_mask.shape[-1] != query.shape[1]:
|
434 |
+
target_length = query.shape[1]
|
435 |
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
436 |
+
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
|
437 |
+
|
438 |
+
######record######
|
439 |
+
if self.processor is not None:
|
440 |
+
self.processor.record_attn_mask(self, hidden_states, query, key, value, attention_mask)
|
441 |
+
##################
|
442 |
+
|
443 |
+
# attention, what we cannot get enough of
|
444 |
+
if self._use_memory_efficient_attention_xformers:
|
445 |
+
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
|
446 |
+
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
447 |
+
hidden_states = hidden_states.to(query.dtype)
|
448 |
+
else:
|
449 |
+
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
450 |
+
hidden_states = self._attention(query, key, value, attention_mask)
|
451 |
+
else:
|
452 |
+
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
|
453 |
+
|
454 |
+
# linear proj
|
455 |
+
hidden_states = self.to_out[0](hidden_states)
|
456 |
+
|
457 |
+
# dropout
|
458 |
+
hidden_states = self.to_out[1](hidden_states)
|
459 |
+
return hidden_states
|
460 |
+
|
461 |
+
def _attention(self, query, key, value, attention_mask=None):
|
462 |
+
if self.upcast_attention:
|
463 |
+
query = query.float()
|
464 |
+
key = key.float()
|
465 |
+
|
466 |
+
attention_scores = torch.baddbmm(
|
467 |
+
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
468 |
+
query,
|
469 |
+
key.transpose(-1, -2),
|
470 |
+
beta=0,
|
471 |
+
alpha=self.scale,
|
472 |
+
)
|
473 |
+
|
474 |
+
if attention_mask is not None:
|
475 |
+
attention_scores = attention_scores + attention_mask
|
476 |
+
|
477 |
+
if self.upcast_softmax:
|
478 |
+
attention_scores = attention_scores.float()
|
479 |
+
|
480 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
481 |
+
|
482 |
+
# cast back to the original dtype
|
483 |
+
attention_probs = attention_probs.to(value.dtype)
|
484 |
+
|
485 |
+
# compute attention output
|
486 |
+
hidden_states = torch.bmm(attention_probs, value)
|
487 |
+
|
488 |
+
# reshape hidden_states
|
489 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
490 |
+
return hidden_states
|
491 |
+
|
492 |
+
def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
|
493 |
+
batch_size_attention = query.shape[0]
|
494 |
+
hidden_states = torch.zeros(
|
495 |
+
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
|
496 |
+
)
|
497 |
+
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
|
498 |
+
for i in range(hidden_states.shape[0] // slice_size):
|
499 |
+
start_idx = i * slice_size
|
500 |
+
end_idx = (i + 1) * slice_size
|
501 |
+
|
502 |
+
query_slice = query[start_idx:end_idx]
|
503 |
+
key_slice = key[start_idx:end_idx]
|
504 |
+
|
505 |
+
if self.upcast_attention:
|
506 |
+
query_slice = query_slice.float()
|
507 |
+
key_slice = key_slice.float()
|
508 |
+
|
509 |
+
attn_slice = torch.baddbmm(
|
510 |
+
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
|
511 |
+
query_slice,
|
512 |
+
key_slice.transpose(-1, -2),
|
513 |
+
beta=0,
|
514 |
+
alpha=self.scale,
|
515 |
+
)
|
516 |
+
|
517 |
+
if attention_mask is not None:
|
518 |
+
attn_slice = attn_slice + attention_mask[start_idx:end_idx]
|
519 |
+
|
520 |
+
if self.upcast_softmax:
|
521 |
+
attn_slice = attn_slice.float()
|
522 |
+
|
523 |
+
attn_slice = attn_slice.softmax(dim=-1)
|
524 |
+
|
525 |
+
# cast back to the original dtype
|
526 |
+
attn_slice = attn_slice.to(value.dtype)
|
527 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
528 |
+
|
529 |
+
hidden_states[start_idx:end_idx] = attn_slice
|
530 |
+
|
531 |
+
# reshape hidden_states
|
532 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
533 |
+
return hidden_states
|
534 |
+
|
535 |
+
def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
|
536 |
+
# TODO attention_mask
|
537 |
+
query = query.contiguous()
|
538 |
+
key = key.contiguous()
|
539 |
+
value = value.contiguous()
|
540 |
+
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
|
541 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
542 |
+
return hidden_states
|
543 |
+
|
544 |
+
def set_processor(self, processor: "AttnProcessor") -> None:
|
545 |
+
r"""
|
546 |
+
Set the attention processor to use.
|
547 |
+
|
548 |
+
Args:
|
549 |
+
processor (`AttnProcessor`):
|
550 |
+
The attention processor to use.
|
551 |
+
"""
|
552 |
+
# if current processor is in `self._modules` and if passed `processor` is not, we need to
|
553 |
+
# pop `processor` from `self._modules`
|
554 |
+
if (
|
555 |
+
hasattr(self, "processor")
|
556 |
+
and isinstance(self.processor, torch.nn.Module)
|
557 |
+
and not isinstance(processor, torch.nn.Module)
|
558 |
+
):
|
559 |
+
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
|
560 |
+
self._modules.pop("processor")
|
561 |
+
|
562 |
+
self.processor = processor
|
563 |
+
|
564 |
+
def get_attention_scores(
|
565 |
+
self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None
|
566 |
+
) -> torch.Tensor:
|
567 |
+
r"""
|
568 |
+
Compute the attention scores.
|
569 |
+
|
570 |
+
Args:
|
571 |
+
query (`torch.Tensor`): The query tensor.
|
572 |
+
key (`torch.Tensor`): The key tensor.
|
573 |
+
attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
|
574 |
+
|
575 |
+
Returns:
|
576 |
+
`torch.Tensor`: The attention probabilities/scores.
|
577 |
+
"""
|
578 |
+
dtype = query.dtype
|
579 |
+
if self.upcast_attention:
|
580 |
+
query = query.float()
|
581 |
+
key = key.float()
|
582 |
+
|
583 |
+
if attention_mask is None:
|
584 |
+
baddbmm_input = torch.empty(
|
585 |
+
query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
|
586 |
+
)
|
587 |
+
beta = 0
|
588 |
+
else:
|
589 |
+
baddbmm_input = attention_mask
|
590 |
+
beta = 1
|
591 |
+
|
592 |
+
|
593 |
+
|
594 |
+
attention_scores = torch.baddbmm(
|
595 |
+
baddbmm_input,
|
596 |
+
query,
|
597 |
+
key.transpose(-1, -2),
|
598 |
+
beta=beta,
|
599 |
+
alpha=self.scale,
|
600 |
+
)
|
601 |
+
del baddbmm_input
|
602 |
+
|
603 |
+
if self.upcast_softmax:
|
604 |
+
attention_scores = attention_scores.float()
|
605 |
+
|
606 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
607 |
+
del attention_scores
|
608 |
+
|
609 |
+
attention_probs = attention_probs.to(dtype)
|
610 |
+
|
611 |
+
return attention_probs
|
motionclone/models/motion_module.py
ADDED
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import List, Optional, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import nn
|
8 |
+
import torchvision
|
9 |
+
|
10 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
11 |
+
from diffusers.models.modeling_utils import ModelMixin
|
12 |
+
from diffusers.utils import BaseOutput
|
13 |
+
from diffusers.utils.import_utils import is_xformers_available
|
14 |
+
from diffusers.models.attention import FeedForward
|
15 |
+
from .attention import CrossAttention
|
16 |
+
|
17 |
+
from einops import rearrange, repeat
|
18 |
+
import math
|
19 |
+
|
20 |
+
|
21 |
+
def zero_module(module):
|
22 |
+
# Zero out the parameters of a module and return it.
|
23 |
+
for p in module.parameters():
|
24 |
+
p.detach().zero_()
|
25 |
+
return module
|
26 |
+
|
27 |
+
|
28 |
+
@dataclass
|
29 |
+
class TemporalTransformer3DModelOutput(BaseOutput):
|
30 |
+
sample: torch.FloatTensor
|
31 |
+
|
32 |
+
|
33 |
+
if is_xformers_available():
|
34 |
+
import xformers
|
35 |
+
import xformers.ops
|
36 |
+
else:
|
37 |
+
xformers = None
|
38 |
+
|
39 |
+
|
40 |
+
def get_motion_module( # 只能返回VanillaTemporalModule类
|
41 |
+
in_channels,
|
42 |
+
motion_module_type: str,
|
43 |
+
motion_module_kwargs: dict
|
44 |
+
):
|
45 |
+
if motion_module_type == "Vanilla":
|
46 |
+
return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,)
|
47 |
+
else:
|
48 |
+
raise ValueError
|
49 |
+
|
50 |
+
|
51 |
+
class VanillaTemporalModule(nn.Module):
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
in_channels,
|
55 |
+
num_attention_heads = 8,
|
56 |
+
num_transformer_block = 2,
|
57 |
+
attention_block_types =( "Temporal_Self", "Temporal_Self" ),
|
58 |
+
cross_frame_attention_mode = None,
|
59 |
+
temporal_position_encoding = False,
|
60 |
+
temporal_position_encoding_max_len = 32,
|
61 |
+
temporal_attention_dim_div = 1,
|
62 |
+
zero_initialize = True,
|
63 |
+
):
|
64 |
+
super().__init__()
|
65 |
+
|
66 |
+
self.temporal_transformer = TemporalTransformer3DModel(
|
67 |
+
in_channels=in_channels,
|
68 |
+
num_attention_heads=num_attention_heads,
|
69 |
+
attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
|
70 |
+
num_layers=num_transformer_block,
|
71 |
+
attention_block_types=attention_block_types,
|
72 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
73 |
+
temporal_position_encoding=temporal_position_encoding,
|
74 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
75 |
+
)
|
76 |
+
|
77 |
+
if zero_initialize:
|
78 |
+
self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
|
79 |
+
|
80 |
+
def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None):
|
81 |
+
hidden_states = input_tensor
|
82 |
+
hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
|
83 |
+
|
84 |
+
output = hidden_states
|
85 |
+
return output
|
86 |
+
|
87 |
+
|
88 |
+
class TemporalTransformer3DModel(nn.Module):
|
89 |
+
def __init__(
|
90 |
+
self,
|
91 |
+
in_channels,
|
92 |
+
num_attention_heads,
|
93 |
+
attention_head_dim,
|
94 |
+
|
95 |
+
num_layers,
|
96 |
+
attention_block_types = ( "Temporal_Self", "Temporal_Self", ), # 两个TempAttn
|
97 |
+
dropout = 0.0,
|
98 |
+
norm_num_groups = 32,
|
99 |
+
cross_attention_dim = 768,
|
100 |
+
activation_fn = "geglu",
|
101 |
+
attention_bias = False,
|
102 |
+
upcast_attention = False,
|
103 |
+
|
104 |
+
cross_frame_attention_mode = None,
|
105 |
+
temporal_position_encoding = False,
|
106 |
+
temporal_position_encoding_max_len = 24,
|
107 |
+
):
|
108 |
+
super().__init__()
|
109 |
+
|
110 |
+
inner_dim = num_attention_heads * attention_head_dim
|
111 |
+
|
112 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
113 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
114 |
+
|
115 |
+
self.transformer_blocks = nn.ModuleList(
|
116 |
+
[
|
117 |
+
TemporalTransformerBlock(
|
118 |
+
dim=inner_dim,
|
119 |
+
num_attention_heads=num_attention_heads,
|
120 |
+
attention_head_dim=attention_head_dim,
|
121 |
+
attention_block_types=attention_block_types,
|
122 |
+
dropout=dropout,
|
123 |
+
norm_num_groups=norm_num_groups,
|
124 |
+
cross_attention_dim=cross_attention_dim,
|
125 |
+
activation_fn=activation_fn,
|
126 |
+
attention_bias=attention_bias,
|
127 |
+
upcast_attention=upcast_attention,
|
128 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
129 |
+
temporal_position_encoding=temporal_position_encoding,
|
130 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
131 |
+
)
|
132 |
+
for d in range(num_layers)
|
133 |
+
]
|
134 |
+
)
|
135 |
+
self.proj_out = nn.Linear(inner_dim, in_channels)
|
136 |
+
|
137 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
138 |
+
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
|
139 |
+
video_length = hidden_states.shape[2]
|
140 |
+
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
|
141 |
+
|
142 |
+
batch, channel, height, weight = hidden_states.shape
|
143 |
+
residual = hidden_states
|
144 |
+
|
145 |
+
hidden_states = self.norm(hidden_states)
|
146 |
+
inner_dim = hidden_states.shape[1]
|
147 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
148 |
+
hidden_states = self.proj_in(hidden_states)
|
149 |
+
|
150 |
+
# Transformer Blocks
|
151 |
+
for block in self.transformer_blocks:
|
152 |
+
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length)
|
153 |
+
|
154 |
+
# output
|
155 |
+
hidden_states = self.proj_out(hidden_states)
|
156 |
+
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
157 |
+
|
158 |
+
output = hidden_states + residual
|
159 |
+
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
|
160 |
+
|
161 |
+
return output
|
162 |
+
|
163 |
+
|
164 |
+
class TemporalTransformerBlock(nn.Module):
|
165 |
+
def __init__(
|
166 |
+
self,
|
167 |
+
dim,
|
168 |
+
num_attention_heads,
|
169 |
+
attention_head_dim,
|
170 |
+
attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
|
171 |
+
dropout = 0.0,
|
172 |
+
norm_num_groups = 32,
|
173 |
+
cross_attention_dim = 768,
|
174 |
+
activation_fn = "geglu",
|
175 |
+
attention_bias = False,
|
176 |
+
upcast_attention = False,
|
177 |
+
cross_frame_attention_mode = None,
|
178 |
+
temporal_position_encoding = False,
|
179 |
+
temporal_position_encoding_max_len = 24,
|
180 |
+
):
|
181 |
+
super().__init__()
|
182 |
+
|
183 |
+
attention_blocks = []
|
184 |
+
norms = []
|
185 |
+
|
186 |
+
for block_name in attention_block_types:
|
187 |
+
attention_blocks.append(
|
188 |
+
VersatileAttention(
|
189 |
+
attention_mode=block_name.split("_")[0],
|
190 |
+
cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
|
191 |
+
|
192 |
+
query_dim=dim,
|
193 |
+
heads=num_attention_heads,
|
194 |
+
dim_head=attention_head_dim,
|
195 |
+
dropout=dropout,
|
196 |
+
bias=attention_bias,
|
197 |
+
upcast_attention=upcast_attention,
|
198 |
+
|
199 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
200 |
+
temporal_position_encoding=temporal_position_encoding,
|
201 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
202 |
+
)
|
203 |
+
)
|
204 |
+
norms.append(nn.LayerNorm(dim))
|
205 |
+
|
206 |
+
self.attention_blocks = nn.ModuleList(attention_blocks)
|
207 |
+
self.norms = nn.ModuleList(norms)
|
208 |
+
|
209 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
210 |
+
self.ff_norm = nn.LayerNorm(dim)
|
211 |
+
|
212 |
+
|
213 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
|
214 |
+
for attention_block, norm in zip(self.attention_blocks, self.norms):
|
215 |
+
norm_hidden_states = norm(hidden_states)
|
216 |
+
hidden_states = attention_block(
|
217 |
+
norm_hidden_states,
|
218 |
+
encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
|
219 |
+
video_length=video_length,
|
220 |
+
) + hidden_states
|
221 |
+
|
222 |
+
hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
|
223 |
+
|
224 |
+
output = hidden_states
|
225 |
+
return output
|
226 |
+
|
227 |
+
|
228 |
+
class PositionalEncoding(nn.Module):
|
229 |
+
def __init__(
|
230 |
+
self,
|
231 |
+
d_model,
|
232 |
+
dropout = 0.,
|
233 |
+
max_len = 24
|
234 |
+
):
|
235 |
+
super().__init__()
|
236 |
+
self.dropout = nn.Dropout(p=dropout)
|
237 |
+
position = torch.arange(max_len).unsqueeze(1)
|
238 |
+
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
|
239 |
+
pe = torch.zeros(1, max_len, d_model)
|
240 |
+
pe[0, :, 0::2] = torch.sin(position * div_term)
|
241 |
+
pe[0, :, 1::2] = torch.cos(position * div_term)
|
242 |
+
# self.register_buffer('pe', pe)
|
243 |
+
self.register_buffer('pe', pe, persistent=False)
|
244 |
+
|
245 |
+
def forward(self, x):
|
246 |
+
x = x + self.pe[:, :x.size(1)]
|
247 |
+
return self.dropout(x)
|
248 |
+
|
249 |
+
|
250 |
+
class VersatileAttention(CrossAttention): # 继承CrossAttention类,不需要在额外写set_processor功能
|
251 |
+
def __init__(
|
252 |
+
self,
|
253 |
+
attention_mode = None,
|
254 |
+
cross_frame_attention_mode = None,
|
255 |
+
temporal_position_encoding = False,
|
256 |
+
temporal_position_encoding_max_len = 24,
|
257 |
+
*args, **kwargs
|
258 |
+
):
|
259 |
+
super().__init__(*args, **kwargs)
|
260 |
+
assert attention_mode == "Temporal"
|
261 |
+
|
262 |
+
self.attention_mode = attention_mode
|
263 |
+
self.is_cross_attention = kwargs["cross_attention_dim"] is not None
|
264 |
+
|
265 |
+
self.pos_encoder = PositionalEncoding(
|
266 |
+
kwargs["query_dim"],
|
267 |
+
dropout=0.,
|
268 |
+
max_len=temporal_position_encoding_max_len
|
269 |
+
) if (temporal_position_encoding and attention_mode == "Temporal") else None
|
270 |
+
|
271 |
+
def extra_repr(self):
|
272 |
+
return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
|
273 |
+
|
274 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
|
275 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
276 |
+
|
277 |
+
if self.attention_mode == "Temporal":
|
278 |
+
d = hidden_states.shape[1]
|
279 |
+
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
|
280 |
+
|
281 |
+
if self.pos_encoder is not None:
|
282 |
+
hidden_states = self.pos_encoder(hidden_states)
|
283 |
+
|
284 |
+
encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states
|
285 |
+
else:
|
286 |
+
raise NotImplementedError
|
287 |
+
|
288 |
+
encoder_hidden_states = encoder_hidden_states
|
289 |
+
|
290 |
+
if self.group_norm is not None:
|
291 |
+
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
292 |
+
|
293 |
+
query = self.to_q(hidden_states)
|
294 |
+
dim = query.shape[-1]
|
295 |
+
# query = self.reshape_heads_to_batch_dim(query) # move backwards
|
296 |
+
|
297 |
+
if self.added_kv_proj_dim is not None:
|
298 |
+
raise NotImplementedError
|
299 |
+
|
300 |
+
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
301 |
+
key = self.to_k(encoder_hidden_states)
|
302 |
+
value = self.to_v(encoder_hidden_states)
|
303 |
+
|
304 |
+
######record###### record before reshape heads to batch dim
|
305 |
+
if self.processor is not None:
|
306 |
+
self.processor.record_qkv(self, hidden_states, query, key, value, attention_mask)
|
307 |
+
##################
|
308 |
+
|
309 |
+
key = self.reshape_heads_to_batch_dim(key)
|
310 |
+
value = self.reshape_heads_to_batch_dim(value)
|
311 |
+
|
312 |
+
query = self.reshape_heads_to_batch_dim(query) # reshape query here
|
313 |
+
|
314 |
+
if attention_mask is not None:
|
315 |
+
if attention_mask.shape[-1] != query.shape[1]:
|
316 |
+
target_length = query.shape[1]
|
317 |
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
318 |
+
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
|
319 |
+
|
320 |
+
######record######
|
321 |
+
# if self.processor is not None:
|
322 |
+
# self.processor.record_attn_mask(self, hidden_states, query, key, value, attention_mask)
|
323 |
+
##################
|
324 |
+
|
325 |
+
# attention, what we cannot get enough of
|
326 |
+
if self._use_memory_efficient_attention_xformers:
|
327 |
+
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
|
328 |
+
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
329 |
+
hidden_states = hidden_states.to(query.dtype)
|
330 |
+
else:
|
331 |
+
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
332 |
+
hidden_states = self._attention(query, key, value, attention_mask)
|
333 |
+
else:
|
334 |
+
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
|
335 |
+
|
336 |
+
# linear proj
|
337 |
+
hidden_states = self.to_out[0](hidden_states)
|
338 |
+
|
339 |
+
# dropout
|
340 |
+
hidden_states = self.to_out[1](hidden_states)
|
341 |
+
|
342 |
+
if self.attention_mode == "Temporal":
|
343 |
+
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
|
344 |
+
|
345 |
+
return hidden_states
|
346 |
+
|
347 |
+
|
motionclone/models/resnet.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from einops import rearrange
|
8 |
+
|
9 |
+
|
10 |
+
class InflatedConv3d(nn.Conv2d):
|
11 |
+
def forward(self, x):
|
12 |
+
video_length = x.shape[2]
|
13 |
+
|
14 |
+
x = rearrange(x, "b c f h w -> (b f) c h w")
|
15 |
+
x = super().forward(x)
|
16 |
+
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
|
17 |
+
|
18 |
+
return x
|
19 |
+
|
20 |
+
|
21 |
+
class InflatedGroupNorm(nn.GroupNorm):
|
22 |
+
def forward(self, x):
|
23 |
+
video_length = x.shape[2]
|
24 |
+
|
25 |
+
x = rearrange(x, "b c f h w -> (b f) c h w")
|
26 |
+
x = super().forward(x)
|
27 |
+
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
|
28 |
+
|
29 |
+
return x
|
30 |
+
|
31 |
+
|
32 |
+
class Upsample3D(nn.Module):
|
33 |
+
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
|
34 |
+
super().__init__()
|
35 |
+
self.channels = channels
|
36 |
+
self.out_channels = out_channels or channels
|
37 |
+
self.use_conv = use_conv
|
38 |
+
self.use_conv_transpose = use_conv_transpose
|
39 |
+
self.name = name
|
40 |
+
|
41 |
+
conv = None
|
42 |
+
if use_conv_transpose:
|
43 |
+
raise NotImplementedError
|
44 |
+
elif use_conv:
|
45 |
+
self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
|
46 |
+
|
47 |
+
def forward(self, hidden_states, output_size=None):
|
48 |
+
assert hidden_states.shape[1] == self.channels
|
49 |
+
|
50 |
+
if self.use_conv_transpose:
|
51 |
+
raise NotImplementedError
|
52 |
+
|
53 |
+
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
54 |
+
dtype = hidden_states.dtype
|
55 |
+
if dtype == torch.bfloat16:
|
56 |
+
hidden_states = hidden_states.to(torch.float32)
|
57 |
+
|
58 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
59 |
+
if hidden_states.shape[0] >= 64:
|
60 |
+
hidden_states = hidden_states.contiguous()
|
61 |
+
|
62 |
+
# if `output_size` is passed we force the interpolation output
|
63 |
+
# size and do not make use of `scale_factor=2`
|
64 |
+
if output_size is None:
|
65 |
+
hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
|
66 |
+
else:
|
67 |
+
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
|
68 |
+
|
69 |
+
# If the input is bfloat16, we cast back to bfloat16
|
70 |
+
if dtype == torch.bfloat16:
|
71 |
+
hidden_states = hidden_states.to(dtype)
|
72 |
+
|
73 |
+
# if self.use_conv:
|
74 |
+
# if self.name == "conv":
|
75 |
+
# hidden_states = self.conv(hidden_states)
|
76 |
+
# else:
|
77 |
+
# hidden_states = self.Conv2d_0(hidden_states)
|
78 |
+
hidden_states = self.conv(hidden_states)
|
79 |
+
|
80 |
+
return hidden_states
|
81 |
+
|
82 |
+
|
83 |
+
class Downsample3D(nn.Module):
|
84 |
+
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
|
85 |
+
super().__init__()
|
86 |
+
self.channels = channels
|
87 |
+
self.out_channels = out_channels or channels
|
88 |
+
self.use_conv = use_conv
|
89 |
+
self.padding = padding
|
90 |
+
stride = 2
|
91 |
+
self.name = name
|
92 |
+
|
93 |
+
if use_conv:
|
94 |
+
self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
95 |
+
else:
|
96 |
+
raise NotImplementedError
|
97 |
+
|
98 |
+
def forward(self, hidden_states):
|
99 |
+
assert hidden_states.shape[1] == self.channels
|
100 |
+
if self.use_conv and self.padding == 0:
|
101 |
+
raise NotImplementedError
|
102 |
+
|
103 |
+
assert hidden_states.shape[1] == self.channels
|
104 |
+
hidden_states = self.conv(hidden_states)
|
105 |
+
|
106 |
+
return hidden_states
|
107 |
+
|
108 |
+
|
109 |
+
class ResnetBlock3D(nn.Module):
|
110 |
+
def __init__(
|
111 |
+
self,
|
112 |
+
*,
|
113 |
+
in_channels,
|
114 |
+
out_channels=None,
|
115 |
+
conv_shortcut=False,
|
116 |
+
dropout=0.0,
|
117 |
+
temb_channels=512,
|
118 |
+
groups=32,
|
119 |
+
groups_out=None,
|
120 |
+
pre_norm=True,
|
121 |
+
eps=1e-6,
|
122 |
+
non_linearity="swish",
|
123 |
+
time_embedding_norm="default",
|
124 |
+
output_scale_factor=1.0,
|
125 |
+
use_in_shortcut=None,
|
126 |
+
use_inflated_groupnorm=False,
|
127 |
+
):
|
128 |
+
super().__init__()
|
129 |
+
self.pre_norm = pre_norm
|
130 |
+
self.pre_norm = True
|
131 |
+
self.in_channels = in_channels
|
132 |
+
out_channels = in_channels if out_channels is None else out_channels
|
133 |
+
self.out_channels = out_channels
|
134 |
+
self.use_conv_shortcut = conv_shortcut
|
135 |
+
self.time_embedding_norm = time_embedding_norm
|
136 |
+
self.output_scale_factor = output_scale_factor
|
137 |
+
self.upsample = self.downsample = None
|
138 |
+
|
139 |
+
if groups_out is None:
|
140 |
+
groups_out = groups
|
141 |
+
|
142 |
+
assert use_inflated_groupnorm != None
|
143 |
+
if use_inflated_groupnorm:
|
144 |
+
self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
145 |
+
else:
|
146 |
+
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
147 |
+
|
148 |
+
self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
149 |
+
|
150 |
+
if temb_channels is not None:
|
151 |
+
if self.time_embedding_norm == "default":
|
152 |
+
time_emb_proj_out_channels = out_channels
|
153 |
+
elif self.time_embedding_norm == "scale_shift":
|
154 |
+
time_emb_proj_out_channels = out_channels * 2
|
155 |
+
else:
|
156 |
+
raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
|
157 |
+
|
158 |
+
self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
|
159 |
+
else:
|
160 |
+
self.time_emb_proj = None
|
161 |
+
|
162 |
+
if use_inflated_groupnorm:
|
163 |
+
self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
164 |
+
else:
|
165 |
+
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
166 |
+
|
167 |
+
self.dropout = torch.nn.Dropout(dropout)
|
168 |
+
self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
169 |
+
|
170 |
+
if non_linearity == "swish":
|
171 |
+
self.nonlinearity = lambda x: F.silu(x)
|
172 |
+
elif non_linearity == "mish":
|
173 |
+
self.nonlinearity = Mish()
|
174 |
+
elif non_linearity == "silu":
|
175 |
+
self.nonlinearity = nn.SiLU()
|
176 |
+
|
177 |
+
self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
|
178 |
+
|
179 |
+
self.conv_shortcut = None
|
180 |
+
if self.use_in_shortcut:
|
181 |
+
self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
182 |
+
|
183 |
+
def forward(self, input_tensor, temb):
|
184 |
+
hidden_states = input_tensor
|
185 |
+
|
186 |
+
hidden_states = self.norm1(hidden_states)
|
187 |
+
hidden_states = self.nonlinearity(hidden_states)
|
188 |
+
|
189 |
+
hidden_states = self.conv1(hidden_states)
|
190 |
+
|
191 |
+
if temb is not None:
|
192 |
+
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
|
193 |
+
|
194 |
+
if temb is not None and self.time_embedding_norm == "default":
|
195 |
+
hidden_states = hidden_states + temb
|
196 |
+
|
197 |
+
hidden_states = self.norm2(hidden_states)
|
198 |
+
|
199 |
+
if temb is not None and self.time_embedding_norm == "scale_shift":
|
200 |
+
scale, shift = torch.chunk(temb, 2, dim=1)
|
201 |
+
hidden_states = hidden_states * (1 + scale) + shift
|
202 |
+
|
203 |
+
hidden_states = self.nonlinearity(hidden_states)
|
204 |
+
|
205 |
+
hidden_states = self.dropout(hidden_states)
|
206 |
+
hidden_states = self.conv2(hidden_states)
|
207 |
+
|
208 |
+
if self.conv_shortcut is not None:
|
209 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
210 |
+
|
211 |
+
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
212 |
+
|
213 |
+
return output_tensor
|
214 |
+
|
215 |
+
|
216 |
+
class Mish(torch.nn.Module):
|
217 |
+
def forward(self, hidden_states):
|
218 |
+
return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
|
motionclone/models/scheduler.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from diffusers import DDIMScheduler
|
5 |
+
from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput
|
6 |
+
from diffusers.utils.torch_utils import randn_tensor
|
7 |
+
|
8 |
+
|
9 |
+
class CustomDDIMScheduler(DDIMScheduler):
|
10 |
+
@torch.no_grad()
|
11 |
+
def step(
|
12 |
+
self,
|
13 |
+
model_output: torch.FloatTensor,
|
14 |
+
timestep: int,
|
15 |
+
sample: torch.FloatTensor,
|
16 |
+
eta: float = 0.0,
|
17 |
+
use_clipped_model_output: bool = False,
|
18 |
+
generator=None,
|
19 |
+
variance_noise: Optional[torch.FloatTensor] = None,
|
20 |
+
return_dict: bool = True,
|
21 |
+
|
22 |
+
# Guidance parameters
|
23 |
+
score=None,
|
24 |
+
guidance_scale=0.0,
|
25 |
+
indices=None, # [0]
|
26 |
+
|
27 |
+
) -> Union[DDIMSchedulerOutput, Tuple]:
|
28 |
+
"""
|
29 |
+
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
30 |
+
process from the learned model outputs (most often the predicted noise).
|
31 |
+
|
32 |
+
Args:
|
33 |
+
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
|
34 |
+
timestep (`int`): current discrete timestep in the diffusion chain.
|
35 |
+
sample (`torch.FloatTensor`):
|
36 |
+
current instance of sample being created by diffusion process.
|
37 |
+
eta (`float`): weight of noise for added noise in diffusion step.
|
38 |
+
use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
|
39 |
+
predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
|
40 |
+
`self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
|
41 |
+
coincide with the one provided as input and `use_clipped_model_output` will have not effect.
|
42 |
+
generator: random number generator.
|
43 |
+
variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we
|
44 |
+
can directly provide the noise for the variance itself. This is useful for methods such as
|
45 |
+
CycleDiffusion. (https://arxiv.org/abs/2210.05559)
|
46 |
+
return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
[`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
|
50 |
+
[`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
51 |
+
returning a tuple, the first element is the sample tensor.
|
52 |
+
|
53 |
+
"""
|
54 |
+
if self.num_inference_steps is None:
|
55 |
+
raise ValueError(
|
56 |
+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
57 |
+
)
|
58 |
+
|
59 |
+
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
|
60 |
+
# Ideally, read DDIM paper in-detail understanding
|
61 |
+
|
62 |
+
# Notation (<variable name> -> <name in paper>
|
63 |
+
# - pred_noise_t -> e_theta(x_t, t)
|
64 |
+
# - pred_original_sample -> f_theta(x_t, t) or x_0
|
65 |
+
# - std_dev_t -> sigma_t
|
66 |
+
# - eta -> η
|
67 |
+
# - pred_sample_direction -> "direction pointing to x_t"
|
68 |
+
# - pred_prev_sample -> "x_t-1"
|
69 |
+
|
70 |
+
|
71 |
+
# Support IF models
|
72 |
+
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
|
73 |
+
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
|
74 |
+
else:
|
75 |
+
predicted_variance = None
|
76 |
+
|
77 |
+
# 1. get previous step value (=t-1)
|
78 |
+
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
|
79 |
+
|
80 |
+
# 2. compute alphas, betas
|
81 |
+
alpha_prod_t = self.alphas_cumprod[timestep]
|
82 |
+
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
83 |
+
|
84 |
+
beta_prod_t = 1 - alpha_prod_t
|
85 |
+
|
86 |
+
# 3. compute predicted original sample from predicted noise also called
|
87 |
+
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
88 |
+
if self.config.prediction_type == "epsilon":
|
89 |
+
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
90 |
+
pred_epsilon = model_output
|
91 |
+
elif self.config.prediction_type == "sample":
|
92 |
+
pred_original_sample = model_output
|
93 |
+
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
|
94 |
+
elif self.config.prediction_type == "v_prediction":
|
95 |
+
pred_original_sample = (alpha_prod_t ** 0.5) * sample - (beta_prod_t ** 0.5) * model_output
|
96 |
+
pred_epsilon = (alpha_prod_t ** 0.5) * model_output + (beta_prod_t ** 0.5) * sample
|
97 |
+
else:
|
98 |
+
raise ValueError(
|
99 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
100 |
+
" `v_prediction`"
|
101 |
+
)
|
102 |
+
|
103 |
+
# 4. Clip or threshold "predicted x_0"
|
104 |
+
if self.config.thresholding:
|
105 |
+
pred_original_sample = self._threshold_sample(pred_original_sample)
|
106 |
+
elif self.config.clip_sample:
|
107 |
+
pred_original_sample = pred_original_sample.clamp(
|
108 |
+
-self.config.clip_sample_range, self.config.clip_sample_range
|
109 |
+
)
|
110 |
+
|
111 |
+
# 5. compute variance: "sigma_t(η)" -> see formula (16)
|
112 |
+
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
113 |
+
variance = self._get_variance(timestep, prev_timestep)
|
114 |
+
std_dev_t = eta * variance ** (0.5)
|
115 |
+
|
116 |
+
if use_clipped_model_output:
|
117 |
+
# the pred_epsilon is always re-derived from the clipped x_0 in Glide
|
118 |
+
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) # [2, 4, 64, 64]
|
119 |
+
|
120 |
+
# 6. apply guidance following the formula (14) from https://arxiv.org/pdf/2105.05233.pdf
|
121 |
+
if score is not None and guidance_scale > 0.0: # indices指定了应用guidance的位置,此处indices = [0]
|
122 |
+
if indices is not None:
|
123 |
+
# import pdb; pdb.set_trace()
|
124 |
+
assert pred_epsilon[indices].shape == score.shape, "pred_epsilon[indices].shape != score.shape"
|
125 |
+
pred_epsilon[indices] = pred_epsilon[indices] - guidance_scale * (1 - alpha_prod_t) ** (0.5) * score # 只修改了其中第一个[1, 4, 64, 64]的部分
|
126 |
+
else:
|
127 |
+
assert pred_epsilon.shape == score.shape
|
128 |
+
pred_epsilon = pred_epsilon - guidance_scale * (1 - alpha_prod_t) ** (0.5) * score
|
129 |
+
#
|
130 |
+
|
131 |
+
# 7. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
132 |
+
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t ** 2) ** (0.5) * pred_epsilon # [2, 4, 64, 64]
|
133 |
+
|
134 |
+
# 8. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
135 |
+
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction # [2, 4, 64, 64]
|
136 |
+
|
137 |
+
if eta > 0:
|
138 |
+
if variance_noise is not None and generator is not None:
|
139 |
+
raise ValueError(
|
140 |
+
"Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
|
141 |
+
" `variance_noise` stays `None`."
|
142 |
+
)
|
143 |
+
|
144 |
+
if variance_noise is None:
|
145 |
+
variance_noise = randn_tensor(
|
146 |
+
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
|
147 |
+
)
|
148 |
+
variance = std_dev_t * variance_noise # 最后还要再加一些随机噪声
|
149 |
+
|
150 |
+
prev_sample = prev_sample + variance # [2, 4, 64, 64]
|
151 |
+
self.pred_epsilon = pred_epsilon
|
152 |
+
if not return_dict:
|
153 |
+
return (prev_sample,)
|
154 |
+
|
155 |
+
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
|
motionclone/models/sparse_controlnet.py
ADDED
@@ -0,0 +1,593 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# Changes were made to this source code by Yuwei Guo.
|
16 |
+
from dataclasses import dataclass
|
17 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from torch import nn
|
21 |
+
from torch.nn import functional as F
|
22 |
+
|
23 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
24 |
+
from diffusers.utils import BaseOutput, logging
|
25 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
26 |
+
from diffusers.models.modeling_utils import ModelMixin
|
27 |
+
|
28 |
+
|
29 |
+
from .unet_blocks import (
|
30 |
+
CrossAttnDownBlock3D,
|
31 |
+
DownBlock3D,
|
32 |
+
UNetMidBlock3DCrossAttn,
|
33 |
+
get_down_block,
|
34 |
+
)
|
35 |
+
from einops import repeat, rearrange
|
36 |
+
from .resnet import InflatedConv3d
|
37 |
+
|
38 |
+
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
39 |
+
|
40 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
41 |
+
|
42 |
+
|
43 |
+
@dataclass
|
44 |
+
class SparseControlNetOutput(BaseOutput):
|
45 |
+
down_block_res_samples: Tuple[torch.Tensor]
|
46 |
+
mid_block_res_sample: torch.Tensor
|
47 |
+
|
48 |
+
|
49 |
+
class SparseControlNetConditioningEmbedding(nn.Module):
|
50 |
+
def __init__(
|
51 |
+
self,
|
52 |
+
conditioning_embedding_channels: int,
|
53 |
+
conditioning_channels: int = 3,
|
54 |
+
block_out_channels: Tuple[int] = (16, 32, 96, 256),
|
55 |
+
):
|
56 |
+
super().__init__()
|
57 |
+
|
58 |
+
self.conv_in = InflatedConv3d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
|
59 |
+
|
60 |
+
self.blocks = nn.ModuleList([])
|
61 |
+
|
62 |
+
for i in range(len(block_out_channels) - 1):
|
63 |
+
channel_in = block_out_channels[i]
|
64 |
+
channel_out = block_out_channels[i + 1]
|
65 |
+
self.blocks.append(InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1))
|
66 |
+
self.blocks.append(InflatedConv3d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
|
67 |
+
|
68 |
+
self.conv_out = zero_module(
|
69 |
+
InflatedConv3d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
|
70 |
+
)
|
71 |
+
|
72 |
+
def forward(self, conditioning):
|
73 |
+
embedding = self.conv_in(conditioning)
|
74 |
+
embedding = F.silu(embedding)
|
75 |
+
|
76 |
+
for block in self.blocks:
|
77 |
+
embedding = block(embedding)
|
78 |
+
embedding = F.silu(embedding)
|
79 |
+
|
80 |
+
embedding = self.conv_out(embedding)
|
81 |
+
|
82 |
+
return embedding
|
83 |
+
|
84 |
+
|
85 |
+
class SparseControlNetModel(ModelMixin, ConfigMixin):
|
86 |
+
_supports_gradient_checkpointing = True
|
87 |
+
|
88 |
+
@register_to_config
|
89 |
+
def __init__(
|
90 |
+
self,
|
91 |
+
in_channels: int = 4,
|
92 |
+
conditioning_channels: int = 3,
|
93 |
+
flip_sin_to_cos: bool = True,
|
94 |
+
freq_shift: int = 0,
|
95 |
+
down_block_types: Tuple[str] = (
|
96 |
+
"CrossAttnDownBlock2D",
|
97 |
+
"CrossAttnDownBlock2D",
|
98 |
+
"CrossAttnDownBlock2D",
|
99 |
+
"DownBlock2D",
|
100 |
+
),
|
101 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
102 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
103 |
+
layers_per_block: int = 2,
|
104 |
+
downsample_padding: int = 1,
|
105 |
+
mid_block_scale_factor: float = 1,
|
106 |
+
act_fn: str = "silu",
|
107 |
+
norm_num_groups: Optional[int] = 32,
|
108 |
+
norm_eps: float = 1e-5,
|
109 |
+
cross_attention_dim: int = 1280,
|
110 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
111 |
+
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
112 |
+
use_linear_projection: bool = False,
|
113 |
+
class_embed_type: Optional[str] = None,
|
114 |
+
num_class_embeds: Optional[int] = None,
|
115 |
+
upcast_attention: bool = False,
|
116 |
+
resnet_time_scale_shift: str = "default",
|
117 |
+
projection_class_embeddings_input_dim: Optional[int] = None,
|
118 |
+
controlnet_conditioning_channel_order: str = "rgb",
|
119 |
+
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
|
120 |
+
global_pool_conditions: bool = False,
|
121 |
+
|
122 |
+
use_motion_module = True,
|
123 |
+
motion_module_resolutions = ( 1,2,4,8 ),
|
124 |
+
motion_module_mid_block = False,
|
125 |
+
motion_module_type = "Vanilla",
|
126 |
+
motion_module_kwargs = {
|
127 |
+
"num_attention_heads": 8,
|
128 |
+
"num_transformer_block": 1,
|
129 |
+
"attention_block_types": ["Temporal_Self"],
|
130 |
+
"temporal_position_encoding": True,
|
131 |
+
"temporal_position_encoding_max_len": 32,
|
132 |
+
"temporal_attention_dim_div": 1,
|
133 |
+
"causal_temporal_attention": False,
|
134 |
+
},
|
135 |
+
|
136 |
+
concate_conditioning_mask: bool = True,
|
137 |
+
use_simplified_condition_embedding: bool = False,
|
138 |
+
|
139 |
+
set_noisy_sample_input_to_zero: bool = False,
|
140 |
+
):
|
141 |
+
super().__init__()
|
142 |
+
|
143 |
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
144 |
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
145 |
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
146 |
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
147 |
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
148 |
+
# which is why we correct for the naming here.
|
149 |
+
num_attention_heads = num_attention_heads or attention_head_dim
|
150 |
+
|
151 |
+
# Check inputs
|
152 |
+
if len(block_out_channels) != len(down_block_types):
|
153 |
+
raise ValueError(
|
154 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
155 |
+
)
|
156 |
+
|
157 |
+
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
158 |
+
raise ValueError(
|
159 |
+
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
160 |
+
)
|
161 |
+
|
162 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
163 |
+
raise ValueError(
|
164 |
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
165 |
+
)
|
166 |
+
|
167 |
+
# input
|
168 |
+
self.set_noisy_sample_input_to_zero = set_noisy_sample_input_to_zero
|
169 |
+
|
170 |
+
conv_in_kernel = 3
|
171 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
172 |
+
self.conv_in = InflatedConv3d(
|
173 |
+
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
174 |
+
)
|
175 |
+
|
176 |
+
if concate_conditioning_mask:
|
177 |
+
conditioning_channels = conditioning_channels + 1
|
178 |
+
self.concate_conditioning_mask = concate_conditioning_mask
|
179 |
+
|
180 |
+
# control net conditioning embedding
|
181 |
+
if use_simplified_condition_embedding:
|
182 |
+
self.controlnet_cond_embedding = zero_module(
|
183 |
+
InflatedConv3d(conditioning_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding)
|
184 |
+
).to(torch.float16)
|
185 |
+
else:
|
186 |
+
self.controlnet_cond_embedding = SparseControlNetConditioningEmbedding(
|
187 |
+
conditioning_embedding_channels=block_out_channels[0],
|
188 |
+
block_out_channels=conditioning_embedding_out_channels,
|
189 |
+
conditioning_channels=conditioning_channels,
|
190 |
+
).to(torch.float16)
|
191 |
+
self.use_simplified_condition_embedding = use_simplified_condition_embedding
|
192 |
+
|
193 |
+
# time
|
194 |
+
time_embed_dim = block_out_channels[0] * 4
|
195 |
+
|
196 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
197 |
+
timestep_input_dim = block_out_channels[0]
|
198 |
+
|
199 |
+
self.time_embedding = TimestepEmbedding(
|
200 |
+
timestep_input_dim,
|
201 |
+
time_embed_dim,
|
202 |
+
act_fn=act_fn,
|
203 |
+
)
|
204 |
+
|
205 |
+
# class embedding
|
206 |
+
if class_embed_type is None and num_class_embeds is not None:
|
207 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
208 |
+
elif class_embed_type == "timestep":
|
209 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
210 |
+
elif class_embed_type == "identity":
|
211 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
212 |
+
elif class_embed_type == "projection":
|
213 |
+
if projection_class_embeddings_input_dim is None:
|
214 |
+
raise ValueError(
|
215 |
+
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
216 |
+
)
|
217 |
+
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
218 |
+
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
219 |
+
# 2. it projects from an arbitrary input dimension.
|
220 |
+
#
|
221 |
+
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
222 |
+
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
223 |
+
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
224 |
+
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
225 |
+
else:
|
226 |
+
self.class_embedding = None
|
227 |
+
|
228 |
+
|
229 |
+
self.down_blocks = nn.ModuleList([])
|
230 |
+
self.controlnet_down_blocks = nn.ModuleList([])
|
231 |
+
|
232 |
+
if isinstance(only_cross_attention, bool):
|
233 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
234 |
+
|
235 |
+
if isinstance(attention_head_dim, int):
|
236 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
237 |
+
|
238 |
+
if isinstance(num_attention_heads, int):
|
239 |
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
240 |
+
|
241 |
+
# down
|
242 |
+
output_channel = block_out_channels[0]
|
243 |
+
|
244 |
+
controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1)
|
245 |
+
controlnet_block = zero_module(controlnet_block)
|
246 |
+
self.controlnet_down_blocks.append(controlnet_block)
|
247 |
+
|
248 |
+
for i, down_block_type in enumerate(down_block_types):
|
249 |
+
res = 2 ** i
|
250 |
+
input_channel = output_channel
|
251 |
+
output_channel = block_out_channels[i]
|
252 |
+
is_final_block = i == len(block_out_channels) - 1
|
253 |
+
|
254 |
+
down_block = get_down_block(
|
255 |
+
down_block_type,
|
256 |
+
num_layers=layers_per_block,
|
257 |
+
in_channels=input_channel,
|
258 |
+
out_channels=output_channel,
|
259 |
+
temb_channels=time_embed_dim,
|
260 |
+
add_downsample=not is_final_block,
|
261 |
+
resnet_eps=norm_eps,
|
262 |
+
resnet_act_fn=act_fn,
|
263 |
+
resnet_groups=norm_num_groups,
|
264 |
+
cross_attention_dim=cross_attention_dim,
|
265 |
+
attn_num_head_channels=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
266 |
+
downsample_padding=downsample_padding,
|
267 |
+
use_linear_projection=use_linear_projection,
|
268 |
+
only_cross_attention=only_cross_attention[i],
|
269 |
+
upcast_attention=upcast_attention,
|
270 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
271 |
+
|
272 |
+
use_inflated_groupnorm=True,
|
273 |
+
|
274 |
+
use_motion_module=use_motion_module and (res in motion_module_resolutions),
|
275 |
+
motion_module_type=motion_module_type,
|
276 |
+
motion_module_kwargs=motion_module_kwargs,
|
277 |
+
)
|
278 |
+
self.down_blocks.append(down_block)
|
279 |
+
|
280 |
+
for _ in range(layers_per_block):
|
281 |
+
controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1)
|
282 |
+
controlnet_block = zero_module(controlnet_block)
|
283 |
+
self.controlnet_down_blocks.append(controlnet_block)
|
284 |
+
|
285 |
+
if not is_final_block:
|
286 |
+
controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1)
|
287 |
+
controlnet_block = zero_module(controlnet_block)
|
288 |
+
self.controlnet_down_blocks.append(controlnet_block)
|
289 |
+
|
290 |
+
# mid
|
291 |
+
mid_block_channel = block_out_channels[-1]
|
292 |
+
|
293 |
+
controlnet_block = InflatedConv3d(mid_block_channel, mid_block_channel, kernel_size=1)
|
294 |
+
controlnet_block = zero_module(controlnet_block)
|
295 |
+
self.controlnet_mid_block = controlnet_block
|
296 |
+
|
297 |
+
self.mid_block = UNetMidBlock3DCrossAttn(
|
298 |
+
in_channels=mid_block_channel,
|
299 |
+
temb_channels=time_embed_dim,
|
300 |
+
resnet_eps=norm_eps,
|
301 |
+
resnet_act_fn=act_fn,
|
302 |
+
output_scale_factor=mid_block_scale_factor,
|
303 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
304 |
+
cross_attention_dim=cross_attention_dim,
|
305 |
+
attn_num_head_channels=num_attention_heads[-1],
|
306 |
+
resnet_groups=norm_num_groups,
|
307 |
+
use_linear_projection=use_linear_projection,
|
308 |
+
upcast_attention=upcast_attention,
|
309 |
+
|
310 |
+
use_inflated_groupnorm=True,
|
311 |
+
use_motion_module=use_motion_module and motion_module_mid_block,
|
312 |
+
motion_module_type=motion_module_type,
|
313 |
+
motion_module_kwargs=motion_module_kwargs,
|
314 |
+
)
|
315 |
+
|
316 |
+
@classmethod
|
317 |
+
def from_unet(
|
318 |
+
cls,
|
319 |
+
unet: UNet2DConditionModel,
|
320 |
+
controlnet_conditioning_channel_order: str = "rgb",
|
321 |
+
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
|
322 |
+
load_weights_from_unet: bool = True,
|
323 |
+
|
324 |
+
controlnet_additional_kwargs: dict = {},
|
325 |
+
):
|
326 |
+
controlnet = cls(
|
327 |
+
in_channels=unet.config.in_channels,
|
328 |
+
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
329 |
+
freq_shift=unet.config.freq_shift,
|
330 |
+
down_block_types=unet.config.down_block_types,
|
331 |
+
only_cross_attention=unet.config.only_cross_attention,
|
332 |
+
block_out_channels=unet.config.block_out_channels,
|
333 |
+
layers_per_block=unet.config.layers_per_block,
|
334 |
+
downsample_padding=unet.config.downsample_padding,
|
335 |
+
mid_block_scale_factor=unet.config.mid_block_scale_factor,
|
336 |
+
act_fn=unet.config.act_fn,
|
337 |
+
norm_num_groups=unet.config.norm_num_groups,
|
338 |
+
norm_eps=unet.config.norm_eps,
|
339 |
+
cross_attention_dim=unet.config.cross_attention_dim,
|
340 |
+
attention_head_dim=unet.config.attention_head_dim,
|
341 |
+
num_attention_heads=unet.config.num_attention_heads,
|
342 |
+
use_linear_projection=unet.config.use_linear_projection,
|
343 |
+
class_embed_type=unet.config.class_embed_type,
|
344 |
+
num_class_embeds=unet.config.num_class_embeds,
|
345 |
+
upcast_attention=unet.config.upcast_attention,
|
346 |
+
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
|
347 |
+
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
|
348 |
+
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
|
349 |
+
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
|
350 |
+
|
351 |
+
**controlnet_additional_kwargs,
|
352 |
+
)
|
353 |
+
|
354 |
+
if load_weights_from_unet:
|
355 |
+
m, u = controlnet.conv_in.load_state_dict(cls.image_layer_filter(unet.conv_in.state_dict()), strict=False)
|
356 |
+
assert len(u) == 0
|
357 |
+
m, u = controlnet.time_proj.load_state_dict(cls.image_layer_filter(unet.time_proj.state_dict()), strict=False)
|
358 |
+
assert len(u) == 0
|
359 |
+
m, u = controlnet.time_embedding.load_state_dict(cls.image_layer_filter(unet.time_embedding.state_dict()), strict=False)
|
360 |
+
assert len(u) == 0
|
361 |
+
|
362 |
+
if controlnet.class_embedding:
|
363 |
+
m, u = controlnet.class_embedding.load_state_dict(cls.image_layer_filter(unet.class_embedding.state_dict()), strict=False)
|
364 |
+
assert len(u) == 0
|
365 |
+
m, u = controlnet.down_blocks.load_state_dict(cls.image_layer_filter(unet.down_blocks.state_dict()), strict=False)
|
366 |
+
assert len(u) == 0
|
367 |
+
m, u = controlnet.mid_block.load_state_dict(cls.image_layer_filter(unet.mid_block.state_dict()), strict=False)
|
368 |
+
assert len(u) == 0
|
369 |
+
|
370 |
+
return controlnet
|
371 |
+
|
372 |
+
@staticmethod
|
373 |
+
def image_layer_filter(state_dict):
|
374 |
+
new_state_dict = {}
|
375 |
+
for name, param in state_dict.items():
|
376 |
+
if "motion_modules." in name or "lora" in name: continue
|
377 |
+
new_state_dict[name] = param
|
378 |
+
return new_state_dict
|
379 |
+
|
380 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
381 |
+
def set_attention_slice(self, slice_size):
|
382 |
+
r"""
|
383 |
+
Enable sliced attention computation.
|
384 |
+
|
385 |
+
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
386 |
+
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
387 |
+
|
388 |
+
Args:
|
389 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
390 |
+
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
391 |
+
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
392 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
393 |
+
must be a multiple of `slice_size`.
|
394 |
+
"""
|
395 |
+
sliceable_head_dims = []
|
396 |
+
|
397 |
+
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
398 |
+
if hasattr(module, "set_attention_slice"):
|
399 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
400 |
+
|
401 |
+
for child in module.children():
|
402 |
+
fn_recursive_retrieve_sliceable_dims(child)
|
403 |
+
|
404 |
+
# retrieve number of attention layers
|
405 |
+
for module in self.children():
|
406 |
+
fn_recursive_retrieve_sliceable_dims(module)
|
407 |
+
|
408 |
+
num_sliceable_layers = len(sliceable_head_dims)
|
409 |
+
|
410 |
+
if slice_size == "auto":
|
411 |
+
# half the attention head size is usually a good trade-off between
|
412 |
+
# speed and memory
|
413 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
414 |
+
elif slice_size == "max":
|
415 |
+
# make smallest slice possible
|
416 |
+
slice_size = num_sliceable_layers * [1]
|
417 |
+
|
418 |
+
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
419 |
+
|
420 |
+
if len(slice_size) != len(sliceable_head_dims):
|
421 |
+
raise ValueError(
|
422 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
423 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
424 |
+
)
|
425 |
+
|
426 |
+
for i in range(len(slice_size)):
|
427 |
+
size = slice_size[i]
|
428 |
+
dim = sliceable_head_dims[i]
|
429 |
+
if size is not None and size > dim:
|
430 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
431 |
+
|
432 |
+
# Recursively walk through all the children.
|
433 |
+
# Any children which exposes the set_attention_slice method
|
434 |
+
# gets the message
|
435 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
436 |
+
if hasattr(module, "set_attention_slice"):
|
437 |
+
module.set_attention_slice(slice_size.pop())
|
438 |
+
|
439 |
+
for child in module.children():
|
440 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
441 |
+
|
442 |
+
reversed_slice_size = list(reversed(slice_size))
|
443 |
+
for module in self.children():
|
444 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
445 |
+
|
446 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
447 |
+
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
|
448 |
+
module.gradient_checkpointing = value
|
449 |
+
|
450 |
+
def forward(
|
451 |
+
self,
|
452 |
+
sample: torch.FloatTensor,
|
453 |
+
timestep: Union[torch.Tensor, float, int],
|
454 |
+
encoder_hidden_states: torch.Tensor,
|
455 |
+
|
456 |
+
controlnet_cond: torch.FloatTensor,
|
457 |
+
conditioning_mask: Optional[torch.FloatTensor] = None,
|
458 |
+
|
459 |
+
conditioning_scale: float = 1.0,
|
460 |
+
class_labels: Optional[torch.Tensor] = None,
|
461 |
+
attention_mask: Optional[torch.Tensor] = None,
|
462 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
463 |
+
guess_mode: bool = False,
|
464 |
+
return_dict: bool = True,
|
465 |
+
) -> Union[SparseControlNetOutput, Tuple]:
|
466 |
+
|
467 |
+
# set input noise to zero
|
468 |
+
# if self.set_noisy_sample_input_to_zero:
|
469 |
+
# sample = torch.zeros_like(sample).to(sample.device)
|
470 |
+
|
471 |
+
# prepare attention_mask
|
472 |
+
if attention_mask is not None:
|
473 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
474 |
+
attention_mask = attention_mask.unsqueeze(1)
|
475 |
+
|
476 |
+
# 1. time
|
477 |
+
timesteps = timestep
|
478 |
+
if not torch.is_tensor(timesteps):
|
479 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
480 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
481 |
+
is_mps = sample.device.type == "mps"
|
482 |
+
if isinstance(timestep, float):
|
483 |
+
dtype = torch.float32 if is_mps else torch.float64
|
484 |
+
else:
|
485 |
+
dtype = torch.int32 if is_mps else torch.int64
|
486 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
487 |
+
elif len(timesteps.shape) == 0:
|
488 |
+
timesteps = timesteps[None].to(sample.device)
|
489 |
+
|
490 |
+
timesteps = timesteps.repeat(sample.shape[0] // timesteps.shape[0])
|
491 |
+
encoder_hidden_states = encoder_hidden_states.repeat(sample.shape[0] // encoder_hidden_states.shape[0], 1, 1)
|
492 |
+
|
493 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
494 |
+
timesteps = timesteps.expand(sample.shape[0])
|
495 |
+
|
496 |
+
t_emb = self.time_proj(timesteps)
|
497 |
+
|
498 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
499 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
500 |
+
# there might be better ways to encapsulate this.
|
501 |
+
t_emb = t_emb.to(dtype=self.dtype)
|
502 |
+
emb = self.time_embedding(t_emb)
|
503 |
+
|
504 |
+
if self.class_embedding is not None:
|
505 |
+
if class_labels is None:
|
506 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
507 |
+
|
508 |
+
if self.config.class_embed_type == "timestep":
|
509 |
+
class_labels = self.time_proj(class_labels)
|
510 |
+
|
511 |
+
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
512 |
+
emb = emb + class_emb
|
513 |
+
|
514 |
+
# 2. pre-process
|
515 |
+
# equal to set input noise to zero
|
516 |
+
if self.set_noisy_sample_input_to_zero:
|
517 |
+
shape = sample.shape
|
518 |
+
sample = self.conv_in.bias.reshape(1,-1,1,1,1).expand(shape[0],-1,shape[2],shape[3],shape[4])
|
519 |
+
else:
|
520 |
+
sample = self.conv_in(sample)
|
521 |
+
|
522 |
+
if self.concate_conditioning_mask:
|
523 |
+
controlnet_cond = torch.cat([controlnet_cond, conditioning_mask], dim=1).to(torch.float16)
|
524 |
+
# import pdb; pdb.set_trace()
|
525 |
+
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
|
526 |
+
|
527 |
+
sample = sample + controlnet_cond
|
528 |
+
|
529 |
+
# 3. down
|
530 |
+
down_block_res_samples = (sample,)
|
531 |
+
for downsample_block in self.down_blocks:
|
532 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
533 |
+
sample, res_samples = downsample_block(
|
534 |
+
hidden_states=sample,
|
535 |
+
temb=emb,
|
536 |
+
encoder_hidden_states=encoder_hidden_states,
|
537 |
+
attention_mask=attention_mask,
|
538 |
+
# cross_attention_kwargs=cross_attention_kwargs,
|
539 |
+
)
|
540 |
+
else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
541 |
+
|
542 |
+
down_block_res_samples += res_samples
|
543 |
+
|
544 |
+
# 4. mid
|
545 |
+
if self.mid_block is not None:
|
546 |
+
sample = self.mid_block(
|
547 |
+
sample,
|
548 |
+
emb,
|
549 |
+
encoder_hidden_states=encoder_hidden_states,
|
550 |
+
attention_mask=attention_mask,
|
551 |
+
# cross_attention_kwargs=cross_attention_kwargs,
|
552 |
+
)
|
553 |
+
|
554 |
+
# 5. controlnet blocks
|
555 |
+
controlnet_down_block_res_samples = ()
|
556 |
+
|
557 |
+
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
|
558 |
+
down_block_res_sample = controlnet_block(down_block_res_sample)
|
559 |
+
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
|
560 |
+
|
561 |
+
down_block_res_samples = controlnet_down_block_res_samples
|
562 |
+
|
563 |
+
mid_block_res_sample = self.controlnet_mid_block(sample)
|
564 |
+
|
565 |
+
# 6. scaling
|
566 |
+
if guess_mode and not self.config.global_pool_conditions:
|
567 |
+
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
|
568 |
+
|
569 |
+
scales = scales * conditioning_scale
|
570 |
+
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
|
571 |
+
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
|
572 |
+
else:
|
573 |
+
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
574 |
+
mid_block_res_sample = mid_block_res_sample * conditioning_scale
|
575 |
+
|
576 |
+
if self.config.global_pool_conditions:
|
577 |
+
down_block_res_samples = [
|
578 |
+
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
|
579 |
+
]
|
580 |
+
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
|
581 |
+
|
582 |
+
if not return_dict:
|
583 |
+
return (down_block_res_samples, mid_block_res_sample)
|
584 |
+
|
585 |
+
return SparseControlNetOutput(
|
586 |
+
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
|
587 |
+
)
|
588 |
+
|
589 |
+
|
590 |
+
def zero_module(module):
|
591 |
+
for p in module.parameters():
|
592 |
+
nn.init.zeros_(p)
|
593 |
+
return module
|
motionclone/models/unet.py
ADDED
@@ -0,0 +1,515 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
|
2 |
+
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from typing import List, Optional, Tuple, Union
|
5 |
+
|
6 |
+
import os
|
7 |
+
import json
|
8 |
+
import pdb
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.utils.checkpoint
|
13 |
+
|
14 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
15 |
+
from diffusers.models.modeling_utils import ModelMixin
|
16 |
+
from diffusers.utils import BaseOutput, logging
|
17 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
18 |
+
from .unet_blocks import (
|
19 |
+
CrossAttnDownBlock3D,
|
20 |
+
CrossAttnUpBlock3D,
|
21 |
+
DownBlock3D,
|
22 |
+
UNetMidBlock3DCrossAttn,
|
23 |
+
UpBlock3D,
|
24 |
+
get_down_block,
|
25 |
+
get_up_block,
|
26 |
+
)
|
27 |
+
from .resnet import InflatedConv3d, InflatedGroupNorm
|
28 |
+
|
29 |
+
|
30 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
31 |
+
|
32 |
+
|
33 |
+
@dataclass
|
34 |
+
class UNet3DConditionOutput(BaseOutput):
|
35 |
+
sample: torch.FloatTensor
|
36 |
+
|
37 |
+
|
38 |
+
class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
39 |
+
_supports_gradient_checkpointing = True
|
40 |
+
|
41 |
+
@register_to_config
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
sample_size: Optional[int] = None,
|
45 |
+
in_channels: int = 4,
|
46 |
+
out_channels: int = 4,
|
47 |
+
center_input_sample: bool = False,
|
48 |
+
flip_sin_to_cos: bool = True,
|
49 |
+
freq_shift: int = 0,
|
50 |
+
down_block_types: Tuple[str] = (
|
51 |
+
"CrossAttnDownBlock3D",
|
52 |
+
"CrossAttnDownBlock3D",
|
53 |
+
"CrossAttnDownBlock3D",
|
54 |
+
"DownBlock3D",
|
55 |
+
),
|
56 |
+
mid_block_type: str = "UNetMidBlock3DCrossAttn",
|
57 |
+
up_block_types: Tuple[str] = ( # 第一个不带有CrossAttn,后面三个带有CrossAttn
|
58 |
+
"UpBlock3D",
|
59 |
+
"CrossAttnUpBlock3D",
|
60 |
+
"CrossAttnUpBlock3D",
|
61 |
+
"CrossAttnUpBlock3D"
|
62 |
+
),
|
63 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
64 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
65 |
+
layers_per_block: int = 2,
|
66 |
+
downsample_padding: int = 1,
|
67 |
+
mid_block_scale_factor: float = 1,
|
68 |
+
act_fn: str = "silu",
|
69 |
+
norm_num_groups: int = 32,
|
70 |
+
norm_eps: float = 1e-5,
|
71 |
+
cross_attention_dim: int = 1280,
|
72 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
73 |
+
dual_cross_attention: bool = False,
|
74 |
+
use_linear_projection: bool = False,
|
75 |
+
class_embed_type: Optional[str] = None,
|
76 |
+
num_class_embeds: Optional[int] = None,
|
77 |
+
upcast_attention: bool = False,
|
78 |
+
resnet_time_scale_shift: str = "default",
|
79 |
+
|
80 |
+
use_inflated_groupnorm=False,
|
81 |
+
|
82 |
+
# Additional
|
83 |
+
use_motion_module = False,
|
84 |
+
motion_module_resolutions = ( 1,2,4,8 ),
|
85 |
+
motion_module_mid_block = False,
|
86 |
+
motion_module_decoder_only = False,
|
87 |
+
motion_module_type = None,
|
88 |
+
motion_module_kwargs = {},
|
89 |
+
unet_use_cross_frame_attention = False,
|
90 |
+
unet_use_temporal_attention = False,
|
91 |
+
):
|
92 |
+
super().__init__()
|
93 |
+
|
94 |
+
self.sample_size = sample_size
|
95 |
+
time_embed_dim = block_out_channels[0] * 4
|
96 |
+
|
97 |
+
# input
|
98 |
+
self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
|
99 |
+
|
100 |
+
# time
|
101 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
102 |
+
timestep_input_dim = block_out_channels[0]
|
103 |
+
|
104 |
+
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
105 |
+
|
106 |
+
# class embedding
|
107 |
+
if class_embed_type is None and num_class_embeds is not None:
|
108 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
109 |
+
elif class_embed_type == "timestep":
|
110 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
111 |
+
elif class_embed_type == "identity":
|
112 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
113 |
+
else:
|
114 |
+
self.class_embedding = None
|
115 |
+
|
116 |
+
self.down_blocks = nn.ModuleList([])
|
117 |
+
self.mid_block = None
|
118 |
+
self.up_blocks = nn.ModuleList([])
|
119 |
+
|
120 |
+
if isinstance(only_cross_attention, bool):
|
121 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
122 |
+
|
123 |
+
if isinstance(attention_head_dim, int):
|
124 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
125 |
+
|
126 |
+
# down
|
127 |
+
output_channel = block_out_channels[0]
|
128 |
+
for i, down_block_type in enumerate(down_block_types):
|
129 |
+
res = 2 ** i
|
130 |
+
input_channel = output_channel
|
131 |
+
output_channel = block_out_channels[i]
|
132 |
+
is_final_block = i == len(block_out_channels) - 1
|
133 |
+
|
134 |
+
down_block = get_down_block(
|
135 |
+
down_block_type,
|
136 |
+
num_layers=layers_per_block,
|
137 |
+
in_channels=input_channel,
|
138 |
+
out_channels=output_channel,
|
139 |
+
temb_channels=time_embed_dim,
|
140 |
+
add_downsample=not is_final_block,
|
141 |
+
resnet_eps=norm_eps,
|
142 |
+
resnet_act_fn=act_fn,
|
143 |
+
resnet_groups=norm_num_groups,
|
144 |
+
cross_attention_dim=cross_attention_dim,
|
145 |
+
attn_num_head_channels=attention_head_dim[i],
|
146 |
+
downsample_padding=downsample_padding,
|
147 |
+
dual_cross_attention=dual_cross_attention,
|
148 |
+
use_linear_projection=use_linear_projection,
|
149 |
+
only_cross_attention=only_cross_attention[i],
|
150 |
+
upcast_attention=upcast_attention,
|
151 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
152 |
+
|
153 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
154 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
155 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
156 |
+
|
157 |
+
use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only),
|
158 |
+
motion_module_type=motion_module_type,
|
159 |
+
motion_module_kwargs=motion_module_kwargs,
|
160 |
+
)
|
161 |
+
self.down_blocks.append(down_block)
|
162 |
+
|
163 |
+
# mid
|
164 |
+
if mid_block_type == "UNetMidBlock3DCrossAttn":
|
165 |
+
self.mid_block = UNetMidBlock3DCrossAttn(
|
166 |
+
in_channels=block_out_channels[-1],
|
167 |
+
temb_channels=time_embed_dim,
|
168 |
+
resnet_eps=norm_eps,
|
169 |
+
resnet_act_fn=act_fn,
|
170 |
+
output_scale_factor=mid_block_scale_factor,
|
171 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
172 |
+
cross_attention_dim=cross_attention_dim,
|
173 |
+
attn_num_head_channels=attention_head_dim[-1],
|
174 |
+
resnet_groups=norm_num_groups,
|
175 |
+
dual_cross_attention=dual_cross_attention,
|
176 |
+
use_linear_projection=use_linear_projection,
|
177 |
+
upcast_attention=upcast_attention,
|
178 |
+
|
179 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
180 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
181 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
182 |
+
|
183 |
+
use_motion_module=use_motion_module and motion_module_mid_block,
|
184 |
+
motion_module_type=motion_module_type,
|
185 |
+
motion_module_kwargs=motion_module_kwargs,
|
186 |
+
)
|
187 |
+
else:
|
188 |
+
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
|
189 |
+
|
190 |
+
# count how many layers upsample the videos
|
191 |
+
self.num_upsamplers = 0
|
192 |
+
|
193 |
+
# up
|
194 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
195 |
+
reversed_attention_head_dim = list(reversed(attention_head_dim))
|
196 |
+
only_cross_attention = list(reversed(only_cross_attention))
|
197 |
+
output_channel = reversed_block_out_channels[0]
|
198 |
+
for i, up_block_type in enumerate(up_block_types):
|
199 |
+
res = 2 ** (3 - i)
|
200 |
+
is_final_block = i == len(block_out_channels) - 1
|
201 |
+
|
202 |
+
prev_output_channel = output_channel
|
203 |
+
output_channel = reversed_block_out_channels[i]
|
204 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
205 |
+
|
206 |
+
# add upsample block for all BUT final layer
|
207 |
+
if not is_final_block:
|
208 |
+
add_upsample = True
|
209 |
+
self.num_upsamplers += 1
|
210 |
+
else:
|
211 |
+
add_upsample = False
|
212 |
+
|
213 |
+
up_block = get_up_block(
|
214 |
+
up_block_type,
|
215 |
+
num_layers=layers_per_block + 1,
|
216 |
+
in_channels=input_channel,
|
217 |
+
out_channels=output_channel,
|
218 |
+
prev_output_channel=prev_output_channel,
|
219 |
+
temb_channels=time_embed_dim,
|
220 |
+
add_upsample=add_upsample,
|
221 |
+
resnet_eps=norm_eps,
|
222 |
+
resnet_act_fn=act_fn,
|
223 |
+
resnet_groups=norm_num_groups,
|
224 |
+
cross_attention_dim=cross_attention_dim,
|
225 |
+
attn_num_head_channels=reversed_attention_head_dim[i],
|
226 |
+
dual_cross_attention=dual_cross_attention,
|
227 |
+
use_linear_projection=use_linear_projection,
|
228 |
+
only_cross_attention=only_cross_attention[i],
|
229 |
+
upcast_attention=upcast_attention,
|
230 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
231 |
+
|
232 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
233 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
234 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
235 |
+
|
236 |
+
use_motion_module=use_motion_module and (res in motion_module_resolutions),
|
237 |
+
motion_module_type=motion_module_type,
|
238 |
+
motion_module_kwargs=motion_module_kwargs,
|
239 |
+
)
|
240 |
+
self.up_blocks.append(up_block)
|
241 |
+
prev_output_channel = output_channel
|
242 |
+
|
243 |
+
# out
|
244 |
+
if use_inflated_groupnorm:
|
245 |
+
self.conv_norm_out = InflatedGroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
|
246 |
+
else:
|
247 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
|
248 |
+
self.conv_act = nn.SiLU()
|
249 |
+
self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
250 |
+
|
251 |
+
def set_attention_slice(self, slice_size):
|
252 |
+
r"""
|
253 |
+
Enable sliced attention computation.
|
254 |
+
|
255 |
+
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
256 |
+
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
257 |
+
|
258 |
+
Args:
|
259 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
260 |
+
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
261 |
+
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
|
262 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
263 |
+
must be a multiple of `slice_size`.
|
264 |
+
"""
|
265 |
+
sliceable_head_dims = []
|
266 |
+
|
267 |
+
def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
|
268 |
+
if hasattr(module, "set_attention_slice"):
|
269 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
270 |
+
|
271 |
+
for child in module.children():
|
272 |
+
fn_recursive_retrieve_slicable_dims(child)
|
273 |
+
|
274 |
+
# retrieve number of attention layers
|
275 |
+
for module in self.children():
|
276 |
+
fn_recursive_retrieve_slicable_dims(module)
|
277 |
+
|
278 |
+
num_slicable_layers = len(sliceable_head_dims)
|
279 |
+
|
280 |
+
if slice_size == "auto":
|
281 |
+
# half the attention head size is usually a good trade-off between
|
282 |
+
# speed and memory
|
283 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
284 |
+
elif slice_size == "max":
|
285 |
+
# make smallest slice possible
|
286 |
+
slice_size = num_slicable_layers * [1]
|
287 |
+
|
288 |
+
slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
289 |
+
|
290 |
+
if len(slice_size) != len(sliceable_head_dims):
|
291 |
+
raise ValueError(
|
292 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
293 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
294 |
+
)
|
295 |
+
|
296 |
+
for i in range(len(slice_size)):
|
297 |
+
size = slice_size[i]
|
298 |
+
dim = sliceable_head_dims[i]
|
299 |
+
if size is not None and size > dim:
|
300 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
301 |
+
|
302 |
+
# Recursively walk through all the children.
|
303 |
+
# Any children which exposes the set_attention_slice method
|
304 |
+
# gets the message
|
305 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
306 |
+
if hasattr(module, "set_attention_slice"):
|
307 |
+
module.set_attention_slice(slice_size.pop())
|
308 |
+
|
309 |
+
for child in module.children():
|
310 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
311 |
+
|
312 |
+
reversed_slice_size = list(reversed(slice_size))
|
313 |
+
for module in self.children():
|
314 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
315 |
+
|
316 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
317 |
+
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
|
318 |
+
module.gradient_checkpointing = value
|
319 |
+
|
320 |
+
def forward(
|
321 |
+
self,
|
322 |
+
sample: torch.FloatTensor,
|
323 |
+
timestep: Union[torch.Tensor, float, int],
|
324 |
+
encoder_hidden_states: torch.Tensor,
|
325 |
+
class_labels: Optional[torch.Tensor] = None,
|
326 |
+
attention_mask: Optional[torch.Tensor] = None,
|
327 |
+
|
328 |
+
# support controlnet
|
329 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
330 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
331 |
+
|
332 |
+
return_dict: bool = True,
|
333 |
+
) -> Union[UNet3DConditionOutput, Tuple]:
|
334 |
+
r"""
|
335 |
+
Args:
|
336 |
+
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
337 |
+
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
338 |
+
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
339 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
340 |
+
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
341 |
+
|
342 |
+
Returns:
|
343 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
344 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
345 |
+
returning a tuple, the first element is the sample tensor.
|
346 |
+
"""
|
347 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
348 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
349 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
350 |
+
# on the fly if necessary.
|
351 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
352 |
+
|
353 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
354 |
+
forward_upsample_size = False
|
355 |
+
upsample_size = None
|
356 |
+
|
357 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
358 |
+
logger.info("Forward upsample size to force interpolation output size.")
|
359 |
+
forward_upsample_size = True
|
360 |
+
|
361 |
+
# prepare attention_mask
|
362 |
+
if attention_mask is not None:
|
363 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
364 |
+
attention_mask = attention_mask.unsqueeze(1)
|
365 |
+
|
366 |
+
# center input if necessary
|
367 |
+
if self.config.center_input_sample:
|
368 |
+
sample = 2 * sample - 1.0
|
369 |
+
|
370 |
+
# time
|
371 |
+
timesteps = timestep
|
372 |
+
if not torch.is_tensor(timesteps):
|
373 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
374 |
+
is_mps = sample.device.type == "mps"
|
375 |
+
if isinstance(timestep, float):
|
376 |
+
dtype = torch.float32 if is_mps else torch.float64
|
377 |
+
else:
|
378 |
+
dtype = torch.int32 if is_mps else torch.int64
|
379 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
380 |
+
elif len(timesteps.shape) == 0:
|
381 |
+
timesteps = timesteps[None].to(sample.device)
|
382 |
+
|
383 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
384 |
+
timesteps = timesteps.expand(sample.shape[0])
|
385 |
+
|
386 |
+
t_emb = self.time_proj(timesteps)
|
387 |
+
|
388 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
389 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
390 |
+
# there might be better ways to encapsulate this.
|
391 |
+
t_emb = t_emb.to(dtype=self.dtype)
|
392 |
+
emb = self.time_embedding(t_emb)
|
393 |
+
|
394 |
+
if self.class_embedding is not None:
|
395 |
+
if class_labels is None:
|
396 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
397 |
+
|
398 |
+
if self.config.class_embed_type == "timestep":
|
399 |
+
class_labels = self.time_proj(class_labels)
|
400 |
+
|
401 |
+
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
402 |
+
emb = emb + class_emb
|
403 |
+
|
404 |
+
# pre-process
|
405 |
+
sample = self.conv_in(sample)
|
406 |
+
|
407 |
+
# down
|
408 |
+
down_block_res_samples = (sample,)
|
409 |
+
for downsample_block in self.down_blocks:
|
410 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
411 |
+
sample, res_samples = downsample_block(
|
412 |
+
hidden_states=sample,
|
413 |
+
temb=emb,
|
414 |
+
encoder_hidden_states=encoder_hidden_states,
|
415 |
+
attention_mask=attention_mask,
|
416 |
+
)
|
417 |
+
else:
|
418 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states)
|
419 |
+
|
420 |
+
down_block_res_samples += res_samples
|
421 |
+
|
422 |
+
# support controlnet
|
423 |
+
down_block_res_samples = list(down_block_res_samples)
|
424 |
+
if down_block_additional_residuals is not None:
|
425 |
+
for i, down_block_additional_residual in enumerate(down_block_additional_residuals):
|
426 |
+
if down_block_additional_residual.dim() == 4: # boardcast
|
427 |
+
down_block_additional_residual = down_block_additional_residual.unsqueeze(2)
|
428 |
+
down_block_res_samples[i] = down_block_res_samples[i] + down_block_additional_residual
|
429 |
+
|
430 |
+
# mid
|
431 |
+
sample = self.mid_block(
|
432 |
+
sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
433 |
+
)
|
434 |
+
|
435 |
+
# support controlnet
|
436 |
+
if mid_block_additional_residual is not None:
|
437 |
+
if mid_block_additional_residual.dim() == 4: # boardcast
|
438 |
+
mid_block_additional_residual = mid_block_additional_residual.unsqueeze(2)
|
439 |
+
sample = sample + mid_block_additional_residual
|
440 |
+
|
441 |
+
# up
|
442 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
443 |
+
is_final_block = i == len(self.up_blocks) - 1
|
444 |
+
|
445 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
446 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
447 |
+
|
448 |
+
# if we have not reached the final block and need to forward the
|
449 |
+
# upsample size, we do it here
|
450 |
+
if not is_final_block and forward_upsample_size:
|
451 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
452 |
+
|
453 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
454 |
+
sample = upsample_block(
|
455 |
+
hidden_states=sample,
|
456 |
+
temb=emb,
|
457 |
+
res_hidden_states_tuple=res_samples,
|
458 |
+
encoder_hidden_states=encoder_hidden_states,
|
459 |
+
upsample_size=upsample_size,
|
460 |
+
attention_mask=attention_mask,
|
461 |
+
)
|
462 |
+
else:
|
463 |
+
sample = upsample_block(
|
464 |
+
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states,
|
465 |
+
)
|
466 |
+
|
467 |
+
# post-process
|
468 |
+
sample = self.conv_norm_out(sample)
|
469 |
+
sample = self.conv_act(sample)
|
470 |
+
sample = self.conv_out(sample)
|
471 |
+
|
472 |
+
if not return_dict:
|
473 |
+
return (sample,)
|
474 |
+
|
475 |
+
return UNet3DConditionOutput(sample=sample)
|
476 |
+
|
477 |
+
@classmethod
|
478 |
+
def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None):
|
479 |
+
if subfolder is not None:
|
480 |
+
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
|
481 |
+
print(f"loaded 3D unet's pretrained weights from {pretrained_model_path} ...")
|
482 |
+
|
483 |
+
config_file = os.path.join(pretrained_model_path, 'config.json')
|
484 |
+
if not os.path.isfile(config_file):
|
485 |
+
raise RuntimeError(f"{config_file} does not exist")
|
486 |
+
with open(config_file, "r") as f:
|
487 |
+
config = json.load(f)
|
488 |
+
config["_class_name"] = cls.__name__
|
489 |
+
config["down_block_types"] = [
|
490 |
+
"CrossAttnDownBlock3D",
|
491 |
+
"CrossAttnDownBlock3D",
|
492 |
+
"CrossAttnDownBlock3D",
|
493 |
+
"DownBlock3D"
|
494 |
+
]
|
495 |
+
config["up_block_types"] = [
|
496 |
+
"UpBlock3D",
|
497 |
+
"CrossAttnUpBlock3D",
|
498 |
+
"CrossAttnUpBlock3D",
|
499 |
+
"CrossAttnUpBlock3D"
|
500 |
+
]
|
501 |
+
|
502 |
+
from diffusers.utils import WEIGHTS_NAME
|
503 |
+
model = cls.from_config(config, **unet_additional_kwargs)
|
504 |
+
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
505 |
+
if not os.path.isfile(model_file):
|
506 |
+
raise RuntimeError(f"{model_file} does not exist")
|
507 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
508 |
+
|
509 |
+
m, u = model.load_state_dict(state_dict, strict=False)
|
510 |
+
print(f"### motion keys will be loaded: {len(m)}; \n### unexpected keys: {len(u)};")
|
511 |
+
|
512 |
+
params = [p.numel() if "motion_modules." in n else 0 for n, p in model.named_parameters()]
|
513 |
+
print(f"### Motion Module Parameters: {sum(params) / 1e6} M")
|
514 |
+
|
515 |
+
return model
|
motionclone/models/unet_blocks.py
ADDED
@@ -0,0 +1,760 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
from .attention import Transformer3DModel
|
7 |
+
from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
|
8 |
+
from .motion_module import get_motion_module
|
9 |
+
|
10 |
+
import pdb
|
11 |
+
|
12 |
+
def get_down_block(
|
13 |
+
down_block_type,
|
14 |
+
num_layers,
|
15 |
+
in_channels,
|
16 |
+
out_channels,
|
17 |
+
temb_channels,
|
18 |
+
add_downsample,
|
19 |
+
resnet_eps,
|
20 |
+
resnet_act_fn,
|
21 |
+
attn_num_head_channels,
|
22 |
+
resnet_groups=None,
|
23 |
+
cross_attention_dim=None,
|
24 |
+
downsample_padding=None,
|
25 |
+
dual_cross_attention=False,
|
26 |
+
use_linear_projection=False,
|
27 |
+
only_cross_attention=False,
|
28 |
+
upcast_attention=False,
|
29 |
+
resnet_time_scale_shift="default",
|
30 |
+
|
31 |
+
unet_use_cross_frame_attention=False,
|
32 |
+
unet_use_temporal_attention=False,
|
33 |
+
use_inflated_groupnorm=False,
|
34 |
+
|
35 |
+
use_motion_module=None,
|
36 |
+
|
37 |
+
motion_module_type=None,
|
38 |
+
motion_module_kwargs=None,
|
39 |
+
):
|
40 |
+
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
|
41 |
+
if down_block_type == "DownBlock3D":
|
42 |
+
return DownBlock3D(
|
43 |
+
num_layers=num_layers,
|
44 |
+
in_channels=in_channels,
|
45 |
+
out_channels=out_channels,
|
46 |
+
temb_channels=temb_channels,
|
47 |
+
add_downsample=add_downsample,
|
48 |
+
resnet_eps=resnet_eps,
|
49 |
+
resnet_act_fn=resnet_act_fn,
|
50 |
+
resnet_groups=resnet_groups,
|
51 |
+
downsample_padding=downsample_padding,
|
52 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
53 |
+
|
54 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
55 |
+
|
56 |
+
use_motion_module=use_motion_module,
|
57 |
+
motion_module_type=motion_module_type,
|
58 |
+
motion_module_kwargs=motion_module_kwargs,
|
59 |
+
)
|
60 |
+
elif down_block_type == "CrossAttnDownBlock3D":
|
61 |
+
if cross_attention_dim is None:
|
62 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
|
63 |
+
return CrossAttnDownBlock3D(
|
64 |
+
num_layers=num_layers,
|
65 |
+
in_channels=in_channels,
|
66 |
+
out_channels=out_channels,
|
67 |
+
temb_channels=temb_channels,
|
68 |
+
add_downsample=add_downsample,
|
69 |
+
resnet_eps=resnet_eps,
|
70 |
+
resnet_act_fn=resnet_act_fn,
|
71 |
+
resnet_groups=resnet_groups,
|
72 |
+
downsample_padding=downsample_padding,
|
73 |
+
cross_attention_dim=cross_attention_dim,
|
74 |
+
attn_num_head_channels=attn_num_head_channels,
|
75 |
+
dual_cross_attention=dual_cross_attention,
|
76 |
+
use_linear_projection=use_linear_projection,
|
77 |
+
only_cross_attention=only_cross_attention,
|
78 |
+
upcast_attention=upcast_attention,
|
79 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
80 |
+
|
81 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
82 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
83 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
84 |
+
|
85 |
+
use_motion_module=use_motion_module,
|
86 |
+
motion_module_type=motion_module_type,
|
87 |
+
motion_module_kwargs=motion_module_kwargs,
|
88 |
+
)
|
89 |
+
raise ValueError(f"{down_block_type} does not exist.")
|
90 |
+
|
91 |
+
|
92 |
+
def get_up_block(
|
93 |
+
up_block_type,
|
94 |
+
num_layers,
|
95 |
+
in_channels,
|
96 |
+
out_channels,
|
97 |
+
prev_output_channel,
|
98 |
+
temb_channels,
|
99 |
+
add_upsample,
|
100 |
+
resnet_eps,
|
101 |
+
resnet_act_fn,
|
102 |
+
attn_num_head_channels,
|
103 |
+
resnet_groups=None,
|
104 |
+
cross_attention_dim=None,
|
105 |
+
dual_cross_attention=False,
|
106 |
+
use_linear_projection=False,
|
107 |
+
only_cross_attention=False,
|
108 |
+
upcast_attention=False,
|
109 |
+
resnet_time_scale_shift="default",
|
110 |
+
|
111 |
+
unet_use_cross_frame_attention=False,
|
112 |
+
unet_use_temporal_attention=False,
|
113 |
+
use_inflated_groupnorm=False,
|
114 |
+
|
115 |
+
use_motion_module=None,
|
116 |
+
motion_module_type=None,
|
117 |
+
motion_module_kwargs=None,
|
118 |
+
):
|
119 |
+
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
120 |
+
if up_block_type == "UpBlock3D":
|
121 |
+
return UpBlock3D(
|
122 |
+
num_layers=num_layers,
|
123 |
+
in_channels=in_channels,
|
124 |
+
out_channels=out_channels,
|
125 |
+
prev_output_channel=prev_output_channel,
|
126 |
+
temb_channels=temb_channels,
|
127 |
+
add_upsample=add_upsample,
|
128 |
+
resnet_eps=resnet_eps,
|
129 |
+
resnet_act_fn=resnet_act_fn,
|
130 |
+
resnet_groups=resnet_groups,
|
131 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
132 |
+
|
133 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
134 |
+
|
135 |
+
use_motion_module=use_motion_module,
|
136 |
+
motion_module_type=motion_module_type,
|
137 |
+
motion_module_kwargs=motion_module_kwargs,
|
138 |
+
)
|
139 |
+
elif up_block_type == "CrossAttnUpBlock3D":
|
140 |
+
if cross_attention_dim is None:
|
141 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
|
142 |
+
return CrossAttnUpBlock3D(
|
143 |
+
num_layers=num_layers,
|
144 |
+
in_channels=in_channels,
|
145 |
+
out_channels=out_channels,
|
146 |
+
prev_output_channel=prev_output_channel,
|
147 |
+
temb_channels=temb_channels,
|
148 |
+
add_upsample=add_upsample,
|
149 |
+
resnet_eps=resnet_eps,
|
150 |
+
resnet_act_fn=resnet_act_fn,
|
151 |
+
resnet_groups=resnet_groups,
|
152 |
+
cross_attention_dim=cross_attention_dim,
|
153 |
+
attn_num_head_channels=attn_num_head_channels,
|
154 |
+
dual_cross_attention=dual_cross_attention,
|
155 |
+
use_linear_projection=use_linear_projection,
|
156 |
+
only_cross_attention=only_cross_attention,
|
157 |
+
upcast_attention=upcast_attention,
|
158 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
159 |
+
|
160 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
161 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
162 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
163 |
+
|
164 |
+
use_motion_module=use_motion_module,
|
165 |
+
motion_module_type=motion_module_type,
|
166 |
+
motion_module_kwargs=motion_module_kwargs,
|
167 |
+
)
|
168 |
+
raise ValueError(f"{up_block_type} does not exist.")
|
169 |
+
|
170 |
+
|
171 |
+
class UNetMidBlock3DCrossAttn(nn.Module):
|
172 |
+
def __init__(
|
173 |
+
self,
|
174 |
+
in_channels: int,
|
175 |
+
temb_channels: int,
|
176 |
+
dropout: float = 0.0,
|
177 |
+
num_layers: int = 1,
|
178 |
+
resnet_eps: float = 1e-6,
|
179 |
+
resnet_time_scale_shift: str = "default",
|
180 |
+
resnet_act_fn: str = "swish",
|
181 |
+
resnet_groups: int = 32,
|
182 |
+
resnet_pre_norm: bool = True,
|
183 |
+
attn_num_head_channels=1,
|
184 |
+
output_scale_factor=1.0,
|
185 |
+
cross_attention_dim=1280,
|
186 |
+
dual_cross_attention=False,
|
187 |
+
use_linear_projection=False,
|
188 |
+
upcast_attention=False,
|
189 |
+
|
190 |
+
unet_use_cross_frame_attention=False,
|
191 |
+
unet_use_temporal_attention=False,
|
192 |
+
use_inflated_groupnorm=False,
|
193 |
+
|
194 |
+
use_motion_module=None,
|
195 |
+
|
196 |
+
motion_module_type=None,
|
197 |
+
motion_module_kwargs=None,
|
198 |
+
):
|
199 |
+
super().__init__()
|
200 |
+
|
201 |
+
self.has_cross_attention = True
|
202 |
+
self.attn_num_head_channels = attn_num_head_channels
|
203 |
+
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
204 |
+
|
205 |
+
# there is always at least one resnet
|
206 |
+
resnets = [
|
207 |
+
ResnetBlock3D(
|
208 |
+
in_channels=in_channels,
|
209 |
+
out_channels=in_channels,
|
210 |
+
temb_channels=temb_channels,
|
211 |
+
eps=resnet_eps,
|
212 |
+
groups=resnet_groups,
|
213 |
+
dropout=dropout,
|
214 |
+
time_embedding_norm=resnet_time_scale_shift,
|
215 |
+
non_linearity=resnet_act_fn,
|
216 |
+
output_scale_factor=output_scale_factor,
|
217 |
+
pre_norm=resnet_pre_norm,
|
218 |
+
|
219 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
220 |
+
)
|
221 |
+
]
|
222 |
+
attentions = []
|
223 |
+
motion_modules = []
|
224 |
+
|
225 |
+
for _ in range(num_layers):
|
226 |
+
if dual_cross_attention:
|
227 |
+
raise NotImplementedError
|
228 |
+
attentions.append(
|
229 |
+
Transformer3DModel(
|
230 |
+
attn_num_head_channels,
|
231 |
+
in_channels // attn_num_head_channels,
|
232 |
+
in_channels=in_channels,
|
233 |
+
num_layers=1,
|
234 |
+
cross_attention_dim=cross_attention_dim,
|
235 |
+
norm_num_groups=resnet_groups,
|
236 |
+
use_linear_projection=use_linear_projection,
|
237 |
+
upcast_attention=upcast_attention,
|
238 |
+
|
239 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
240 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
241 |
+
)
|
242 |
+
)
|
243 |
+
motion_modules.append(
|
244 |
+
get_motion_module(
|
245 |
+
in_channels=in_channels,
|
246 |
+
motion_module_type=motion_module_type,
|
247 |
+
motion_module_kwargs=motion_module_kwargs,
|
248 |
+
) if use_motion_module else None
|
249 |
+
)
|
250 |
+
resnets.append(
|
251 |
+
ResnetBlock3D(
|
252 |
+
in_channels=in_channels,
|
253 |
+
out_channels=in_channels,
|
254 |
+
temb_channels=temb_channels,
|
255 |
+
eps=resnet_eps,
|
256 |
+
groups=resnet_groups,
|
257 |
+
dropout=dropout,
|
258 |
+
time_embedding_norm=resnet_time_scale_shift,
|
259 |
+
non_linearity=resnet_act_fn,
|
260 |
+
output_scale_factor=output_scale_factor,
|
261 |
+
pre_norm=resnet_pre_norm,
|
262 |
+
|
263 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
264 |
+
)
|
265 |
+
)
|
266 |
+
|
267 |
+
self.attentions = nn.ModuleList(attentions)
|
268 |
+
self.resnets = nn.ModuleList(resnets)
|
269 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
270 |
+
|
271 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
272 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
273 |
+
for attn, resnet, motion_module in zip(self.attentions, self.resnets[1:], self.motion_modules):
|
274 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
275 |
+
hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
|
276 |
+
hidden_states = resnet(hidden_states, temb)
|
277 |
+
|
278 |
+
return hidden_states
|
279 |
+
|
280 |
+
|
281 |
+
class CrossAttnDownBlock3D(nn.Module):
|
282 |
+
def __init__(
|
283 |
+
self,
|
284 |
+
in_channels: int,
|
285 |
+
out_channels: int,
|
286 |
+
temb_channels: int,
|
287 |
+
dropout: float = 0.0,
|
288 |
+
num_layers: int = 1,
|
289 |
+
resnet_eps: float = 1e-6,
|
290 |
+
resnet_time_scale_shift: str = "default",
|
291 |
+
resnet_act_fn: str = "swish",
|
292 |
+
resnet_groups: int = 32,
|
293 |
+
resnet_pre_norm: bool = True,
|
294 |
+
attn_num_head_channels=1,
|
295 |
+
cross_attention_dim=1280,
|
296 |
+
output_scale_factor=1.0,
|
297 |
+
downsample_padding=1,
|
298 |
+
add_downsample=True,
|
299 |
+
dual_cross_attention=False,
|
300 |
+
use_linear_projection=False,
|
301 |
+
only_cross_attention=False,
|
302 |
+
upcast_attention=False,
|
303 |
+
|
304 |
+
unet_use_cross_frame_attention=False,
|
305 |
+
unet_use_temporal_attention=False,
|
306 |
+
use_inflated_groupnorm=False,
|
307 |
+
|
308 |
+
use_motion_module=None,
|
309 |
+
|
310 |
+
motion_module_type=None,
|
311 |
+
motion_module_kwargs=None,
|
312 |
+
):
|
313 |
+
super().__init__()
|
314 |
+
resnets = []
|
315 |
+
attentions = []
|
316 |
+
motion_modules = []
|
317 |
+
|
318 |
+
self.has_cross_attention = True
|
319 |
+
self.attn_num_head_channels = attn_num_head_channels
|
320 |
+
|
321 |
+
for i in range(num_layers):
|
322 |
+
in_channels = in_channels if i == 0 else out_channels
|
323 |
+
resnets.append(
|
324 |
+
ResnetBlock3D(
|
325 |
+
in_channels=in_channels,
|
326 |
+
out_channels=out_channels,
|
327 |
+
temb_channels=temb_channels,
|
328 |
+
eps=resnet_eps,
|
329 |
+
groups=resnet_groups,
|
330 |
+
dropout=dropout,
|
331 |
+
time_embedding_norm=resnet_time_scale_shift,
|
332 |
+
non_linearity=resnet_act_fn,
|
333 |
+
output_scale_factor=output_scale_factor,
|
334 |
+
pre_norm=resnet_pre_norm,
|
335 |
+
|
336 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
337 |
+
)
|
338 |
+
)
|
339 |
+
if dual_cross_attention:
|
340 |
+
raise NotImplementedError
|
341 |
+
attentions.append(
|
342 |
+
Transformer3DModel(
|
343 |
+
attn_num_head_channels,
|
344 |
+
out_channels // attn_num_head_channels,
|
345 |
+
in_channels=out_channels,
|
346 |
+
num_layers=1,
|
347 |
+
cross_attention_dim=cross_attention_dim,
|
348 |
+
norm_num_groups=resnet_groups,
|
349 |
+
use_linear_projection=use_linear_projection,
|
350 |
+
only_cross_attention=only_cross_attention,
|
351 |
+
upcast_attention=upcast_attention,
|
352 |
+
|
353 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
354 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
355 |
+
)
|
356 |
+
)
|
357 |
+
motion_modules.append(
|
358 |
+
get_motion_module(
|
359 |
+
in_channels=out_channels,
|
360 |
+
motion_module_type=motion_module_type,
|
361 |
+
motion_module_kwargs=motion_module_kwargs,
|
362 |
+
) if use_motion_module else None
|
363 |
+
)
|
364 |
+
|
365 |
+
self.attentions = nn.ModuleList(attentions)
|
366 |
+
self.resnets = nn.ModuleList(resnets)
|
367 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
368 |
+
|
369 |
+
if add_downsample:
|
370 |
+
self.downsamplers = nn.ModuleList(
|
371 |
+
[
|
372 |
+
Downsample3D(
|
373 |
+
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
374 |
+
)
|
375 |
+
]
|
376 |
+
)
|
377 |
+
else:
|
378 |
+
self.downsamplers = None
|
379 |
+
|
380 |
+
self.gradient_checkpointing = False
|
381 |
+
|
382 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
383 |
+
output_states = ()
|
384 |
+
|
385 |
+
for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
|
386 |
+
if self.training and self.gradient_checkpointing:
|
387 |
+
|
388 |
+
def create_custom_forward(module, return_dict=None):
|
389 |
+
def custom_forward(*inputs):
|
390 |
+
if return_dict is not None:
|
391 |
+
return module(*inputs, return_dict=return_dict)
|
392 |
+
else:
|
393 |
+
return module(*inputs)
|
394 |
+
|
395 |
+
return custom_forward
|
396 |
+
|
397 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
398 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
399 |
+
create_custom_forward(attn, return_dict=False),
|
400 |
+
hidden_states,
|
401 |
+
encoder_hidden_states,
|
402 |
+
)[0]
|
403 |
+
if motion_module is not None:
|
404 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
|
405 |
+
|
406 |
+
else:
|
407 |
+
hidden_states = resnet(hidden_states, temb)
|
408 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
409 |
+
|
410 |
+
# add motion module
|
411 |
+
hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
|
412 |
+
|
413 |
+
output_states += (hidden_states,)
|
414 |
+
|
415 |
+
if self.downsamplers is not None:
|
416 |
+
for downsampler in self.downsamplers:
|
417 |
+
hidden_states = downsampler(hidden_states)
|
418 |
+
|
419 |
+
output_states += (hidden_states,)
|
420 |
+
|
421 |
+
return hidden_states, output_states
|
422 |
+
|
423 |
+
|
424 |
+
class DownBlock3D(nn.Module):
|
425 |
+
def __init__(
|
426 |
+
self,
|
427 |
+
in_channels: int,
|
428 |
+
out_channels: int,
|
429 |
+
temb_channels: int,
|
430 |
+
dropout: float = 0.0,
|
431 |
+
num_layers: int = 1,
|
432 |
+
resnet_eps: float = 1e-6,
|
433 |
+
resnet_time_scale_shift: str = "default",
|
434 |
+
resnet_act_fn: str = "swish",
|
435 |
+
resnet_groups: int = 32,
|
436 |
+
resnet_pre_norm: bool = True,
|
437 |
+
output_scale_factor=1.0,
|
438 |
+
add_downsample=True,
|
439 |
+
downsample_padding=1,
|
440 |
+
|
441 |
+
use_inflated_groupnorm=False,
|
442 |
+
|
443 |
+
use_motion_module=None,
|
444 |
+
motion_module_type=None,
|
445 |
+
motion_module_kwargs=None,
|
446 |
+
):
|
447 |
+
super().__init__()
|
448 |
+
resnets = []
|
449 |
+
motion_modules = []
|
450 |
+
|
451 |
+
for i in range(num_layers):
|
452 |
+
in_channels = in_channels if i == 0 else out_channels
|
453 |
+
resnets.append(
|
454 |
+
ResnetBlock3D(
|
455 |
+
in_channels=in_channels,
|
456 |
+
out_channels=out_channels,
|
457 |
+
temb_channels=temb_channels,
|
458 |
+
eps=resnet_eps,
|
459 |
+
groups=resnet_groups,
|
460 |
+
dropout=dropout,
|
461 |
+
time_embedding_norm=resnet_time_scale_shift,
|
462 |
+
non_linearity=resnet_act_fn,
|
463 |
+
output_scale_factor=output_scale_factor,
|
464 |
+
pre_norm=resnet_pre_norm,
|
465 |
+
|
466 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
467 |
+
)
|
468 |
+
)
|
469 |
+
motion_modules.append(
|
470 |
+
get_motion_module(
|
471 |
+
in_channels=out_channels,
|
472 |
+
motion_module_type=motion_module_type,
|
473 |
+
motion_module_kwargs=motion_module_kwargs,
|
474 |
+
) if use_motion_module else None
|
475 |
+
)
|
476 |
+
|
477 |
+
self.resnets = nn.ModuleList(resnets)
|
478 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
479 |
+
|
480 |
+
if add_downsample:
|
481 |
+
self.downsamplers = nn.ModuleList(
|
482 |
+
[
|
483 |
+
Downsample3D(
|
484 |
+
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
485 |
+
)
|
486 |
+
]
|
487 |
+
)
|
488 |
+
else:
|
489 |
+
self.downsamplers = None
|
490 |
+
|
491 |
+
self.gradient_checkpointing = False
|
492 |
+
|
493 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
494 |
+
output_states = ()
|
495 |
+
|
496 |
+
for resnet, motion_module in zip(self.resnets, self.motion_modules):
|
497 |
+
if self.training and self.gradient_checkpointing:
|
498 |
+
def create_custom_forward(module):
|
499 |
+
def custom_forward(*inputs):
|
500 |
+
return module(*inputs)
|
501 |
+
|
502 |
+
return custom_forward
|
503 |
+
|
504 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
505 |
+
if motion_module is not None:
|
506 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
|
507 |
+
else:
|
508 |
+
hidden_states = resnet(hidden_states, temb)
|
509 |
+
|
510 |
+
# add motion module
|
511 |
+
hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
|
512 |
+
|
513 |
+
output_states += (hidden_states,)
|
514 |
+
|
515 |
+
if self.downsamplers is not None:
|
516 |
+
for downsampler in self.downsamplers:
|
517 |
+
hidden_states = downsampler(hidden_states)
|
518 |
+
|
519 |
+
output_states += (hidden_states,)
|
520 |
+
|
521 |
+
return hidden_states, output_states
|
522 |
+
|
523 |
+
|
524 |
+
class CrossAttnUpBlock3D(nn.Module):
|
525 |
+
def __init__(
|
526 |
+
self,
|
527 |
+
in_channels: int,
|
528 |
+
out_channels: int,
|
529 |
+
prev_output_channel: int,
|
530 |
+
temb_channels: int,
|
531 |
+
dropout: float = 0.0,
|
532 |
+
num_layers: int = 1,
|
533 |
+
resnet_eps: float = 1e-6,
|
534 |
+
resnet_time_scale_shift: str = "default",
|
535 |
+
resnet_act_fn: str = "swish",
|
536 |
+
resnet_groups: int = 32,
|
537 |
+
resnet_pre_norm: bool = True,
|
538 |
+
attn_num_head_channels=1,
|
539 |
+
cross_attention_dim=1280,
|
540 |
+
output_scale_factor=1.0,
|
541 |
+
add_upsample=True,
|
542 |
+
dual_cross_attention=False,
|
543 |
+
use_linear_projection=False,
|
544 |
+
only_cross_attention=False,
|
545 |
+
upcast_attention=False,
|
546 |
+
|
547 |
+
unet_use_cross_frame_attention=False,
|
548 |
+
unet_use_temporal_attention=False,
|
549 |
+
use_inflated_groupnorm=False,
|
550 |
+
|
551 |
+
use_motion_module=None,
|
552 |
+
|
553 |
+
motion_module_type=None,
|
554 |
+
motion_module_kwargs=None,
|
555 |
+
):
|
556 |
+
super().__init__()
|
557 |
+
resnets = []
|
558 |
+
attentions = []
|
559 |
+
motion_modules = []
|
560 |
+
|
561 |
+
self.has_cross_attention = True
|
562 |
+
self.attn_num_head_channels = attn_num_head_channels
|
563 |
+
|
564 |
+
for i in range(num_layers):
|
565 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
566 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
567 |
+
|
568 |
+
resnets.append(
|
569 |
+
ResnetBlock3D(
|
570 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
571 |
+
out_channels=out_channels,
|
572 |
+
temb_channels=temb_channels,
|
573 |
+
eps=resnet_eps,
|
574 |
+
groups=resnet_groups,
|
575 |
+
dropout=dropout,
|
576 |
+
time_embedding_norm=resnet_time_scale_shift,
|
577 |
+
non_linearity=resnet_act_fn,
|
578 |
+
output_scale_factor=output_scale_factor,
|
579 |
+
pre_norm=resnet_pre_norm,
|
580 |
+
|
581 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
582 |
+
)
|
583 |
+
)
|
584 |
+
if dual_cross_attention:
|
585 |
+
raise NotImplementedError
|
586 |
+
attentions.append(
|
587 |
+
Transformer3DModel(
|
588 |
+
attn_num_head_channels,
|
589 |
+
out_channels // attn_num_head_channels,
|
590 |
+
in_channels=out_channels,
|
591 |
+
num_layers=1,
|
592 |
+
cross_attention_dim=cross_attention_dim,
|
593 |
+
norm_num_groups=resnet_groups,
|
594 |
+
use_linear_projection=use_linear_projection,
|
595 |
+
only_cross_attention=only_cross_attention,
|
596 |
+
upcast_attention=upcast_attention,
|
597 |
+
|
598 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
599 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
600 |
+
)
|
601 |
+
)
|
602 |
+
motion_modules.append(
|
603 |
+
get_motion_module(
|
604 |
+
in_channels=out_channels,
|
605 |
+
motion_module_type=motion_module_type,
|
606 |
+
motion_module_kwargs=motion_module_kwargs,
|
607 |
+
) if use_motion_module else None
|
608 |
+
)
|
609 |
+
|
610 |
+
self.attentions = nn.ModuleList(attentions)
|
611 |
+
self.resnets = nn.ModuleList(resnets)
|
612 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
613 |
+
|
614 |
+
if add_upsample:
|
615 |
+
self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
|
616 |
+
else:
|
617 |
+
self.upsamplers = None
|
618 |
+
|
619 |
+
self.gradient_checkpointing = False
|
620 |
+
|
621 |
+
def forward(
|
622 |
+
self,
|
623 |
+
hidden_states,
|
624 |
+
res_hidden_states_tuple,
|
625 |
+
temb=None,
|
626 |
+
encoder_hidden_states=None,
|
627 |
+
upsample_size=None,
|
628 |
+
attention_mask=None,
|
629 |
+
):
|
630 |
+
for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
|
631 |
+
# pop res hidden states
|
632 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
633 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
634 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
635 |
+
|
636 |
+
if self.training and self.gradient_checkpointing:
|
637 |
+
|
638 |
+
def create_custom_forward(module, return_dict=None):
|
639 |
+
def custom_forward(*inputs):
|
640 |
+
if return_dict is not None:
|
641 |
+
return module(*inputs, return_dict=return_dict)
|
642 |
+
else:
|
643 |
+
return module(*inputs)
|
644 |
+
|
645 |
+
return custom_forward
|
646 |
+
|
647 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
648 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
649 |
+
create_custom_forward(attn, return_dict=False),
|
650 |
+
hidden_states,
|
651 |
+
encoder_hidden_states,
|
652 |
+
)[0]
|
653 |
+
if motion_module is not None:
|
654 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
|
655 |
+
|
656 |
+
else:
|
657 |
+
hidden_states = resnet(hidden_states, temb)
|
658 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
659 |
+
|
660 |
+
# add motion module
|
661 |
+
hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
|
662 |
+
|
663 |
+
if self.upsamplers is not None:
|
664 |
+
for upsampler in self.upsamplers:
|
665 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
666 |
+
|
667 |
+
return hidden_states
|
668 |
+
|
669 |
+
|
670 |
+
class UpBlock3D(nn.Module):
|
671 |
+
def __init__(
|
672 |
+
self,
|
673 |
+
in_channels: int,
|
674 |
+
prev_output_channel: int,
|
675 |
+
out_channels: int,
|
676 |
+
temb_channels: int,
|
677 |
+
dropout: float = 0.0,
|
678 |
+
num_layers: int = 1,
|
679 |
+
resnet_eps: float = 1e-6,
|
680 |
+
resnet_time_scale_shift: str = "default",
|
681 |
+
resnet_act_fn: str = "swish",
|
682 |
+
resnet_groups: int = 32,
|
683 |
+
resnet_pre_norm: bool = True,
|
684 |
+
output_scale_factor=1.0,
|
685 |
+
add_upsample=True,
|
686 |
+
|
687 |
+
use_inflated_groupnorm=False,
|
688 |
+
|
689 |
+
use_motion_module=None,
|
690 |
+
motion_module_type=None,
|
691 |
+
motion_module_kwargs=None,
|
692 |
+
):
|
693 |
+
super().__init__()
|
694 |
+
resnets = []
|
695 |
+
motion_modules = []
|
696 |
+
|
697 |
+
for i in range(num_layers):
|
698 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
699 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
700 |
+
|
701 |
+
resnets.append(
|
702 |
+
ResnetBlock3D(
|
703 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
704 |
+
out_channels=out_channels,
|
705 |
+
temb_channels=temb_channels,
|
706 |
+
eps=resnet_eps,
|
707 |
+
groups=resnet_groups,
|
708 |
+
dropout=dropout,
|
709 |
+
time_embedding_norm=resnet_time_scale_shift,
|
710 |
+
non_linearity=resnet_act_fn,
|
711 |
+
output_scale_factor=output_scale_factor,
|
712 |
+
pre_norm=resnet_pre_norm,
|
713 |
+
|
714 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
715 |
+
)
|
716 |
+
)
|
717 |
+
motion_modules.append(
|
718 |
+
get_motion_module(
|
719 |
+
in_channels=out_channels,
|
720 |
+
motion_module_type=motion_module_type,
|
721 |
+
motion_module_kwargs=motion_module_kwargs,
|
722 |
+
) if use_motion_module else None
|
723 |
+
)
|
724 |
+
|
725 |
+
self.resnets = nn.ModuleList(resnets)
|
726 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
727 |
+
|
728 |
+
if add_upsample:
|
729 |
+
self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
|
730 |
+
else:
|
731 |
+
self.upsamplers = None
|
732 |
+
|
733 |
+
self.gradient_checkpointing = False
|
734 |
+
|
735 |
+
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, encoder_hidden_states=None,):
|
736 |
+
for resnet, motion_module in zip(self.resnets, self.motion_modules):
|
737 |
+
# pop res hidden states
|
738 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
739 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
740 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
741 |
+
|
742 |
+
if self.training and self.gradient_checkpointing:
|
743 |
+
def create_custom_forward(module):
|
744 |
+
def custom_forward(*inputs):
|
745 |
+
return module(*inputs)
|
746 |
+
|
747 |
+
return custom_forward
|
748 |
+
|
749 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
750 |
+
if motion_module is not None:
|
751 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
|
752 |
+
else:
|
753 |
+
hidden_states = resnet(hidden_states, temb)
|
754 |
+
hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
|
755 |
+
|
756 |
+
if self.upsamplers is not None:
|
757 |
+
for upsampler in self.upsamplers:
|
758 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
759 |
+
|
760 |
+
return hidden_states
|
motionclone/pipelines/__pycache__/pipeline_animation.cpython-310.pyc
ADDED
Binary file (13.7 kB). View file
|
|
motionclone/pipelines/__pycache__/pipeline_animation.cpython-38.pyc
ADDED
Binary file (13.4 kB). View file
|
|