svjack commited on
Commit
ce68674
·
verified ·
1 Parent(s): d102b36

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +22 -0
  2. README.md +168 -12
  3. __assets__/feature_visualization.png +3 -0
  4. __assets__/pipeline.png +3 -0
  5. __assets__/teaser.gif +3 -0
  6. __assets__/teaser.mp4 +3 -0
  7. condition_images/rgb/dog_on_grass.png +3 -0
  8. condition_images/scribble/lion_forest.png +0 -0
  9. configs/i2v_rgb.jsonl +1 -0
  10. configs/i2v_rgb.yaml +20 -0
  11. configs/i2v_sketch.jsonl +1 -0
  12. configs/i2v_sketch.yaml +20 -0
  13. configs/model_config/inference-v1.yaml +25 -0
  14. configs/model_config/inference-v2.yaml +24 -0
  15. configs/model_config/inference-v3.yaml +22 -0
  16. configs/model_config/model_config copy.yaml +22 -0
  17. configs/model_config/model_config.yaml +21 -0
  18. configs/model_config/model_config_public.yaml +25 -0
  19. configs/sparsectrl/image_condition.yaml +17 -0
  20. configs/sparsectrl/latent_condition.yaml +17 -0
  21. configs/t2v_camera.jsonl +12 -0
  22. configs/t2v_camera.yaml +19 -0
  23. configs/t2v_object.jsonl +6 -0
  24. configs/t2v_object.yaml +19 -0
  25. environment.yaml +25 -0
  26. generated_videos/camera_zoom_out_Dog,_lying_on_the_grass76739_76739.mp4 +3 -0
  27. generated_videos/inference_config.json +21 -0
  28. generated_videos/sample_white_tiger_Lion,_walks_in_the_forest76739_76739.mp4 +3 -0
  29. i2v_video_sample.py +157 -0
  30. models/Motion_Module/Put motion module checkpoints here.txt +0 -0
  31. motionclone/models/__pycache__/attention.cpython-310.pyc +0 -0
  32. motionclone/models/__pycache__/attention.cpython-38.pyc +0 -0
  33. motionclone/models/__pycache__/motion_module.cpython-310.pyc +0 -0
  34. motionclone/models/__pycache__/motion_module.cpython-38.pyc +0 -0
  35. motionclone/models/__pycache__/resnet.cpython-310.pyc +0 -0
  36. motionclone/models/__pycache__/resnet.cpython-38.pyc +0 -0
  37. motionclone/models/__pycache__/sparse_controlnet.cpython-38.pyc +0 -0
  38. motionclone/models/__pycache__/unet.cpython-310.pyc +0 -0
  39. motionclone/models/__pycache__/unet.cpython-38.pyc +0 -0
  40. motionclone/models/__pycache__/unet_blocks.cpython-310.pyc +0 -0
  41. motionclone/models/__pycache__/unet_blocks.cpython-38.pyc +0 -0
  42. motionclone/models/attention.py +611 -0
  43. motionclone/models/motion_module.py +347 -0
  44. motionclone/models/resnet.py +218 -0
  45. motionclone/models/scheduler.py +155 -0
  46. motionclone/models/sparse_controlnet.py +593 -0
  47. motionclone/models/unet.py +515 -0
  48. motionclone/models/unet_blocks.py +760 -0
  49. motionclone/pipelines/__pycache__/pipeline_animation.cpython-310.pyc +0 -0
  50. motionclone/pipelines/__pycache__/pipeline_animation.cpython-38.pyc +0 -0
.gitattributes CHANGED
@@ -33,3 +33,25 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ __assets__/feature_visualization.png filter=lfs diff=lfs merge=lfs -text
37
+ __assets__/pipeline.png filter=lfs diff=lfs merge=lfs -text
38
+ __assets__/teaser.gif filter=lfs diff=lfs merge=lfs -text
39
+ __assets__/teaser.mp4 filter=lfs diff=lfs merge=lfs -text
40
+ condition_images/rgb/dog_on_grass.png filter=lfs diff=lfs merge=lfs -text
41
+ generated_videos/camera_zoom_out_Dog,_lying_on_the_grass76739_76739.mp4 filter=lfs diff=lfs merge=lfs -text
42
+ generated_videos/sample_white_tiger_Lion,_walks_in_the_forest76739_76739.mp4 filter=lfs diff=lfs merge=lfs -text
43
+ reference_videos/camera_1.mp4 filter=lfs diff=lfs merge=lfs -text
44
+ reference_videos/camera_pan_down.mp4 filter=lfs diff=lfs merge=lfs -text
45
+ reference_videos/camera_pan_up.mp4 filter=lfs diff=lfs merge=lfs -text
46
+ reference_videos/camera_translation_1.mp4 filter=lfs diff=lfs merge=lfs -text
47
+ reference_videos/camera_translation_2.mp4 filter=lfs diff=lfs merge=lfs -text
48
+ reference_videos/camera_zoom_in.mp4 filter=lfs diff=lfs merge=lfs -text
49
+ reference_videos/camera_zoom_out.mp4 filter=lfs diff=lfs merge=lfs -text
50
+ reference_videos/sample_astronaut.mp4 filter=lfs diff=lfs merge=lfs -text
51
+ reference_videos/sample_blackswan.mp4 filter=lfs diff=lfs merge=lfs -text
52
+ reference_videos/sample_cat.mp4 filter=lfs diff=lfs merge=lfs -text
53
+ reference_videos/sample_cow.mp4 filter=lfs diff=lfs merge=lfs -text
54
+ reference_videos/sample_fox.mp4 filter=lfs diff=lfs merge=lfs -text
55
+ reference_videos/sample_leaves.mp4 filter=lfs diff=lfs merge=lfs -text
56
+ reference_videos/sample_white_tiger.mp4 filter=lfs diff=lfs merge=lfs -text
57
+ reference_videos/sample_wolf.mp4 filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,168 @@
1
- ---
2
- title: MotionClone
3
- emoji: 🐠
4
- colorFrom: red
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 5.16.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MotionClone
2
+ This repository is the official implementation of [MotionClone](https://arxiv.org/abs/2406.05338). It is a **training-free framework** that enables motion cloning from a reference video for controllable video generation, **without cumbersome video inversion processes**.
3
+ <details><summary>Click for the full abstract of MotionClone</summary>
4
+
5
+ > Motion-based controllable video generation offers the potential for creating captivating visual content. Existing methods typically necessitate model training to encode particular motion cues or incorporate fine-tuning to inject certain motion patterns, resulting in limited flexibility and generalization.
6
+ In this work, we propose **MotionClone** a training-free framework that enables motion cloning from reference videos to versatile motion-controlled video generation, including text-to-video and image-to-video. Based on the observation that the dominant components in temporal-attention maps drive motion synthesis, while the rest mainly capture noisy or very subtle motions, MotionClone utilizes sparse temporal attention weights as motion representations for motion guidance, facilitating diverse motion transfer across varying scenarios. Meanwhile, MotionClone allows for the direct extraction of motion representation through a single denoising step, bypassing the cumbersome inversion processes and thus promoting both efficiency and flexibility.
7
+ Extensive experiments demonstrate that MotionClone exhibits proficiency in both global camera motion and local object motion, with notable superiority in terms of motion fidelity, textual alignment, and temporal consistency.
8
+ </details>
9
+
10
+ **[MotionClone: Training-Free Motion Cloning for Controllable Video Generation](https://arxiv.org/abs/2406.05338)**
11
+ </br>
12
+ [Pengyang Ling*](https://github.com/LPengYang/),
13
+ [Jiazi Bu*](https://github.com/Bujiazi/),
14
+ [Pan Zhang<sup>†</sup>](https://panzhang0212.github.io/),
15
+ [Xiaoyi Dong](https://scholar.google.com/citations?user=FscToE0AAAAJ&hl=en/),
16
+ [Yuhang Zang](https://yuhangzang.github.io/),
17
+ [Tong Wu](https://wutong16.github.io/),
18
+ [Huaian Chen](https://scholar.google.com.hk/citations?hl=zh-CN&user=D6ol9XkAAAAJ),
19
+ [Jiaqi Wang](https://myownskyw7.github.io/),
20
+ [Yi Jin<sup>†</sup>](https://scholar.google.ca/citations?hl=en&user=mAJ1dCYAAAAJ)
21
+ (*Equal Contribution)(<sup>†</sup>Corresponding Author)
22
+
23
+ <!-- [Arxiv Report](https://arxiv.org/abs/2307.04725) | [Project Page](https://animatediff.github.io/) -->
24
+ [![arXiv](https://img.shields.io/badge/arXiv-2406.05338-b31b1b.svg)](https://arxiv.org/abs/2406.05338)
25
+ [![Project Page](https://img.shields.io/badge/Project-Website-green)](https://bujiazi.github.io/motionclone.github.io/)
26
+ ![](https://img.shields.io/github/stars/LPengYang/MotionClone?style=social)
27
+ <!-- [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://bujiazi.github.io/motionclone.github.io/) -->
28
+ <!-- [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-yellow)](https://bujiazi.github.io/motionclone.github.io/) -->
29
+
30
+ ## Demo
31
+ [![]](https://github.com/user-attachments/assets/d1f1c753-f192-455b-9779-94c925e51aaa)
32
+
33
+ ```bash
34
+ sudo apt-get update && sudo apt-get install git-lfs ffmpeg cbm
35
+
36
+ conda create --name py310 python=3.10
37
+ conda activate py310
38
+ pip install ipykernel
39
+ python -m ipykernel install --user --name py310 --display-name "py310"
40
+
41
+ git clone https://github.com/svjack/MotionClone && cd MotionClone
42
+ pip install -r requirements.txt
43
+
44
+ mkdir -p models
45
+ git clone https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5 models/StableDiffusion/
46
+
47
+ mkdir -p models/DreamBooth_LoRA
48
+ wget https://huggingface.co/svjack/Realistic-Vision-V6.0-B1/resolve/main/realisticVisionV60B1_v51VAE.safetensors -O models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors
49
+
50
+ mkdir -p models/Motion_Module
51
+ wget https://huggingface.co/guoyww/animatediff/resolve/main/v3_sd15_mm.ckpt -O models/Motion_Module/v3_sd15_mm.ckpt
52
+ wget https://huggingface.co/guoyww/animatediff/resolve/main/v3_sd15_adapter.ckpt -O models/Motion_Module/v3_sd15_adapter.ckpt
53
+
54
+ mkdir -p models/SparseCtrl
55
+ wget https://huggingface.co/guoyww/animatediff/resolve/main/v3_sd15_sparsectrl_rgb.ckpt -O models/SparseCtrl/v3_sd15_sparsectrl_rgb.ckpt
56
+ wget https://huggingface.co/guoyww/animatediff/resolve/main/v3_sd15_sparsectrl_scribble.ckpt -O models/SparseCtrl/v3_sd15_sparsectrl_scribble.ckpt
57
+ ```
58
+
59
+ ## 🖋 News
60
+ - The latest version of our paper (**v4**) is available on arXiv! (10.08)
61
+ - The latest version of our paper (**v3**) is available on arXiv! (7.2)
62
+ - Code released! (6.29)
63
+
64
+ ## 🏗️ Todo
65
+ - [x] We have updated the latest version of MotionCloning, which performs motion transfer **without video inversion** and supports **image-to-video and sketch-to-video**.
66
+ - [x] Release the MotionClone code (We have released **the first version** of our code and will continue to optimize it. We welcome any questions or issues you may have and will address them promptly.)
67
+ - [x] Release paper
68
+
69
+ ## 📚 Gallery
70
+ We show more results in the [Project Page](https://bujiazi.github.io/motionclone.github.io/).
71
+
72
+ ## 🚀 Method Overview
73
+ ### Feature visualization
74
+ <div align="center">
75
+ <img src='__assets__/feature_visualization.png'/>
76
+ </div>
77
+
78
+ ### Pipeline
79
+ <div align="center">
80
+ <img src='__assets__/pipeline.png'/>
81
+ </div>
82
+
83
+ MotionClone utilizes sparse temporal attention weights as motion representations for motion guidance, facilitating diverse motion transfer across varying scenarios. Meanwhile, MotionClone allows for the direct extraction of motion representation through a single denoising step, bypassing the cumbersome inversion processes and thus promoting both efficiency and flexibility.
84
+
85
+ ## 🔧 Installations (python==3.11.3 recommended)
86
+
87
+ ### Setup repository and conda environment
88
+
89
+ ```
90
+ git clone https://github.com/Bujiazi/MotionClone.git
91
+ cd MotionClone
92
+
93
+ conda env create -f environment.yaml
94
+ conda activate motionclone
95
+ ```
96
+
97
+ ## 🔑 Pretrained Model Preparations
98
+
99
+ ### Download Stable Diffusion V1.5
100
+
101
+ ```
102
+ git lfs install
103
+ git clone https://huggingface.co/runwayml/stable-diffusion-v1-5 models/StableDiffusion/
104
+ ```
105
+
106
+ After downloading Stable Diffusion, save them to `models/StableDiffusion`.
107
+
108
+ ### Prepare Community Models
109
+
110
+ Manually download the community `.safetensors` models from [RealisticVision V5.1](https://civitai.com/models/4201?modelVersionId=130072) and save them to `models/DreamBooth_LoRA`.
111
+
112
+ ### Prepare AnimateDiff Motion Modules
113
+
114
+ Manually download the AnimateDiff modules from [AnimateDiff](https://github.com/guoyww/AnimateDiff), we recommend [`v3_adapter_sd_v15.ckpt`](https://huggingface.co/guoyww/animatediff/blob/main/v3_sd15_adapter.ckpt) and [`v3_sd15_mm.ckpt.ckpt`](https://huggingface.co/guoyww/animatediff/blob/main/v3_sd15_mm.ckpt). Save the modules to `models/Motion_Module`.
115
+
116
+ ### Prepare SparseCtrl for image-to-video and sketch-to-video
117
+ Manually download "v3_sd15_sparsectrl_rgb.ckpt" and "v3_sd15_sparsectrl_scribble.ckpt" from [AnimateDiff](https://huggingface.co/guoyww/animatediff/tree/main). Save the modules to `models/SparseCtrl`.
118
+
119
+ ## 🎈 Quick Start
120
+
121
+ ### Perform Text-to-video generation with customized camera motion
122
+ ```
123
+ python t2v_video_sample.py --inference_config "configs/t2v_camera.yaml" --examples "configs/t2v_camera.jsonl"
124
+ ```
125
+
126
+ https://github.com/user-attachments/assets/2656a49a-c57d-4f89-bc65-5ec09ac037ea
127
+
128
+
129
+
130
+
131
+
132
+ ### Perform Text-to-video generation with customized object motion
133
+ ```
134
+ python t2v_video_sample.py --inference_config "configs/t2v_object.yaml" --examples "configs/t2v_object.jsonl"
135
+ ```
136
+ ### Combine motion cloning with sketch-to-video
137
+ ```
138
+ python i2v_video_sample.py --inference_config "configs/i2v_sketch.yaml" --examples "configs/i2v_sketch.jsonl"
139
+ ```
140
+ ### Combine motion cloning with image-to-video
141
+ ```
142
+ python i2v_video_sample.py --inference_config "configs/i2v_rgb.yaml" --examples "configs/i2v_rgb.jsonl"
143
+ ```
144
+
145
+
146
+ ## 📎 Citation
147
+
148
+ If you find this work helpful, please cite the following paper:
149
+
150
+ ```
151
+ @article{ling2024motionclone,
152
+ title={MotionClone: Training-Free Motion Cloning for Controllable Video Generation},
153
+ author={Ling, Pengyang and Bu, Jiazi and Zhang, Pan and Dong, Xiaoyi and Zang, Yuhang and Wu, Tong and Chen, Huaian and Wang, Jiaqi and Jin, Yi},
154
+ journal={arXiv preprint arXiv:2406.05338},
155
+ year={2024}
156
+ }
157
+ ```
158
+
159
+ ## 📣 Disclaimer
160
+
161
+ This is official code of MotionClone.
162
+ All the copyrights of the demo images and audio are from community users.
163
+ Feel free to contact us if you would like remove them.
164
+
165
+ ## 💞 Acknowledgements
166
+ The code is built upon the below repositories, we thank all the contributors for open-sourcing.
167
+ * [AnimateDiff](https://github.com/guoyww/AnimateDiff)
168
+ * [FreeControl](https://github.com/genforce/freecontrol)
__assets__/feature_visualization.png ADDED

Git LFS Details

  • SHA256: 4c0891fbfe56b1650d6c65dac700d02faee46cff0cc56515c8a23a8be0c9a46b
  • Pointer size: 131 Bytes
  • Size of remote file: 944 kB
__assets__/pipeline.png ADDED

Git LFS Details

  • SHA256: bc9926f5f4a746475cb1963a4e908671db82d0cc630c8a5e9cd43f78885fd82d
  • Pointer size: 132 Bytes
  • Size of remote file: 1.01 MB
__assets__/teaser.gif ADDED

Git LFS Details

  • SHA256: 2ee4ff21495ae52ff2c9f4ff9ad5406c3f4445633a437664f9cc20277460ea6f
  • Pointer size: 133 Bytes
  • Size of remote file: 14.6 MB
__assets__/teaser.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:201747f42691e708b9efe48ea054961fd82cf54b83ac43e0d97a43f81779c00b
3
+ size 4957080
condition_images/rgb/dog_on_grass.png ADDED

Git LFS Details

  • SHA256: 1b3ead35573919274f59d763c5085608ca78a993bf508448ca22af31ebcab113
  • Pointer size: 132 Bytes
  • Size of remote file: 1.51 MB
condition_images/scribble/lion_forest.png ADDED
configs/i2v_rgb.jsonl ADDED
@@ -0,0 +1 @@
 
 
1
+ {"video_path":"reference_videos/camera_zoom_out.mp4", "condition_image_paths":["condition_images/rgb/dog_on_grass.png"], "new_prompt": "Dog, lying on the grass"}
configs/i2v_rgb.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ motion_module: "models/Motion_Module/v3_sd15_mm.ckpt"
2
+ dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors"
3
+ model_config: "configs/model_config/model_config.yaml"
4
+ controlnet_path: "models/SparseCtrl/v3_sd15_sparsectrl_rgb.ckpt"
5
+ controlnet_config: "configs/sparsectrl/latent_condition.yaml"
6
+ adapter_lora_path: "models/Motion_Module/v3_sd15_adapter.ckpt"
7
+
8
+ cfg_scale: 7.5 # in default realistic classifer-free guidance
9
+ negative_prompt: "ugly, deformed, noisy, blurry, distorted, out of focus, bad anatomy, extra limbs, poorly drawn face, poorly drawn hands, missing fingers"
10
+
11
+ inference_steps: 100 # the total denosing step for inference
12
+ guidance_scale: 0.3 # which scale of time step to end guidance
13
+ guidance_steps: 40 # the step for guidance in inference, no more than 1000*guidance_scale, the remaining steps (inference_steps-guidance_steps) is performed without gudiance
14
+ warm_up_steps: 10
15
+ cool_up_steps: 10
16
+
17
+ motion_guidance_weight: 2000
18
+ motion_guidance_blocks: ['up_blocks.1']
19
+
20
+ add_noise_step: 400
configs/i2v_sketch.jsonl ADDED
@@ -0,0 +1 @@
 
 
1
+ {"video_path":"reference_videos/sample_white_tiger.mp4", "condition_image_paths":["condition_images/scribble/lion_forest.png"], "new_prompt": "Lion, walks in the forest"}
configs/i2v_sketch.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ motion_module: "models/Motion_Module/v3_sd15_mm.ckpt"
2
+ dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors"
3
+ model_config: "configs/model_config/model_config.yaml"
4
+ controlnet_config: "configs/sparsectrl/image_condition.yaml"
5
+ controlnet_path: "models/SparseCtrl/v3_sd15_sparsectrl_scribble.ckpt"
6
+ adapter_lora_path: "models/Motion_Module/v3_sd15_adapter.ckpt"
7
+
8
+ cfg_scale: 7.5 # in default realistic classifer-free guidance
9
+ negative_prompt: "ugly, deformed, noisy, blurry, distorted, out of focus, bad anatomy, extra limbs, poorly drawn face, poorly drawn hands, missing fingers"
10
+
11
+ inference_steps: 200 # the total denosing step for inference
12
+ guidance_scale: 0.4 # which scale of time step to end guidance
13
+ guidance_steps: 120 # the step for guidance in inference, no more than 1000*guidance_scale, the remaining steps (inference_steps-guidance_steps) is performed without gudiance
14
+ warm_up_steps: 10
15
+ cool_up_steps: 10
16
+
17
+ motion_guidance_weight: 2000
18
+ motion_guidance_blocks: ['up_blocks.1']
19
+
20
+ add_noise_step: 400
configs/model_config/inference-v1.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ unet_additional_kwargs:
2
+ use_inflated_groupnorm: true # from config v3
3
+
4
+
5
+ use_motion_module: true
6
+ motion_module_resolutions: [1,2,4,8]
7
+ motion_module_mid_block: false
8
+ motion_module_decoder_only: false
9
+ motion_module_type: "Vanilla"
10
+
11
+ motion_module_kwargs:
12
+ num_attention_heads: 8
13
+ num_transformer_block: 1
14
+ attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
15
+ temporal_position_encoding: true
16
+ temporal_position_encoding_max_len: 32
17
+ temporal_attention_dim_div: 1
18
+ zero_initialize: true # from config v3
19
+
20
+ noise_scheduler_kwargs:
21
+ beta_start: 0.00085
22
+ beta_end: 0.012
23
+ beta_schedule: "linear"
24
+ steps_offset: 1
25
+ clip_sample: False
configs/model_config/inference-v2.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ unet_additional_kwargs:
2
+ use_inflated_groupnorm: true
3
+ unet_use_cross_frame_attention: false
4
+ unet_use_temporal_attention: false
5
+ use_motion_module: true
6
+ motion_module_resolutions: [1,2,4,8]
7
+ motion_module_mid_block: true
8
+ motion_module_decoder_only: false
9
+ motion_module_type: "Vanilla"
10
+
11
+ motion_module_kwargs:
12
+ num_attention_heads: 8
13
+ num_transformer_block: 1
14
+ attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
15
+ temporal_position_encoding: true
16
+ temporal_position_encoding_max_len: 32
17
+ temporal_attention_dim_div: 1
18
+
19
+ noise_scheduler_kwargs:
20
+ beta_start: 0.00085
21
+ beta_end: 0.012
22
+ beta_schedule: "linear"
23
+ steps_offset: 1
24
+ clip_sample: False
configs/model_config/inference-v3.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ unet_additional_kwargs:
2
+ use_inflated_groupnorm: true
3
+ use_motion_module: true
4
+ motion_module_resolutions: [1,2,4,8]
5
+ motion_module_mid_block: false
6
+ motion_module_type: Vanilla
7
+
8
+ motion_module_kwargs:
9
+ num_attention_heads: 8
10
+ num_transformer_block: 1
11
+ attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
12
+ temporal_position_encoding: true
13
+ temporal_position_encoding_max_len: 32
14
+ temporal_attention_dim_div: 1
15
+ zero_initialize: true
16
+
17
+ noise_scheduler_kwargs:
18
+ beta_start: 0.00085
19
+ beta_end: 0.012
20
+ beta_schedule: "linear"
21
+ steps_offset: 1
22
+ clip_sample: False
configs/model_config/model_config copy.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ unet_additional_kwargs:
2
+ use_inflated_groupnorm: true # from config v3
3
+ use_motion_module: true
4
+ motion_module_resolutions: [1,2,4,8]
5
+ motion_module_mid_block: false
6
+ motion_module_type: "Vanilla"
7
+
8
+ motion_module_kwargs:
9
+ num_attention_heads: 8
10
+ num_transformer_block: 1
11
+ attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
12
+ temporal_position_encoding: true
13
+ temporal_position_encoding_max_len: 32
14
+ temporal_attention_dim_div: 1
15
+ zero_initialize: true # from config v3
16
+
17
+ noise_scheduler_kwargs:
18
+ beta_start: 0.00085
19
+ beta_end: 0.012
20
+ beta_schedule: "linear"
21
+ steps_offset: 1
22
+ clip_sample: False
configs/model_config/model_config.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ unet_additional_kwargs:
2
+ use_inflated_groupnorm: true
3
+ use_motion_module: true
4
+ motion_module_resolutions: [1,2,4,8]
5
+ motion_module_mid_block: false
6
+ motion_module_type: "Vanilla"
7
+
8
+ motion_module_kwargs:
9
+ num_attention_heads: 8
10
+ num_transformer_block: 1
11
+ attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
12
+ temporal_position_encoding: true
13
+ temporal_attention_dim_div: 1
14
+ zero_initialize: true
15
+
16
+ noise_scheduler_kwargs:
17
+ beta_start: 0.00085
18
+ beta_end: 0.012
19
+ beta_schedule: "linear"
20
+ steps_offset: 1
21
+ clip_sample: false
configs/model_config/model_config_public.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ unet_additional_kwargs:
2
+ use_inflated_groupnorm: true # from config v3
3
+ unet_use_cross_frame_attention: false
4
+ unet_use_temporal_attention: false
5
+ use_motion_module: true
6
+ motion_module_resolutions: [1,2,4,8]
7
+ motion_module_mid_block: false
8
+ motion_module_decoder_only: false
9
+ motion_module_type: "Vanilla"
10
+
11
+ motion_module_kwargs:
12
+ num_attention_heads: 8
13
+ num_transformer_block: 1
14
+ attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
15
+ temporal_position_encoding: true
16
+ temporal_position_encoding_max_len: 32
17
+ temporal_attention_dim_div: 1
18
+ zero_initialize: true # from config v3
19
+
20
+ noise_scheduler_kwargs:
21
+ beta_start: 0.00085
22
+ beta_end: 0.012
23
+ beta_schedule: "linear"
24
+ steps_offset: 1
25
+ clip_sample: False
configs/sparsectrl/image_condition.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ controlnet_additional_kwargs:
2
+ set_noisy_sample_input_to_zero: true
3
+ use_simplified_condition_embedding: false
4
+ conditioning_channels: 3
5
+
6
+ use_motion_module: true
7
+ motion_module_resolutions: [1,2,4,8]
8
+ motion_module_mid_block: false
9
+ motion_module_type: "Vanilla"
10
+
11
+ motion_module_kwargs:
12
+ num_attention_heads: 8
13
+ num_transformer_block: 1
14
+ attention_block_types: [ "Temporal_Self" ]
15
+ temporal_position_encoding: true
16
+ temporal_position_encoding_max_len: 32
17
+ temporal_attention_dim_div: 1
configs/sparsectrl/latent_condition.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ controlnet_additional_kwargs:
2
+ set_noisy_sample_input_to_zero: true
3
+ use_simplified_condition_embedding: true
4
+ conditioning_channels: 4
5
+
6
+ use_motion_module: true
7
+ motion_module_resolutions: [1,2,4,8]
8
+ motion_module_mid_block: false
9
+ motion_module_type: "Vanilla"
10
+
11
+ motion_module_kwargs:
12
+ num_attention_heads: 8
13
+ num_transformer_block: 1
14
+ attention_block_types: [ "Temporal_Self" ]
15
+ temporal_position_encoding: true
16
+ temporal_position_encoding_max_len: 32
17
+ temporal_attention_dim_div: 1
configs/t2v_camera.jsonl ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"video_path":"reference_videos/camera_zoom_in.mp4", "new_prompt": "Relics on the seabed", "seed": 42}
2
+ {"video_path":"reference_videos/camera_zoom_in.mp4", "new_prompt": "A road in the mountain", "seed": 42}
3
+ {"video_path":"reference_videos/camera_zoom_in.mp4", "new_prompt": "Caves, a path for exploration", "seed": 2026}
4
+ {"video_path":"reference_videos/camera_zoom_in.mp4", "new_prompt": "Railway for train"}
5
+ {"video_path":"reference_videos/camera_zoom_out.mp4", "new_prompt": "Tree, in the mountain", "seed": 2026}
6
+ {"video_path":"reference_videos/camera_zoom_out.mp4", "new_prompt": "Red car on the track", "seed": 2026}
7
+ {"video_path":"reference_videos/camera_zoom_out.mp4", "new_prompt": "Man, standing in his garden.", "seed": 2026}
8
+ {"video_path":"reference_videos/camera_1.mp4", "new_prompt": "A island, on the ocean, sunny day"}
9
+ {"video_path":"reference_videos/camera_1.mp4", "new_prompt": "A tower, with fireworks"}
10
+ {"video_path":"reference_videos/camera_pan_up.mp4", "new_prompt": "Beautiful house, around with flowers", "seed": 42}
11
+ {"video_path":"reference_videos/camera_translation_2.mp4", "new_prompt": "Forest, in winter", "seed": 2028}
12
+ {"video_path":"reference_videos/camera_pan_down.mp4", "new_prompt": "Eagle, standing in the tree", "seed": 2026}
configs/t2v_camera.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ motion_module: "models/Motion_Module/v3_sd15_mm.ckpt"
3
+ dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors"
4
+ model_config: "configs/model_config/model_config.yaml"
5
+
6
+ cfg_scale: 7.5 # in default realistic classifer-free guidance
7
+ negative_prompt: "bad anatomy, extra limbs, ugly, deformed, noisy, blurry, distorted, out of focus, poorly drawn face, poorly drawn hands, missing fingers"
8
+ postive_prompt: " 8k, high detailed, best quality, film grain, Fujifilm XT3"
9
+
10
+ inference_steps: 100 # the total denosing step for inference
11
+ guidance_scale: 0.3 # which scale of time step to end guidance 0.2/40
12
+ guidance_steps: 50 # the step for guidance in inference, no more than 1000*guidance_scale, the remaining steps (inference_steps-guidance_steps) is performed without gudiance
13
+ warm_up_steps: 10
14
+ cool_up_steps: 10
15
+
16
+ motion_guidance_weight: 2000
17
+ motion_guidance_blocks: ['up_blocks.1']
18
+
19
+ add_noise_step: 400
configs/t2v_object.jsonl ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {"video_path":"reference_videos/sample_astronaut.mp4", "new_prompt": "Robot, walks in the street.","seed":59}
2
+ {"video_path":"reference_videos/sample_cat.mp4", "new_prompt": "Tiger, raises its head.", "seed": 2025}
3
+ {"video_path":"reference_videos/sample_leaves.mp4", "new_prompt": "Petals falling in the wind.","seed":3407}
4
+ {"video_path":"reference_videos/sample_fox.mp4", "new_prompt": "Cat, turns its head in the living room."}
5
+ {"video_path":"reference_videos/sample_blackswan.mp4", "new_prompt": "Duck, swims in the river.","seed":3407}
6
+ {"video_path":"reference_videos/sample_cow.mp4", "new_prompt": "Pig, drinks water on beach.","seed":3407}
configs/t2v_object.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ motion_module: "models/Motion_Module/v3_sd15_mm.ckpt"
3
+ dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors"
4
+ model_config: "configs/model_config/model_config.yaml"
5
+
6
+ cfg_scale: 7.5 # in default realistic classifer-free guidance
7
+ negative_prompt: "bad anatomy, extra limbs, ugly, deformed, noisy, blurry, distorted, out of focus, poorly drawn face, poorly drawn hands, missing fingers"
8
+ postive_prompt: "8k, high detailed, best quality, film grain, Fujifilm XT3"
9
+
10
+ inference_steps: 300 # the total denosing step for inference
11
+ guidance_scale: 0.4 # which scale of time step to end guidance
12
+ guidance_steps: 180 # the step for guidance in inference, no more than 1000*guidance_scale, the remaining steps (inference_steps-guidance_steps) is performed without gudiance
13
+ warm_up_steps: 10
14
+ cool_up_steps: 10
15
+
16
+ motion_guidance_weight: 2000
17
+ motion_guidance_blocks: ['up_blocks.1',]
18
+
19
+ add_noise_step: 400
environment.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: motionclone
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ dependencies:
6
+ - python=3.11.3
7
+ - pytorch=2.0.1
8
+ - torchvision=0.15.2
9
+ - pytorch-cuda=11.8
10
+ - pip
11
+ - pip:
12
+ - accelerate
13
+ - diffusers==0.16.0
14
+ - transformers==4.28.1
15
+ - xformers==0.0.20
16
+ - imageio[ffmpeg]
17
+ - decord==0.6.0
18
+ - gdown
19
+ - einops
20
+ - omegaconf
21
+ - safetensors
22
+ - gradio
23
+ - wandb
24
+ - triton
25
+ - opencv-python
generated_videos/camera_zoom_out_Dog,_lying_on_the_grass76739_76739.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:63ecf6f1250b83d71b50352a020c97eb60223ee33813219b2bd8d7588f1ecfec
3
+ size 285735
generated_videos/inference_config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ motion_module: models/Motion_Module/v3_sd15_mm.ckpt
2
+ dreambooth_path: models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors
3
+ model_config: configs/model_config/model_config.yaml
4
+ controlnet_config: configs/sparsectrl/image_condition.yaml
5
+ controlnet_path: models/SparseCtrl/v3_sd15_sparsectrl_scribble.ckpt
6
+ adapter_lora_path: models/Motion_Module/v3_sd15_adapter.ckpt
7
+ cfg_scale: 7.5
8
+ negative_prompt: ugly, deformed, noisy, blurry, distorted, out of focus, bad anatomy,
9
+ extra limbs, poorly drawn face, poorly drawn hands, missing fingers
10
+ inference_steps: 200
11
+ guidance_scale: 0.4
12
+ guidance_steps: 120
13
+ warm_up_steps: 10
14
+ cool_up_steps: 10
15
+ motion_guidance_weight: 2000
16
+ motion_guidance_blocks:
17
+ - up_blocks.1
18
+ add_noise_step: 400
19
+ width: 512
20
+ height: 512
21
+ video_length: 16
generated_videos/sample_white_tiger_Lion,_walks_in_the_forest76739_76739.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3ae68b549f1c6541417009d1cdd35d01286876bada07fb53a3354ad9225856cf
3
+ size 538343
i2v_video_sample.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from omegaconf import OmegaConf
3
+ import torch
4
+ from diffusers import AutoencoderKL, DDIMScheduler
5
+ from transformers import CLIPTextModel, CLIPTokenizer
6
+ from motionclone.models.unet import UNet3DConditionModel
7
+ from motionclone.models.sparse_controlnet import SparseControlNetModel
8
+ from motionclone.pipelines.pipeline_animation import AnimationPipeline
9
+ from motionclone.utils.util import load_weights, auto_download
10
+ from diffusers.utils.import_utils import is_xformers_available
11
+ from motionclone.utils.motionclone_functions import *
12
+ import json
13
+ from motionclone.utils.xformer_attention import *
14
+
15
+
16
+ def main(args):
17
+
18
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.visible_gpu or str(os.getenv('CUDA_VISIBLE_DEVICES', 0))
19
+
20
+ config = OmegaConf.load(args.inference_config)
21
+ adopted_dtype = torch.float16
22
+ device = "cuda"
23
+ set_all_seed(42)
24
+
25
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
26
+ text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder").to(device).to(dtype=adopted_dtype)
27
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae").to(device).to(dtype=adopted_dtype)
28
+
29
+ config.width = config.get("W", args.W)
30
+ config.height = config.get("H", args.H)
31
+ config.video_length = config.get("L", args.L)
32
+
33
+ if not os.path.exists(args.generated_videos_save_dir):
34
+ os.makedirs(args.generated_videos_save_dir)
35
+ OmegaConf.save(config, os.path.join(args.generated_videos_save_dir,"inference_config.json"))
36
+
37
+ model_config = OmegaConf.load(config.get("model_config", ""))
38
+ unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(model_config.unet_additional_kwargs),).to(device).to(dtype=adopted_dtype)
39
+
40
+ # load controlnet model
41
+ controlnet = None
42
+ if config.get("controlnet_path", "") != "":
43
+ # assert model_config.get("controlnet_images", "") != ""
44
+ assert config.get("controlnet_config", "") != ""
45
+
46
+ unet.config.num_attention_heads = 8
47
+ unet.config.projection_class_embeddings_input_dim = None
48
+
49
+ controlnet_config = OmegaConf.load(config.controlnet_config)
50
+ controlnet = SparseControlNetModel.from_unet(unet, controlnet_additional_kwargs=controlnet_config.get("controlnet_additional_kwargs", {})).to(device).to(dtype=adopted_dtype)
51
+
52
+ auto_download(config.controlnet_path, is_dreambooth_lora=False)
53
+ print(f"loading controlnet checkpoint from {config.controlnet_path} ...")
54
+ controlnet_state_dict = torch.load(config.controlnet_path, map_location="cpu")
55
+ controlnet_state_dict = controlnet_state_dict["controlnet"] if "controlnet" in controlnet_state_dict else controlnet_state_dict
56
+ controlnet_state_dict = {name: param for name, param in controlnet_state_dict.items() if "pos_encoder.pe" not in name}
57
+ controlnet_state_dict.pop("animatediff_config", "")
58
+ controlnet.load_state_dict(controlnet_state_dict)
59
+ del controlnet_state_dict
60
+
61
+ # set xformers
62
+ if is_xformers_available() and (not args.without_xformers):
63
+ unet.enable_xformers_memory_efficient_attention()
64
+
65
+ pipeline = AnimationPipeline(
66
+ vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
67
+ controlnet=controlnet,
68
+ scheduler=DDIMScheduler(**OmegaConf.to_container(model_config.noise_scheduler_kwargs)),
69
+ ).to(device)
70
+
71
+ pipeline = load_weights(
72
+ pipeline,
73
+ # motion module
74
+ motion_module_path = config.get("motion_module", ""),
75
+ # domain adapter
76
+ adapter_lora_path = config.get("adapter_lora_path", ""),
77
+ adapter_lora_scale = config.get("adapter_lora_scale", 1.0),
78
+ # image layer
79
+ dreambooth_model_path = config.get("dreambooth_path", ""),
80
+ ).to(device)
81
+ pipeline.text_encoder.to(dtype=adopted_dtype)
82
+
83
+ # customized functions in motionclone_functions
84
+ pipeline.scheduler.customized_step = schedule_customized_step.__get__(pipeline.scheduler)
85
+ pipeline.scheduler.customized_set_timesteps = schedule_set_timesteps.__get__(pipeline.scheduler)
86
+ pipeline.unet.forward = unet_customized_forward.__get__(pipeline.unet)
87
+ pipeline.sample_video = sample_video.__get__(pipeline)
88
+ pipeline.single_step_video = single_step_video.__get__(pipeline)
89
+ pipeline.get_temp_attn_prob = get_temp_attn_prob.__get__(pipeline)
90
+ pipeline.add_noise = add_noise.__get__(pipeline)
91
+ pipeline.compute_temp_loss = compute_temp_loss.__get__(pipeline)
92
+ pipeline.obtain_motion_representation = obtain_motion_representation.__get__(pipeline)
93
+
94
+ for param in pipeline.unet.parameters():
95
+ param.requires_grad = False
96
+ for param in pipeline.controlnet.parameters():
97
+ param.requires_grad = False
98
+
99
+ pipeline.input_config, pipeline.unet.input_config = config, config
100
+ pipeline.unet = prep_unet_attention(pipeline.unet,pipeline.input_config.motion_guidance_blocks)
101
+ pipeline.unet = prep_unet_conv(pipeline.unet)
102
+ pipeline.scheduler.customized_set_timesteps(config.inference_steps, config.guidance_steps,config.guidance_scale,device=device,timestep_spacing_type = "uneven")
103
+
104
+ with open(args.examples, 'r') as files:
105
+ for line in files:
106
+ # prepare infor of each case
107
+ example_infor = json.loads(line)
108
+ config.video_path = example_infor["video_path"]
109
+ config.condition_image_path_list = example_infor["condition_image_paths"]
110
+ config.image_index = example_infor.get("image_index",[0])
111
+ assert len(config.image_index) == len(config.condition_image_path_list)
112
+ config.new_prompt = example_infor["new_prompt"] + config.get("positive_prompt", "")
113
+ config.controlnet_scale = example_infor.get("controlnet_scale", 1.0)
114
+ pipeline.input_config, pipeline.unet.input_config = config, config # update config
115
+
116
+ # perform motion representation extraction
117
+ seed_motion = seed_motion = example_infor.get("seed", args.default_seed)
118
+ generator = torch.Generator(device=pipeline.device)
119
+ generator.manual_seed(seed_motion)
120
+ if not os.path.exists(args.motion_representation_save_dir):
121
+ os.makedirs(args.motion_representation_save_dir)
122
+ motion_representation_path = os.path.join(args.motion_representation_save_dir, os.path.splitext(os.path.basename(config.video_path))[0] + '.pt')
123
+ pipeline.obtain_motion_representation(generator= generator, motion_representation_path = motion_representation_path, use_controlnet=True,)
124
+
125
+ # perform video generation
126
+ seed = seed_motion # can assign other seed here
127
+ generator = torch.Generator(device=pipeline.device)
128
+ generator.manual_seed(seed)
129
+ pipeline.input_config.seed = seed
130
+ videos = pipeline.sample_video(generator = generator, add_controlnet=True,)
131
+
132
+ videos = rearrange(videos, "b c f h w -> b f h w c")
133
+ save_path = os.path.join(args.generated_videos_save_dir, os.path.splitext(os.path.basename(config.video_path))[0]
134
+ + "_" + config.new_prompt.strip().replace(' ', '_') + str(seed_motion) + "_" +str(seed)+'.mp4')
135
+ videos_uint8 = (videos[0] * 255).astype(np.uint8)
136
+ imageio.mimwrite(save_path, videos_uint8, fps=8)
137
+ print(save_path,"is done")
138
+
139
+ if __name__ == "__main__":
140
+ parser = argparse.ArgumentParser()
141
+ parser.add_argument("--pretrained-model-path", type=str, default="models/StableDiffusion",)
142
+
143
+ parser.add_argument("--inference_config", type=str, default="configs/i2v_sketch.yaml")
144
+ parser.add_argument("--examples", type=str, default="configs/i2v_sketch.jsonl")
145
+ parser.add_argument("--motion-representation-save-dir", type=str, default="motion_representation/")
146
+ parser.add_argument("--generated-videos-save-dir", type=str, default="generated_videos/")
147
+
148
+ parser.add_argument("--visible_gpu", type=str, default=None)
149
+ parser.add_argument("--default-seed", type=int, default=76739)
150
+ parser.add_argument("--L", type=int, default=16)
151
+ parser.add_argument("--W", type=int, default=512)
152
+ parser.add_argument("--H", type=int, default=512)
153
+
154
+ parser.add_argument("--without-xformers", action="store_true")
155
+
156
+ args = parser.parse_args()
157
+ main(args)
models/Motion_Module/Put motion module checkpoints here.txt ADDED
File without changes
motionclone/models/__pycache__/attention.cpython-310.pyc ADDED
Binary file (13.7 kB). View file
 
motionclone/models/__pycache__/attention.cpython-38.pyc ADDED
Binary file (13.6 kB). View file
 
motionclone/models/__pycache__/motion_module.cpython-310.pyc ADDED
Binary file (8.71 kB). View file
 
motionclone/models/__pycache__/motion_module.cpython-38.pyc ADDED
Binary file (8.67 kB). View file
 
motionclone/models/__pycache__/resnet.cpython-310.pyc ADDED
Binary file (5.31 kB). View file
 
motionclone/models/__pycache__/resnet.cpython-38.pyc ADDED
Binary file (5.41 kB). View file
 
motionclone/models/__pycache__/sparse_controlnet.cpython-38.pyc ADDED
Binary file (14 kB). View file
 
motionclone/models/__pycache__/unet.cpython-310.pyc ADDED
Binary file (12.7 kB). View file
 
motionclone/models/__pycache__/unet.cpython-38.pyc ADDED
Binary file (12.4 kB). View file
 
motionclone/models/__pycache__/unet_blocks.cpython-310.pyc ADDED
Binary file (12.8 kB). View file
 
motionclone/models/__pycache__/unet_blocks.cpython-38.pyc ADDED
Binary file (12.1 kB). View file
 
motionclone/models/attention.py ADDED
@@ -0,0 +1,611 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+
10
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
11
+ from diffusers.models.modeling_utils import ModelMixin
12
+ from diffusers.utils import BaseOutput
13
+ from diffusers.utils.import_utils import is_xformers_available
14
+ from diffusers.models.attention import FeedForward, AdaLayerNorm
15
+
16
+ from einops import rearrange, repeat
17
+ import pdb
18
+
19
+ @dataclass
20
+ class Transformer3DModelOutput(BaseOutput):
21
+ sample: torch.FloatTensor
22
+
23
+
24
+ if is_xformers_available():
25
+ import xformers
26
+ import xformers.ops
27
+ else:
28
+ xformers = None
29
+
30
+
31
+ class Transformer3DModel(ModelMixin, ConfigMixin):
32
+ @register_to_config
33
+ def __init__(
34
+ self,
35
+ num_attention_heads: int = 16,
36
+ attention_head_dim: int = 88,
37
+ in_channels: Optional[int] = None,
38
+ num_layers: int = 1,
39
+ dropout: float = 0.0,
40
+ norm_num_groups: int = 32,
41
+ cross_attention_dim: Optional[int] = None,
42
+ attention_bias: bool = False,
43
+ activation_fn: str = "geglu",
44
+ num_embeds_ada_norm: Optional[int] = None,
45
+ use_linear_projection: bool = False,
46
+ only_cross_attention: bool = False,
47
+ upcast_attention: bool = False,
48
+
49
+ unet_use_cross_frame_attention=None,
50
+ unet_use_temporal_attention=None,
51
+ ):
52
+ super().__init__()
53
+ self.use_linear_projection = use_linear_projection
54
+ self.num_attention_heads = num_attention_heads
55
+ self.attention_head_dim = attention_head_dim
56
+ inner_dim = num_attention_heads * attention_head_dim
57
+
58
+ # Define input layers
59
+ self.in_channels = in_channels
60
+
61
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
62
+ if use_linear_projection:
63
+ self.proj_in = nn.Linear(in_channels, inner_dim)
64
+ else:
65
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
66
+
67
+ # Define transformers blocks
68
+ self.transformer_blocks = nn.ModuleList(
69
+ [
70
+ BasicTransformerBlock(
71
+ inner_dim,
72
+ num_attention_heads,
73
+ attention_head_dim,
74
+ dropout=dropout,
75
+ cross_attention_dim=cross_attention_dim,
76
+ activation_fn=activation_fn,
77
+ num_embeds_ada_norm=num_embeds_ada_norm,
78
+ attention_bias=attention_bias,
79
+ only_cross_attention=only_cross_attention,
80
+ upcast_attention=upcast_attention,
81
+
82
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
83
+ unet_use_temporal_attention=unet_use_temporal_attention,
84
+ )
85
+ for d in range(num_layers)
86
+ ]
87
+ )
88
+
89
+ # 4. Define output layers
90
+ if use_linear_projection:
91
+ self.proj_out = nn.Linear(in_channels, inner_dim)
92
+ else:
93
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
94
+
95
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
96
+ # Input
97
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
98
+ video_length = hidden_states.shape[2]
99
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
100
+ encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
101
+
102
+ batch, channel, height, weight = hidden_states.shape
103
+ residual = hidden_states
104
+
105
+ hidden_states = self.norm(hidden_states)
106
+ if not self.use_linear_projection:
107
+ hidden_states = self.proj_in(hidden_states)
108
+ inner_dim = hidden_states.shape[1]
109
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
110
+ else:
111
+ inner_dim = hidden_states.shape[1]
112
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
113
+ hidden_states = self.proj_in(hidden_states)
114
+
115
+ # Blocks
116
+ for block in self.transformer_blocks:
117
+ hidden_states = block(
118
+ hidden_states,
119
+ encoder_hidden_states=encoder_hidden_states,
120
+ timestep=timestep,
121
+ video_length=video_length
122
+ )
123
+
124
+ # Output
125
+ if not self.use_linear_projection:
126
+ hidden_states = (
127
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
128
+ )
129
+ hidden_states = self.proj_out(hidden_states)
130
+ else:
131
+ hidden_states = self.proj_out(hidden_states)
132
+ hidden_states = (
133
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
134
+ )
135
+
136
+ output = hidden_states + residual
137
+
138
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
139
+ if not return_dict:
140
+ return (output,)
141
+
142
+ return Transformer3DModelOutput(sample=output)
143
+
144
+
145
+ class BasicTransformerBlock(nn.Module):
146
+ def __init__(
147
+ self,
148
+ dim: int,
149
+ num_attention_heads: int,
150
+ attention_head_dim: int,
151
+ dropout=0.0,
152
+ cross_attention_dim: Optional[int] = None,
153
+ activation_fn: str = "geglu",
154
+ num_embeds_ada_norm: Optional[int] = None,
155
+ attention_bias: bool = False,
156
+ only_cross_attention: bool = False,
157
+ upcast_attention: bool = False,
158
+
159
+ unet_use_cross_frame_attention = None,
160
+ unet_use_temporal_attention = None,
161
+ ):
162
+ super().__init__()
163
+ self.only_cross_attention = only_cross_attention
164
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
165
+ self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
166
+ self.unet_use_temporal_attention = unet_use_temporal_attention
167
+
168
+ # SC-Attn
169
+ assert unet_use_cross_frame_attention is not None
170
+ if unet_use_cross_frame_attention:
171
+ self.attn1 = SparseCausalAttention2D(
172
+ query_dim=dim,
173
+ heads=num_attention_heads,
174
+ dim_head=attention_head_dim,
175
+ dropout=dropout,
176
+ bias=attention_bias,
177
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
178
+ upcast_attention=upcast_attention,
179
+ )
180
+ else:
181
+ self.attn1 = CrossAttention(
182
+ query_dim=dim,
183
+ heads=num_attention_heads,
184
+ dim_head=attention_head_dim,
185
+ dropout=dropout,
186
+ bias=attention_bias,
187
+ upcast_attention=upcast_attention,
188
+ )
189
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
190
+
191
+ # Cross-Attn
192
+ if cross_attention_dim is not None:
193
+ self.attn2 = CrossAttention(
194
+ query_dim=dim,
195
+ cross_attention_dim=cross_attention_dim,
196
+ heads=num_attention_heads,
197
+ dim_head=attention_head_dim,
198
+ dropout=dropout,
199
+ bias=attention_bias,
200
+ upcast_attention=upcast_attention,
201
+ )
202
+ else:
203
+ self.attn2 = None
204
+
205
+ if cross_attention_dim is not None:
206
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
207
+ else:
208
+ self.norm2 = None
209
+
210
+ # Feed-forward
211
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
212
+ self.norm3 = nn.LayerNorm(dim)
213
+
214
+ # Temp-Attn
215
+ assert unet_use_temporal_attention is not None
216
+ if unet_use_temporal_attention:
217
+ self.attn_temp = CrossAttention(
218
+ query_dim=dim,
219
+ heads=num_attention_heads,
220
+ dim_head=attention_head_dim,
221
+ dropout=dropout,
222
+ bias=attention_bias,
223
+ upcast_attention=upcast_attention,
224
+ )
225
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
226
+ self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
227
+
228
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, op=None):
229
+ if not is_xformers_available():
230
+ print("Here is how to install it")
231
+ raise ModuleNotFoundError(
232
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
233
+ " xformers",
234
+ name="xformers",
235
+ )
236
+ elif not torch.cuda.is_available():
237
+ raise ValueError(
238
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
239
+ " available for GPU "
240
+ )
241
+ else:
242
+ try:
243
+ # Make sure we can run the memory efficient attention
244
+ _ = xformers.ops.memory_efficient_attention(
245
+ torch.randn((1, 2, 40), device="cuda"),
246
+ torch.randn((1, 2, 40), device="cuda"),
247
+ torch.randn((1, 2, 40), device="cuda"),
248
+ )
249
+ except Exception as e:
250
+ raise e
251
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
252
+ if self.attn2 is not None:
253
+ self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
254
+ # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
255
+
256
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None):
257
+ # SparseCausal-Attention
258
+ norm_hidden_states = (
259
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
260
+ )
261
+
262
+ # if self.only_cross_attention:
263
+ # hidden_states = (
264
+ # self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
265
+ # )
266
+ # else:
267
+ # hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
268
+
269
+ # pdb.set_trace()
270
+ if self.unet_use_cross_frame_attention:
271
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
272
+ else:
273
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
274
+
275
+ if self.attn2 is not None:
276
+ # Cross-Attention
277
+ norm_hidden_states = (
278
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
279
+ )
280
+ hidden_states = (
281
+ self.attn2(
282
+ norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
283
+ )
284
+ + hidden_states
285
+ )
286
+
287
+ # Feed-forward
288
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
289
+
290
+ # Temporal-Attention
291
+ if self.unet_use_temporal_attention:
292
+ d = hidden_states.shape[1]
293
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
294
+ norm_hidden_states = (
295
+ self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
296
+ )
297
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
298
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
299
+
300
+ return hidden_states
301
+
302
+ class CrossAttention(nn.Module):
303
+ r"""
304
+ A cross attention layer.
305
+
306
+ Parameters:
307
+ query_dim (`int`): The number of channels in the query.
308
+ cross_attention_dim (`int`, *optional*):
309
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
310
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
311
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
312
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
313
+ bias (`bool`, *optional*, defaults to False):
314
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
315
+ """
316
+
317
+ def __init__(
318
+ self,
319
+ query_dim: int,
320
+ cross_attention_dim: Optional[int] = None,
321
+ heads: int = 8,
322
+ dim_head: int = 64,
323
+ dropout: float = 0.0,
324
+ bias=False,
325
+ upcast_attention: bool = False,
326
+ upcast_softmax: bool = False,
327
+ added_kv_proj_dim: Optional[int] = None,
328
+ norm_num_groups: Optional[int] = None,
329
+ ):
330
+ super().__init__()
331
+ inner_dim = dim_head * heads
332
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
333
+ self.upcast_attention = upcast_attention
334
+ self.upcast_softmax = upcast_softmax
335
+
336
+ self.scale = dim_head**-0.5
337
+
338
+ self.heads = heads
339
+ # for slice_size > 0 the attention score computation
340
+ # is split across the batch axis to save memory
341
+ # You can set slice_size with `set_attention_slice`
342
+ self.sliceable_head_dim = heads
343
+ self._slice_size = None
344
+ self._use_memory_efficient_attention_xformers = False
345
+ self.added_kv_proj_dim = added_kv_proj_dim
346
+
347
+ #### add processer
348
+ self.processor = None
349
+
350
+ if norm_num_groups is not None:
351
+ self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
352
+ else:
353
+ self.group_norm = None
354
+
355
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
356
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
357
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
358
+
359
+ if self.added_kv_proj_dim is not None:
360
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
361
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
362
+
363
+ self.to_out = nn.ModuleList([])
364
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
365
+ self.to_out.append(nn.Dropout(dropout))
366
+
367
+ def reshape_heads_to_batch_dim(self, tensor):
368
+ batch_size, seq_len, dim = tensor.shape
369
+ head_size = self.heads
370
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
371
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
372
+ return tensor
373
+
374
+ def reshape_batch_dim_to_heads(self, tensor):
375
+ batch_size, seq_len, dim = tensor.shape
376
+ head_size = self.heads
377
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
378
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
379
+ return tensor
380
+
381
+ def set_attention_slice(self, slice_size):
382
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
383
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
384
+
385
+ self._slice_size = slice_size
386
+
387
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
388
+ batch_size, sequence_length, _ = hidden_states.shape
389
+
390
+ encoder_hidden_states = encoder_hidden_states
391
+
392
+ if self.group_norm is not None:
393
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
394
+
395
+ query = self.to_q(hidden_states)
396
+ dim = query.shape[-1]
397
+ # query = self.reshape_heads_to_batch_dim(query) # move backwards
398
+
399
+ if self.added_kv_proj_dim is not None:
400
+ key = self.to_k(hidden_states)
401
+ value = self.to_v(hidden_states)
402
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
403
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
404
+
405
+ ######record###### record before reshape heads to batch dim
406
+ if self.processor is not None:
407
+ self.processor.record_qkv(self, hidden_states, query, key, value, attention_mask)
408
+ ##################
409
+
410
+ key = self.reshape_heads_to_batch_dim(key)
411
+ value = self.reshape_heads_to_batch_dim(value)
412
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
413
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
414
+
415
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
416
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
417
+ else:
418
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
419
+ key = self.to_k(encoder_hidden_states)
420
+ value = self.to_v(encoder_hidden_states)
421
+
422
+ ######record######
423
+ if self.processor is not None:
424
+ self.processor.record_qkv(self, hidden_states, query, key, value, attention_mask)
425
+ ##################
426
+
427
+ key = self.reshape_heads_to_batch_dim(key)
428
+ value = self.reshape_heads_to_batch_dim(value)
429
+
430
+ query = self.reshape_heads_to_batch_dim(query) # reshape query
431
+
432
+ if attention_mask is not None:
433
+ if attention_mask.shape[-1] != query.shape[1]:
434
+ target_length = query.shape[1]
435
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
436
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
437
+
438
+ ######record######
439
+ if self.processor is not None:
440
+ self.processor.record_attn_mask(self, hidden_states, query, key, value, attention_mask)
441
+ ##################
442
+
443
+ # attention, what we cannot get enough of
444
+ if self._use_memory_efficient_attention_xformers:
445
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
446
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
447
+ hidden_states = hidden_states.to(query.dtype)
448
+ else:
449
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
450
+ hidden_states = self._attention(query, key, value, attention_mask)
451
+ else:
452
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
453
+
454
+ # linear proj
455
+ hidden_states = self.to_out[0](hidden_states)
456
+
457
+ # dropout
458
+ hidden_states = self.to_out[1](hidden_states)
459
+ return hidden_states
460
+
461
+ def _attention(self, query, key, value, attention_mask=None):
462
+ if self.upcast_attention:
463
+ query = query.float()
464
+ key = key.float()
465
+
466
+ attention_scores = torch.baddbmm(
467
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
468
+ query,
469
+ key.transpose(-1, -2),
470
+ beta=0,
471
+ alpha=self.scale,
472
+ )
473
+
474
+ if attention_mask is not None:
475
+ attention_scores = attention_scores + attention_mask
476
+
477
+ if self.upcast_softmax:
478
+ attention_scores = attention_scores.float()
479
+
480
+ attention_probs = attention_scores.softmax(dim=-1)
481
+
482
+ # cast back to the original dtype
483
+ attention_probs = attention_probs.to(value.dtype)
484
+
485
+ # compute attention output
486
+ hidden_states = torch.bmm(attention_probs, value)
487
+
488
+ # reshape hidden_states
489
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
490
+ return hidden_states
491
+
492
+ def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
493
+ batch_size_attention = query.shape[0]
494
+ hidden_states = torch.zeros(
495
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
496
+ )
497
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
498
+ for i in range(hidden_states.shape[0] // slice_size):
499
+ start_idx = i * slice_size
500
+ end_idx = (i + 1) * slice_size
501
+
502
+ query_slice = query[start_idx:end_idx]
503
+ key_slice = key[start_idx:end_idx]
504
+
505
+ if self.upcast_attention:
506
+ query_slice = query_slice.float()
507
+ key_slice = key_slice.float()
508
+
509
+ attn_slice = torch.baddbmm(
510
+ torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
511
+ query_slice,
512
+ key_slice.transpose(-1, -2),
513
+ beta=0,
514
+ alpha=self.scale,
515
+ )
516
+
517
+ if attention_mask is not None:
518
+ attn_slice = attn_slice + attention_mask[start_idx:end_idx]
519
+
520
+ if self.upcast_softmax:
521
+ attn_slice = attn_slice.float()
522
+
523
+ attn_slice = attn_slice.softmax(dim=-1)
524
+
525
+ # cast back to the original dtype
526
+ attn_slice = attn_slice.to(value.dtype)
527
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
528
+
529
+ hidden_states[start_idx:end_idx] = attn_slice
530
+
531
+ # reshape hidden_states
532
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
533
+ return hidden_states
534
+
535
+ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
536
+ # TODO attention_mask
537
+ query = query.contiguous()
538
+ key = key.contiguous()
539
+ value = value.contiguous()
540
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
541
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
542
+ return hidden_states
543
+
544
+ def set_processor(self, processor: "AttnProcessor") -> None:
545
+ r"""
546
+ Set the attention processor to use.
547
+
548
+ Args:
549
+ processor (`AttnProcessor`):
550
+ The attention processor to use.
551
+ """
552
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
553
+ # pop `processor` from `self._modules`
554
+ if (
555
+ hasattr(self, "processor")
556
+ and isinstance(self.processor, torch.nn.Module)
557
+ and not isinstance(processor, torch.nn.Module)
558
+ ):
559
+ logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
560
+ self._modules.pop("processor")
561
+
562
+ self.processor = processor
563
+
564
+ def get_attention_scores(
565
+ self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None
566
+ ) -> torch.Tensor:
567
+ r"""
568
+ Compute the attention scores.
569
+
570
+ Args:
571
+ query (`torch.Tensor`): The query tensor.
572
+ key (`torch.Tensor`): The key tensor.
573
+ attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
574
+
575
+ Returns:
576
+ `torch.Tensor`: The attention probabilities/scores.
577
+ """
578
+ dtype = query.dtype
579
+ if self.upcast_attention:
580
+ query = query.float()
581
+ key = key.float()
582
+
583
+ if attention_mask is None:
584
+ baddbmm_input = torch.empty(
585
+ query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
586
+ )
587
+ beta = 0
588
+ else:
589
+ baddbmm_input = attention_mask
590
+ beta = 1
591
+
592
+
593
+
594
+ attention_scores = torch.baddbmm(
595
+ baddbmm_input,
596
+ query,
597
+ key.transpose(-1, -2),
598
+ beta=beta,
599
+ alpha=self.scale,
600
+ )
601
+ del baddbmm_input
602
+
603
+ if self.upcast_softmax:
604
+ attention_scores = attention_scores.float()
605
+
606
+ attention_probs = attention_scores.softmax(dim=-1)
607
+ del attention_scores
608
+
609
+ attention_probs = attention_probs.to(dtype)
610
+
611
+ return attention_probs
motionclone/models/motion_module.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import numpy as np
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ import torchvision
9
+
10
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
11
+ from diffusers.models.modeling_utils import ModelMixin
12
+ from diffusers.utils import BaseOutput
13
+ from diffusers.utils.import_utils import is_xformers_available
14
+ from diffusers.models.attention import FeedForward
15
+ from .attention import CrossAttention
16
+
17
+ from einops import rearrange, repeat
18
+ import math
19
+
20
+
21
+ def zero_module(module):
22
+ # Zero out the parameters of a module and return it.
23
+ for p in module.parameters():
24
+ p.detach().zero_()
25
+ return module
26
+
27
+
28
+ @dataclass
29
+ class TemporalTransformer3DModelOutput(BaseOutput):
30
+ sample: torch.FloatTensor
31
+
32
+
33
+ if is_xformers_available():
34
+ import xformers
35
+ import xformers.ops
36
+ else:
37
+ xformers = None
38
+
39
+
40
+ def get_motion_module( # 只能返回VanillaTemporalModule类
41
+ in_channels,
42
+ motion_module_type: str,
43
+ motion_module_kwargs: dict
44
+ ):
45
+ if motion_module_type == "Vanilla":
46
+ return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,)
47
+ else:
48
+ raise ValueError
49
+
50
+
51
+ class VanillaTemporalModule(nn.Module):
52
+ def __init__(
53
+ self,
54
+ in_channels,
55
+ num_attention_heads = 8,
56
+ num_transformer_block = 2,
57
+ attention_block_types =( "Temporal_Self", "Temporal_Self" ),
58
+ cross_frame_attention_mode = None,
59
+ temporal_position_encoding = False,
60
+ temporal_position_encoding_max_len = 32,
61
+ temporal_attention_dim_div = 1,
62
+ zero_initialize = True,
63
+ ):
64
+ super().__init__()
65
+
66
+ self.temporal_transformer = TemporalTransformer3DModel(
67
+ in_channels=in_channels,
68
+ num_attention_heads=num_attention_heads,
69
+ attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
70
+ num_layers=num_transformer_block,
71
+ attention_block_types=attention_block_types,
72
+ cross_frame_attention_mode=cross_frame_attention_mode,
73
+ temporal_position_encoding=temporal_position_encoding,
74
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
75
+ )
76
+
77
+ if zero_initialize:
78
+ self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
79
+
80
+ def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None):
81
+ hidden_states = input_tensor
82
+ hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
83
+
84
+ output = hidden_states
85
+ return output
86
+
87
+
88
+ class TemporalTransformer3DModel(nn.Module):
89
+ def __init__(
90
+ self,
91
+ in_channels,
92
+ num_attention_heads,
93
+ attention_head_dim,
94
+
95
+ num_layers,
96
+ attention_block_types = ( "Temporal_Self", "Temporal_Self", ), # 两个TempAttn
97
+ dropout = 0.0,
98
+ norm_num_groups = 32,
99
+ cross_attention_dim = 768,
100
+ activation_fn = "geglu",
101
+ attention_bias = False,
102
+ upcast_attention = False,
103
+
104
+ cross_frame_attention_mode = None,
105
+ temporal_position_encoding = False,
106
+ temporal_position_encoding_max_len = 24,
107
+ ):
108
+ super().__init__()
109
+
110
+ inner_dim = num_attention_heads * attention_head_dim
111
+
112
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
113
+ self.proj_in = nn.Linear(in_channels, inner_dim)
114
+
115
+ self.transformer_blocks = nn.ModuleList(
116
+ [
117
+ TemporalTransformerBlock(
118
+ dim=inner_dim,
119
+ num_attention_heads=num_attention_heads,
120
+ attention_head_dim=attention_head_dim,
121
+ attention_block_types=attention_block_types,
122
+ dropout=dropout,
123
+ norm_num_groups=norm_num_groups,
124
+ cross_attention_dim=cross_attention_dim,
125
+ activation_fn=activation_fn,
126
+ attention_bias=attention_bias,
127
+ upcast_attention=upcast_attention,
128
+ cross_frame_attention_mode=cross_frame_attention_mode,
129
+ temporal_position_encoding=temporal_position_encoding,
130
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
131
+ )
132
+ for d in range(num_layers)
133
+ ]
134
+ )
135
+ self.proj_out = nn.Linear(inner_dim, in_channels)
136
+
137
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
138
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
139
+ video_length = hidden_states.shape[2]
140
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
141
+
142
+ batch, channel, height, weight = hidden_states.shape
143
+ residual = hidden_states
144
+
145
+ hidden_states = self.norm(hidden_states)
146
+ inner_dim = hidden_states.shape[1]
147
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
148
+ hidden_states = self.proj_in(hidden_states)
149
+
150
+ # Transformer Blocks
151
+ for block in self.transformer_blocks:
152
+ hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length)
153
+
154
+ # output
155
+ hidden_states = self.proj_out(hidden_states)
156
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
157
+
158
+ output = hidden_states + residual
159
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
160
+
161
+ return output
162
+
163
+
164
+ class TemporalTransformerBlock(nn.Module):
165
+ def __init__(
166
+ self,
167
+ dim,
168
+ num_attention_heads,
169
+ attention_head_dim,
170
+ attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
171
+ dropout = 0.0,
172
+ norm_num_groups = 32,
173
+ cross_attention_dim = 768,
174
+ activation_fn = "geglu",
175
+ attention_bias = False,
176
+ upcast_attention = False,
177
+ cross_frame_attention_mode = None,
178
+ temporal_position_encoding = False,
179
+ temporal_position_encoding_max_len = 24,
180
+ ):
181
+ super().__init__()
182
+
183
+ attention_blocks = []
184
+ norms = []
185
+
186
+ for block_name in attention_block_types:
187
+ attention_blocks.append(
188
+ VersatileAttention(
189
+ attention_mode=block_name.split("_")[0],
190
+ cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
191
+
192
+ query_dim=dim,
193
+ heads=num_attention_heads,
194
+ dim_head=attention_head_dim,
195
+ dropout=dropout,
196
+ bias=attention_bias,
197
+ upcast_attention=upcast_attention,
198
+
199
+ cross_frame_attention_mode=cross_frame_attention_mode,
200
+ temporal_position_encoding=temporal_position_encoding,
201
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
202
+ )
203
+ )
204
+ norms.append(nn.LayerNorm(dim))
205
+
206
+ self.attention_blocks = nn.ModuleList(attention_blocks)
207
+ self.norms = nn.ModuleList(norms)
208
+
209
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
210
+ self.ff_norm = nn.LayerNorm(dim)
211
+
212
+
213
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
214
+ for attention_block, norm in zip(self.attention_blocks, self.norms):
215
+ norm_hidden_states = norm(hidden_states)
216
+ hidden_states = attention_block(
217
+ norm_hidden_states,
218
+ encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
219
+ video_length=video_length,
220
+ ) + hidden_states
221
+
222
+ hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
223
+
224
+ output = hidden_states
225
+ return output
226
+
227
+
228
+ class PositionalEncoding(nn.Module):
229
+ def __init__(
230
+ self,
231
+ d_model,
232
+ dropout = 0.,
233
+ max_len = 24
234
+ ):
235
+ super().__init__()
236
+ self.dropout = nn.Dropout(p=dropout)
237
+ position = torch.arange(max_len).unsqueeze(1)
238
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
239
+ pe = torch.zeros(1, max_len, d_model)
240
+ pe[0, :, 0::2] = torch.sin(position * div_term)
241
+ pe[0, :, 1::2] = torch.cos(position * div_term)
242
+ # self.register_buffer('pe', pe)
243
+ self.register_buffer('pe', pe, persistent=False)
244
+
245
+ def forward(self, x):
246
+ x = x + self.pe[:, :x.size(1)]
247
+ return self.dropout(x)
248
+
249
+
250
+ class VersatileAttention(CrossAttention): # 继承CrossAttention类,不需要在额外写set_processor功能
251
+ def __init__(
252
+ self,
253
+ attention_mode = None,
254
+ cross_frame_attention_mode = None,
255
+ temporal_position_encoding = False,
256
+ temporal_position_encoding_max_len = 24,
257
+ *args, **kwargs
258
+ ):
259
+ super().__init__(*args, **kwargs)
260
+ assert attention_mode == "Temporal"
261
+
262
+ self.attention_mode = attention_mode
263
+ self.is_cross_attention = kwargs["cross_attention_dim"] is not None
264
+
265
+ self.pos_encoder = PositionalEncoding(
266
+ kwargs["query_dim"],
267
+ dropout=0.,
268
+ max_len=temporal_position_encoding_max_len
269
+ ) if (temporal_position_encoding and attention_mode == "Temporal") else None
270
+
271
+ def extra_repr(self):
272
+ return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
273
+
274
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
275
+ batch_size, sequence_length, _ = hidden_states.shape
276
+
277
+ if self.attention_mode == "Temporal":
278
+ d = hidden_states.shape[1]
279
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
280
+
281
+ if self.pos_encoder is not None:
282
+ hidden_states = self.pos_encoder(hidden_states)
283
+
284
+ encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states
285
+ else:
286
+ raise NotImplementedError
287
+
288
+ encoder_hidden_states = encoder_hidden_states
289
+
290
+ if self.group_norm is not None:
291
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
292
+
293
+ query = self.to_q(hidden_states)
294
+ dim = query.shape[-1]
295
+ # query = self.reshape_heads_to_batch_dim(query) # move backwards
296
+
297
+ if self.added_kv_proj_dim is not None:
298
+ raise NotImplementedError
299
+
300
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
301
+ key = self.to_k(encoder_hidden_states)
302
+ value = self.to_v(encoder_hidden_states)
303
+
304
+ ######record###### record before reshape heads to batch dim
305
+ if self.processor is not None:
306
+ self.processor.record_qkv(self, hidden_states, query, key, value, attention_mask)
307
+ ##################
308
+
309
+ key = self.reshape_heads_to_batch_dim(key)
310
+ value = self.reshape_heads_to_batch_dim(value)
311
+
312
+ query = self.reshape_heads_to_batch_dim(query) # reshape query here
313
+
314
+ if attention_mask is not None:
315
+ if attention_mask.shape[-1] != query.shape[1]:
316
+ target_length = query.shape[1]
317
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
318
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
319
+
320
+ ######record######
321
+ # if self.processor is not None:
322
+ # self.processor.record_attn_mask(self, hidden_states, query, key, value, attention_mask)
323
+ ##################
324
+
325
+ # attention, what we cannot get enough of
326
+ if self._use_memory_efficient_attention_xformers:
327
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
328
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
329
+ hidden_states = hidden_states.to(query.dtype)
330
+ else:
331
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
332
+ hidden_states = self._attention(query, key, value, attention_mask)
333
+ else:
334
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
335
+
336
+ # linear proj
337
+ hidden_states = self.to_out[0](hidden_states)
338
+
339
+ # dropout
340
+ hidden_states = self.to_out[1](hidden_states)
341
+
342
+ if self.attention_mode == "Temporal":
343
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
344
+
345
+ return hidden_states
346
+
347
+
motionclone/models/resnet.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from einops import rearrange
8
+
9
+
10
+ class InflatedConv3d(nn.Conv2d):
11
+ def forward(self, x):
12
+ video_length = x.shape[2]
13
+
14
+ x = rearrange(x, "b c f h w -> (b f) c h w")
15
+ x = super().forward(x)
16
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
17
+
18
+ return x
19
+
20
+
21
+ class InflatedGroupNorm(nn.GroupNorm):
22
+ def forward(self, x):
23
+ video_length = x.shape[2]
24
+
25
+ x = rearrange(x, "b c f h w -> (b f) c h w")
26
+ x = super().forward(x)
27
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
28
+
29
+ return x
30
+
31
+
32
+ class Upsample3D(nn.Module):
33
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
34
+ super().__init__()
35
+ self.channels = channels
36
+ self.out_channels = out_channels or channels
37
+ self.use_conv = use_conv
38
+ self.use_conv_transpose = use_conv_transpose
39
+ self.name = name
40
+
41
+ conv = None
42
+ if use_conv_transpose:
43
+ raise NotImplementedError
44
+ elif use_conv:
45
+ self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
46
+
47
+ def forward(self, hidden_states, output_size=None):
48
+ assert hidden_states.shape[1] == self.channels
49
+
50
+ if self.use_conv_transpose:
51
+ raise NotImplementedError
52
+
53
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
54
+ dtype = hidden_states.dtype
55
+ if dtype == torch.bfloat16:
56
+ hidden_states = hidden_states.to(torch.float32)
57
+
58
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
59
+ if hidden_states.shape[0] >= 64:
60
+ hidden_states = hidden_states.contiguous()
61
+
62
+ # if `output_size` is passed we force the interpolation output
63
+ # size and do not make use of `scale_factor=2`
64
+ if output_size is None:
65
+ hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
66
+ else:
67
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
68
+
69
+ # If the input is bfloat16, we cast back to bfloat16
70
+ if dtype == torch.bfloat16:
71
+ hidden_states = hidden_states.to(dtype)
72
+
73
+ # if self.use_conv:
74
+ # if self.name == "conv":
75
+ # hidden_states = self.conv(hidden_states)
76
+ # else:
77
+ # hidden_states = self.Conv2d_0(hidden_states)
78
+ hidden_states = self.conv(hidden_states)
79
+
80
+ return hidden_states
81
+
82
+
83
+ class Downsample3D(nn.Module):
84
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
85
+ super().__init__()
86
+ self.channels = channels
87
+ self.out_channels = out_channels or channels
88
+ self.use_conv = use_conv
89
+ self.padding = padding
90
+ stride = 2
91
+ self.name = name
92
+
93
+ if use_conv:
94
+ self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
95
+ else:
96
+ raise NotImplementedError
97
+
98
+ def forward(self, hidden_states):
99
+ assert hidden_states.shape[1] == self.channels
100
+ if self.use_conv and self.padding == 0:
101
+ raise NotImplementedError
102
+
103
+ assert hidden_states.shape[1] == self.channels
104
+ hidden_states = self.conv(hidden_states)
105
+
106
+ return hidden_states
107
+
108
+
109
+ class ResnetBlock3D(nn.Module):
110
+ def __init__(
111
+ self,
112
+ *,
113
+ in_channels,
114
+ out_channels=None,
115
+ conv_shortcut=False,
116
+ dropout=0.0,
117
+ temb_channels=512,
118
+ groups=32,
119
+ groups_out=None,
120
+ pre_norm=True,
121
+ eps=1e-6,
122
+ non_linearity="swish",
123
+ time_embedding_norm="default",
124
+ output_scale_factor=1.0,
125
+ use_in_shortcut=None,
126
+ use_inflated_groupnorm=False,
127
+ ):
128
+ super().__init__()
129
+ self.pre_norm = pre_norm
130
+ self.pre_norm = True
131
+ self.in_channels = in_channels
132
+ out_channels = in_channels if out_channels is None else out_channels
133
+ self.out_channels = out_channels
134
+ self.use_conv_shortcut = conv_shortcut
135
+ self.time_embedding_norm = time_embedding_norm
136
+ self.output_scale_factor = output_scale_factor
137
+ self.upsample = self.downsample = None
138
+
139
+ if groups_out is None:
140
+ groups_out = groups
141
+
142
+ assert use_inflated_groupnorm != None
143
+ if use_inflated_groupnorm:
144
+ self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
145
+ else:
146
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
147
+
148
+ self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
149
+
150
+ if temb_channels is not None:
151
+ if self.time_embedding_norm == "default":
152
+ time_emb_proj_out_channels = out_channels
153
+ elif self.time_embedding_norm == "scale_shift":
154
+ time_emb_proj_out_channels = out_channels * 2
155
+ else:
156
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
157
+
158
+ self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
159
+ else:
160
+ self.time_emb_proj = None
161
+
162
+ if use_inflated_groupnorm:
163
+ self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
164
+ else:
165
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
166
+
167
+ self.dropout = torch.nn.Dropout(dropout)
168
+ self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
169
+
170
+ if non_linearity == "swish":
171
+ self.nonlinearity = lambda x: F.silu(x)
172
+ elif non_linearity == "mish":
173
+ self.nonlinearity = Mish()
174
+ elif non_linearity == "silu":
175
+ self.nonlinearity = nn.SiLU()
176
+
177
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
178
+
179
+ self.conv_shortcut = None
180
+ if self.use_in_shortcut:
181
+ self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
182
+
183
+ def forward(self, input_tensor, temb):
184
+ hidden_states = input_tensor
185
+
186
+ hidden_states = self.norm1(hidden_states)
187
+ hidden_states = self.nonlinearity(hidden_states)
188
+
189
+ hidden_states = self.conv1(hidden_states)
190
+
191
+ if temb is not None:
192
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
193
+
194
+ if temb is not None and self.time_embedding_norm == "default":
195
+ hidden_states = hidden_states + temb
196
+
197
+ hidden_states = self.norm2(hidden_states)
198
+
199
+ if temb is not None and self.time_embedding_norm == "scale_shift":
200
+ scale, shift = torch.chunk(temb, 2, dim=1)
201
+ hidden_states = hidden_states * (1 + scale) + shift
202
+
203
+ hidden_states = self.nonlinearity(hidden_states)
204
+
205
+ hidden_states = self.dropout(hidden_states)
206
+ hidden_states = self.conv2(hidden_states)
207
+
208
+ if self.conv_shortcut is not None:
209
+ input_tensor = self.conv_shortcut(input_tensor)
210
+
211
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
212
+
213
+ return output_tensor
214
+
215
+
216
+ class Mish(torch.nn.Module):
217
+ def forward(self, hidden_states):
218
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
motionclone/models/scheduler.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+
3
+ import torch
4
+ from diffusers import DDIMScheduler
5
+ from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput
6
+ from diffusers.utils.torch_utils import randn_tensor
7
+
8
+
9
+ class CustomDDIMScheduler(DDIMScheduler):
10
+ @torch.no_grad()
11
+ def step(
12
+ self,
13
+ model_output: torch.FloatTensor,
14
+ timestep: int,
15
+ sample: torch.FloatTensor,
16
+ eta: float = 0.0,
17
+ use_clipped_model_output: bool = False,
18
+ generator=None,
19
+ variance_noise: Optional[torch.FloatTensor] = None,
20
+ return_dict: bool = True,
21
+
22
+ # Guidance parameters
23
+ score=None,
24
+ guidance_scale=0.0,
25
+ indices=None, # [0]
26
+
27
+ ) -> Union[DDIMSchedulerOutput, Tuple]:
28
+ """
29
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
30
+ process from the learned model outputs (most often the predicted noise).
31
+
32
+ Args:
33
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
34
+ timestep (`int`): current discrete timestep in the diffusion chain.
35
+ sample (`torch.FloatTensor`):
36
+ current instance of sample being created by diffusion process.
37
+ eta (`float`): weight of noise for added noise in diffusion step.
38
+ use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
39
+ predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
40
+ `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
41
+ coincide with the one provided as input and `use_clipped_model_output` will have not effect.
42
+ generator: random number generator.
43
+ variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we
44
+ can directly provide the noise for the variance itself. This is useful for methods such as
45
+ CycleDiffusion. (https://arxiv.org/abs/2210.05559)
46
+ return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
47
+
48
+ Returns:
49
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
50
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
51
+ returning a tuple, the first element is the sample tensor.
52
+
53
+ """
54
+ if self.num_inference_steps is None:
55
+ raise ValueError(
56
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
57
+ )
58
+
59
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
60
+ # Ideally, read DDIM paper in-detail understanding
61
+
62
+ # Notation (<variable name> -> <name in paper>
63
+ # - pred_noise_t -> e_theta(x_t, t)
64
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
65
+ # - std_dev_t -> sigma_t
66
+ # - eta -> η
67
+ # - pred_sample_direction -> "direction pointing to x_t"
68
+ # - pred_prev_sample -> "x_t-1"
69
+
70
+
71
+ # Support IF models
72
+ if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
73
+ model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
74
+ else:
75
+ predicted_variance = None
76
+
77
+ # 1. get previous step value (=t-1)
78
+ prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
79
+
80
+ # 2. compute alphas, betas
81
+ alpha_prod_t = self.alphas_cumprod[timestep]
82
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
83
+
84
+ beta_prod_t = 1 - alpha_prod_t
85
+
86
+ # 3. compute predicted original sample from predicted noise also called
87
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
88
+ if self.config.prediction_type == "epsilon":
89
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
90
+ pred_epsilon = model_output
91
+ elif self.config.prediction_type == "sample":
92
+ pred_original_sample = model_output
93
+ pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
94
+ elif self.config.prediction_type == "v_prediction":
95
+ pred_original_sample = (alpha_prod_t ** 0.5) * sample - (beta_prod_t ** 0.5) * model_output
96
+ pred_epsilon = (alpha_prod_t ** 0.5) * model_output + (beta_prod_t ** 0.5) * sample
97
+ else:
98
+ raise ValueError(
99
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
100
+ " `v_prediction`"
101
+ )
102
+
103
+ # 4. Clip or threshold "predicted x_0"
104
+ if self.config.thresholding:
105
+ pred_original_sample = self._threshold_sample(pred_original_sample)
106
+ elif self.config.clip_sample:
107
+ pred_original_sample = pred_original_sample.clamp(
108
+ -self.config.clip_sample_range, self.config.clip_sample_range
109
+ )
110
+
111
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
112
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
113
+ variance = self._get_variance(timestep, prev_timestep)
114
+ std_dev_t = eta * variance ** (0.5)
115
+
116
+ if use_clipped_model_output:
117
+ # the pred_epsilon is always re-derived from the clipped x_0 in Glide
118
+ pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) # [2, 4, 64, 64]
119
+
120
+ # 6. apply guidance following the formula (14) from https://arxiv.org/pdf/2105.05233.pdf
121
+ if score is not None and guidance_scale > 0.0: # indices指定了应用guidance的位置,此处indices = [0]
122
+ if indices is not None:
123
+ # import pdb; pdb.set_trace()
124
+ assert pred_epsilon[indices].shape == score.shape, "pred_epsilon[indices].shape != score.shape"
125
+ pred_epsilon[indices] = pred_epsilon[indices] - guidance_scale * (1 - alpha_prod_t) ** (0.5) * score # 只修改了其中第一个[1, 4, 64, 64]的部分
126
+ else:
127
+ assert pred_epsilon.shape == score.shape
128
+ pred_epsilon = pred_epsilon - guidance_scale * (1 - alpha_prod_t) ** (0.5) * score
129
+ #
130
+
131
+ # 7. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
132
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t ** 2) ** (0.5) * pred_epsilon # [2, 4, 64, 64]
133
+
134
+ # 8. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
135
+ prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction # [2, 4, 64, 64]
136
+
137
+ if eta > 0:
138
+ if variance_noise is not None and generator is not None:
139
+ raise ValueError(
140
+ "Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
141
+ " `variance_noise` stays `None`."
142
+ )
143
+
144
+ if variance_noise is None:
145
+ variance_noise = randn_tensor(
146
+ model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
147
+ )
148
+ variance = std_dev_t * variance_noise # 最后还要再加一些随机噪声
149
+
150
+ prev_sample = prev_sample + variance # [2, 4, 64, 64]
151
+ self.pred_epsilon = pred_epsilon
152
+ if not return_dict:
153
+ return (prev_sample,)
154
+
155
+ return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
motionclone/models/sparse_controlnet.py ADDED
@@ -0,0 +1,593 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # Changes were made to this source code by Yuwei Guo.
16
+ from dataclasses import dataclass
17
+ from typing import Any, Dict, List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ from torch import nn
21
+ from torch.nn import functional as F
22
+
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.utils import BaseOutput, logging
25
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
26
+ from diffusers.models.modeling_utils import ModelMixin
27
+
28
+
29
+ from .unet_blocks import (
30
+ CrossAttnDownBlock3D,
31
+ DownBlock3D,
32
+ UNetMidBlock3DCrossAttn,
33
+ get_down_block,
34
+ )
35
+ from einops import repeat, rearrange
36
+ from .resnet import InflatedConv3d
37
+
38
+ from diffusers.models.unet_2d_condition import UNet2DConditionModel
39
+
40
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
+
42
+
43
+ @dataclass
44
+ class SparseControlNetOutput(BaseOutput):
45
+ down_block_res_samples: Tuple[torch.Tensor]
46
+ mid_block_res_sample: torch.Tensor
47
+
48
+
49
+ class SparseControlNetConditioningEmbedding(nn.Module):
50
+ def __init__(
51
+ self,
52
+ conditioning_embedding_channels: int,
53
+ conditioning_channels: int = 3,
54
+ block_out_channels: Tuple[int] = (16, 32, 96, 256),
55
+ ):
56
+ super().__init__()
57
+
58
+ self.conv_in = InflatedConv3d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
59
+
60
+ self.blocks = nn.ModuleList([])
61
+
62
+ for i in range(len(block_out_channels) - 1):
63
+ channel_in = block_out_channels[i]
64
+ channel_out = block_out_channels[i + 1]
65
+ self.blocks.append(InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1))
66
+ self.blocks.append(InflatedConv3d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
67
+
68
+ self.conv_out = zero_module(
69
+ InflatedConv3d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
70
+ )
71
+
72
+ def forward(self, conditioning):
73
+ embedding = self.conv_in(conditioning)
74
+ embedding = F.silu(embedding)
75
+
76
+ for block in self.blocks:
77
+ embedding = block(embedding)
78
+ embedding = F.silu(embedding)
79
+
80
+ embedding = self.conv_out(embedding)
81
+
82
+ return embedding
83
+
84
+
85
+ class SparseControlNetModel(ModelMixin, ConfigMixin):
86
+ _supports_gradient_checkpointing = True
87
+
88
+ @register_to_config
89
+ def __init__(
90
+ self,
91
+ in_channels: int = 4,
92
+ conditioning_channels: int = 3,
93
+ flip_sin_to_cos: bool = True,
94
+ freq_shift: int = 0,
95
+ down_block_types: Tuple[str] = (
96
+ "CrossAttnDownBlock2D",
97
+ "CrossAttnDownBlock2D",
98
+ "CrossAttnDownBlock2D",
99
+ "DownBlock2D",
100
+ ),
101
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
102
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
103
+ layers_per_block: int = 2,
104
+ downsample_padding: int = 1,
105
+ mid_block_scale_factor: float = 1,
106
+ act_fn: str = "silu",
107
+ norm_num_groups: Optional[int] = 32,
108
+ norm_eps: float = 1e-5,
109
+ cross_attention_dim: int = 1280,
110
+ attention_head_dim: Union[int, Tuple[int]] = 8,
111
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
112
+ use_linear_projection: bool = False,
113
+ class_embed_type: Optional[str] = None,
114
+ num_class_embeds: Optional[int] = None,
115
+ upcast_attention: bool = False,
116
+ resnet_time_scale_shift: str = "default",
117
+ projection_class_embeddings_input_dim: Optional[int] = None,
118
+ controlnet_conditioning_channel_order: str = "rgb",
119
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
120
+ global_pool_conditions: bool = False,
121
+
122
+ use_motion_module = True,
123
+ motion_module_resolutions = ( 1,2,4,8 ),
124
+ motion_module_mid_block = False,
125
+ motion_module_type = "Vanilla",
126
+ motion_module_kwargs = {
127
+ "num_attention_heads": 8,
128
+ "num_transformer_block": 1,
129
+ "attention_block_types": ["Temporal_Self"],
130
+ "temporal_position_encoding": True,
131
+ "temporal_position_encoding_max_len": 32,
132
+ "temporal_attention_dim_div": 1,
133
+ "causal_temporal_attention": False,
134
+ },
135
+
136
+ concate_conditioning_mask: bool = True,
137
+ use_simplified_condition_embedding: bool = False,
138
+
139
+ set_noisy_sample_input_to_zero: bool = False,
140
+ ):
141
+ super().__init__()
142
+
143
+ # If `num_attention_heads` is not defined (which is the case for most models)
144
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
145
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
146
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
147
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
148
+ # which is why we correct for the naming here.
149
+ num_attention_heads = num_attention_heads or attention_head_dim
150
+
151
+ # Check inputs
152
+ if len(block_out_channels) != len(down_block_types):
153
+ raise ValueError(
154
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
155
+ )
156
+
157
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
158
+ raise ValueError(
159
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
160
+ )
161
+
162
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
163
+ raise ValueError(
164
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
165
+ )
166
+
167
+ # input
168
+ self.set_noisy_sample_input_to_zero = set_noisy_sample_input_to_zero
169
+
170
+ conv_in_kernel = 3
171
+ conv_in_padding = (conv_in_kernel - 1) // 2
172
+ self.conv_in = InflatedConv3d(
173
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
174
+ )
175
+
176
+ if concate_conditioning_mask:
177
+ conditioning_channels = conditioning_channels + 1
178
+ self.concate_conditioning_mask = concate_conditioning_mask
179
+
180
+ # control net conditioning embedding
181
+ if use_simplified_condition_embedding:
182
+ self.controlnet_cond_embedding = zero_module(
183
+ InflatedConv3d(conditioning_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding)
184
+ ).to(torch.float16)
185
+ else:
186
+ self.controlnet_cond_embedding = SparseControlNetConditioningEmbedding(
187
+ conditioning_embedding_channels=block_out_channels[0],
188
+ block_out_channels=conditioning_embedding_out_channels,
189
+ conditioning_channels=conditioning_channels,
190
+ ).to(torch.float16)
191
+ self.use_simplified_condition_embedding = use_simplified_condition_embedding
192
+
193
+ # time
194
+ time_embed_dim = block_out_channels[0] * 4
195
+
196
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
197
+ timestep_input_dim = block_out_channels[0]
198
+
199
+ self.time_embedding = TimestepEmbedding(
200
+ timestep_input_dim,
201
+ time_embed_dim,
202
+ act_fn=act_fn,
203
+ )
204
+
205
+ # class embedding
206
+ if class_embed_type is None and num_class_embeds is not None:
207
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
208
+ elif class_embed_type == "timestep":
209
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
210
+ elif class_embed_type == "identity":
211
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
212
+ elif class_embed_type == "projection":
213
+ if projection_class_embeddings_input_dim is None:
214
+ raise ValueError(
215
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
216
+ )
217
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
218
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
219
+ # 2. it projects from an arbitrary input dimension.
220
+ #
221
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
222
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
223
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
224
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
225
+ else:
226
+ self.class_embedding = None
227
+
228
+
229
+ self.down_blocks = nn.ModuleList([])
230
+ self.controlnet_down_blocks = nn.ModuleList([])
231
+
232
+ if isinstance(only_cross_attention, bool):
233
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
234
+
235
+ if isinstance(attention_head_dim, int):
236
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
237
+
238
+ if isinstance(num_attention_heads, int):
239
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
240
+
241
+ # down
242
+ output_channel = block_out_channels[0]
243
+
244
+ controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1)
245
+ controlnet_block = zero_module(controlnet_block)
246
+ self.controlnet_down_blocks.append(controlnet_block)
247
+
248
+ for i, down_block_type in enumerate(down_block_types):
249
+ res = 2 ** i
250
+ input_channel = output_channel
251
+ output_channel = block_out_channels[i]
252
+ is_final_block = i == len(block_out_channels) - 1
253
+
254
+ down_block = get_down_block(
255
+ down_block_type,
256
+ num_layers=layers_per_block,
257
+ in_channels=input_channel,
258
+ out_channels=output_channel,
259
+ temb_channels=time_embed_dim,
260
+ add_downsample=not is_final_block,
261
+ resnet_eps=norm_eps,
262
+ resnet_act_fn=act_fn,
263
+ resnet_groups=norm_num_groups,
264
+ cross_attention_dim=cross_attention_dim,
265
+ attn_num_head_channels=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
266
+ downsample_padding=downsample_padding,
267
+ use_linear_projection=use_linear_projection,
268
+ only_cross_attention=only_cross_attention[i],
269
+ upcast_attention=upcast_attention,
270
+ resnet_time_scale_shift=resnet_time_scale_shift,
271
+
272
+ use_inflated_groupnorm=True,
273
+
274
+ use_motion_module=use_motion_module and (res in motion_module_resolutions),
275
+ motion_module_type=motion_module_type,
276
+ motion_module_kwargs=motion_module_kwargs,
277
+ )
278
+ self.down_blocks.append(down_block)
279
+
280
+ for _ in range(layers_per_block):
281
+ controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1)
282
+ controlnet_block = zero_module(controlnet_block)
283
+ self.controlnet_down_blocks.append(controlnet_block)
284
+
285
+ if not is_final_block:
286
+ controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1)
287
+ controlnet_block = zero_module(controlnet_block)
288
+ self.controlnet_down_blocks.append(controlnet_block)
289
+
290
+ # mid
291
+ mid_block_channel = block_out_channels[-1]
292
+
293
+ controlnet_block = InflatedConv3d(mid_block_channel, mid_block_channel, kernel_size=1)
294
+ controlnet_block = zero_module(controlnet_block)
295
+ self.controlnet_mid_block = controlnet_block
296
+
297
+ self.mid_block = UNetMidBlock3DCrossAttn(
298
+ in_channels=mid_block_channel,
299
+ temb_channels=time_embed_dim,
300
+ resnet_eps=norm_eps,
301
+ resnet_act_fn=act_fn,
302
+ output_scale_factor=mid_block_scale_factor,
303
+ resnet_time_scale_shift=resnet_time_scale_shift,
304
+ cross_attention_dim=cross_attention_dim,
305
+ attn_num_head_channels=num_attention_heads[-1],
306
+ resnet_groups=norm_num_groups,
307
+ use_linear_projection=use_linear_projection,
308
+ upcast_attention=upcast_attention,
309
+
310
+ use_inflated_groupnorm=True,
311
+ use_motion_module=use_motion_module and motion_module_mid_block,
312
+ motion_module_type=motion_module_type,
313
+ motion_module_kwargs=motion_module_kwargs,
314
+ )
315
+
316
+ @classmethod
317
+ def from_unet(
318
+ cls,
319
+ unet: UNet2DConditionModel,
320
+ controlnet_conditioning_channel_order: str = "rgb",
321
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
322
+ load_weights_from_unet: bool = True,
323
+
324
+ controlnet_additional_kwargs: dict = {},
325
+ ):
326
+ controlnet = cls(
327
+ in_channels=unet.config.in_channels,
328
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
329
+ freq_shift=unet.config.freq_shift,
330
+ down_block_types=unet.config.down_block_types,
331
+ only_cross_attention=unet.config.only_cross_attention,
332
+ block_out_channels=unet.config.block_out_channels,
333
+ layers_per_block=unet.config.layers_per_block,
334
+ downsample_padding=unet.config.downsample_padding,
335
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
336
+ act_fn=unet.config.act_fn,
337
+ norm_num_groups=unet.config.norm_num_groups,
338
+ norm_eps=unet.config.norm_eps,
339
+ cross_attention_dim=unet.config.cross_attention_dim,
340
+ attention_head_dim=unet.config.attention_head_dim,
341
+ num_attention_heads=unet.config.num_attention_heads,
342
+ use_linear_projection=unet.config.use_linear_projection,
343
+ class_embed_type=unet.config.class_embed_type,
344
+ num_class_embeds=unet.config.num_class_embeds,
345
+ upcast_attention=unet.config.upcast_attention,
346
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
347
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
348
+ controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
349
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
350
+
351
+ **controlnet_additional_kwargs,
352
+ )
353
+
354
+ if load_weights_from_unet:
355
+ m, u = controlnet.conv_in.load_state_dict(cls.image_layer_filter(unet.conv_in.state_dict()), strict=False)
356
+ assert len(u) == 0
357
+ m, u = controlnet.time_proj.load_state_dict(cls.image_layer_filter(unet.time_proj.state_dict()), strict=False)
358
+ assert len(u) == 0
359
+ m, u = controlnet.time_embedding.load_state_dict(cls.image_layer_filter(unet.time_embedding.state_dict()), strict=False)
360
+ assert len(u) == 0
361
+
362
+ if controlnet.class_embedding:
363
+ m, u = controlnet.class_embedding.load_state_dict(cls.image_layer_filter(unet.class_embedding.state_dict()), strict=False)
364
+ assert len(u) == 0
365
+ m, u = controlnet.down_blocks.load_state_dict(cls.image_layer_filter(unet.down_blocks.state_dict()), strict=False)
366
+ assert len(u) == 0
367
+ m, u = controlnet.mid_block.load_state_dict(cls.image_layer_filter(unet.mid_block.state_dict()), strict=False)
368
+ assert len(u) == 0
369
+
370
+ return controlnet
371
+
372
+ @staticmethod
373
+ def image_layer_filter(state_dict):
374
+ new_state_dict = {}
375
+ for name, param in state_dict.items():
376
+ if "motion_modules." in name or "lora" in name: continue
377
+ new_state_dict[name] = param
378
+ return new_state_dict
379
+
380
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
381
+ def set_attention_slice(self, slice_size):
382
+ r"""
383
+ Enable sliced attention computation.
384
+
385
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
386
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
387
+
388
+ Args:
389
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
390
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
391
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
392
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
393
+ must be a multiple of `slice_size`.
394
+ """
395
+ sliceable_head_dims = []
396
+
397
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
398
+ if hasattr(module, "set_attention_slice"):
399
+ sliceable_head_dims.append(module.sliceable_head_dim)
400
+
401
+ for child in module.children():
402
+ fn_recursive_retrieve_sliceable_dims(child)
403
+
404
+ # retrieve number of attention layers
405
+ for module in self.children():
406
+ fn_recursive_retrieve_sliceable_dims(module)
407
+
408
+ num_sliceable_layers = len(sliceable_head_dims)
409
+
410
+ if slice_size == "auto":
411
+ # half the attention head size is usually a good trade-off between
412
+ # speed and memory
413
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
414
+ elif slice_size == "max":
415
+ # make smallest slice possible
416
+ slice_size = num_sliceable_layers * [1]
417
+
418
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
419
+
420
+ if len(slice_size) != len(sliceable_head_dims):
421
+ raise ValueError(
422
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
423
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
424
+ )
425
+
426
+ for i in range(len(slice_size)):
427
+ size = slice_size[i]
428
+ dim = sliceable_head_dims[i]
429
+ if size is not None and size > dim:
430
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
431
+
432
+ # Recursively walk through all the children.
433
+ # Any children which exposes the set_attention_slice method
434
+ # gets the message
435
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
436
+ if hasattr(module, "set_attention_slice"):
437
+ module.set_attention_slice(slice_size.pop())
438
+
439
+ for child in module.children():
440
+ fn_recursive_set_attention_slice(child, slice_size)
441
+
442
+ reversed_slice_size = list(reversed(slice_size))
443
+ for module in self.children():
444
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
445
+
446
+ def _set_gradient_checkpointing(self, module, value=False):
447
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
448
+ module.gradient_checkpointing = value
449
+
450
+ def forward(
451
+ self,
452
+ sample: torch.FloatTensor,
453
+ timestep: Union[torch.Tensor, float, int],
454
+ encoder_hidden_states: torch.Tensor,
455
+
456
+ controlnet_cond: torch.FloatTensor,
457
+ conditioning_mask: Optional[torch.FloatTensor] = None,
458
+
459
+ conditioning_scale: float = 1.0,
460
+ class_labels: Optional[torch.Tensor] = None,
461
+ attention_mask: Optional[torch.Tensor] = None,
462
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
463
+ guess_mode: bool = False,
464
+ return_dict: bool = True,
465
+ ) -> Union[SparseControlNetOutput, Tuple]:
466
+
467
+ # set input noise to zero
468
+ # if self.set_noisy_sample_input_to_zero:
469
+ # sample = torch.zeros_like(sample).to(sample.device)
470
+
471
+ # prepare attention_mask
472
+ if attention_mask is not None:
473
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
474
+ attention_mask = attention_mask.unsqueeze(1)
475
+
476
+ # 1. time
477
+ timesteps = timestep
478
+ if not torch.is_tensor(timesteps):
479
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
480
+ # This would be a good case for the `match` statement (Python 3.10+)
481
+ is_mps = sample.device.type == "mps"
482
+ if isinstance(timestep, float):
483
+ dtype = torch.float32 if is_mps else torch.float64
484
+ else:
485
+ dtype = torch.int32 if is_mps else torch.int64
486
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
487
+ elif len(timesteps.shape) == 0:
488
+ timesteps = timesteps[None].to(sample.device)
489
+
490
+ timesteps = timesteps.repeat(sample.shape[0] // timesteps.shape[0])
491
+ encoder_hidden_states = encoder_hidden_states.repeat(sample.shape[0] // encoder_hidden_states.shape[0], 1, 1)
492
+
493
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
494
+ timesteps = timesteps.expand(sample.shape[0])
495
+
496
+ t_emb = self.time_proj(timesteps)
497
+
498
+ # timesteps does not contain any weights and will always return f32 tensors
499
+ # but time_embedding might actually be running in fp16. so we need to cast here.
500
+ # there might be better ways to encapsulate this.
501
+ t_emb = t_emb.to(dtype=self.dtype)
502
+ emb = self.time_embedding(t_emb)
503
+
504
+ if self.class_embedding is not None:
505
+ if class_labels is None:
506
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
507
+
508
+ if self.config.class_embed_type == "timestep":
509
+ class_labels = self.time_proj(class_labels)
510
+
511
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
512
+ emb = emb + class_emb
513
+
514
+ # 2. pre-process
515
+ # equal to set input noise to zero
516
+ if self.set_noisy_sample_input_to_zero:
517
+ shape = sample.shape
518
+ sample = self.conv_in.bias.reshape(1,-1,1,1,1).expand(shape[0],-1,shape[2],shape[3],shape[4])
519
+ else:
520
+ sample = self.conv_in(sample)
521
+
522
+ if self.concate_conditioning_mask:
523
+ controlnet_cond = torch.cat([controlnet_cond, conditioning_mask], dim=1).to(torch.float16)
524
+ # import pdb; pdb.set_trace()
525
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
526
+
527
+ sample = sample + controlnet_cond
528
+
529
+ # 3. down
530
+ down_block_res_samples = (sample,)
531
+ for downsample_block in self.down_blocks:
532
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
533
+ sample, res_samples = downsample_block(
534
+ hidden_states=sample,
535
+ temb=emb,
536
+ encoder_hidden_states=encoder_hidden_states,
537
+ attention_mask=attention_mask,
538
+ # cross_attention_kwargs=cross_attention_kwargs,
539
+ )
540
+ else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
541
+
542
+ down_block_res_samples += res_samples
543
+
544
+ # 4. mid
545
+ if self.mid_block is not None:
546
+ sample = self.mid_block(
547
+ sample,
548
+ emb,
549
+ encoder_hidden_states=encoder_hidden_states,
550
+ attention_mask=attention_mask,
551
+ # cross_attention_kwargs=cross_attention_kwargs,
552
+ )
553
+
554
+ # 5. controlnet blocks
555
+ controlnet_down_block_res_samples = ()
556
+
557
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
558
+ down_block_res_sample = controlnet_block(down_block_res_sample)
559
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
560
+
561
+ down_block_res_samples = controlnet_down_block_res_samples
562
+
563
+ mid_block_res_sample = self.controlnet_mid_block(sample)
564
+
565
+ # 6. scaling
566
+ if guess_mode and not self.config.global_pool_conditions:
567
+ scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
568
+
569
+ scales = scales * conditioning_scale
570
+ down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
571
+ mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
572
+ else:
573
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
574
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
575
+
576
+ if self.config.global_pool_conditions:
577
+ down_block_res_samples = [
578
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
579
+ ]
580
+ mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
581
+
582
+ if not return_dict:
583
+ return (down_block_res_samples, mid_block_res_sample)
584
+
585
+ return SparseControlNetOutput(
586
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
587
+ )
588
+
589
+
590
+ def zero_module(module):
591
+ for p in module.parameters():
592
+ nn.init.zeros_(p)
593
+ return module
motionclone/models/unet.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import os
7
+ import json
8
+ import pdb
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.utils.checkpoint
13
+
14
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
15
+ from diffusers.models.modeling_utils import ModelMixin
16
+ from diffusers.utils import BaseOutput, logging
17
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
18
+ from .unet_blocks import (
19
+ CrossAttnDownBlock3D,
20
+ CrossAttnUpBlock3D,
21
+ DownBlock3D,
22
+ UNetMidBlock3DCrossAttn,
23
+ UpBlock3D,
24
+ get_down_block,
25
+ get_up_block,
26
+ )
27
+ from .resnet import InflatedConv3d, InflatedGroupNorm
28
+
29
+
30
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
31
+
32
+
33
+ @dataclass
34
+ class UNet3DConditionOutput(BaseOutput):
35
+ sample: torch.FloatTensor
36
+
37
+
38
+ class UNet3DConditionModel(ModelMixin, ConfigMixin):
39
+ _supports_gradient_checkpointing = True
40
+
41
+ @register_to_config
42
+ def __init__(
43
+ self,
44
+ sample_size: Optional[int] = None,
45
+ in_channels: int = 4,
46
+ out_channels: int = 4,
47
+ center_input_sample: bool = False,
48
+ flip_sin_to_cos: bool = True,
49
+ freq_shift: int = 0,
50
+ down_block_types: Tuple[str] = (
51
+ "CrossAttnDownBlock3D",
52
+ "CrossAttnDownBlock3D",
53
+ "CrossAttnDownBlock3D",
54
+ "DownBlock3D",
55
+ ),
56
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
57
+ up_block_types: Tuple[str] = ( # 第一个不带有CrossAttn,后面三个带有CrossAttn
58
+ "UpBlock3D",
59
+ "CrossAttnUpBlock3D",
60
+ "CrossAttnUpBlock3D",
61
+ "CrossAttnUpBlock3D"
62
+ ),
63
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
64
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
65
+ layers_per_block: int = 2,
66
+ downsample_padding: int = 1,
67
+ mid_block_scale_factor: float = 1,
68
+ act_fn: str = "silu",
69
+ norm_num_groups: int = 32,
70
+ norm_eps: float = 1e-5,
71
+ cross_attention_dim: int = 1280,
72
+ attention_head_dim: Union[int, Tuple[int]] = 8,
73
+ dual_cross_attention: bool = False,
74
+ use_linear_projection: bool = False,
75
+ class_embed_type: Optional[str] = None,
76
+ num_class_embeds: Optional[int] = None,
77
+ upcast_attention: bool = False,
78
+ resnet_time_scale_shift: str = "default",
79
+
80
+ use_inflated_groupnorm=False,
81
+
82
+ # Additional
83
+ use_motion_module = False,
84
+ motion_module_resolutions = ( 1,2,4,8 ),
85
+ motion_module_mid_block = False,
86
+ motion_module_decoder_only = False,
87
+ motion_module_type = None,
88
+ motion_module_kwargs = {},
89
+ unet_use_cross_frame_attention = False,
90
+ unet_use_temporal_attention = False,
91
+ ):
92
+ super().__init__()
93
+
94
+ self.sample_size = sample_size
95
+ time_embed_dim = block_out_channels[0] * 4
96
+
97
+ # input
98
+ self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
99
+
100
+ # time
101
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
102
+ timestep_input_dim = block_out_channels[0]
103
+
104
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
105
+
106
+ # class embedding
107
+ if class_embed_type is None and num_class_embeds is not None:
108
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
109
+ elif class_embed_type == "timestep":
110
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
111
+ elif class_embed_type == "identity":
112
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
113
+ else:
114
+ self.class_embedding = None
115
+
116
+ self.down_blocks = nn.ModuleList([])
117
+ self.mid_block = None
118
+ self.up_blocks = nn.ModuleList([])
119
+
120
+ if isinstance(only_cross_attention, bool):
121
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
122
+
123
+ if isinstance(attention_head_dim, int):
124
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
125
+
126
+ # down
127
+ output_channel = block_out_channels[0]
128
+ for i, down_block_type in enumerate(down_block_types):
129
+ res = 2 ** i
130
+ input_channel = output_channel
131
+ output_channel = block_out_channels[i]
132
+ is_final_block = i == len(block_out_channels) - 1
133
+
134
+ down_block = get_down_block(
135
+ down_block_type,
136
+ num_layers=layers_per_block,
137
+ in_channels=input_channel,
138
+ out_channels=output_channel,
139
+ temb_channels=time_embed_dim,
140
+ add_downsample=not is_final_block,
141
+ resnet_eps=norm_eps,
142
+ resnet_act_fn=act_fn,
143
+ resnet_groups=norm_num_groups,
144
+ cross_attention_dim=cross_attention_dim,
145
+ attn_num_head_channels=attention_head_dim[i],
146
+ downsample_padding=downsample_padding,
147
+ dual_cross_attention=dual_cross_attention,
148
+ use_linear_projection=use_linear_projection,
149
+ only_cross_attention=only_cross_attention[i],
150
+ upcast_attention=upcast_attention,
151
+ resnet_time_scale_shift=resnet_time_scale_shift,
152
+
153
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
154
+ unet_use_temporal_attention=unet_use_temporal_attention,
155
+ use_inflated_groupnorm=use_inflated_groupnorm,
156
+
157
+ use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only),
158
+ motion_module_type=motion_module_type,
159
+ motion_module_kwargs=motion_module_kwargs,
160
+ )
161
+ self.down_blocks.append(down_block)
162
+
163
+ # mid
164
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
165
+ self.mid_block = UNetMidBlock3DCrossAttn(
166
+ in_channels=block_out_channels[-1],
167
+ temb_channels=time_embed_dim,
168
+ resnet_eps=norm_eps,
169
+ resnet_act_fn=act_fn,
170
+ output_scale_factor=mid_block_scale_factor,
171
+ resnet_time_scale_shift=resnet_time_scale_shift,
172
+ cross_attention_dim=cross_attention_dim,
173
+ attn_num_head_channels=attention_head_dim[-1],
174
+ resnet_groups=norm_num_groups,
175
+ dual_cross_attention=dual_cross_attention,
176
+ use_linear_projection=use_linear_projection,
177
+ upcast_attention=upcast_attention,
178
+
179
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
180
+ unet_use_temporal_attention=unet_use_temporal_attention,
181
+ use_inflated_groupnorm=use_inflated_groupnorm,
182
+
183
+ use_motion_module=use_motion_module and motion_module_mid_block,
184
+ motion_module_type=motion_module_type,
185
+ motion_module_kwargs=motion_module_kwargs,
186
+ )
187
+ else:
188
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
189
+
190
+ # count how many layers upsample the videos
191
+ self.num_upsamplers = 0
192
+
193
+ # up
194
+ reversed_block_out_channels = list(reversed(block_out_channels))
195
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
196
+ only_cross_attention = list(reversed(only_cross_attention))
197
+ output_channel = reversed_block_out_channels[0]
198
+ for i, up_block_type in enumerate(up_block_types):
199
+ res = 2 ** (3 - i)
200
+ is_final_block = i == len(block_out_channels) - 1
201
+
202
+ prev_output_channel = output_channel
203
+ output_channel = reversed_block_out_channels[i]
204
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
205
+
206
+ # add upsample block for all BUT final layer
207
+ if not is_final_block:
208
+ add_upsample = True
209
+ self.num_upsamplers += 1
210
+ else:
211
+ add_upsample = False
212
+
213
+ up_block = get_up_block(
214
+ up_block_type,
215
+ num_layers=layers_per_block + 1,
216
+ in_channels=input_channel,
217
+ out_channels=output_channel,
218
+ prev_output_channel=prev_output_channel,
219
+ temb_channels=time_embed_dim,
220
+ add_upsample=add_upsample,
221
+ resnet_eps=norm_eps,
222
+ resnet_act_fn=act_fn,
223
+ resnet_groups=norm_num_groups,
224
+ cross_attention_dim=cross_attention_dim,
225
+ attn_num_head_channels=reversed_attention_head_dim[i],
226
+ dual_cross_attention=dual_cross_attention,
227
+ use_linear_projection=use_linear_projection,
228
+ only_cross_attention=only_cross_attention[i],
229
+ upcast_attention=upcast_attention,
230
+ resnet_time_scale_shift=resnet_time_scale_shift,
231
+
232
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
233
+ unet_use_temporal_attention=unet_use_temporal_attention,
234
+ use_inflated_groupnorm=use_inflated_groupnorm,
235
+
236
+ use_motion_module=use_motion_module and (res in motion_module_resolutions),
237
+ motion_module_type=motion_module_type,
238
+ motion_module_kwargs=motion_module_kwargs,
239
+ )
240
+ self.up_blocks.append(up_block)
241
+ prev_output_channel = output_channel
242
+
243
+ # out
244
+ if use_inflated_groupnorm:
245
+ self.conv_norm_out = InflatedGroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
246
+ else:
247
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
248
+ self.conv_act = nn.SiLU()
249
+ self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
250
+
251
+ def set_attention_slice(self, slice_size):
252
+ r"""
253
+ Enable sliced attention computation.
254
+
255
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
256
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
257
+
258
+ Args:
259
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
260
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
261
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
262
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
263
+ must be a multiple of `slice_size`.
264
+ """
265
+ sliceable_head_dims = []
266
+
267
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
268
+ if hasattr(module, "set_attention_slice"):
269
+ sliceable_head_dims.append(module.sliceable_head_dim)
270
+
271
+ for child in module.children():
272
+ fn_recursive_retrieve_slicable_dims(child)
273
+
274
+ # retrieve number of attention layers
275
+ for module in self.children():
276
+ fn_recursive_retrieve_slicable_dims(module)
277
+
278
+ num_slicable_layers = len(sliceable_head_dims)
279
+
280
+ if slice_size == "auto":
281
+ # half the attention head size is usually a good trade-off between
282
+ # speed and memory
283
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
284
+ elif slice_size == "max":
285
+ # make smallest slice possible
286
+ slice_size = num_slicable_layers * [1]
287
+
288
+ slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
289
+
290
+ if len(slice_size) != len(sliceable_head_dims):
291
+ raise ValueError(
292
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
293
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
294
+ )
295
+
296
+ for i in range(len(slice_size)):
297
+ size = slice_size[i]
298
+ dim = sliceable_head_dims[i]
299
+ if size is not None and size > dim:
300
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
301
+
302
+ # Recursively walk through all the children.
303
+ # Any children which exposes the set_attention_slice method
304
+ # gets the message
305
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
306
+ if hasattr(module, "set_attention_slice"):
307
+ module.set_attention_slice(slice_size.pop())
308
+
309
+ for child in module.children():
310
+ fn_recursive_set_attention_slice(child, slice_size)
311
+
312
+ reversed_slice_size = list(reversed(slice_size))
313
+ for module in self.children():
314
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
315
+
316
+ def _set_gradient_checkpointing(self, module, value=False):
317
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
318
+ module.gradient_checkpointing = value
319
+
320
+ def forward(
321
+ self,
322
+ sample: torch.FloatTensor,
323
+ timestep: Union[torch.Tensor, float, int],
324
+ encoder_hidden_states: torch.Tensor,
325
+ class_labels: Optional[torch.Tensor] = None,
326
+ attention_mask: Optional[torch.Tensor] = None,
327
+
328
+ # support controlnet
329
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
330
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
331
+
332
+ return_dict: bool = True,
333
+ ) -> Union[UNet3DConditionOutput, Tuple]:
334
+ r"""
335
+ Args:
336
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
337
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
338
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
339
+ return_dict (`bool`, *optional*, defaults to `True`):
340
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
341
+
342
+ Returns:
343
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
344
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
345
+ returning a tuple, the first element is the sample tensor.
346
+ """
347
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
348
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
349
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
350
+ # on the fly if necessary.
351
+ default_overall_up_factor = 2**self.num_upsamplers
352
+
353
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
354
+ forward_upsample_size = False
355
+ upsample_size = None
356
+
357
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
358
+ logger.info("Forward upsample size to force interpolation output size.")
359
+ forward_upsample_size = True
360
+
361
+ # prepare attention_mask
362
+ if attention_mask is not None:
363
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
364
+ attention_mask = attention_mask.unsqueeze(1)
365
+
366
+ # center input if necessary
367
+ if self.config.center_input_sample:
368
+ sample = 2 * sample - 1.0
369
+
370
+ # time
371
+ timesteps = timestep
372
+ if not torch.is_tensor(timesteps):
373
+ # This would be a good case for the `match` statement (Python 3.10+)
374
+ is_mps = sample.device.type == "mps"
375
+ if isinstance(timestep, float):
376
+ dtype = torch.float32 if is_mps else torch.float64
377
+ else:
378
+ dtype = torch.int32 if is_mps else torch.int64
379
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
380
+ elif len(timesteps.shape) == 0:
381
+ timesteps = timesteps[None].to(sample.device)
382
+
383
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
384
+ timesteps = timesteps.expand(sample.shape[0])
385
+
386
+ t_emb = self.time_proj(timesteps)
387
+
388
+ # timesteps does not contain any weights and will always return f32 tensors
389
+ # but time_embedding might actually be running in fp16. so we need to cast here.
390
+ # there might be better ways to encapsulate this.
391
+ t_emb = t_emb.to(dtype=self.dtype)
392
+ emb = self.time_embedding(t_emb)
393
+
394
+ if self.class_embedding is not None:
395
+ if class_labels is None:
396
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
397
+
398
+ if self.config.class_embed_type == "timestep":
399
+ class_labels = self.time_proj(class_labels)
400
+
401
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
402
+ emb = emb + class_emb
403
+
404
+ # pre-process
405
+ sample = self.conv_in(sample)
406
+
407
+ # down
408
+ down_block_res_samples = (sample,)
409
+ for downsample_block in self.down_blocks:
410
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
411
+ sample, res_samples = downsample_block(
412
+ hidden_states=sample,
413
+ temb=emb,
414
+ encoder_hidden_states=encoder_hidden_states,
415
+ attention_mask=attention_mask,
416
+ )
417
+ else:
418
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states)
419
+
420
+ down_block_res_samples += res_samples
421
+
422
+ # support controlnet
423
+ down_block_res_samples = list(down_block_res_samples)
424
+ if down_block_additional_residuals is not None:
425
+ for i, down_block_additional_residual in enumerate(down_block_additional_residuals):
426
+ if down_block_additional_residual.dim() == 4: # boardcast
427
+ down_block_additional_residual = down_block_additional_residual.unsqueeze(2)
428
+ down_block_res_samples[i] = down_block_res_samples[i] + down_block_additional_residual
429
+
430
+ # mid
431
+ sample = self.mid_block(
432
+ sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
433
+ )
434
+
435
+ # support controlnet
436
+ if mid_block_additional_residual is not None:
437
+ if mid_block_additional_residual.dim() == 4: # boardcast
438
+ mid_block_additional_residual = mid_block_additional_residual.unsqueeze(2)
439
+ sample = sample + mid_block_additional_residual
440
+
441
+ # up
442
+ for i, upsample_block in enumerate(self.up_blocks):
443
+ is_final_block = i == len(self.up_blocks) - 1
444
+
445
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
446
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
447
+
448
+ # if we have not reached the final block and need to forward the
449
+ # upsample size, we do it here
450
+ if not is_final_block and forward_upsample_size:
451
+ upsample_size = down_block_res_samples[-1].shape[2:]
452
+
453
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
454
+ sample = upsample_block(
455
+ hidden_states=sample,
456
+ temb=emb,
457
+ res_hidden_states_tuple=res_samples,
458
+ encoder_hidden_states=encoder_hidden_states,
459
+ upsample_size=upsample_size,
460
+ attention_mask=attention_mask,
461
+ )
462
+ else:
463
+ sample = upsample_block(
464
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states,
465
+ )
466
+
467
+ # post-process
468
+ sample = self.conv_norm_out(sample)
469
+ sample = self.conv_act(sample)
470
+ sample = self.conv_out(sample)
471
+
472
+ if not return_dict:
473
+ return (sample,)
474
+
475
+ return UNet3DConditionOutput(sample=sample)
476
+
477
+ @classmethod
478
+ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None):
479
+ if subfolder is not None:
480
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
481
+ print(f"loaded 3D unet's pretrained weights from {pretrained_model_path} ...")
482
+
483
+ config_file = os.path.join(pretrained_model_path, 'config.json')
484
+ if not os.path.isfile(config_file):
485
+ raise RuntimeError(f"{config_file} does not exist")
486
+ with open(config_file, "r") as f:
487
+ config = json.load(f)
488
+ config["_class_name"] = cls.__name__
489
+ config["down_block_types"] = [
490
+ "CrossAttnDownBlock3D",
491
+ "CrossAttnDownBlock3D",
492
+ "CrossAttnDownBlock3D",
493
+ "DownBlock3D"
494
+ ]
495
+ config["up_block_types"] = [
496
+ "UpBlock3D",
497
+ "CrossAttnUpBlock3D",
498
+ "CrossAttnUpBlock3D",
499
+ "CrossAttnUpBlock3D"
500
+ ]
501
+
502
+ from diffusers.utils import WEIGHTS_NAME
503
+ model = cls.from_config(config, **unet_additional_kwargs)
504
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
505
+ if not os.path.isfile(model_file):
506
+ raise RuntimeError(f"{model_file} does not exist")
507
+ state_dict = torch.load(model_file, map_location="cpu")
508
+
509
+ m, u = model.load_state_dict(state_dict, strict=False)
510
+ print(f"### motion keys will be loaded: {len(m)}; \n### unexpected keys: {len(u)};")
511
+
512
+ params = [p.numel() if "motion_modules." in n else 0 for n, p in model.named_parameters()]
513
+ print(f"### Motion Module Parameters: {sum(params) / 1e6} M")
514
+
515
+ return model
motionclone/models/unet_blocks.py ADDED
@@ -0,0 +1,760 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from .attention import Transformer3DModel
7
+ from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
8
+ from .motion_module import get_motion_module
9
+
10
+ import pdb
11
+
12
+ def get_down_block(
13
+ down_block_type,
14
+ num_layers,
15
+ in_channels,
16
+ out_channels,
17
+ temb_channels,
18
+ add_downsample,
19
+ resnet_eps,
20
+ resnet_act_fn,
21
+ attn_num_head_channels,
22
+ resnet_groups=None,
23
+ cross_attention_dim=None,
24
+ downsample_padding=None,
25
+ dual_cross_attention=False,
26
+ use_linear_projection=False,
27
+ only_cross_attention=False,
28
+ upcast_attention=False,
29
+ resnet_time_scale_shift="default",
30
+
31
+ unet_use_cross_frame_attention=False,
32
+ unet_use_temporal_attention=False,
33
+ use_inflated_groupnorm=False,
34
+
35
+ use_motion_module=None,
36
+
37
+ motion_module_type=None,
38
+ motion_module_kwargs=None,
39
+ ):
40
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
41
+ if down_block_type == "DownBlock3D":
42
+ return DownBlock3D(
43
+ num_layers=num_layers,
44
+ in_channels=in_channels,
45
+ out_channels=out_channels,
46
+ temb_channels=temb_channels,
47
+ add_downsample=add_downsample,
48
+ resnet_eps=resnet_eps,
49
+ resnet_act_fn=resnet_act_fn,
50
+ resnet_groups=resnet_groups,
51
+ downsample_padding=downsample_padding,
52
+ resnet_time_scale_shift=resnet_time_scale_shift,
53
+
54
+ use_inflated_groupnorm=use_inflated_groupnorm,
55
+
56
+ use_motion_module=use_motion_module,
57
+ motion_module_type=motion_module_type,
58
+ motion_module_kwargs=motion_module_kwargs,
59
+ )
60
+ elif down_block_type == "CrossAttnDownBlock3D":
61
+ if cross_attention_dim is None:
62
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
63
+ return CrossAttnDownBlock3D(
64
+ num_layers=num_layers,
65
+ in_channels=in_channels,
66
+ out_channels=out_channels,
67
+ temb_channels=temb_channels,
68
+ add_downsample=add_downsample,
69
+ resnet_eps=resnet_eps,
70
+ resnet_act_fn=resnet_act_fn,
71
+ resnet_groups=resnet_groups,
72
+ downsample_padding=downsample_padding,
73
+ cross_attention_dim=cross_attention_dim,
74
+ attn_num_head_channels=attn_num_head_channels,
75
+ dual_cross_attention=dual_cross_attention,
76
+ use_linear_projection=use_linear_projection,
77
+ only_cross_attention=only_cross_attention,
78
+ upcast_attention=upcast_attention,
79
+ resnet_time_scale_shift=resnet_time_scale_shift,
80
+
81
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
82
+ unet_use_temporal_attention=unet_use_temporal_attention,
83
+ use_inflated_groupnorm=use_inflated_groupnorm,
84
+
85
+ use_motion_module=use_motion_module,
86
+ motion_module_type=motion_module_type,
87
+ motion_module_kwargs=motion_module_kwargs,
88
+ )
89
+ raise ValueError(f"{down_block_type} does not exist.")
90
+
91
+
92
+ def get_up_block(
93
+ up_block_type,
94
+ num_layers,
95
+ in_channels,
96
+ out_channels,
97
+ prev_output_channel,
98
+ temb_channels,
99
+ add_upsample,
100
+ resnet_eps,
101
+ resnet_act_fn,
102
+ attn_num_head_channels,
103
+ resnet_groups=None,
104
+ cross_attention_dim=None,
105
+ dual_cross_attention=False,
106
+ use_linear_projection=False,
107
+ only_cross_attention=False,
108
+ upcast_attention=False,
109
+ resnet_time_scale_shift="default",
110
+
111
+ unet_use_cross_frame_attention=False,
112
+ unet_use_temporal_attention=False,
113
+ use_inflated_groupnorm=False,
114
+
115
+ use_motion_module=None,
116
+ motion_module_type=None,
117
+ motion_module_kwargs=None,
118
+ ):
119
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
120
+ if up_block_type == "UpBlock3D":
121
+ return UpBlock3D(
122
+ num_layers=num_layers,
123
+ in_channels=in_channels,
124
+ out_channels=out_channels,
125
+ prev_output_channel=prev_output_channel,
126
+ temb_channels=temb_channels,
127
+ add_upsample=add_upsample,
128
+ resnet_eps=resnet_eps,
129
+ resnet_act_fn=resnet_act_fn,
130
+ resnet_groups=resnet_groups,
131
+ resnet_time_scale_shift=resnet_time_scale_shift,
132
+
133
+ use_inflated_groupnorm=use_inflated_groupnorm,
134
+
135
+ use_motion_module=use_motion_module,
136
+ motion_module_type=motion_module_type,
137
+ motion_module_kwargs=motion_module_kwargs,
138
+ )
139
+ elif up_block_type == "CrossAttnUpBlock3D":
140
+ if cross_attention_dim is None:
141
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
142
+ return CrossAttnUpBlock3D(
143
+ num_layers=num_layers,
144
+ in_channels=in_channels,
145
+ out_channels=out_channels,
146
+ prev_output_channel=prev_output_channel,
147
+ temb_channels=temb_channels,
148
+ add_upsample=add_upsample,
149
+ resnet_eps=resnet_eps,
150
+ resnet_act_fn=resnet_act_fn,
151
+ resnet_groups=resnet_groups,
152
+ cross_attention_dim=cross_attention_dim,
153
+ attn_num_head_channels=attn_num_head_channels,
154
+ dual_cross_attention=dual_cross_attention,
155
+ use_linear_projection=use_linear_projection,
156
+ only_cross_attention=only_cross_attention,
157
+ upcast_attention=upcast_attention,
158
+ resnet_time_scale_shift=resnet_time_scale_shift,
159
+
160
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
161
+ unet_use_temporal_attention=unet_use_temporal_attention,
162
+ use_inflated_groupnorm=use_inflated_groupnorm,
163
+
164
+ use_motion_module=use_motion_module,
165
+ motion_module_type=motion_module_type,
166
+ motion_module_kwargs=motion_module_kwargs,
167
+ )
168
+ raise ValueError(f"{up_block_type} does not exist.")
169
+
170
+
171
+ class UNetMidBlock3DCrossAttn(nn.Module):
172
+ def __init__(
173
+ self,
174
+ in_channels: int,
175
+ temb_channels: int,
176
+ dropout: float = 0.0,
177
+ num_layers: int = 1,
178
+ resnet_eps: float = 1e-6,
179
+ resnet_time_scale_shift: str = "default",
180
+ resnet_act_fn: str = "swish",
181
+ resnet_groups: int = 32,
182
+ resnet_pre_norm: bool = True,
183
+ attn_num_head_channels=1,
184
+ output_scale_factor=1.0,
185
+ cross_attention_dim=1280,
186
+ dual_cross_attention=False,
187
+ use_linear_projection=False,
188
+ upcast_attention=False,
189
+
190
+ unet_use_cross_frame_attention=False,
191
+ unet_use_temporal_attention=False,
192
+ use_inflated_groupnorm=False,
193
+
194
+ use_motion_module=None,
195
+
196
+ motion_module_type=None,
197
+ motion_module_kwargs=None,
198
+ ):
199
+ super().__init__()
200
+
201
+ self.has_cross_attention = True
202
+ self.attn_num_head_channels = attn_num_head_channels
203
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
204
+
205
+ # there is always at least one resnet
206
+ resnets = [
207
+ ResnetBlock3D(
208
+ in_channels=in_channels,
209
+ out_channels=in_channels,
210
+ temb_channels=temb_channels,
211
+ eps=resnet_eps,
212
+ groups=resnet_groups,
213
+ dropout=dropout,
214
+ time_embedding_norm=resnet_time_scale_shift,
215
+ non_linearity=resnet_act_fn,
216
+ output_scale_factor=output_scale_factor,
217
+ pre_norm=resnet_pre_norm,
218
+
219
+ use_inflated_groupnorm=use_inflated_groupnorm,
220
+ )
221
+ ]
222
+ attentions = []
223
+ motion_modules = []
224
+
225
+ for _ in range(num_layers):
226
+ if dual_cross_attention:
227
+ raise NotImplementedError
228
+ attentions.append(
229
+ Transformer3DModel(
230
+ attn_num_head_channels,
231
+ in_channels // attn_num_head_channels,
232
+ in_channels=in_channels,
233
+ num_layers=1,
234
+ cross_attention_dim=cross_attention_dim,
235
+ norm_num_groups=resnet_groups,
236
+ use_linear_projection=use_linear_projection,
237
+ upcast_attention=upcast_attention,
238
+
239
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
240
+ unet_use_temporal_attention=unet_use_temporal_attention,
241
+ )
242
+ )
243
+ motion_modules.append(
244
+ get_motion_module(
245
+ in_channels=in_channels,
246
+ motion_module_type=motion_module_type,
247
+ motion_module_kwargs=motion_module_kwargs,
248
+ ) if use_motion_module else None
249
+ )
250
+ resnets.append(
251
+ ResnetBlock3D(
252
+ in_channels=in_channels,
253
+ out_channels=in_channels,
254
+ temb_channels=temb_channels,
255
+ eps=resnet_eps,
256
+ groups=resnet_groups,
257
+ dropout=dropout,
258
+ time_embedding_norm=resnet_time_scale_shift,
259
+ non_linearity=resnet_act_fn,
260
+ output_scale_factor=output_scale_factor,
261
+ pre_norm=resnet_pre_norm,
262
+
263
+ use_inflated_groupnorm=use_inflated_groupnorm,
264
+ )
265
+ )
266
+
267
+ self.attentions = nn.ModuleList(attentions)
268
+ self.resnets = nn.ModuleList(resnets)
269
+ self.motion_modules = nn.ModuleList(motion_modules)
270
+
271
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
272
+ hidden_states = self.resnets[0](hidden_states, temb)
273
+ for attn, resnet, motion_module in zip(self.attentions, self.resnets[1:], self.motion_modules):
274
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
275
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
276
+ hidden_states = resnet(hidden_states, temb)
277
+
278
+ return hidden_states
279
+
280
+
281
+ class CrossAttnDownBlock3D(nn.Module):
282
+ def __init__(
283
+ self,
284
+ in_channels: int,
285
+ out_channels: int,
286
+ temb_channels: int,
287
+ dropout: float = 0.0,
288
+ num_layers: int = 1,
289
+ resnet_eps: float = 1e-6,
290
+ resnet_time_scale_shift: str = "default",
291
+ resnet_act_fn: str = "swish",
292
+ resnet_groups: int = 32,
293
+ resnet_pre_norm: bool = True,
294
+ attn_num_head_channels=1,
295
+ cross_attention_dim=1280,
296
+ output_scale_factor=1.0,
297
+ downsample_padding=1,
298
+ add_downsample=True,
299
+ dual_cross_attention=False,
300
+ use_linear_projection=False,
301
+ only_cross_attention=False,
302
+ upcast_attention=False,
303
+
304
+ unet_use_cross_frame_attention=False,
305
+ unet_use_temporal_attention=False,
306
+ use_inflated_groupnorm=False,
307
+
308
+ use_motion_module=None,
309
+
310
+ motion_module_type=None,
311
+ motion_module_kwargs=None,
312
+ ):
313
+ super().__init__()
314
+ resnets = []
315
+ attentions = []
316
+ motion_modules = []
317
+
318
+ self.has_cross_attention = True
319
+ self.attn_num_head_channels = attn_num_head_channels
320
+
321
+ for i in range(num_layers):
322
+ in_channels = in_channels if i == 0 else out_channels
323
+ resnets.append(
324
+ ResnetBlock3D(
325
+ in_channels=in_channels,
326
+ out_channels=out_channels,
327
+ temb_channels=temb_channels,
328
+ eps=resnet_eps,
329
+ groups=resnet_groups,
330
+ dropout=dropout,
331
+ time_embedding_norm=resnet_time_scale_shift,
332
+ non_linearity=resnet_act_fn,
333
+ output_scale_factor=output_scale_factor,
334
+ pre_norm=resnet_pre_norm,
335
+
336
+ use_inflated_groupnorm=use_inflated_groupnorm,
337
+ )
338
+ )
339
+ if dual_cross_attention:
340
+ raise NotImplementedError
341
+ attentions.append(
342
+ Transformer3DModel(
343
+ attn_num_head_channels,
344
+ out_channels // attn_num_head_channels,
345
+ in_channels=out_channels,
346
+ num_layers=1,
347
+ cross_attention_dim=cross_attention_dim,
348
+ norm_num_groups=resnet_groups,
349
+ use_linear_projection=use_linear_projection,
350
+ only_cross_attention=only_cross_attention,
351
+ upcast_attention=upcast_attention,
352
+
353
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
354
+ unet_use_temporal_attention=unet_use_temporal_attention,
355
+ )
356
+ )
357
+ motion_modules.append(
358
+ get_motion_module(
359
+ in_channels=out_channels,
360
+ motion_module_type=motion_module_type,
361
+ motion_module_kwargs=motion_module_kwargs,
362
+ ) if use_motion_module else None
363
+ )
364
+
365
+ self.attentions = nn.ModuleList(attentions)
366
+ self.resnets = nn.ModuleList(resnets)
367
+ self.motion_modules = nn.ModuleList(motion_modules)
368
+
369
+ if add_downsample:
370
+ self.downsamplers = nn.ModuleList(
371
+ [
372
+ Downsample3D(
373
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
374
+ )
375
+ ]
376
+ )
377
+ else:
378
+ self.downsamplers = None
379
+
380
+ self.gradient_checkpointing = False
381
+
382
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
383
+ output_states = ()
384
+
385
+ for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
386
+ if self.training and self.gradient_checkpointing:
387
+
388
+ def create_custom_forward(module, return_dict=None):
389
+ def custom_forward(*inputs):
390
+ if return_dict is not None:
391
+ return module(*inputs, return_dict=return_dict)
392
+ else:
393
+ return module(*inputs)
394
+
395
+ return custom_forward
396
+
397
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
398
+ hidden_states = torch.utils.checkpoint.checkpoint(
399
+ create_custom_forward(attn, return_dict=False),
400
+ hidden_states,
401
+ encoder_hidden_states,
402
+ )[0]
403
+ if motion_module is not None:
404
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
405
+
406
+ else:
407
+ hidden_states = resnet(hidden_states, temb)
408
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
409
+
410
+ # add motion module
411
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
412
+
413
+ output_states += (hidden_states,)
414
+
415
+ if self.downsamplers is not None:
416
+ for downsampler in self.downsamplers:
417
+ hidden_states = downsampler(hidden_states)
418
+
419
+ output_states += (hidden_states,)
420
+
421
+ return hidden_states, output_states
422
+
423
+
424
+ class DownBlock3D(nn.Module):
425
+ def __init__(
426
+ self,
427
+ in_channels: int,
428
+ out_channels: int,
429
+ temb_channels: int,
430
+ dropout: float = 0.0,
431
+ num_layers: int = 1,
432
+ resnet_eps: float = 1e-6,
433
+ resnet_time_scale_shift: str = "default",
434
+ resnet_act_fn: str = "swish",
435
+ resnet_groups: int = 32,
436
+ resnet_pre_norm: bool = True,
437
+ output_scale_factor=1.0,
438
+ add_downsample=True,
439
+ downsample_padding=1,
440
+
441
+ use_inflated_groupnorm=False,
442
+
443
+ use_motion_module=None,
444
+ motion_module_type=None,
445
+ motion_module_kwargs=None,
446
+ ):
447
+ super().__init__()
448
+ resnets = []
449
+ motion_modules = []
450
+
451
+ for i in range(num_layers):
452
+ in_channels = in_channels if i == 0 else out_channels
453
+ resnets.append(
454
+ ResnetBlock3D(
455
+ in_channels=in_channels,
456
+ out_channels=out_channels,
457
+ temb_channels=temb_channels,
458
+ eps=resnet_eps,
459
+ groups=resnet_groups,
460
+ dropout=dropout,
461
+ time_embedding_norm=resnet_time_scale_shift,
462
+ non_linearity=resnet_act_fn,
463
+ output_scale_factor=output_scale_factor,
464
+ pre_norm=resnet_pre_norm,
465
+
466
+ use_inflated_groupnorm=use_inflated_groupnorm,
467
+ )
468
+ )
469
+ motion_modules.append(
470
+ get_motion_module(
471
+ in_channels=out_channels,
472
+ motion_module_type=motion_module_type,
473
+ motion_module_kwargs=motion_module_kwargs,
474
+ ) if use_motion_module else None
475
+ )
476
+
477
+ self.resnets = nn.ModuleList(resnets)
478
+ self.motion_modules = nn.ModuleList(motion_modules)
479
+
480
+ if add_downsample:
481
+ self.downsamplers = nn.ModuleList(
482
+ [
483
+ Downsample3D(
484
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
485
+ )
486
+ ]
487
+ )
488
+ else:
489
+ self.downsamplers = None
490
+
491
+ self.gradient_checkpointing = False
492
+
493
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
494
+ output_states = ()
495
+
496
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
497
+ if self.training and self.gradient_checkpointing:
498
+ def create_custom_forward(module):
499
+ def custom_forward(*inputs):
500
+ return module(*inputs)
501
+
502
+ return custom_forward
503
+
504
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
505
+ if motion_module is not None:
506
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
507
+ else:
508
+ hidden_states = resnet(hidden_states, temb)
509
+
510
+ # add motion module
511
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
512
+
513
+ output_states += (hidden_states,)
514
+
515
+ if self.downsamplers is not None:
516
+ for downsampler in self.downsamplers:
517
+ hidden_states = downsampler(hidden_states)
518
+
519
+ output_states += (hidden_states,)
520
+
521
+ return hidden_states, output_states
522
+
523
+
524
+ class CrossAttnUpBlock3D(nn.Module):
525
+ def __init__(
526
+ self,
527
+ in_channels: int,
528
+ out_channels: int,
529
+ prev_output_channel: int,
530
+ temb_channels: int,
531
+ dropout: float = 0.0,
532
+ num_layers: int = 1,
533
+ resnet_eps: float = 1e-6,
534
+ resnet_time_scale_shift: str = "default",
535
+ resnet_act_fn: str = "swish",
536
+ resnet_groups: int = 32,
537
+ resnet_pre_norm: bool = True,
538
+ attn_num_head_channels=1,
539
+ cross_attention_dim=1280,
540
+ output_scale_factor=1.0,
541
+ add_upsample=True,
542
+ dual_cross_attention=False,
543
+ use_linear_projection=False,
544
+ only_cross_attention=False,
545
+ upcast_attention=False,
546
+
547
+ unet_use_cross_frame_attention=False,
548
+ unet_use_temporal_attention=False,
549
+ use_inflated_groupnorm=False,
550
+
551
+ use_motion_module=None,
552
+
553
+ motion_module_type=None,
554
+ motion_module_kwargs=None,
555
+ ):
556
+ super().__init__()
557
+ resnets = []
558
+ attentions = []
559
+ motion_modules = []
560
+
561
+ self.has_cross_attention = True
562
+ self.attn_num_head_channels = attn_num_head_channels
563
+
564
+ for i in range(num_layers):
565
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
566
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
567
+
568
+ resnets.append(
569
+ ResnetBlock3D(
570
+ in_channels=resnet_in_channels + res_skip_channels,
571
+ out_channels=out_channels,
572
+ temb_channels=temb_channels,
573
+ eps=resnet_eps,
574
+ groups=resnet_groups,
575
+ dropout=dropout,
576
+ time_embedding_norm=resnet_time_scale_shift,
577
+ non_linearity=resnet_act_fn,
578
+ output_scale_factor=output_scale_factor,
579
+ pre_norm=resnet_pre_norm,
580
+
581
+ use_inflated_groupnorm=use_inflated_groupnorm,
582
+ )
583
+ )
584
+ if dual_cross_attention:
585
+ raise NotImplementedError
586
+ attentions.append(
587
+ Transformer3DModel(
588
+ attn_num_head_channels,
589
+ out_channels // attn_num_head_channels,
590
+ in_channels=out_channels,
591
+ num_layers=1,
592
+ cross_attention_dim=cross_attention_dim,
593
+ norm_num_groups=resnet_groups,
594
+ use_linear_projection=use_linear_projection,
595
+ only_cross_attention=only_cross_attention,
596
+ upcast_attention=upcast_attention,
597
+
598
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
599
+ unet_use_temporal_attention=unet_use_temporal_attention,
600
+ )
601
+ )
602
+ motion_modules.append(
603
+ get_motion_module(
604
+ in_channels=out_channels,
605
+ motion_module_type=motion_module_type,
606
+ motion_module_kwargs=motion_module_kwargs,
607
+ ) if use_motion_module else None
608
+ )
609
+
610
+ self.attentions = nn.ModuleList(attentions)
611
+ self.resnets = nn.ModuleList(resnets)
612
+ self.motion_modules = nn.ModuleList(motion_modules)
613
+
614
+ if add_upsample:
615
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
616
+ else:
617
+ self.upsamplers = None
618
+
619
+ self.gradient_checkpointing = False
620
+
621
+ def forward(
622
+ self,
623
+ hidden_states,
624
+ res_hidden_states_tuple,
625
+ temb=None,
626
+ encoder_hidden_states=None,
627
+ upsample_size=None,
628
+ attention_mask=None,
629
+ ):
630
+ for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
631
+ # pop res hidden states
632
+ res_hidden_states = res_hidden_states_tuple[-1]
633
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
634
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
635
+
636
+ if self.training and self.gradient_checkpointing:
637
+
638
+ def create_custom_forward(module, return_dict=None):
639
+ def custom_forward(*inputs):
640
+ if return_dict is not None:
641
+ return module(*inputs, return_dict=return_dict)
642
+ else:
643
+ return module(*inputs)
644
+
645
+ return custom_forward
646
+
647
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
648
+ hidden_states = torch.utils.checkpoint.checkpoint(
649
+ create_custom_forward(attn, return_dict=False),
650
+ hidden_states,
651
+ encoder_hidden_states,
652
+ )[0]
653
+ if motion_module is not None:
654
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
655
+
656
+ else:
657
+ hidden_states = resnet(hidden_states, temb)
658
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
659
+
660
+ # add motion module
661
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
662
+
663
+ if self.upsamplers is not None:
664
+ for upsampler in self.upsamplers:
665
+ hidden_states = upsampler(hidden_states, upsample_size)
666
+
667
+ return hidden_states
668
+
669
+
670
+ class UpBlock3D(nn.Module):
671
+ def __init__(
672
+ self,
673
+ in_channels: int,
674
+ prev_output_channel: int,
675
+ out_channels: int,
676
+ temb_channels: int,
677
+ dropout: float = 0.0,
678
+ num_layers: int = 1,
679
+ resnet_eps: float = 1e-6,
680
+ resnet_time_scale_shift: str = "default",
681
+ resnet_act_fn: str = "swish",
682
+ resnet_groups: int = 32,
683
+ resnet_pre_norm: bool = True,
684
+ output_scale_factor=1.0,
685
+ add_upsample=True,
686
+
687
+ use_inflated_groupnorm=False,
688
+
689
+ use_motion_module=None,
690
+ motion_module_type=None,
691
+ motion_module_kwargs=None,
692
+ ):
693
+ super().__init__()
694
+ resnets = []
695
+ motion_modules = []
696
+
697
+ for i in range(num_layers):
698
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
699
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
700
+
701
+ resnets.append(
702
+ ResnetBlock3D(
703
+ in_channels=resnet_in_channels + res_skip_channels,
704
+ out_channels=out_channels,
705
+ temb_channels=temb_channels,
706
+ eps=resnet_eps,
707
+ groups=resnet_groups,
708
+ dropout=dropout,
709
+ time_embedding_norm=resnet_time_scale_shift,
710
+ non_linearity=resnet_act_fn,
711
+ output_scale_factor=output_scale_factor,
712
+ pre_norm=resnet_pre_norm,
713
+
714
+ use_inflated_groupnorm=use_inflated_groupnorm,
715
+ )
716
+ )
717
+ motion_modules.append(
718
+ get_motion_module(
719
+ in_channels=out_channels,
720
+ motion_module_type=motion_module_type,
721
+ motion_module_kwargs=motion_module_kwargs,
722
+ ) if use_motion_module else None
723
+ )
724
+
725
+ self.resnets = nn.ModuleList(resnets)
726
+ self.motion_modules = nn.ModuleList(motion_modules)
727
+
728
+ if add_upsample:
729
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
730
+ else:
731
+ self.upsamplers = None
732
+
733
+ self.gradient_checkpointing = False
734
+
735
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, encoder_hidden_states=None,):
736
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
737
+ # pop res hidden states
738
+ res_hidden_states = res_hidden_states_tuple[-1]
739
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
740
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
741
+
742
+ if self.training and self.gradient_checkpointing:
743
+ def create_custom_forward(module):
744
+ def custom_forward(*inputs):
745
+ return module(*inputs)
746
+
747
+ return custom_forward
748
+
749
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
750
+ if motion_module is not None:
751
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
752
+ else:
753
+ hidden_states = resnet(hidden_states, temb)
754
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
755
+
756
+ if self.upsamplers is not None:
757
+ for upsampler in self.upsamplers:
758
+ hidden_states = upsampler(hidden_states, upsample_size)
759
+
760
+ return hidden_states
motionclone/pipelines/__pycache__/pipeline_animation.cpython-310.pyc ADDED
Binary file (13.7 kB). View file
 
motionclone/pipelines/__pycache__/pipeline_animation.cpython-38.pyc ADDED
Binary file (13.4 kB). View file