John6666 commited on
Commit
4058ef5
·
verified ·
1 Parent(s): ea1aefb

Upload 45 files

Browse files
LICENSE CHANGED
@@ -1,28 +1,28 @@
1
- BSD 3-Clause License
2
-
3
- Copyright 2023 MagicAnimate Team All rights reserved.
4
-
5
- Redistribution and use in source and binary forms, with or without
6
- modification, are permitted provided that the following conditions are met:
7
-
8
- 1. Redistributions of source code must retain the above copyright notice, this
9
- list of conditions and the following disclaimer.
10
-
11
- 2. Redistributions in binary form must reproduce the above copyright notice,
12
- this list of conditions and the following disclaimer in the documentation
13
- and/or other materials provided with the distribution.
14
-
15
- 3. Neither the name of the copyright holder nor the names of its
16
- contributors may be used to endorse or promote products derived from
17
- this software without specific prior written permission.
18
-
19
- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
- AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
- IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
- DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
- FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
- DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
- SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
- CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
- OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
  OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright 2023 MagicAnimate Team All rights reserved.
4
+
5
+ Redistribution and use in source and binary forms, with or without
6
+ modification, are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this
9
+ list of conditions and the following disclaimer.
10
+
11
+ 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ this list of conditions and the following disclaimer in the documentation
13
+ and/or other materials provided with the distribution.
14
+
15
+ 3. Neither the name of the copyright holder nor the names of its
16
+ contributors may be used to endorse or promote products derived from
17
+ this software without specific prior written permission.
18
+
19
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
  OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
app.py CHANGED
@@ -8,18 +8,19 @@
8
  # disclosure or distribution of this material and related documentation
9
  # without an express license agreement from ByteDance or
10
  # its affiliates is strictly prohibited.
11
- import argparse
12
  import imageio
13
  import numpy as np
14
  import gradio as gr
15
  import os
16
  from PIL import Image
17
- from subprocess import PIPE, run
18
 
19
  from demo.animate import MagicAnimate
20
 
21
  from huggingface_hub import snapshot_download
22
 
 
 
 
23
  snapshot_download(repo_id="stable-diffusion-v1-5/stable-diffusion-v1-5", local_dir="./stable-diffusion-v1-5", ignore_patterns=["*.safetensors"])
24
  snapshot_download(repo_id="stabilityai/sd-vae-ft-mse", local_dir="./sd-vae-ft-mse")
25
  snapshot_download(repo_id="zcxu-eric/MagicAnimate", local_dir="./MagicAnimate")
 
8
  # disclosure or distribution of this material and related documentation
9
  # without an express license agreement from ByteDance or
10
  # its affiliates is strictly prohibited.
 
11
  import imageio
12
  import numpy as np
13
  import gradio as gr
14
  import os
15
  from PIL import Image
 
16
 
17
  from demo.animate import MagicAnimate
18
 
19
  from huggingface_hub import snapshot_download
20
 
21
+ import subprocess
22
+ subprocess.run('pip cache purge', shell=True)
23
+
24
  snapshot_download(repo_id="stable-diffusion-v1-5/stable-diffusion-v1-5", local_dir="./stable-diffusion-v1-5", ignore_patterns=["*.safetensors"])
25
  snapshot_download(repo_id="stabilityai/sd-vae-ft-mse", local_dir="./sd-vae-ft-mse")
26
  snapshot_download(repo_id="zcxu-eric/MagicAnimate", local_dir="./MagicAnimate")
configs/inference/inference.yaml CHANGED
@@ -1,26 +1,26 @@
1
- unet_additional_kwargs:
2
- unet_use_cross_frame_attention: false
3
- unet_use_temporal_attention: false
4
- use_motion_module: true
5
- motion_module_resolutions:
6
- - 1
7
- - 2
8
- - 4
9
- - 8
10
- motion_module_mid_block: false
11
- motion_module_decoder_only: false
12
- motion_module_type: Vanilla
13
- motion_module_kwargs:
14
- num_attention_heads: 8
15
- num_transformer_block: 1
16
- attention_block_types:
17
- - Temporal_Self
18
- - Temporal_Self
19
- temporal_position_encoding: true
20
- temporal_position_encoding_max_len: 24
21
- temporal_attention_dim_div: 1
22
-
23
- noise_scheduler_kwargs:
24
- beta_start: 0.00085
25
- beta_end: 0.012
26
- beta_schedule: "linear"
 
1
+ unet_additional_kwargs:
2
+ unet_use_cross_frame_attention: false
3
+ unet_use_temporal_attention: false
4
+ use_motion_module: true
5
+ motion_module_resolutions:
6
+ - 1
7
+ - 2
8
+ - 4
9
+ - 8
10
+ motion_module_mid_block: false
11
+ motion_module_decoder_only: false
12
+ motion_module_type: Vanilla
13
+ motion_module_kwargs:
14
+ num_attention_heads: 8
15
+ num_transformer_block: 1
16
+ attention_block_types:
17
+ - Temporal_Self
18
+ - Temporal_Self
19
+ temporal_position_encoding: true
20
+ temporal_position_encoding_max_len: 24
21
+ temporal_attention_dim_div: 1
22
+
23
+ noise_scheduler_kwargs:
24
+ beta_start: 0.00085
25
+ beta_end: 0.012
26
+ beta_schedule: "linear"
configs/prompts/animation.yaml CHANGED
@@ -1,40 +1,40 @@
1
- pretrained_model_path: "stable-diffusion-v1-5"
2
- pretrained_vae_path: "sd-vae-ft-mse"
3
- pretrained_controlnet_path: "MagicAnimate/densepose_controlnet"
4
- pretrained_appearance_encoder_path: "MagicAnimate/appearance_encoder"
5
- pretrained_unet_path: ""
6
-
7
- motion_module: "MagicAnimate/temporal_attention/temporal_attention.ckpt"
8
-
9
- savename: null
10
-
11
- fusion_blocks: "midup"
12
-
13
- seed: [1]
14
- steps: 25
15
- guidance_scale: 7.5
16
-
17
- source_image:
18
- - "inputs/applications/source_image/monalisa.png"
19
- - "inputs/applications/source_image/demo4.png"
20
- - "inputs/applications/source_image/dalle2.jpeg"
21
- - "inputs/applications/source_image/dalle8.jpeg"
22
- - "inputs/applications/source_image/multi1_source.png"
23
- video_path:
24
- - "inputs/applications/driving/densepose/running.mp4"
25
- - "inputs/applications/driving/densepose/demo4.mp4"
26
- - "inputs/applications/driving/densepose/running2.mp4"
27
- - "inputs/applications/driving/densepose/dancing2.mp4"
28
- - "inputs/applications/driving/densepose/multi_dancing.mp4"
29
-
30
- inference_config: "configs/inference/inference.yaml"
31
- size: 512
32
- L: 16
33
- S: 1
34
- I: 0
35
- clip: 0
36
- offset: 0
37
- max_length: null
38
- video_type: "condition"
39
- invert_video: false
40
- save_individual_videos: false
 
1
+ pretrained_model_path: "stable-diffusion-v1-5"
2
+ pretrained_vae_path: "sd-vae-ft-mse"
3
+ pretrained_controlnet_path: "MagicAnimate/densepose_controlnet"
4
+ pretrained_appearance_encoder_path: "MagicAnimate/appearance_encoder"
5
+ pretrained_unet_path: ""
6
+
7
+ motion_module: "MagicAnimate/temporal_attention/temporal_attention.ckpt"
8
+
9
+ savename: null
10
+
11
+ fusion_blocks: "midup"
12
+
13
+ seed: [1]
14
+ steps: 25
15
+ guidance_scale: 7.5
16
+
17
+ source_image:
18
+ - "inputs/applications/source_image/monalisa.png"
19
+ - "inputs/applications/source_image/demo4.png"
20
+ - "inputs/applications/source_image/dalle2.jpeg"
21
+ - "inputs/applications/source_image/dalle8.jpeg"
22
+ - "inputs/applications/source_image/multi1_source.png"
23
+ video_path:
24
+ - "inputs/applications/driving/densepose/running.mp4"
25
+ - "inputs/applications/driving/densepose/demo4.mp4"
26
+ - "inputs/applications/driving/densepose/running2.mp4"
27
+ - "inputs/applications/driving/densepose/dancing2.mp4"
28
+ - "inputs/applications/driving/densepose/multi_dancing.mp4"
29
+
30
+ inference_config: "configs/inference/inference.yaml"
31
+ size: 512
32
+ L: 16
33
+ S: 1
34
+ I: 0
35
+ clip: 0
36
+ offset: 0
37
+ max_length: null
38
+ video_type: "condition"
39
+ invert_video: false
40
+ save_individual_videos: false
demo/animate.py CHANGED
@@ -62,7 +62,6 @@ class MagicAnimate():
62
  vae = AutoencoderKL.from_pretrained(config.pretrained_vae_path)
63
  else:
64
  vae = AutoencoderKL.from_pretrained(config.pretrained_model_path, subfolder="vae")
65
-
66
  ### Load controlnet
67
  controlnet = ControlNetModel.from_pretrained(config.pretrained_controlnet_path)
68
 
@@ -84,7 +83,7 @@ class MagicAnimate():
84
 
85
  # 1. unet ckpt
86
  # 1.1 motion module
87
- motion_module_state_dict = torch.load(motion_module, map_location="cpu")
88
  if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]})
89
  motion_module_state_dict = motion_module_state_dict['state_dict'] if 'state_dict' in motion_module_state_dict else motion_module_state_dict
90
  try:
 
62
  vae = AutoencoderKL.from_pretrained(config.pretrained_vae_path)
63
  else:
64
  vae = AutoencoderKL.from_pretrained(config.pretrained_model_path, subfolder="vae")
 
65
  ### Load controlnet
66
  controlnet = ControlNetModel.from_pretrained(config.pretrained_controlnet_path)
67
 
 
83
 
84
  # 1. unet ckpt
85
  # 1.1 motion module
86
+ motion_module_state_dict = torch.load(motion_module, map_location="cpu", weights_only=True)
87
  if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]})
88
  motion_module_state_dict = motion_module_state_dict['state_dict'] if 'state_dict' in motion_module_state_dict else motion_module_state_dict
89
  try:
magicanimate/models/appearance_encoder.py CHANGED
The diff for this file is too large to render. See raw diff
 
magicanimate/models/attention.py CHANGED
@@ -1,320 +1,320 @@
1
- # *************************************************************************
2
- # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
- # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
- # ytedance Inc..
5
- # *************************************************************************
6
-
7
- # Copyright 2023 The HuggingFace Team. All rights reserved.
8
- #
9
- # Licensed under the Apache License, Version 2.0 (the "License");
10
- # you may not use this file except in compliance with the License.
11
- # You may obtain a copy of the License at
12
- #
13
- # http://www.apache.org/licenses/LICENSE-2.0
14
- #
15
- # Unless required by applicable law or agreed to in writing, software
16
- # distributed under the License is distributed on an "AS IS" BASIS,
17
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
- # See the License for the specific language governing permissions and
19
- # limitations under the License.
20
- from dataclasses import dataclass
21
- from typing import Optional
22
-
23
- import torch
24
- import torch.nn.functional as F
25
- from torch import nn
26
-
27
- from diffusers.configuration_utils import ConfigMixin, register_to_config
28
- from diffusers.models.modeling_utils import ModelMixin
29
- from diffusers.utils import BaseOutput
30
- from diffusers.utils.import_utils import is_xformers_available
31
- from diffusers.models.attention import FeedForward, AdaLayerNorm
32
- from diffusers.models.attention import Attention as CrossAttention
33
-
34
- from einops import rearrange, repeat
35
-
36
- @dataclass
37
- class Transformer3DModelOutput(BaseOutput):
38
- sample: torch.FloatTensor
39
-
40
-
41
- if is_xformers_available():
42
- import xformers
43
- import xformers.ops
44
- else:
45
- xformers = None
46
-
47
-
48
- class Transformer3DModel(ModelMixin, ConfigMixin):
49
- @register_to_config
50
- def __init__(
51
- self,
52
- num_attention_heads: int = 16,
53
- attention_head_dim: int = 88,
54
- in_channels: Optional[int] = None,
55
- num_layers: int = 1,
56
- dropout: float = 0.0,
57
- norm_num_groups: int = 32,
58
- cross_attention_dim: Optional[int] = None,
59
- attention_bias: bool = False,
60
- activation_fn: str = "geglu",
61
- num_embeds_ada_norm: Optional[int] = None,
62
- use_linear_projection: bool = False,
63
- only_cross_attention: bool = False,
64
- upcast_attention: bool = False,
65
-
66
- unet_use_cross_frame_attention=None,
67
- unet_use_temporal_attention=None,
68
- ):
69
- super().__init__()
70
- self.use_linear_projection = use_linear_projection
71
- self.num_attention_heads = num_attention_heads
72
- self.attention_head_dim = attention_head_dim
73
- inner_dim = num_attention_heads * attention_head_dim
74
-
75
- # Define input layers
76
- self.in_channels = in_channels
77
-
78
- self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
79
- if use_linear_projection:
80
- self.proj_in = nn.Linear(in_channels, inner_dim)
81
- else:
82
- self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
83
-
84
- # Define transformers blocks
85
- self.transformer_blocks = nn.ModuleList(
86
- [
87
- BasicTransformerBlock(
88
- inner_dim,
89
- num_attention_heads,
90
- attention_head_dim,
91
- dropout=dropout,
92
- cross_attention_dim=cross_attention_dim,
93
- activation_fn=activation_fn,
94
- num_embeds_ada_norm=num_embeds_ada_norm,
95
- attention_bias=attention_bias,
96
- only_cross_attention=only_cross_attention,
97
- upcast_attention=upcast_attention,
98
-
99
- unet_use_cross_frame_attention=unet_use_cross_frame_attention,
100
- unet_use_temporal_attention=unet_use_temporal_attention,
101
- )
102
- for d in range(num_layers)
103
- ]
104
- )
105
-
106
- # 4. Define output layers
107
- if use_linear_projection:
108
- self.proj_out = nn.Linear(in_channels, inner_dim)
109
- else:
110
- self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
111
-
112
- def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
113
- # Input
114
- assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
115
- video_length = hidden_states.shape[2]
116
- hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
117
- # JH: need not repeat when a list of prompts are given
118
- if encoder_hidden_states.shape[0] != hidden_states.shape[0]:
119
- encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
120
-
121
- batch, channel, height, weight = hidden_states.shape
122
- residual = hidden_states
123
-
124
- hidden_states = self.norm(hidden_states)
125
- if not self.use_linear_projection:
126
- hidden_states = self.proj_in(hidden_states)
127
- inner_dim = hidden_states.shape[1]
128
- hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
129
- else:
130
- inner_dim = hidden_states.shape[1]
131
- hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
132
- hidden_states = self.proj_in(hidden_states)
133
-
134
- # Blocks
135
- for block in self.transformer_blocks:
136
- hidden_states = block(
137
- hidden_states,
138
- encoder_hidden_states=encoder_hidden_states,
139
- timestep=timestep,
140
- video_length=video_length
141
- )
142
-
143
- # Output
144
- if not self.use_linear_projection:
145
- hidden_states = (
146
- hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
147
- )
148
- hidden_states = self.proj_out(hidden_states)
149
- else:
150
- hidden_states = self.proj_out(hidden_states)
151
- hidden_states = (
152
- hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
153
- )
154
-
155
- output = hidden_states + residual
156
-
157
- output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
158
- if not return_dict:
159
- return (output,)
160
-
161
- return Transformer3DModelOutput(sample=output)
162
-
163
-
164
- class BasicTransformerBlock(nn.Module):
165
- def __init__(
166
- self,
167
- dim: int,
168
- num_attention_heads: int,
169
- attention_head_dim: int,
170
- dropout=0.0,
171
- cross_attention_dim: Optional[int] = None,
172
- activation_fn: str = "geglu",
173
- num_embeds_ada_norm: Optional[int] = None,
174
- attention_bias: bool = False,
175
- only_cross_attention: bool = False,
176
- upcast_attention: bool = False,
177
-
178
- unet_use_cross_frame_attention = None,
179
- unet_use_temporal_attention = None,
180
- ):
181
- super().__init__()
182
- self.only_cross_attention = only_cross_attention
183
- self.use_ada_layer_norm = num_embeds_ada_norm is not None
184
- self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
185
- self.unet_use_temporal_attention = unet_use_temporal_attention
186
-
187
- # SC-Attn
188
- assert unet_use_cross_frame_attention is not None
189
- if unet_use_cross_frame_attention:
190
- self.attn1 = SparseCausalAttention2D(
191
- query_dim=dim,
192
- heads=num_attention_heads,
193
- dim_head=attention_head_dim,
194
- dropout=dropout,
195
- bias=attention_bias,
196
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
197
- upcast_attention=upcast_attention,
198
- )
199
- else:
200
- self.attn1 = CrossAttention(
201
- query_dim=dim,
202
- heads=num_attention_heads,
203
- dim_head=attention_head_dim,
204
- dropout=dropout,
205
- bias=attention_bias,
206
- upcast_attention=upcast_attention,
207
- )
208
- self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
209
-
210
- # Cross-Attn
211
- if cross_attention_dim is not None:
212
- self.attn2 = CrossAttention(
213
- query_dim=dim,
214
- cross_attention_dim=cross_attention_dim,
215
- heads=num_attention_heads,
216
- dim_head=attention_head_dim,
217
- dropout=dropout,
218
- bias=attention_bias,
219
- upcast_attention=upcast_attention,
220
- )
221
- else:
222
- self.attn2 = None
223
-
224
- if cross_attention_dim is not None:
225
- self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
226
- else:
227
- self.norm2 = None
228
-
229
- # Feed-forward
230
- self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
231
- self.norm3 = nn.LayerNorm(dim)
232
- self.use_ada_layer_norm_zero = False
233
-
234
- # Temp-Attn
235
- assert unet_use_temporal_attention is not None
236
- if unet_use_temporal_attention:
237
- self.attn_temp = CrossAttention(
238
- query_dim=dim,
239
- heads=num_attention_heads,
240
- dim_head=attention_head_dim,
241
- dropout=dropout,
242
- bias=attention_bias,
243
- upcast_attention=upcast_attention,
244
- )
245
- nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
246
- self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
247
-
248
- def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, *args, **kwargs):
249
- if not is_xformers_available():
250
- print("Here is how to install it")
251
- raise ModuleNotFoundError(
252
- "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
253
- " xformers",
254
- name="xformers",
255
- )
256
- elif not torch.cuda.is_available():
257
- raise ValueError(
258
- "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
259
- " available for GPU "
260
- )
261
- else:
262
- try:
263
- # Make sure we can run the memory efficient attention
264
- _ = xformers.ops.memory_efficient_attention(
265
- torch.randn((1, 2, 40), device="cuda"),
266
- torch.randn((1, 2, 40), device="cuda"),
267
- torch.randn((1, 2, 40), device="cuda"),
268
- )
269
- except Exception as e:
270
- raise e
271
- self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
272
- if self.attn2 is not None:
273
- self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
274
- # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
275
-
276
- def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None):
277
- # SparseCausal-Attention
278
- norm_hidden_states = (
279
- self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
280
- )
281
-
282
- # if self.only_cross_attention:
283
- # hidden_states = (
284
- # self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
285
- # )
286
- # else:
287
- # hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
288
-
289
- # pdb.set_trace()
290
- if self.unet_use_cross_frame_attention:
291
- hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
292
- else:
293
- hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
294
-
295
- if self.attn2 is not None:
296
- # Cross-Attention
297
- norm_hidden_states = (
298
- self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
299
- )
300
- hidden_states = (
301
- self.attn2(
302
- norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
303
- )
304
- + hidden_states
305
- )
306
-
307
- # Feed-forward
308
- hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
309
-
310
- # Temporal-Attention
311
- if self.unet_use_temporal_attention:
312
- d = hidden_states.shape[1]
313
- hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
314
- norm_hidden_states = (
315
- self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
316
- )
317
- hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
318
- hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
319
-
320
- return hidden_states
 
1
+ # *************************************************************************
2
+ # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
+ # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
+ # ytedance Inc..
5
+ # *************************************************************************
6
+
7
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ from dataclasses import dataclass
21
+ from typing import Optional
22
+
23
+ import torch
24
+ import torch.nn.functional as F
25
+ from torch import nn
26
+
27
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
28
+ from diffusers.models.modeling_utils import ModelMixin
29
+ from diffusers.utils import BaseOutput
30
+ from diffusers.utils.import_utils import is_xformers_available
31
+ from diffusers.models.attention import FeedForward, AdaLayerNorm
32
+ from diffusers.models.attention import Attention as CrossAttention
33
+
34
+ from einops import rearrange, repeat
35
+
36
+ @dataclass
37
+ class Transformer3DModelOutput(BaseOutput):
38
+ sample: torch.FloatTensor
39
+
40
+
41
+ if is_xformers_available():
42
+ import xformers
43
+ import xformers.ops
44
+ else:
45
+ xformers = None
46
+
47
+
48
+ class Transformer3DModel(ModelMixin, ConfigMixin):
49
+ @register_to_config
50
+ def __init__(
51
+ self,
52
+ num_attention_heads: int = 16,
53
+ attention_head_dim: int = 88,
54
+ in_channels: Optional[int] = None,
55
+ num_layers: int = 1,
56
+ dropout: float = 0.0,
57
+ norm_num_groups: int = 32,
58
+ cross_attention_dim: Optional[int] = None,
59
+ attention_bias: bool = False,
60
+ activation_fn: str = "geglu",
61
+ num_embeds_ada_norm: Optional[int] = None,
62
+ use_linear_projection: bool = False,
63
+ only_cross_attention: bool = False,
64
+ upcast_attention: bool = False,
65
+
66
+ unet_use_cross_frame_attention=None,
67
+ unet_use_temporal_attention=None,
68
+ ):
69
+ super().__init__()
70
+ self.use_linear_projection = use_linear_projection
71
+ self.num_attention_heads = num_attention_heads
72
+ self.attention_head_dim = attention_head_dim
73
+ inner_dim = num_attention_heads * attention_head_dim
74
+
75
+ # Define input layers
76
+ self.in_channels = in_channels
77
+
78
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
79
+ if use_linear_projection:
80
+ self.proj_in = nn.Linear(in_channels, inner_dim)
81
+ else:
82
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
83
+
84
+ # Define transformers blocks
85
+ self.transformer_blocks = nn.ModuleList(
86
+ [
87
+ BasicTransformerBlock(
88
+ inner_dim,
89
+ num_attention_heads,
90
+ attention_head_dim,
91
+ dropout=dropout,
92
+ cross_attention_dim=cross_attention_dim,
93
+ activation_fn=activation_fn,
94
+ num_embeds_ada_norm=num_embeds_ada_norm,
95
+ attention_bias=attention_bias,
96
+ only_cross_attention=only_cross_attention,
97
+ upcast_attention=upcast_attention,
98
+
99
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
100
+ unet_use_temporal_attention=unet_use_temporal_attention,
101
+ )
102
+ for d in range(num_layers)
103
+ ]
104
+ )
105
+
106
+ # 4. Define output layers
107
+ if use_linear_projection:
108
+ self.proj_out = nn.Linear(in_channels, inner_dim)
109
+ else:
110
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
111
+
112
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
113
+ # Input
114
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
115
+ video_length = hidden_states.shape[2]
116
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
117
+ # JH: need not repeat when a list of prompts are given
118
+ if encoder_hidden_states.shape[0] != hidden_states.shape[0]:
119
+ encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
120
+
121
+ batch, channel, height, weight = hidden_states.shape
122
+ residual = hidden_states
123
+
124
+ hidden_states = self.norm(hidden_states)
125
+ if not self.use_linear_projection:
126
+ hidden_states = self.proj_in(hidden_states)
127
+ inner_dim = hidden_states.shape[1]
128
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
129
+ else:
130
+ inner_dim = hidden_states.shape[1]
131
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
132
+ hidden_states = self.proj_in(hidden_states)
133
+
134
+ # Blocks
135
+ for block in self.transformer_blocks:
136
+ hidden_states = block(
137
+ hidden_states,
138
+ encoder_hidden_states=encoder_hidden_states,
139
+ timestep=timestep,
140
+ video_length=video_length
141
+ )
142
+
143
+ # Output
144
+ if not self.use_linear_projection:
145
+ hidden_states = (
146
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
147
+ )
148
+ hidden_states = self.proj_out(hidden_states)
149
+ else:
150
+ hidden_states = self.proj_out(hidden_states)
151
+ hidden_states = (
152
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
153
+ )
154
+
155
+ output = hidden_states + residual
156
+
157
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
158
+ if not return_dict:
159
+ return (output,)
160
+
161
+ return Transformer3DModelOutput(sample=output)
162
+
163
+
164
+ class BasicTransformerBlock(nn.Module):
165
+ def __init__(
166
+ self,
167
+ dim: int,
168
+ num_attention_heads: int,
169
+ attention_head_dim: int,
170
+ dropout=0.0,
171
+ cross_attention_dim: Optional[int] = None,
172
+ activation_fn: str = "geglu",
173
+ num_embeds_ada_norm: Optional[int] = None,
174
+ attention_bias: bool = False,
175
+ only_cross_attention: bool = False,
176
+ upcast_attention: bool = False,
177
+
178
+ unet_use_cross_frame_attention = None,
179
+ unet_use_temporal_attention = None,
180
+ ):
181
+ super().__init__()
182
+ self.only_cross_attention = only_cross_attention
183
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
184
+ self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
185
+ self.unet_use_temporal_attention = unet_use_temporal_attention
186
+
187
+ # SC-Attn
188
+ assert unet_use_cross_frame_attention is not None
189
+ if unet_use_cross_frame_attention:
190
+ self.attn1 = SparseCausalAttention2D(
191
+ query_dim=dim,
192
+ heads=num_attention_heads,
193
+ dim_head=attention_head_dim,
194
+ dropout=dropout,
195
+ bias=attention_bias,
196
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
197
+ upcast_attention=upcast_attention,
198
+ )
199
+ else:
200
+ self.attn1 = CrossAttention(
201
+ query_dim=dim,
202
+ heads=num_attention_heads,
203
+ dim_head=attention_head_dim,
204
+ dropout=dropout,
205
+ bias=attention_bias,
206
+ upcast_attention=upcast_attention,
207
+ )
208
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
209
+
210
+ # Cross-Attn
211
+ if cross_attention_dim is not None:
212
+ self.attn2 = CrossAttention(
213
+ query_dim=dim,
214
+ cross_attention_dim=cross_attention_dim,
215
+ heads=num_attention_heads,
216
+ dim_head=attention_head_dim,
217
+ dropout=dropout,
218
+ bias=attention_bias,
219
+ upcast_attention=upcast_attention,
220
+ )
221
+ else:
222
+ self.attn2 = None
223
+
224
+ if cross_attention_dim is not None:
225
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
226
+ else:
227
+ self.norm2 = None
228
+
229
+ # Feed-forward
230
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
231
+ self.norm3 = nn.LayerNorm(dim)
232
+ self.use_ada_layer_norm_zero = False
233
+
234
+ # Temp-Attn
235
+ assert unet_use_temporal_attention is not None
236
+ if unet_use_temporal_attention:
237
+ self.attn_temp = CrossAttention(
238
+ query_dim=dim,
239
+ heads=num_attention_heads,
240
+ dim_head=attention_head_dim,
241
+ dropout=dropout,
242
+ bias=attention_bias,
243
+ upcast_attention=upcast_attention,
244
+ )
245
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
246
+ self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
247
+
248
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, *args, **kwargs):
249
+ if not is_xformers_available():
250
+ print("Here is how to install it")
251
+ raise ModuleNotFoundError(
252
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
253
+ " xformers",
254
+ name="xformers",
255
+ )
256
+ elif not torch.cuda.is_available():
257
+ raise ValueError(
258
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
259
+ " available for GPU "
260
+ )
261
+ else:
262
+ try:
263
+ # Make sure we can run the memory efficient attention
264
+ _ = xformers.ops.memory_efficient_attention(
265
+ torch.randn((1, 2, 40), device="cuda"),
266
+ torch.randn((1, 2, 40), device="cuda"),
267
+ torch.randn((1, 2, 40), device="cuda"),
268
+ )
269
+ except Exception as e:
270
+ raise e
271
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
272
+ if self.attn2 is not None:
273
+ self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
274
+ # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
275
+
276
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None):
277
+ # SparseCausal-Attention
278
+ norm_hidden_states = (
279
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
280
+ )
281
+
282
+ # if self.only_cross_attention:
283
+ # hidden_states = (
284
+ # self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
285
+ # )
286
+ # else:
287
+ # hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
288
+
289
+ # pdb.set_trace()
290
+ if self.unet_use_cross_frame_attention:
291
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
292
+ else:
293
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
294
+
295
+ if self.attn2 is not None:
296
+ # Cross-Attention
297
+ norm_hidden_states = (
298
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
299
+ )
300
+ hidden_states = (
301
+ self.attn2(
302
+ norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
303
+ )
304
+ + hidden_states
305
+ )
306
+
307
+ # Feed-forward
308
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
309
+
310
+ # Temporal-Attention
311
+ if self.unet_use_temporal_attention:
312
+ d = hidden_states.shape[1]
313
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
314
+ norm_hidden_states = (
315
+ self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
316
+ )
317
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
318
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
319
+
320
+ return hidden_states
magicanimate/models/controlnet.py CHANGED
@@ -1,578 +1,578 @@
1
- # *************************************************************************
2
- # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
- # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
- # ytedance Inc..
5
- # *************************************************************************
6
-
7
- # Copyright 2023 The HuggingFace Team. All rights reserved.
8
- #
9
- # Licensed under the Apache License, Version 2.0 (the "License");
10
- # you may not use this file except in compliance with the License.
11
- # You may obtain a copy of the License at
12
- #
13
- # http://www.apache.org/licenses/LICENSE-2.0
14
- #
15
- # Unless required by applicable law or agreed to in writing, software
16
- # distributed under the License is distributed on an "AS IS" BASIS,
17
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
- # See the License for the specific language governing permissions and
19
- # limitations under the License.
20
- from dataclasses import dataclass
21
- from typing import Any, Dict, List, Optional, Tuple, Union
22
-
23
- import torch
24
- from torch import nn
25
- from torch.nn import functional as F
26
-
27
- from diffusers.configuration_utils import ConfigMixin, register_to_config
28
- from diffusers.utils import BaseOutput, logging
29
- from .embeddings import TimestepEmbedding, Timesteps
30
- from diffusers.models.modeling_utils import ModelMixin
31
- from diffusers.models.unet_2d_blocks import (
32
- CrossAttnDownBlock2D,
33
- DownBlock2D,
34
- UNetMidBlock2DCrossAttn,
35
- get_down_block,
36
- )
37
- from diffusers.models.unet_2d_condition import UNet2DConditionModel
38
-
39
-
40
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
-
42
-
43
- @dataclass
44
- class ControlNetOutput(BaseOutput):
45
- down_block_res_samples: Tuple[torch.Tensor]
46
- mid_block_res_sample: torch.Tensor
47
-
48
-
49
- class ControlNetConditioningEmbedding(nn.Module):
50
- """
51
- Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
52
- [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
53
- training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
54
- convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
55
- (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
56
- model) to encode image-space conditions ... into feature maps ..."
57
- """
58
-
59
- def __init__(
60
- self,
61
- conditioning_embedding_channels: int,
62
- conditioning_channels: int = 3,
63
- block_out_channels: Tuple[int] = (16, 32, 96, 256),
64
- ):
65
- super().__init__()
66
-
67
- self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
68
-
69
- self.blocks = nn.ModuleList([])
70
-
71
- for i in range(len(block_out_channels) - 1):
72
- channel_in = block_out_channels[i]
73
- channel_out = block_out_channels[i + 1]
74
- self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
75
- self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
76
-
77
- self.conv_out = zero_module(
78
- nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
79
- )
80
-
81
- def forward(self, conditioning):
82
- embedding = self.conv_in(conditioning)
83
- embedding = F.silu(embedding)
84
-
85
- for block in self.blocks:
86
- embedding = block(embedding)
87
- embedding = F.silu(embedding)
88
-
89
- embedding = self.conv_out(embedding)
90
-
91
- return embedding
92
-
93
-
94
- class ControlNetModel(ModelMixin, ConfigMixin):
95
- _supports_gradient_checkpointing = True
96
-
97
- @register_to_config
98
- def __init__(
99
- self,
100
- in_channels: int = 4,
101
- flip_sin_to_cos: bool = True,
102
- freq_shift: int = 0,
103
- down_block_types: Tuple[str] = (
104
- "CrossAttnDownBlock2D",
105
- "CrossAttnDownBlock2D",
106
- "CrossAttnDownBlock2D",
107
- "DownBlock2D",
108
- ),
109
- only_cross_attention: Union[bool, Tuple[bool]] = False,
110
- block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
111
- layers_per_block: int = 2,
112
- downsample_padding: int = 1,
113
- mid_block_scale_factor: float = 1,
114
- act_fn: str = "silu",
115
- norm_num_groups: Optional[int] = 32,
116
- norm_eps: float = 1e-5,
117
- cross_attention_dim: int = 1280,
118
- attention_head_dim: Union[int, Tuple[int]] = 8,
119
- use_linear_projection: bool = False,
120
- class_embed_type: Optional[str] = None,
121
- num_class_embeds: Optional[int] = None,
122
- upcast_attention: bool = False,
123
- resnet_time_scale_shift: str = "default",
124
- projection_class_embeddings_input_dim: Optional[int] = None,
125
- controlnet_conditioning_channel_order: str = "rgb",
126
- conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
127
- ):
128
- super().__init__()
129
-
130
- # Check inputs
131
- if len(block_out_channels) != len(down_block_types):
132
- raise ValueError(
133
- f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
134
- )
135
-
136
- if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
137
- raise ValueError(
138
- f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
139
- )
140
-
141
- if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
142
- raise ValueError(
143
- f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
144
- )
145
-
146
- # input
147
- conv_in_kernel = 3
148
- conv_in_padding = (conv_in_kernel - 1) // 2
149
- self.conv_in = nn.Conv2d(
150
- in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
151
- )
152
-
153
- # time
154
- time_embed_dim = block_out_channels[0] * 4
155
-
156
- self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
157
- timestep_input_dim = block_out_channels[0]
158
-
159
- self.time_embedding = TimestepEmbedding(
160
- timestep_input_dim,
161
- time_embed_dim,
162
- act_fn=act_fn,
163
- )
164
-
165
- # class embedding
166
- if class_embed_type is None and num_class_embeds is not None:
167
- self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
168
- elif class_embed_type == "timestep":
169
- self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
170
- elif class_embed_type == "identity":
171
- self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
172
- elif class_embed_type == "projection":
173
- if projection_class_embeddings_input_dim is None:
174
- raise ValueError(
175
- "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
176
- )
177
- # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
178
- # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
179
- # 2. it projects from an arbitrary input dimension.
180
- #
181
- # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
182
- # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
183
- # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
184
- self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
185
- else:
186
- self.class_embedding = None
187
-
188
- # control net conditioning embedding
189
- self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
190
- conditioning_embedding_channels=block_out_channels[0],
191
- block_out_channels=conditioning_embedding_out_channels,
192
- )
193
-
194
- self.down_blocks = nn.ModuleList([])
195
- self.controlnet_down_blocks = nn.ModuleList([])
196
-
197
- if isinstance(only_cross_attention, bool):
198
- only_cross_attention = [only_cross_attention] * len(down_block_types)
199
-
200
- if isinstance(attention_head_dim, int):
201
- attention_head_dim = (attention_head_dim,) * len(down_block_types)
202
-
203
- # down
204
- output_channel = block_out_channels[0]
205
-
206
- controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
207
- controlnet_block = zero_module(controlnet_block)
208
- self.controlnet_down_blocks.append(controlnet_block)
209
-
210
- for i, down_block_type in enumerate(down_block_types):
211
- input_channel = output_channel
212
- output_channel = block_out_channels[i]
213
- is_final_block = i == len(block_out_channels) - 1
214
-
215
- down_block = get_down_block(
216
- down_block_type,
217
- num_layers=layers_per_block,
218
- in_channels=input_channel,
219
- out_channels=output_channel,
220
- temb_channels=time_embed_dim,
221
- add_downsample=not is_final_block,
222
- resnet_eps=norm_eps,
223
- resnet_act_fn=act_fn,
224
- resnet_groups=norm_num_groups,
225
- cross_attention_dim=cross_attention_dim,
226
- num_attention_heads=attention_head_dim[i],
227
- downsample_padding=downsample_padding,
228
- use_linear_projection=use_linear_projection,
229
- only_cross_attention=only_cross_attention[i],
230
- upcast_attention=upcast_attention,
231
- resnet_time_scale_shift=resnet_time_scale_shift,
232
- )
233
- self.down_blocks.append(down_block)
234
-
235
- for _ in range(layers_per_block):
236
- controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
237
- controlnet_block = zero_module(controlnet_block)
238
- self.controlnet_down_blocks.append(controlnet_block)
239
-
240
- if not is_final_block:
241
- controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
242
- controlnet_block = zero_module(controlnet_block)
243
- self.controlnet_down_blocks.append(controlnet_block)
244
-
245
- # mid
246
- mid_block_channel = block_out_channels[-1]
247
-
248
- controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
249
- controlnet_block = zero_module(controlnet_block)
250
- self.controlnet_mid_block = controlnet_block
251
-
252
- self.mid_block = UNetMidBlock2DCrossAttn(
253
- in_channels=mid_block_channel,
254
- temb_channels=time_embed_dim,
255
- resnet_eps=norm_eps,
256
- resnet_act_fn=act_fn,
257
- output_scale_factor=mid_block_scale_factor,
258
- resnet_time_scale_shift=resnet_time_scale_shift,
259
- cross_attention_dim=cross_attention_dim,
260
- num_attention_heads=attention_head_dim[-1],
261
- resnet_groups=norm_num_groups,
262
- use_linear_projection=use_linear_projection,
263
- upcast_attention=upcast_attention,
264
- )
265
-
266
- @classmethod
267
- def from_unet(
268
- cls,
269
- unet: UNet2DConditionModel,
270
- controlnet_conditioning_channel_order: str = "rgb",
271
- conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
272
- load_weights_from_unet: bool = True,
273
- ):
274
- r"""
275
- Instantiate Controlnet class from UNet2DConditionModel.
276
-
277
- Parameters:
278
- unet (`UNet2DConditionModel`):
279
- UNet model which weights are copied to the ControlNet. Note that all configuration options are also
280
- copied where applicable.
281
- """
282
- controlnet = cls(
283
- in_channels=unet.config.in_channels,
284
- flip_sin_to_cos=unet.config.flip_sin_to_cos,
285
- freq_shift=unet.config.freq_shift,
286
- down_block_types=unet.config.down_block_types,
287
- only_cross_attention=unet.config.only_cross_attention,
288
- block_out_channels=unet.config.block_out_channels,
289
- layers_per_block=unet.config.layers_per_block,
290
- downsample_padding=unet.config.downsample_padding,
291
- mid_block_scale_factor=unet.config.mid_block_scale_factor,
292
- act_fn=unet.config.act_fn,
293
- norm_num_groups=unet.config.norm_num_groups,
294
- norm_eps=unet.config.norm_eps,
295
- cross_attention_dim=unet.config.cross_attention_dim,
296
- attention_head_dim=unet.config.attention_head_dim,
297
- use_linear_projection=unet.config.use_linear_projection,
298
- class_embed_type=unet.config.class_embed_type,
299
- num_class_embeds=unet.config.num_class_embeds,
300
- upcast_attention=unet.config.upcast_attention,
301
- resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
302
- projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
303
- controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
304
- conditioning_embedding_out_channels=conditioning_embedding_out_channels,
305
- )
306
-
307
- if load_weights_from_unet:
308
- controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
309
- controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
310
- controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
311
-
312
- if controlnet.class_embedding:
313
- controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
314
-
315
- controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
316
- controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
317
-
318
- return controlnet
319
-
320
- # @property
321
- # # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
322
- # def attn_processors(self) -> Dict[str, AttentionProcessor]:
323
- # r"""
324
- # Returns:
325
- # `dict` of attention processors: A dictionary containing all attention processors used in the model with
326
- # indexed by its weight name.
327
- # """
328
- # # set recursively
329
- # processors = {}
330
-
331
- # def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
332
- # if hasattr(module, "set_processor"):
333
- # processors[f"{name}.processor"] = module.processor
334
-
335
- # for sub_name, child in module.named_children():
336
- # fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
337
-
338
- # return processors
339
-
340
- # for name, module in self.named_children():
341
- # fn_recursive_add_processors(name, module, processors)
342
-
343
- # return processors
344
-
345
- # # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
346
- # def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
347
- # r"""
348
- # Parameters:
349
- # `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
350
- # The instantiated processor class or a dictionary of processor classes that will be set as the processor
351
- # of **all** `Attention` layers.
352
- # In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
353
-
354
- # """
355
- # count = len(self.attn_processors.keys())
356
-
357
- # if isinstance(processor, dict) and len(processor) != count:
358
- # raise ValueError(
359
- # f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
360
- # f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
361
- # )
362
-
363
- # def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
364
- # if hasattr(module, "set_processor"):
365
- # if not isinstance(processor, dict):
366
- # module.set_processor(processor)
367
- # else:
368
- # module.set_processor(processor.pop(f"{name}.processor"))
369
-
370
- # for sub_name, child in module.named_children():
371
- # fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
372
-
373
- # for name, module in self.named_children():
374
- # fn_recursive_attn_processor(name, module, processor)
375
-
376
- # # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
377
- # def set_default_attn_processor(self):
378
- # """
379
- # Disables custom attention processors and sets the default attention implementation.
380
- # """
381
- # self.set_attn_processor(AttnProcessor())
382
-
383
- # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
384
- def set_attention_slice(self, slice_size):
385
- r"""
386
- Enable sliced attention computation.
387
-
388
- When this option is enabled, the attention module will split the input tensor in slices, to compute attention
389
- in several steps. This is useful to save some memory in exchange for a small speed decrease.
390
-
391
- Args:
392
- slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
393
- When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
394
- `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
395
- provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
396
- must be a multiple of `slice_size`.
397
- """
398
- sliceable_head_dims = []
399
-
400
- def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
401
- if hasattr(module, "set_attention_slice"):
402
- sliceable_head_dims.append(module.sliceable_head_dim)
403
-
404
- for child in module.children():
405
- fn_recursive_retrieve_sliceable_dims(child)
406
-
407
- # retrieve number of attention layers
408
- for module in self.children():
409
- fn_recursive_retrieve_sliceable_dims(module)
410
-
411
- num_sliceable_layers = len(sliceable_head_dims)
412
-
413
- if slice_size == "auto":
414
- # half the attention head size is usually a good trade-off between
415
- # speed and memory
416
- slice_size = [dim // 2 for dim in sliceable_head_dims]
417
- elif slice_size == "max":
418
- # make smallest slice possible
419
- slice_size = num_sliceable_layers * [1]
420
-
421
- slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
422
-
423
- if len(slice_size) != len(sliceable_head_dims):
424
- raise ValueError(
425
- f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
426
- f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
427
- )
428
-
429
- for i in range(len(slice_size)):
430
- size = slice_size[i]
431
- dim = sliceable_head_dims[i]
432
- if size is not None and size > dim:
433
- raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
434
-
435
- # Recursively walk through all the children.
436
- # Any children which exposes the set_attention_slice method
437
- # gets the message
438
- def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
439
- if hasattr(module, "set_attention_slice"):
440
- module.set_attention_slice(slice_size.pop())
441
-
442
- for child in module.children():
443
- fn_recursive_set_attention_slice(child, slice_size)
444
-
445
- reversed_slice_size = list(reversed(slice_size))
446
- for module in self.children():
447
- fn_recursive_set_attention_slice(module, reversed_slice_size)
448
-
449
- def _set_gradient_checkpointing(self, module, value=False):
450
- if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
451
- module.gradient_checkpointing = value
452
-
453
- def forward(
454
- self,
455
- sample: torch.FloatTensor,
456
- timestep: Union[torch.Tensor, float, int],
457
- encoder_hidden_states: torch.Tensor,
458
- controlnet_cond: torch.FloatTensor,
459
- conditioning_scale: float = 1.0,
460
- class_labels: Optional[torch.Tensor] = None,
461
- timestep_cond: Optional[torch.Tensor] = None,
462
- attention_mask: Optional[torch.Tensor] = None,
463
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
464
- return_dict: bool = True,
465
- ) -> Union[ControlNetOutput, Tuple]:
466
- # check channel order
467
- channel_order = self.config.controlnet_conditioning_channel_order
468
-
469
- if channel_order == "rgb":
470
- # in rgb order by default
471
- ...
472
- elif channel_order == "bgr":
473
- controlnet_cond = torch.flip(controlnet_cond, dims=[1])
474
- else:
475
- raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
476
-
477
- # prepare attention_mask
478
- if attention_mask is not None:
479
- attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
480
- attention_mask = attention_mask.unsqueeze(1)
481
-
482
- # 1. time
483
- timesteps = timestep
484
- if not torch.is_tensor(timesteps):
485
- # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
486
- # This would be a good case for the `match` statement (Python 3.10+)
487
- is_mps = sample.device.type == "mps"
488
- if isinstance(timestep, float):
489
- dtype = torch.float32 if is_mps else torch.float64
490
- else:
491
- dtype = torch.int32 if is_mps else torch.int64
492
- timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
493
- elif len(timesteps.shape) == 0:
494
- timesteps = timesteps[None].to(sample.device)
495
-
496
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
497
- timesteps = timesteps.expand(sample.shape[0])
498
-
499
- t_emb = self.time_proj(timesteps)
500
-
501
- # timesteps does not contain any weights and will always return f32 tensors
502
- # but time_embedding might actually be running in fp16. so we need to cast here.
503
- # there might be better ways to encapsulate this.
504
- t_emb = t_emb.to(dtype=self.dtype)
505
-
506
- emb = self.time_embedding(t_emb, timestep_cond)
507
-
508
- if self.class_embedding is not None:
509
- if class_labels is None:
510
- raise ValueError("class_labels should be provided when num_class_embeds > 0")
511
-
512
- if self.config.class_embed_type == "timestep":
513
- class_labels = self.time_proj(class_labels)
514
-
515
- class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
516
- emb = emb + class_emb
517
-
518
- # 2. pre-process
519
- sample = self.conv_in(sample)
520
-
521
- controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
522
-
523
- sample += controlnet_cond
524
-
525
- # 3. down
526
- down_block_res_samples = (sample,)
527
- for downsample_block in self.down_blocks:
528
- if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
529
- sample, res_samples = downsample_block(
530
- hidden_states=sample,
531
- temb=emb,
532
- encoder_hidden_states=encoder_hidden_states,
533
- attention_mask=attention_mask,
534
- # cross_attention_kwargs=cross_attention_kwargs,
535
- )
536
- else:
537
- sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
538
-
539
- down_block_res_samples += res_samples
540
-
541
- # 4. mid
542
- if self.mid_block is not None:
543
- sample = self.mid_block(
544
- sample,
545
- emb,
546
- encoder_hidden_states=encoder_hidden_states,
547
- attention_mask=attention_mask,
548
- # cross_attention_kwargs=cross_attention_kwargs,
549
- )
550
-
551
- # 5. Control net blocks
552
-
553
- controlnet_down_block_res_samples = ()
554
-
555
- for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
556
- down_block_res_sample = controlnet_block(down_block_res_sample)
557
- controlnet_down_block_res_samples += (down_block_res_sample,)
558
-
559
- down_block_res_samples = controlnet_down_block_res_samples
560
-
561
- mid_block_res_sample = self.controlnet_mid_block(sample)
562
-
563
- # 6. scaling
564
- down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
565
- mid_block_res_sample *= conditioning_scale
566
-
567
- if not return_dict:
568
- return (down_block_res_samples, mid_block_res_sample)
569
-
570
- return ControlNetOutput(
571
- down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
572
- )
573
-
574
-
575
- def zero_module(module):
576
- for p in module.parameters():
577
- nn.init.zeros_(p)
578
  return module
 
1
+ # *************************************************************************
2
+ # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
+ # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
+ # ytedance Inc..
5
+ # *************************************************************************
6
+
7
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ from dataclasses import dataclass
21
+ from typing import Any, Dict, List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ from torch import nn
25
+ from torch.nn import functional as F
26
+
27
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
28
+ from diffusers.utils import BaseOutput, logging
29
+ from .embeddings import TimestepEmbedding, Timesteps
30
+ from diffusers.models.modeling_utils import ModelMixin
31
+ from diffusers.models.unets.unet_2d_blocks import (
32
+ CrossAttnDownBlock2D,
33
+ DownBlock2D,
34
+ UNetMidBlock2DCrossAttn,
35
+ get_down_block,
36
+ )
37
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
38
+
39
+
40
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
+
42
+
43
+ @dataclass
44
+ class ControlNetOutput(BaseOutput):
45
+ down_block_res_samples: Tuple[torch.Tensor]
46
+ mid_block_res_sample: torch.Tensor
47
+
48
+
49
+ class ControlNetConditioningEmbedding(nn.Module):
50
+ """
51
+ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
52
+ [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
53
+ training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
54
+ convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
55
+ (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
56
+ model) to encode image-space conditions ... into feature maps ..."
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ conditioning_embedding_channels: int,
62
+ conditioning_channels: int = 3,
63
+ block_out_channels: Tuple[int] = (16, 32, 96, 256),
64
+ ):
65
+ super().__init__()
66
+
67
+ self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
68
+
69
+ self.blocks = nn.ModuleList([])
70
+
71
+ for i in range(len(block_out_channels) - 1):
72
+ channel_in = block_out_channels[i]
73
+ channel_out = block_out_channels[i + 1]
74
+ self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
75
+ self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
76
+
77
+ self.conv_out = zero_module(
78
+ nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
79
+ )
80
+
81
+ def forward(self, conditioning):
82
+ embedding = self.conv_in(conditioning)
83
+ embedding = F.silu(embedding)
84
+
85
+ for block in self.blocks:
86
+ embedding = block(embedding)
87
+ embedding = F.silu(embedding)
88
+
89
+ embedding = self.conv_out(embedding)
90
+
91
+ return embedding
92
+
93
+
94
+ class ControlNetModel(ModelMixin, ConfigMixin):
95
+ _supports_gradient_checkpointing = True
96
+
97
+ @register_to_config
98
+ def __init__(
99
+ self,
100
+ in_channels: int = 4,
101
+ flip_sin_to_cos: bool = True,
102
+ freq_shift: int = 0,
103
+ down_block_types: Tuple[str] = (
104
+ "CrossAttnDownBlock2D",
105
+ "CrossAttnDownBlock2D",
106
+ "CrossAttnDownBlock2D",
107
+ "DownBlock2D",
108
+ ),
109
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
110
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
111
+ layers_per_block: int = 2,
112
+ downsample_padding: int = 1,
113
+ mid_block_scale_factor: float = 1,
114
+ act_fn: str = "silu",
115
+ norm_num_groups: Optional[int] = 32,
116
+ norm_eps: float = 1e-5,
117
+ cross_attention_dim: int = 1280,
118
+ attention_head_dim: Union[int, Tuple[int]] = 8,
119
+ use_linear_projection: bool = False,
120
+ class_embed_type: Optional[str] = None,
121
+ num_class_embeds: Optional[int] = None,
122
+ upcast_attention: bool = False,
123
+ resnet_time_scale_shift: str = "default",
124
+ projection_class_embeddings_input_dim: Optional[int] = None,
125
+ controlnet_conditioning_channel_order: str = "rgb",
126
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
127
+ ):
128
+ super().__init__()
129
+
130
+ # Check inputs
131
+ if len(block_out_channels) != len(down_block_types):
132
+ raise ValueError(
133
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
134
+ )
135
+
136
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
137
+ raise ValueError(
138
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
139
+ )
140
+
141
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
142
+ raise ValueError(
143
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
144
+ )
145
+
146
+ # input
147
+ conv_in_kernel = 3
148
+ conv_in_padding = (conv_in_kernel - 1) // 2
149
+ self.conv_in = nn.Conv2d(
150
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
151
+ )
152
+
153
+ # time
154
+ time_embed_dim = block_out_channels[0] * 4
155
+
156
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
157
+ timestep_input_dim = block_out_channels[0]
158
+
159
+ self.time_embedding = TimestepEmbedding(
160
+ timestep_input_dim,
161
+ time_embed_dim,
162
+ act_fn=act_fn,
163
+ )
164
+
165
+ # class embedding
166
+ if class_embed_type is None and num_class_embeds is not None:
167
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
168
+ elif class_embed_type == "timestep":
169
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
170
+ elif class_embed_type == "identity":
171
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
172
+ elif class_embed_type == "projection":
173
+ if projection_class_embeddings_input_dim is None:
174
+ raise ValueError(
175
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
176
+ )
177
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
178
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
179
+ # 2. it projects from an arbitrary input dimension.
180
+ #
181
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
182
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
183
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
184
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
185
+ else:
186
+ self.class_embedding = None
187
+
188
+ # control net conditioning embedding
189
+ self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
190
+ conditioning_embedding_channels=block_out_channels[0],
191
+ block_out_channels=conditioning_embedding_out_channels,
192
+ )
193
+
194
+ self.down_blocks = nn.ModuleList([])
195
+ self.controlnet_down_blocks = nn.ModuleList([])
196
+
197
+ if isinstance(only_cross_attention, bool):
198
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
199
+
200
+ if isinstance(attention_head_dim, int):
201
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
202
+
203
+ # down
204
+ output_channel = block_out_channels[0]
205
+
206
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
207
+ controlnet_block = zero_module(controlnet_block)
208
+ self.controlnet_down_blocks.append(controlnet_block)
209
+
210
+ for i, down_block_type in enumerate(down_block_types):
211
+ input_channel = output_channel
212
+ output_channel = block_out_channels[i]
213
+ is_final_block = i == len(block_out_channels) - 1
214
+
215
+ down_block = get_down_block(
216
+ down_block_type,
217
+ num_layers=layers_per_block,
218
+ in_channels=input_channel,
219
+ out_channels=output_channel,
220
+ temb_channels=time_embed_dim,
221
+ add_downsample=not is_final_block,
222
+ resnet_eps=norm_eps,
223
+ resnet_act_fn=act_fn,
224
+ resnet_groups=norm_num_groups,
225
+ cross_attention_dim=cross_attention_dim,
226
+ num_attention_heads=attention_head_dim[i],
227
+ downsample_padding=downsample_padding,
228
+ use_linear_projection=use_linear_projection,
229
+ only_cross_attention=only_cross_attention[i],
230
+ upcast_attention=upcast_attention,
231
+ resnet_time_scale_shift=resnet_time_scale_shift,
232
+ )
233
+ self.down_blocks.append(down_block)
234
+
235
+ for _ in range(layers_per_block):
236
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
237
+ controlnet_block = zero_module(controlnet_block)
238
+ self.controlnet_down_blocks.append(controlnet_block)
239
+
240
+ if not is_final_block:
241
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
242
+ controlnet_block = zero_module(controlnet_block)
243
+ self.controlnet_down_blocks.append(controlnet_block)
244
+
245
+ # mid
246
+ mid_block_channel = block_out_channels[-1]
247
+
248
+ controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
249
+ controlnet_block = zero_module(controlnet_block)
250
+ self.controlnet_mid_block = controlnet_block
251
+
252
+ self.mid_block = UNetMidBlock2DCrossAttn(
253
+ in_channels=mid_block_channel,
254
+ temb_channels=time_embed_dim,
255
+ resnet_eps=norm_eps,
256
+ resnet_act_fn=act_fn,
257
+ output_scale_factor=mid_block_scale_factor,
258
+ resnet_time_scale_shift=resnet_time_scale_shift,
259
+ cross_attention_dim=cross_attention_dim,
260
+ num_attention_heads=attention_head_dim[-1],
261
+ resnet_groups=norm_num_groups,
262
+ use_linear_projection=use_linear_projection,
263
+ upcast_attention=upcast_attention,
264
+ )
265
+
266
+ @classmethod
267
+ def from_unet(
268
+ cls,
269
+ unet: UNet2DConditionModel,
270
+ controlnet_conditioning_channel_order: str = "rgb",
271
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
272
+ load_weights_from_unet: bool = True,
273
+ ):
274
+ r"""
275
+ Instantiate Controlnet class from UNet2DConditionModel.
276
+
277
+ Parameters:
278
+ unet (`UNet2DConditionModel`):
279
+ UNet model which weights are copied to the ControlNet. Note that all configuration options are also
280
+ copied where applicable.
281
+ """
282
+ controlnet = cls(
283
+ in_channels=unet.config.in_channels,
284
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
285
+ freq_shift=unet.config.freq_shift,
286
+ down_block_types=unet.config.down_block_types,
287
+ only_cross_attention=unet.config.only_cross_attention,
288
+ block_out_channels=unet.config.block_out_channels,
289
+ layers_per_block=unet.config.layers_per_block,
290
+ downsample_padding=unet.config.downsample_padding,
291
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
292
+ act_fn=unet.config.act_fn,
293
+ norm_num_groups=unet.config.norm_num_groups,
294
+ norm_eps=unet.config.norm_eps,
295
+ cross_attention_dim=unet.config.cross_attention_dim,
296
+ attention_head_dim=unet.config.attention_head_dim,
297
+ use_linear_projection=unet.config.use_linear_projection,
298
+ class_embed_type=unet.config.class_embed_type,
299
+ num_class_embeds=unet.config.num_class_embeds,
300
+ upcast_attention=unet.config.upcast_attention,
301
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
302
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
303
+ controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
304
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
305
+ )
306
+
307
+ if load_weights_from_unet:
308
+ controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
309
+ controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
310
+ controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
311
+
312
+ if controlnet.class_embedding:
313
+ controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
314
+
315
+ controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
316
+ controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
317
+
318
+ return controlnet
319
+
320
+ # @property
321
+ # # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
322
+ # def attn_processors(self) -> Dict[str, AttentionProcessor]:
323
+ # r"""
324
+ # Returns:
325
+ # `dict` of attention processors: A dictionary containing all attention processors used in the model with
326
+ # indexed by its weight name.
327
+ # """
328
+ # # set recursively
329
+ # processors = {}
330
+
331
+ # def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
332
+ # if hasattr(module, "set_processor"):
333
+ # processors[f"{name}.processor"] = module.processor
334
+
335
+ # for sub_name, child in module.named_children():
336
+ # fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
337
+
338
+ # return processors
339
+
340
+ # for name, module in self.named_children():
341
+ # fn_recursive_add_processors(name, module, processors)
342
+
343
+ # return processors
344
+
345
+ # # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
346
+ # def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
347
+ # r"""
348
+ # Parameters:
349
+ # `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
350
+ # The instantiated processor class or a dictionary of processor classes that will be set as the processor
351
+ # of **all** `Attention` layers.
352
+ # In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
353
+
354
+ # """
355
+ # count = len(self.attn_processors.keys())
356
+
357
+ # if isinstance(processor, dict) and len(processor) != count:
358
+ # raise ValueError(
359
+ # f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
360
+ # f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
361
+ # )
362
+
363
+ # def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
364
+ # if hasattr(module, "set_processor"):
365
+ # if not isinstance(processor, dict):
366
+ # module.set_processor(processor)
367
+ # else:
368
+ # module.set_processor(processor.pop(f"{name}.processor"))
369
+
370
+ # for sub_name, child in module.named_children():
371
+ # fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
372
+
373
+ # for name, module in self.named_children():
374
+ # fn_recursive_attn_processor(name, module, processor)
375
+
376
+ # # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
377
+ # def set_default_attn_processor(self):
378
+ # """
379
+ # Disables custom attention processors and sets the default attention implementation.
380
+ # """
381
+ # self.set_attn_processor(AttnProcessor())
382
+
383
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
384
+ def set_attention_slice(self, slice_size):
385
+ r"""
386
+ Enable sliced attention computation.
387
+
388
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
389
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
390
+
391
+ Args:
392
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
393
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
394
+ `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
395
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
396
+ must be a multiple of `slice_size`.
397
+ """
398
+ sliceable_head_dims = []
399
+
400
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
401
+ if hasattr(module, "set_attention_slice"):
402
+ sliceable_head_dims.append(module.sliceable_head_dim)
403
+
404
+ for child in module.children():
405
+ fn_recursive_retrieve_sliceable_dims(child)
406
+
407
+ # retrieve number of attention layers
408
+ for module in self.children():
409
+ fn_recursive_retrieve_sliceable_dims(module)
410
+
411
+ num_sliceable_layers = len(sliceable_head_dims)
412
+
413
+ if slice_size == "auto":
414
+ # half the attention head size is usually a good trade-off between
415
+ # speed and memory
416
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
417
+ elif slice_size == "max":
418
+ # make smallest slice possible
419
+ slice_size = num_sliceable_layers * [1]
420
+
421
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
422
+
423
+ if len(slice_size) != len(sliceable_head_dims):
424
+ raise ValueError(
425
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
426
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
427
+ )
428
+
429
+ for i in range(len(slice_size)):
430
+ size = slice_size[i]
431
+ dim = sliceable_head_dims[i]
432
+ if size is not None and size > dim:
433
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
434
+
435
+ # Recursively walk through all the children.
436
+ # Any children which exposes the set_attention_slice method
437
+ # gets the message
438
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
439
+ if hasattr(module, "set_attention_slice"):
440
+ module.set_attention_slice(slice_size.pop())
441
+
442
+ for child in module.children():
443
+ fn_recursive_set_attention_slice(child, slice_size)
444
+
445
+ reversed_slice_size = list(reversed(slice_size))
446
+ for module in self.children():
447
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
448
+
449
+ def _set_gradient_checkpointing(self, module, value=False):
450
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
451
+ module.gradient_checkpointing = value
452
+
453
+ def forward(
454
+ self,
455
+ sample: torch.FloatTensor,
456
+ timestep: Union[torch.Tensor, float, int],
457
+ encoder_hidden_states: torch.Tensor,
458
+ controlnet_cond: torch.FloatTensor,
459
+ conditioning_scale: float = 1.0,
460
+ class_labels: Optional[torch.Tensor] = None,
461
+ timestep_cond: Optional[torch.Tensor] = None,
462
+ attention_mask: Optional[torch.Tensor] = None,
463
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
464
+ return_dict: bool = True,
465
+ ) -> Union[ControlNetOutput, Tuple]:
466
+ # check channel order
467
+ channel_order = self.config.controlnet_conditioning_channel_order
468
+
469
+ if channel_order == "rgb":
470
+ # in rgb order by default
471
+ ...
472
+ elif channel_order == "bgr":
473
+ controlnet_cond = torch.flip(controlnet_cond, dims=[1])
474
+ else:
475
+ raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
476
+
477
+ # prepare attention_mask
478
+ if attention_mask is not None:
479
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
480
+ attention_mask = attention_mask.unsqueeze(1)
481
+
482
+ # 1. time
483
+ timesteps = timestep
484
+ if not torch.is_tensor(timesteps):
485
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
486
+ # This would be a good case for the `match` statement (Python 3.10+)
487
+ is_mps = sample.device.type == "mps"
488
+ if isinstance(timestep, float):
489
+ dtype = torch.float32 if is_mps else torch.float64
490
+ else:
491
+ dtype = torch.int32 if is_mps else torch.int64
492
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
493
+ elif len(timesteps.shape) == 0:
494
+ timesteps = timesteps[None].to(sample.device)
495
+
496
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
497
+ timesteps = timesteps.expand(sample.shape[0])
498
+
499
+ t_emb = self.time_proj(timesteps)
500
+
501
+ # timesteps does not contain any weights and will always return f32 tensors
502
+ # but time_embedding might actually be running in fp16. so we need to cast here.
503
+ # there might be better ways to encapsulate this.
504
+ t_emb = t_emb.to(dtype=self.dtype)
505
+
506
+ emb = self.time_embedding(t_emb, timestep_cond)
507
+
508
+ if self.class_embedding is not None:
509
+ if class_labels is None:
510
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
511
+
512
+ if self.config.class_embed_type == "timestep":
513
+ class_labels = self.time_proj(class_labels)
514
+
515
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
516
+ emb = emb + class_emb
517
+
518
+ # 2. pre-process
519
+ sample = self.conv_in(sample)
520
+
521
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
522
+
523
+ sample += controlnet_cond
524
+
525
+ # 3. down
526
+ down_block_res_samples = (sample,)
527
+ for downsample_block in self.down_blocks:
528
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
529
+ sample, res_samples = downsample_block(
530
+ hidden_states=sample,
531
+ temb=emb,
532
+ encoder_hidden_states=encoder_hidden_states,
533
+ attention_mask=attention_mask,
534
+ # cross_attention_kwargs=cross_attention_kwargs,
535
+ )
536
+ else:
537
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
538
+
539
+ down_block_res_samples += res_samples
540
+
541
+ # 4. mid
542
+ if self.mid_block is not None:
543
+ sample = self.mid_block(
544
+ sample,
545
+ emb,
546
+ encoder_hidden_states=encoder_hidden_states,
547
+ attention_mask=attention_mask,
548
+ # cross_attention_kwargs=cross_attention_kwargs,
549
+ )
550
+
551
+ # 5. Control net blocks
552
+
553
+ controlnet_down_block_res_samples = ()
554
+
555
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
556
+ down_block_res_sample = controlnet_block(down_block_res_sample)
557
+ controlnet_down_block_res_samples += (down_block_res_sample,)
558
+
559
+ down_block_res_samples = controlnet_down_block_res_samples
560
+
561
+ mid_block_res_sample = self.controlnet_mid_block(sample)
562
+
563
+ # 6. scaling
564
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
565
+ mid_block_res_sample *= conditioning_scale
566
+
567
+ if not return_dict:
568
+ return (down_block_res_samples, mid_block_res_sample)
569
+
570
+ return ControlNetOutput(
571
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
572
+ )
573
+
574
+
575
+ def zero_module(module):
576
+ for p in module.parameters():
577
+ nn.init.zeros_(p)
578
  return module
magicanimate/models/embeddings.py CHANGED
@@ -1,385 +1,385 @@
1
- # *************************************************************************
2
- # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
- # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
- # ytedance Inc..
5
- # *************************************************************************
6
-
7
- # Copyright 2023 The HuggingFace Team. All rights reserved.
8
- #
9
- # Licensed under the Apache License, Version 2.0 (the "License");
10
- # you may not use this file except in compliance with the License.
11
- # You may obtain a copy of the License at
12
- #
13
- # http://www.apache.org/licenses/LICENSE-2.0
14
- #
15
- # Unless required by applicable law or agreed to in writing, software
16
- # distributed under the License is distributed on an "AS IS" BASIS,
17
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
- # See the License for the specific language governing permissions and
19
- # limitations under the License.
20
- import math
21
- from typing import Optional
22
-
23
- import numpy as np
24
- import torch
25
- from torch import nn
26
-
27
-
28
- def get_timestep_embedding(
29
- timesteps: torch.Tensor,
30
- embedding_dim: int,
31
- flip_sin_to_cos: bool = False,
32
- downscale_freq_shift: float = 1,
33
- scale: float = 1,
34
- max_period: int = 10000,
35
- ):
36
- """
37
- This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
38
-
39
- :param timesteps: a 1-D Tensor of N indices, one per batch element.
40
- These may be fractional.
41
- :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
42
- embeddings. :return: an [N x dim] Tensor of positional embeddings.
43
- """
44
- assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
45
-
46
- half_dim = embedding_dim // 2
47
- exponent = -math.log(max_period) * torch.arange(
48
- start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
49
- )
50
- exponent = exponent / (half_dim - downscale_freq_shift)
51
-
52
- emb = torch.exp(exponent)
53
- emb = timesteps[:, None].float() * emb[None, :]
54
-
55
- # scale embeddings
56
- emb = scale * emb
57
-
58
- # concat sine and cosine embeddings
59
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
60
-
61
- # flip sine and cosine embeddings
62
- if flip_sin_to_cos:
63
- emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
64
-
65
- # zero pad
66
- if embedding_dim % 2 == 1:
67
- emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
68
- return emb
69
-
70
-
71
- def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
72
- """
73
- grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
74
- [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
75
- """
76
- grid_h = np.arange(grid_size, dtype=np.float32)
77
- grid_w = np.arange(grid_size, dtype=np.float32)
78
- grid = np.meshgrid(grid_w, grid_h) # here w goes first
79
- grid = np.stack(grid, axis=0)
80
-
81
- grid = grid.reshape([2, 1, grid_size, grid_size])
82
- pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
83
- if cls_token and extra_tokens > 0:
84
- pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
85
- return pos_embed
86
-
87
-
88
- def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
89
- if embed_dim % 2 != 0:
90
- raise ValueError("embed_dim must be divisible by 2")
91
-
92
- # use half of dimensions to encode grid_h
93
- emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
94
- emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
95
-
96
- emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
97
- return emb
98
-
99
-
100
- def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
101
- """
102
- embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
103
- """
104
- if embed_dim % 2 != 0:
105
- raise ValueError("embed_dim must be divisible by 2")
106
-
107
- omega = np.arange(embed_dim // 2, dtype=np.float64)
108
- omega /= embed_dim / 2.0
109
- omega = 1.0 / 10000**omega # (D/2,)
110
-
111
- pos = pos.reshape(-1) # (M,)
112
- out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
113
-
114
- emb_sin = np.sin(out) # (M, D/2)
115
- emb_cos = np.cos(out) # (M, D/2)
116
-
117
- emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
118
- return emb
119
-
120
-
121
- class PatchEmbed(nn.Module):
122
- """2D Image to Patch Embedding"""
123
-
124
- def __init__(
125
- self,
126
- height=224,
127
- width=224,
128
- patch_size=16,
129
- in_channels=3,
130
- embed_dim=768,
131
- layer_norm=False,
132
- flatten=True,
133
- bias=True,
134
- ):
135
- super().__init__()
136
-
137
- num_patches = (height // patch_size) * (width // patch_size)
138
- self.flatten = flatten
139
- self.layer_norm = layer_norm
140
-
141
- self.proj = nn.Conv2d(
142
- in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
143
- )
144
- if layer_norm:
145
- self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
146
- else:
147
- self.norm = None
148
-
149
- pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5))
150
- self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
151
-
152
- def forward(self, latent):
153
- latent = self.proj(latent)
154
- if self.flatten:
155
- latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
156
- if self.layer_norm:
157
- latent = self.norm(latent)
158
- return latent + self.pos_embed
159
-
160
-
161
- class TimestepEmbedding(nn.Module):
162
- def __init__(
163
- self,
164
- in_channels: int,
165
- time_embed_dim: int,
166
- act_fn: str = "silu",
167
- out_dim: int = None,
168
- post_act_fn: Optional[str] = None,
169
- cond_proj_dim=None,
170
- ):
171
- super().__init__()
172
-
173
- self.linear_1 = nn.Linear(in_channels, time_embed_dim)
174
-
175
- if cond_proj_dim is not None:
176
- self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
177
- else:
178
- self.cond_proj = None
179
-
180
- if act_fn == "silu":
181
- self.act = nn.SiLU()
182
- elif act_fn == "mish":
183
- self.act = nn.Mish()
184
- elif act_fn == "gelu":
185
- self.act = nn.GELU()
186
- else:
187
- raise ValueError(f"{act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'")
188
-
189
- if out_dim is not None:
190
- time_embed_dim_out = out_dim
191
- else:
192
- time_embed_dim_out = time_embed_dim
193
- self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
194
-
195
- if post_act_fn is None:
196
- self.post_act = None
197
- elif post_act_fn == "silu":
198
- self.post_act = nn.SiLU()
199
- elif post_act_fn == "mish":
200
- self.post_act = nn.Mish()
201
- elif post_act_fn == "gelu":
202
- self.post_act = nn.GELU()
203
- else:
204
- raise ValueError(f"{post_act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'")
205
-
206
- def forward(self, sample, condition=None):
207
- if condition is not None:
208
- sample = sample + self.cond_proj(condition)
209
- sample = self.linear_1(sample)
210
-
211
- if self.act is not None:
212
- sample = self.act(sample)
213
-
214
- sample = self.linear_2(sample)
215
-
216
- if self.post_act is not None:
217
- sample = self.post_act(sample)
218
- return sample
219
-
220
-
221
- class Timesteps(nn.Module):
222
- def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
223
- super().__init__()
224
- self.num_channels = num_channels
225
- self.flip_sin_to_cos = flip_sin_to_cos
226
- self.downscale_freq_shift = downscale_freq_shift
227
-
228
- def forward(self, timesteps):
229
- t_emb = get_timestep_embedding(
230
- timesteps,
231
- self.num_channels,
232
- flip_sin_to_cos=self.flip_sin_to_cos,
233
- downscale_freq_shift=self.downscale_freq_shift,
234
- )
235
- return t_emb
236
-
237
-
238
- class GaussianFourierProjection(nn.Module):
239
- """Gaussian Fourier embeddings for noise levels."""
240
-
241
- def __init__(
242
- self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
243
- ):
244
- super().__init__()
245
- self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
246
- self.log = log
247
- self.flip_sin_to_cos = flip_sin_to_cos
248
-
249
- if set_W_to_weight:
250
- # to delete later
251
- self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
252
-
253
- self.weight = self.W
254
-
255
- def forward(self, x):
256
- if self.log:
257
- x = torch.log(x)
258
-
259
- x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
260
-
261
- if self.flip_sin_to_cos:
262
- out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
263
- else:
264
- out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
265
- return out
266
-
267
-
268
- class ImagePositionalEmbeddings(nn.Module):
269
- """
270
- Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
271
- height and width of the latent space.
272
-
273
- For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
274
-
275
- For VQ-diffusion:
276
-
277
- Output vector embeddings are used as input for the transformer.
278
-
279
- Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
280
-
281
- Args:
282
- num_embed (`int`):
283
- Number of embeddings for the latent pixels embeddings.
284
- height (`int`):
285
- Height of the latent image i.e. the number of height embeddings.
286
- width (`int`):
287
- Width of the latent image i.e. the number of width embeddings.
288
- embed_dim (`int`):
289
- Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
290
- """
291
-
292
- def __init__(
293
- self,
294
- num_embed: int,
295
- height: int,
296
- width: int,
297
- embed_dim: int,
298
- ):
299
- super().__init__()
300
-
301
- self.height = height
302
- self.width = width
303
- self.num_embed = num_embed
304
- self.embed_dim = embed_dim
305
-
306
- self.emb = nn.Embedding(self.num_embed, embed_dim)
307
- self.height_emb = nn.Embedding(self.height, embed_dim)
308
- self.width_emb = nn.Embedding(self.width, embed_dim)
309
-
310
- def forward(self, index):
311
- emb = self.emb(index)
312
-
313
- height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))
314
-
315
- # 1 x H x D -> 1 x H x 1 x D
316
- height_emb = height_emb.unsqueeze(2)
317
-
318
- width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))
319
-
320
- # 1 x W x D -> 1 x 1 x W x D
321
- width_emb = width_emb.unsqueeze(1)
322
-
323
- pos_emb = height_emb + width_emb
324
-
325
- # 1 x H x W x D -> 1 x L xD
326
- pos_emb = pos_emb.view(1, self.height * self.width, -1)
327
-
328
- emb = emb + pos_emb[:, : emb.shape[1], :]
329
-
330
- return emb
331
-
332
-
333
- class LabelEmbedding(nn.Module):
334
- """
335
- Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
336
-
337
- Args:
338
- num_classes (`int`): The number of classes.
339
- hidden_size (`int`): The size of the vector embeddings.
340
- dropout_prob (`float`): The probability of dropping a label.
341
- """
342
-
343
- def __init__(self, num_classes, hidden_size, dropout_prob):
344
- super().__init__()
345
- use_cfg_embedding = dropout_prob > 0
346
- self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
347
- self.num_classes = num_classes
348
- self.dropout_prob = dropout_prob
349
-
350
- def token_drop(self, labels, force_drop_ids=None):
351
- """
352
- Drops labels to enable classifier-free guidance.
353
- """
354
- if force_drop_ids is None:
355
- drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
356
- else:
357
- drop_ids = torch.tensor(force_drop_ids == 1)
358
- labels = torch.where(drop_ids, self.num_classes, labels)
359
- return labels
360
-
361
- def forward(self, labels, force_drop_ids=None):
362
- use_dropout = self.dropout_prob > 0
363
- if (self.training and use_dropout) or (force_drop_ids is not None):
364
- labels = self.token_drop(labels, force_drop_ids)
365
- embeddings = self.embedding_table(labels)
366
- return embeddings
367
-
368
-
369
- class CombinedTimestepLabelEmbeddings(nn.Module):
370
- def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
371
- super().__init__()
372
-
373
- self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
374
- self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
375
- self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)
376
-
377
- def forward(self, timestep, class_labels, hidden_dtype=None):
378
- timesteps_proj = self.time_proj(timestep)
379
- timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
380
-
381
- class_labels = self.class_embedder(class_labels) # (N, D)
382
-
383
- conditioning = timesteps_emb + class_labels # (N, D)
384
-
385
  return conditioning
 
1
+ # *************************************************************************
2
+ # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
+ # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
+ # ytedance Inc..
5
+ # *************************************************************************
6
+
7
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ import math
21
+ from typing import Optional
22
+
23
+ import numpy as np
24
+ import torch
25
+ from torch import nn
26
+
27
+
28
+ def get_timestep_embedding(
29
+ timesteps: torch.Tensor,
30
+ embedding_dim: int,
31
+ flip_sin_to_cos: bool = False,
32
+ downscale_freq_shift: float = 1,
33
+ scale: float = 1,
34
+ max_period: int = 10000,
35
+ ):
36
+ """
37
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
38
+
39
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
40
+ These may be fractional.
41
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
42
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
43
+ """
44
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
45
+
46
+ half_dim = embedding_dim // 2
47
+ exponent = -math.log(max_period) * torch.arange(
48
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
49
+ )
50
+ exponent = exponent / (half_dim - downscale_freq_shift)
51
+
52
+ emb = torch.exp(exponent)
53
+ emb = timesteps[:, None].float() * emb[None, :]
54
+
55
+ # scale embeddings
56
+ emb = scale * emb
57
+
58
+ # concat sine and cosine embeddings
59
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
60
+
61
+ # flip sine and cosine embeddings
62
+ if flip_sin_to_cos:
63
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
64
+
65
+ # zero pad
66
+ if embedding_dim % 2 == 1:
67
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
68
+ return emb
69
+
70
+
71
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
72
+ """
73
+ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
74
+ [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
75
+ """
76
+ grid_h = np.arange(grid_size, dtype=np.float32)
77
+ grid_w = np.arange(grid_size, dtype=np.float32)
78
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
79
+ grid = np.stack(grid, axis=0)
80
+
81
+ grid = grid.reshape([2, 1, grid_size, grid_size])
82
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
83
+ if cls_token and extra_tokens > 0:
84
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
85
+ return pos_embed
86
+
87
+
88
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
89
+ if embed_dim % 2 != 0:
90
+ raise ValueError("embed_dim must be divisible by 2")
91
+
92
+ # use half of dimensions to encode grid_h
93
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
94
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
95
+
96
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
97
+ return emb
98
+
99
+
100
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
101
+ """
102
+ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
103
+ """
104
+ if embed_dim % 2 != 0:
105
+ raise ValueError("embed_dim must be divisible by 2")
106
+
107
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
108
+ omega /= embed_dim / 2.0
109
+ omega = 1.0 / 10000**omega # (D/2,)
110
+
111
+ pos = pos.reshape(-1) # (M,)
112
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
113
+
114
+ emb_sin = np.sin(out) # (M, D/2)
115
+ emb_cos = np.cos(out) # (M, D/2)
116
+
117
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
118
+ return emb
119
+
120
+
121
+ class PatchEmbed(nn.Module):
122
+ """2D Image to Patch Embedding"""
123
+
124
+ def __init__(
125
+ self,
126
+ height=224,
127
+ width=224,
128
+ patch_size=16,
129
+ in_channels=3,
130
+ embed_dim=768,
131
+ layer_norm=False,
132
+ flatten=True,
133
+ bias=True,
134
+ ):
135
+ super().__init__()
136
+
137
+ num_patches = (height // patch_size) * (width // patch_size)
138
+ self.flatten = flatten
139
+ self.layer_norm = layer_norm
140
+
141
+ self.proj = nn.Conv2d(
142
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
143
+ )
144
+ if layer_norm:
145
+ self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
146
+ else:
147
+ self.norm = None
148
+
149
+ pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5))
150
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
151
+
152
+ def forward(self, latent):
153
+ latent = self.proj(latent)
154
+ if self.flatten:
155
+ latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
156
+ if self.layer_norm:
157
+ latent = self.norm(latent)
158
+ return latent + self.pos_embed
159
+
160
+
161
+ class TimestepEmbedding(nn.Module):
162
+ def __init__(
163
+ self,
164
+ in_channels: int,
165
+ time_embed_dim: int,
166
+ act_fn: str = "silu",
167
+ out_dim: int = None,
168
+ post_act_fn: Optional[str] = None,
169
+ cond_proj_dim=None,
170
+ ):
171
+ super().__init__()
172
+
173
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim)
174
+
175
+ if cond_proj_dim is not None:
176
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
177
+ else:
178
+ self.cond_proj = None
179
+
180
+ if act_fn == "silu":
181
+ self.act = nn.SiLU()
182
+ elif act_fn == "mish":
183
+ self.act = nn.Mish()
184
+ elif act_fn == "gelu":
185
+ self.act = nn.GELU()
186
+ else:
187
+ raise ValueError(f"{act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'")
188
+
189
+ if out_dim is not None:
190
+ time_embed_dim_out = out_dim
191
+ else:
192
+ time_embed_dim_out = time_embed_dim
193
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
194
+
195
+ if post_act_fn is None:
196
+ self.post_act = None
197
+ elif post_act_fn == "silu":
198
+ self.post_act = nn.SiLU()
199
+ elif post_act_fn == "mish":
200
+ self.post_act = nn.Mish()
201
+ elif post_act_fn == "gelu":
202
+ self.post_act = nn.GELU()
203
+ else:
204
+ raise ValueError(f"{post_act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'")
205
+
206
+ def forward(self, sample, condition=None):
207
+ if condition is not None:
208
+ sample = sample + self.cond_proj(condition)
209
+ sample = self.linear_1(sample)
210
+
211
+ if self.act is not None:
212
+ sample = self.act(sample)
213
+
214
+ sample = self.linear_2(sample)
215
+
216
+ if self.post_act is not None:
217
+ sample = self.post_act(sample)
218
+ return sample
219
+
220
+
221
+ class Timesteps(nn.Module):
222
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
223
+ super().__init__()
224
+ self.num_channels = num_channels
225
+ self.flip_sin_to_cos = flip_sin_to_cos
226
+ self.downscale_freq_shift = downscale_freq_shift
227
+
228
+ def forward(self, timesteps):
229
+ t_emb = get_timestep_embedding(
230
+ timesteps,
231
+ self.num_channels,
232
+ flip_sin_to_cos=self.flip_sin_to_cos,
233
+ downscale_freq_shift=self.downscale_freq_shift,
234
+ )
235
+ return t_emb
236
+
237
+
238
+ class GaussianFourierProjection(nn.Module):
239
+ """Gaussian Fourier embeddings for noise levels."""
240
+
241
+ def __init__(
242
+ self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
243
+ ):
244
+ super().__init__()
245
+ self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
246
+ self.log = log
247
+ self.flip_sin_to_cos = flip_sin_to_cos
248
+
249
+ if set_W_to_weight:
250
+ # to delete later
251
+ self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
252
+
253
+ self.weight = self.W
254
+
255
+ def forward(self, x):
256
+ if self.log:
257
+ x = torch.log(x)
258
+
259
+ x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
260
+
261
+ if self.flip_sin_to_cos:
262
+ out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
263
+ else:
264
+ out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
265
+ return out
266
+
267
+
268
+ class ImagePositionalEmbeddings(nn.Module):
269
+ """
270
+ Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
271
+ height and width of the latent space.
272
+
273
+ For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
274
+
275
+ For VQ-diffusion:
276
+
277
+ Output vector embeddings are used as input for the transformer.
278
+
279
+ Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
280
+
281
+ Args:
282
+ num_embed (`int`):
283
+ Number of embeddings for the latent pixels embeddings.
284
+ height (`int`):
285
+ Height of the latent image i.e. the number of height embeddings.
286
+ width (`int`):
287
+ Width of the latent image i.e. the number of width embeddings.
288
+ embed_dim (`int`):
289
+ Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
290
+ """
291
+
292
+ def __init__(
293
+ self,
294
+ num_embed: int,
295
+ height: int,
296
+ width: int,
297
+ embed_dim: int,
298
+ ):
299
+ super().__init__()
300
+
301
+ self.height = height
302
+ self.width = width
303
+ self.num_embed = num_embed
304
+ self.embed_dim = embed_dim
305
+
306
+ self.emb = nn.Embedding(self.num_embed, embed_dim)
307
+ self.height_emb = nn.Embedding(self.height, embed_dim)
308
+ self.width_emb = nn.Embedding(self.width, embed_dim)
309
+
310
+ def forward(self, index):
311
+ emb = self.emb(index)
312
+
313
+ height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))
314
+
315
+ # 1 x H x D -> 1 x H x 1 x D
316
+ height_emb = height_emb.unsqueeze(2)
317
+
318
+ width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))
319
+
320
+ # 1 x W x D -> 1 x 1 x W x D
321
+ width_emb = width_emb.unsqueeze(1)
322
+
323
+ pos_emb = height_emb + width_emb
324
+
325
+ # 1 x H x W x D -> 1 x L xD
326
+ pos_emb = pos_emb.view(1, self.height * self.width, -1)
327
+
328
+ emb = emb + pos_emb[:, : emb.shape[1], :]
329
+
330
+ return emb
331
+
332
+
333
+ class LabelEmbedding(nn.Module):
334
+ """
335
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
336
+
337
+ Args:
338
+ num_classes (`int`): The number of classes.
339
+ hidden_size (`int`): The size of the vector embeddings.
340
+ dropout_prob (`float`): The probability of dropping a label.
341
+ """
342
+
343
+ def __init__(self, num_classes, hidden_size, dropout_prob):
344
+ super().__init__()
345
+ use_cfg_embedding = dropout_prob > 0
346
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
347
+ self.num_classes = num_classes
348
+ self.dropout_prob = dropout_prob
349
+
350
+ def token_drop(self, labels, force_drop_ids=None):
351
+ """
352
+ Drops labels to enable classifier-free guidance.
353
+ """
354
+ if force_drop_ids is None:
355
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
356
+ else:
357
+ drop_ids = torch.tensor(force_drop_ids == 1)
358
+ labels = torch.where(drop_ids, self.num_classes, labels)
359
+ return labels
360
+
361
+ def forward(self, labels, force_drop_ids=None):
362
+ use_dropout = self.dropout_prob > 0
363
+ if (self.training and use_dropout) or (force_drop_ids is not None):
364
+ labels = self.token_drop(labels, force_drop_ids)
365
+ embeddings = self.embedding_table(labels)
366
+ return embeddings
367
+
368
+
369
+ class CombinedTimestepLabelEmbeddings(nn.Module):
370
+ def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
371
+ super().__init__()
372
+
373
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
374
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
375
+ self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)
376
+
377
+ def forward(self, timestep, class_labels, hidden_dtype=None):
378
+ timesteps_proj = self.time_proj(timestep)
379
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
380
+
381
+ class_labels = self.class_embedder(class_labels) # (N, D)
382
+
383
+ conditioning = timesteps_emb + class_labels # (N, D)
384
+
385
  return conditioning
magicanimate/models/motion_module.py CHANGED
@@ -1,334 +1,334 @@
1
- # *************************************************************************
2
- # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
- # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
- # ytedance Inc..
5
- # *************************************************************************
6
-
7
- # Adapted from https://github.com/guoyww/AnimateDiff
8
- from dataclasses import dataclass
9
-
10
- import torch
11
- import torch.nn.functional as F
12
- from torch import nn
13
-
14
- from diffusers.utils import BaseOutput
15
- from diffusers.utils.import_utils import is_xformers_available
16
- from diffusers.models.attention import FeedForward
17
- from magicanimate.models.orig_attention import CrossAttention
18
-
19
- from einops import rearrange, repeat
20
- import math
21
-
22
-
23
- def zero_module(module):
24
- # Zero out the parameters of a module and return it.
25
- for p in module.parameters():
26
- p.detach().zero_()
27
- return module
28
-
29
-
30
- @dataclass
31
- class TemporalTransformer3DModelOutput(BaseOutput):
32
- sample: torch.FloatTensor
33
-
34
-
35
- if is_xformers_available():
36
- import xformers
37
- import xformers.ops
38
- else:
39
- xformers = None
40
-
41
-
42
- def get_motion_module(
43
- in_channels,
44
- motion_module_type: str,
45
- motion_module_kwargs: dict
46
- ):
47
- if motion_module_type == "Vanilla":
48
- return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,)
49
- else:
50
- raise ValueError
51
-
52
-
53
- class VanillaTemporalModule(nn.Module):
54
- def __init__(
55
- self,
56
- in_channels,
57
- num_attention_heads = 8,
58
- num_transformer_block = 2,
59
- attention_block_types =( "Temporal_Self", "Temporal_Self" ),
60
- cross_frame_attention_mode = None,
61
- temporal_position_encoding = False,
62
- temporal_position_encoding_max_len = 24,
63
- temporal_attention_dim_div = 1,
64
- zero_initialize = True,
65
- ):
66
- super().__init__()
67
-
68
- self.temporal_transformer = TemporalTransformer3DModel(
69
- in_channels=in_channels,
70
- num_attention_heads=num_attention_heads,
71
- attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
72
- num_layers=num_transformer_block,
73
- attention_block_types=attention_block_types,
74
- cross_frame_attention_mode=cross_frame_attention_mode,
75
- temporal_position_encoding=temporal_position_encoding,
76
- temporal_position_encoding_max_len=temporal_position_encoding_max_len,
77
- )
78
-
79
- if zero_initialize:
80
- self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
81
-
82
- def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None):
83
- hidden_states = input_tensor
84
- hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
85
-
86
- output = hidden_states
87
- return output
88
-
89
-
90
- class TemporalTransformer3DModel(nn.Module):
91
- def __init__(
92
- self,
93
- in_channels,
94
- num_attention_heads,
95
- attention_head_dim,
96
-
97
- num_layers,
98
- attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
99
- dropout = 0.0,
100
- norm_num_groups = 32,
101
- cross_attention_dim = 768,
102
- activation_fn = "geglu",
103
- attention_bias = False,
104
- upcast_attention = False,
105
-
106
- cross_frame_attention_mode = None,
107
- temporal_position_encoding = False,
108
- temporal_position_encoding_max_len = 24,
109
- ):
110
- super().__init__()
111
-
112
- inner_dim = num_attention_heads * attention_head_dim
113
-
114
- self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
115
- self.proj_in = nn.Linear(in_channels, inner_dim)
116
-
117
- self.transformer_blocks = nn.ModuleList(
118
- [
119
- TemporalTransformerBlock(
120
- dim=inner_dim,
121
- num_attention_heads=num_attention_heads,
122
- attention_head_dim=attention_head_dim,
123
- attention_block_types=attention_block_types,
124
- dropout=dropout,
125
- norm_num_groups=norm_num_groups,
126
- cross_attention_dim=cross_attention_dim,
127
- activation_fn=activation_fn,
128
- attention_bias=attention_bias,
129
- upcast_attention=upcast_attention,
130
- cross_frame_attention_mode=cross_frame_attention_mode,
131
- temporal_position_encoding=temporal_position_encoding,
132
- temporal_position_encoding_max_len=temporal_position_encoding_max_len,
133
- )
134
- for d in range(num_layers)
135
- ]
136
- )
137
- self.proj_out = nn.Linear(inner_dim, in_channels)
138
-
139
- def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
140
- assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
141
- video_length = hidden_states.shape[2]
142
- hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
143
-
144
- batch, channel, height, weight = hidden_states.shape
145
- residual = hidden_states
146
-
147
- hidden_states = self.norm(hidden_states)
148
- inner_dim = hidden_states.shape[1]
149
- hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
150
- hidden_states = self.proj_in(hidden_states)
151
-
152
- # Transformer Blocks
153
- for block in self.transformer_blocks:
154
- hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length)
155
-
156
- # output
157
- hidden_states = self.proj_out(hidden_states)
158
- hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
159
-
160
- output = hidden_states + residual
161
- output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
162
-
163
- return output
164
-
165
-
166
- class TemporalTransformerBlock(nn.Module):
167
- def __init__(
168
- self,
169
- dim,
170
- num_attention_heads,
171
- attention_head_dim,
172
- attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
173
- dropout = 0.0,
174
- norm_num_groups = 32,
175
- cross_attention_dim = 768,
176
- activation_fn = "geglu",
177
- attention_bias = False,
178
- upcast_attention = False,
179
- cross_frame_attention_mode = None,
180
- temporal_position_encoding = False,
181
- temporal_position_encoding_max_len = 24,
182
- ):
183
- super().__init__()
184
-
185
- attention_blocks = []
186
- norms = []
187
-
188
- for block_name in attention_block_types:
189
- attention_blocks.append(
190
- VersatileAttention(
191
- attention_mode=block_name.split("_")[0],
192
- cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
193
-
194
- query_dim=dim,
195
- heads=num_attention_heads,
196
- dim_head=attention_head_dim,
197
- dropout=dropout,
198
- bias=attention_bias,
199
- upcast_attention=upcast_attention,
200
-
201
- cross_frame_attention_mode=cross_frame_attention_mode,
202
- temporal_position_encoding=temporal_position_encoding,
203
- temporal_position_encoding_max_len=temporal_position_encoding_max_len,
204
- )
205
- )
206
- norms.append(nn.LayerNorm(dim))
207
-
208
- self.attention_blocks = nn.ModuleList(attention_blocks)
209
- self.norms = nn.ModuleList(norms)
210
-
211
- self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
212
- self.ff_norm = nn.LayerNorm(dim)
213
-
214
-
215
- def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
216
- for attention_block, norm in zip(self.attention_blocks, self.norms):
217
- norm_hidden_states = norm(hidden_states)
218
- hidden_states = attention_block(
219
- norm_hidden_states,
220
- encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
221
- video_length=video_length,
222
- ) + hidden_states
223
-
224
- hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
225
-
226
- output = hidden_states
227
- return output
228
-
229
-
230
- class PositionalEncoding(nn.Module):
231
- def __init__(
232
- self,
233
- d_model,
234
- dropout = 0.,
235
- max_len = 24
236
- ):
237
- super().__init__()
238
- self.dropout = nn.Dropout(p=dropout)
239
- position = torch.arange(max_len).unsqueeze(1)
240
- div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
241
- pe = torch.zeros(1, max_len, d_model)
242
- pe[0, :, 0::2] = torch.sin(position * div_term)
243
- pe[0, :, 1::2] = torch.cos(position * div_term)
244
- self.register_buffer('pe', pe)
245
-
246
- def forward(self, x):
247
- x = x + self.pe[:, :x.size(1)]
248
- return self.dropout(x)
249
-
250
-
251
- class VersatileAttention(CrossAttention):
252
- def __init__(
253
- self,
254
- attention_mode = None,
255
- cross_frame_attention_mode = None,
256
- temporal_position_encoding = False,
257
- temporal_position_encoding_max_len = 24,
258
- *args, **kwargs
259
- ):
260
- super().__init__(*args, **kwargs)
261
- assert attention_mode == "Temporal"
262
-
263
- self.attention_mode = attention_mode
264
- self.is_cross_attention = kwargs["cross_attention_dim"] is not None
265
-
266
- self.pos_encoder = PositionalEncoding(
267
- kwargs["query_dim"],
268
- dropout=0.,
269
- max_len=temporal_position_encoding_max_len
270
- ) if (temporal_position_encoding and attention_mode == "Temporal") else None
271
-
272
- def extra_repr(self):
273
- return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
274
-
275
- def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
276
- batch_size, sequence_length, _ = hidden_states.shape
277
-
278
- if self.attention_mode == "Temporal":
279
- d = hidden_states.shape[1]
280
- hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
281
-
282
- if self.pos_encoder is not None:
283
- hidden_states = self.pos_encoder(hidden_states)
284
-
285
- encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states
286
- else:
287
- raise NotImplementedError
288
-
289
- encoder_hidden_states = encoder_hidden_states
290
-
291
- if self.group_norm is not None:
292
- hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
293
-
294
- query = self.to_q(hidden_states)
295
- dim = query.shape[-1]
296
- query = self.reshape_heads_to_batch_dim(query)
297
-
298
- if self.added_kv_proj_dim is not None:
299
- raise NotImplementedError
300
-
301
- encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
302
- key = self.to_k(encoder_hidden_states)
303
- value = self.to_v(encoder_hidden_states)
304
-
305
- key = self.reshape_heads_to_batch_dim(key)
306
- value = self.reshape_heads_to_batch_dim(value)
307
-
308
- if attention_mask is not None:
309
- if attention_mask.shape[-1] != query.shape[1]:
310
- target_length = query.shape[1]
311
- attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
312
- attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
313
-
314
- # attention, what we cannot get enough of
315
- if self._use_memory_efficient_attention_xformers:
316
- hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
317
- # Some versions of xformers return output in fp32, cast it back to the dtype of the input
318
- hidden_states = hidden_states.to(query.dtype)
319
- else:
320
- if self._slice_size is None or query.shape[0] // self._slice_size == 1:
321
- hidden_states = self._attention(query, key, value, attention_mask)
322
- else:
323
- hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
324
-
325
- # linear proj
326
- hidden_states = self.to_out[0](hidden_states)
327
-
328
- # dropout
329
- hidden_states = self.to_out[1](hidden_states)
330
-
331
- if self.attention_mode == "Temporal":
332
- hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
333
-
334
- return hidden_states
 
1
+ # *************************************************************************
2
+ # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
+ # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
+ # ytedance Inc..
5
+ # *************************************************************************
6
+
7
+ # Adapted from https://github.com/guoyww/AnimateDiff
8
+ from dataclasses import dataclass
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from torch import nn
13
+
14
+ from diffusers.utils import BaseOutput
15
+ from diffusers.utils.import_utils import is_xformers_available
16
+ from diffusers.models.attention import FeedForward
17
+ from magicanimate.models.orig_attention import CrossAttention
18
+
19
+ from einops import rearrange, repeat
20
+ import math
21
+
22
+
23
+ def zero_module(module):
24
+ # Zero out the parameters of a module and return it.
25
+ for p in module.parameters():
26
+ p.detach().zero_()
27
+ return module
28
+
29
+
30
+ @dataclass
31
+ class TemporalTransformer3DModelOutput(BaseOutput):
32
+ sample: torch.FloatTensor
33
+
34
+
35
+ if is_xformers_available():
36
+ import xformers
37
+ import xformers.ops
38
+ else:
39
+ xformers = None
40
+
41
+
42
+ def get_motion_module(
43
+ in_channels,
44
+ motion_module_type: str,
45
+ motion_module_kwargs: dict
46
+ ):
47
+ if motion_module_type == "Vanilla":
48
+ return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,)
49
+ else:
50
+ raise ValueError
51
+
52
+
53
+ class VanillaTemporalModule(nn.Module):
54
+ def __init__(
55
+ self,
56
+ in_channels,
57
+ num_attention_heads = 8,
58
+ num_transformer_block = 2,
59
+ attention_block_types =( "Temporal_Self", "Temporal_Self" ),
60
+ cross_frame_attention_mode = None,
61
+ temporal_position_encoding = False,
62
+ temporal_position_encoding_max_len = 24,
63
+ temporal_attention_dim_div = 1,
64
+ zero_initialize = True,
65
+ ):
66
+ super().__init__()
67
+
68
+ self.temporal_transformer = TemporalTransformer3DModel(
69
+ in_channels=in_channels,
70
+ num_attention_heads=num_attention_heads,
71
+ attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
72
+ num_layers=num_transformer_block,
73
+ attention_block_types=attention_block_types,
74
+ cross_frame_attention_mode=cross_frame_attention_mode,
75
+ temporal_position_encoding=temporal_position_encoding,
76
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
77
+ )
78
+
79
+ if zero_initialize:
80
+ self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
81
+
82
+ def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None):
83
+ hidden_states = input_tensor
84
+ hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
85
+
86
+ output = hidden_states
87
+ return output
88
+
89
+
90
+ class TemporalTransformer3DModel(nn.Module):
91
+ def __init__(
92
+ self,
93
+ in_channels,
94
+ num_attention_heads,
95
+ attention_head_dim,
96
+
97
+ num_layers,
98
+ attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
99
+ dropout = 0.0,
100
+ norm_num_groups = 32,
101
+ cross_attention_dim = 768,
102
+ activation_fn = "geglu",
103
+ attention_bias = False,
104
+ upcast_attention = False,
105
+
106
+ cross_frame_attention_mode = None,
107
+ temporal_position_encoding = False,
108
+ temporal_position_encoding_max_len = 24,
109
+ ):
110
+ super().__init__()
111
+
112
+ inner_dim = num_attention_heads * attention_head_dim
113
+
114
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
115
+ self.proj_in = nn.Linear(in_channels, inner_dim)
116
+
117
+ self.transformer_blocks = nn.ModuleList(
118
+ [
119
+ TemporalTransformerBlock(
120
+ dim=inner_dim,
121
+ num_attention_heads=num_attention_heads,
122
+ attention_head_dim=attention_head_dim,
123
+ attention_block_types=attention_block_types,
124
+ dropout=dropout,
125
+ norm_num_groups=norm_num_groups,
126
+ cross_attention_dim=cross_attention_dim,
127
+ activation_fn=activation_fn,
128
+ attention_bias=attention_bias,
129
+ upcast_attention=upcast_attention,
130
+ cross_frame_attention_mode=cross_frame_attention_mode,
131
+ temporal_position_encoding=temporal_position_encoding,
132
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
133
+ )
134
+ for d in range(num_layers)
135
+ ]
136
+ )
137
+ self.proj_out = nn.Linear(inner_dim, in_channels)
138
+
139
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
140
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
141
+ video_length = hidden_states.shape[2]
142
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
143
+
144
+ batch, channel, height, weight = hidden_states.shape
145
+ residual = hidden_states
146
+
147
+ hidden_states = self.norm(hidden_states)
148
+ inner_dim = hidden_states.shape[1]
149
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
150
+ hidden_states = self.proj_in(hidden_states)
151
+
152
+ # Transformer Blocks
153
+ for block in self.transformer_blocks:
154
+ hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length)
155
+
156
+ # output
157
+ hidden_states = self.proj_out(hidden_states)
158
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
159
+
160
+ output = hidden_states + residual
161
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
162
+
163
+ return output
164
+
165
+
166
+ class TemporalTransformerBlock(nn.Module):
167
+ def __init__(
168
+ self,
169
+ dim,
170
+ num_attention_heads,
171
+ attention_head_dim,
172
+ attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
173
+ dropout = 0.0,
174
+ norm_num_groups = 32,
175
+ cross_attention_dim = 768,
176
+ activation_fn = "geglu",
177
+ attention_bias = False,
178
+ upcast_attention = False,
179
+ cross_frame_attention_mode = None,
180
+ temporal_position_encoding = False,
181
+ temporal_position_encoding_max_len = 24,
182
+ ):
183
+ super().__init__()
184
+
185
+ attention_blocks = []
186
+ norms = []
187
+
188
+ for block_name in attention_block_types:
189
+ attention_blocks.append(
190
+ VersatileAttention(
191
+ attention_mode=block_name.split("_")[0],
192
+ cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
193
+
194
+ query_dim=dim,
195
+ heads=num_attention_heads,
196
+ dim_head=attention_head_dim,
197
+ dropout=dropout,
198
+ bias=attention_bias,
199
+ upcast_attention=upcast_attention,
200
+
201
+ cross_frame_attention_mode=cross_frame_attention_mode,
202
+ temporal_position_encoding=temporal_position_encoding,
203
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
204
+ )
205
+ )
206
+ norms.append(nn.LayerNorm(dim))
207
+
208
+ self.attention_blocks = nn.ModuleList(attention_blocks)
209
+ self.norms = nn.ModuleList(norms)
210
+
211
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
212
+ self.ff_norm = nn.LayerNorm(dim)
213
+
214
+
215
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
216
+ for attention_block, norm in zip(self.attention_blocks, self.norms):
217
+ norm_hidden_states = norm(hidden_states)
218
+ hidden_states = attention_block(
219
+ norm_hidden_states,
220
+ encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
221
+ video_length=video_length,
222
+ ) + hidden_states
223
+
224
+ hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
225
+
226
+ output = hidden_states
227
+ return output
228
+
229
+
230
+ class PositionalEncoding(nn.Module):
231
+ def __init__(
232
+ self,
233
+ d_model,
234
+ dropout = 0.,
235
+ max_len = 24
236
+ ):
237
+ super().__init__()
238
+ self.dropout = nn.Dropout(p=dropout)
239
+ position = torch.arange(max_len).unsqueeze(1)
240
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
241
+ pe = torch.zeros(1, max_len, d_model)
242
+ pe[0, :, 0::2] = torch.sin(position * div_term)
243
+ pe[0, :, 1::2] = torch.cos(position * div_term)
244
+ self.register_buffer('pe', pe)
245
+
246
+ def forward(self, x):
247
+ x = x + self.pe[:, :x.size(1)]
248
+ return self.dropout(x)
249
+
250
+
251
+ class VersatileAttention(CrossAttention):
252
+ def __init__(
253
+ self,
254
+ attention_mode = None,
255
+ cross_frame_attention_mode = None,
256
+ temporal_position_encoding = False,
257
+ temporal_position_encoding_max_len = 24,
258
+ *args, **kwargs
259
+ ):
260
+ super().__init__(*args, **kwargs)
261
+ assert attention_mode == "Temporal"
262
+
263
+ self.attention_mode = attention_mode
264
+ self.is_cross_attention = kwargs["cross_attention_dim"] is not None
265
+
266
+ self.pos_encoder = PositionalEncoding(
267
+ kwargs["query_dim"],
268
+ dropout=0.,
269
+ max_len=temporal_position_encoding_max_len
270
+ ) if (temporal_position_encoding and attention_mode == "Temporal") else None
271
+
272
+ def extra_repr(self):
273
+ return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
274
+
275
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
276
+ batch_size, sequence_length, _ = hidden_states.shape
277
+
278
+ if self.attention_mode == "Temporal":
279
+ d = hidden_states.shape[1]
280
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
281
+
282
+ if self.pos_encoder is not None:
283
+ hidden_states = self.pos_encoder(hidden_states)
284
+
285
+ encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states
286
+ else:
287
+ raise NotImplementedError
288
+
289
+ encoder_hidden_states = encoder_hidden_states
290
+
291
+ if self.group_norm is not None:
292
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
293
+
294
+ query = self.to_q(hidden_states)
295
+ dim = query.shape[-1]
296
+ query = self.reshape_heads_to_batch_dim(query)
297
+
298
+ if self.added_kv_proj_dim is not None:
299
+ raise NotImplementedError
300
+
301
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
302
+ key = self.to_k(encoder_hidden_states)
303
+ value = self.to_v(encoder_hidden_states)
304
+
305
+ key = self.reshape_heads_to_batch_dim(key)
306
+ value = self.reshape_heads_to_batch_dim(value)
307
+
308
+ if attention_mask is not None:
309
+ if attention_mask.shape[-1] != query.shape[1]:
310
+ target_length = query.shape[1]
311
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
312
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
313
+
314
+ # attention, what we cannot get enough of
315
+ if self._use_memory_efficient_attention_xformers:
316
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
317
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
318
+ hidden_states = hidden_states.to(query.dtype)
319
+ else:
320
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
321
+ hidden_states = self._attention(query, key, value, attention_mask)
322
+ else:
323
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
324
+
325
+ # linear proj
326
+ hidden_states = self.to_out[0](hidden_states)
327
+
328
+ # dropout
329
+ hidden_states = self.to_out[1](hidden_states)
330
+
331
+ if self.attention_mode == "Temporal":
332
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
333
+
334
+ return hidden_states
magicanimate/models/mutual_self_attention.py CHANGED
@@ -1,642 +1,642 @@
1
- # Copyright 2023 ByteDance and/or its affiliates.
2
- #
3
- # Copyright (2023) MagicAnimate Authors
4
- #
5
- # ByteDance, its affiliates and licensors retain all intellectual
6
- # property and proprietary rights in and to this material, related
7
- # documentation and any modifications thereto. Any use, reproduction,
8
- # disclosure or distribution of this material and related documentation
9
- # without an express license agreement from ByteDance or
10
- # its affiliates is strictly prohibited.
11
-
12
- import torch
13
- import torch.nn.functional as F
14
-
15
- from einops import rearrange
16
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
-
18
- from diffusers.models.attention import BasicTransformerBlock
19
- from magicanimate.models.attention import BasicTransformerBlock as _BasicTransformerBlock
20
- from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
21
- from .stable_diffusion_controlnet_reference import torch_dfs
22
-
23
-
24
- class AttentionBase:
25
- def __init__(self):
26
- self.cur_step = 0
27
- self.num_att_layers = -1
28
- self.cur_att_layer = 0
29
-
30
- def after_step(self):
31
- pass
32
-
33
- def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
34
- out = self.forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
35
- self.cur_att_layer += 1
36
- if self.cur_att_layer == self.num_att_layers:
37
- self.cur_att_layer = 0
38
- self.cur_step += 1
39
- # after step
40
- self.after_step()
41
- return out
42
-
43
- def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
44
- out = torch.einsum('b i j, b j d -> b i d', attn, v)
45
- out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)
46
- return out
47
-
48
- def reset(self):
49
- self.cur_step = 0
50
- self.cur_att_layer = 0
51
-
52
-
53
- class MutualSelfAttentionControl(AttentionBase):
54
-
55
- def __init__(self, total_steps=50, hijack_init_state=True, with_negative_guidance=False, appearance_control_alpha=0.5, mode='enqueue'):
56
- """
57
- Mutual self-attention control for Stable-Diffusion MODEl
58
- Args:
59
- total_steps: the total number of steps
60
- """
61
- super().__init__()
62
- self.total_steps = total_steps
63
- self.hijack = hijack_init_state
64
- self.with_negative_guidance = with_negative_guidance
65
-
66
- # alpha: mutual self attention intensity
67
- # TODO: make alpha learnable
68
- self.alpha = appearance_control_alpha
69
- self.GLOBAL_ATTN_QUEUE = []
70
- assert mode in ['enqueue', 'dequeue']
71
- MODE = mode
72
-
73
- def attn_batch(self, q, k, v, num_heads, **kwargs):
74
- """
75
- Performing attention for a batch of queries, keys, and values
76
- """
77
- b = q.shape[0] // num_heads
78
- q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads)
79
- k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads)
80
- v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads)
81
-
82
- sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale")
83
- attn = sim.softmax(-1)
84
- out = torch.einsum("h i j, h j d -> h i d", attn, v)
85
- out = rearrange(out, "h (b n) d -> b n (h d)", b=b)
86
- return out
87
-
88
- def mutual_self_attn(self, q, k, v, num_heads, **kwargs):
89
- q_tgt, q_src = q.chunk(2)
90
- k_tgt, k_src = k.chunk(2)
91
- v_tgt, v_src = v.chunk(2)
92
-
93
- # out_tgt = self.attn_batch(q_tgt, k_src, v_src, num_heads, **kwargs) * self.alpha + \
94
- # self.attn_batch(q_tgt, k_tgt, v_tgt, num_heads, **kwargs) * (1 - self.alpha)
95
- out_tgt = self.attn_batch(q_tgt, torch.cat([k_tgt, k_src], dim=1), torch.cat([v_tgt, v_src], dim=1), num_heads, **kwargs)
96
- out_src = self.attn_batch(q_src, k_src, v_src, num_heads, **kwargs)
97
- out = torch.cat([out_tgt, out_src], dim=0)
98
- return out
99
-
100
- def mutual_self_attn_wq(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
101
- if self.MODE == 'dequeue' and len(self.kv_queue) > 0:
102
- k_src, v_src = self.kv_queue.pop(0)
103
- out = self.attn_batch(q, torch.cat([k, k_src], dim=1), torch.cat([v, v_src], dim=1), num_heads, **kwargs)
104
- return out
105
- else:
106
- self.kv_queue.append([k.clone(), v.clone()])
107
- return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
108
-
109
- def get_queue(self):
110
- return self.GLOBAL_ATTN_QUEUE
111
-
112
- def set_queue(self, attn_queue):
113
- self.GLOBAL_ATTN_QUEUE = attn_queue
114
-
115
- def clear_queue(self):
116
- self.GLOBAL_ATTN_QUEUE = []
117
-
118
- def to(self, dtype):
119
- self.GLOBAL_ATTN_QUEUE = [p.to(dtype) for p in self.GLOBAL_ATTN_QUEUE]
120
-
121
- def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
122
- """
123
- Attention forward function
124
- """
125
- return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
126
-
127
-
128
- class ReferenceAttentionControl():
129
-
130
- def __init__(self,
131
- unet,
132
- mode="write",
133
- do_classifier_free_guidance=False,
134
- attention_auto_machine_weight = float('inf'),
135
- gn_auto_machine_weight = 1.0,
136
- style_fidelity = 1.0,
137
- reference_attn=True,
138
- reference_adain=False,
139
- fusion_blocks="midup",
140
- batch_size=1,
141
- ) -> None:
142
- # 10. Modify self attention and group norm
143
- self.unet = unet
144
- assert mode in ["read", "write"]
145
- assert fusion_blocks in ["midup", "full"]
146
- self.reference_attn = reference_attn
147
- self.reference_adain = reference_adain
148
- self.fusion_blocks = fusion_blocks
149
- self.register_reference_hooks(
150
- mode,
151
- do_classifier_free_guidance,
152
- attention_auto_machine_weight,
153
- gn_auto_machine_weight,
154
- style_fidelity,
155
- reference_attn,
156
- reference_adain,
157
- fusion_blocks,
158
- batch_size=batch_size,
159
- )
160
-
161
- def register_reference_hooks(
162
- self,
163
- mode,
164
- do_classifier_free_guidance,
165
- attention_auto_machine_weight,
166
- gn_auto_machine_weight,
167
- style_fidelity,
168
- reference_attn,
169
- reference_adain,
170
- dtype=torch.float16,
171
- batch_size=1,
172
- num_images_per_prompt=1,
173
- device=torch.device("cpu"),
174
- fusion_blocks='midup',
175
- ):
176
- MODE = mode
177
- do_classifier_free_guidance = do_classifier_free_guidance
178
- attention_auto_machine_weight = attention_auto_machine_weight
179
- gn_auto_machine_weight = gn_auto_machine_weight
180
- style_fidelity = style_fidelity
181
- reference_attn = reference_attn
182
- reference_adain = reference_adain
183
- fusion_blocks = fusion_blocks
184
- num_images_per_prompt = num_images_per_prompt
185
- dtype=dtype
186
- if do_classifier_free_guidance:
187
- uc_mask = (
188
- torch.Tensor([1] * batch_size * num_images_per_prompt * 16 + [0] * batch_size * num_images_per_prompt * 16)
189
- .to(device)
190
- .bool()
191
- )
192
- else:
193
- uc_mask = (
194
- torch.Tensor([0] * batch_size * num_images_per_prompt * 2)
195
- .to(device)
196
- .bool()
197
- )
198
-
199
- def hacked_basic_transformer_inner_forward(
200
- self,
201
- hidden_states: torch.FloatTensor,
202
- attention_mask: Optional[torch.FloatTensor] = None,
203
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
204
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
205
- timestep: Optional[torch.LongTensor] = None,
206
- cross_attention_kwargs: Dict[str, Any] = None,
207
- class_labels: Optional[torch.LongTensor] = None,
208
- video_length=None,
209
- ):
210
- if self.use_ada_layer_norm:
211
- norm_hidden_states = self.norm1(hidden_states, timestep)
212
- elif self.use_ada_layer_norm_zero:
213
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
214
- hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
215
- )
216
- else:
217
- norm_hidden_states = self.norm1(hidden_states)
218
-
219
- # 1. Self-Attention
220
- cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
221
- if self.only_cross_attention:
222
- attn_output = self.attn1(
223
- norm_hidden_states,
224
- encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
225
- attention_mask=attention_mask,
226
- **cross_attention_kwargs,
227
- )
228
- else:
229
- if MODE == "write":
230
- self.bank.append(norm_hidden_states.clone())
231
- attn_output = self.attn1(
232
- norm_hidden_states,
233
- encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
234
- attention_mask=attention_mask,
235
- **cross_attention_kwargs,
236
- )
237
- if MODE == "read":
238
- self.bank = [rearrange(d.unsqueeze(1).repeat(1, video_length, 1, 1), "b t l c -> (b t) l c")[:hidden_states.shape[0]] for d in self.bank]
239
- hidden_states_uc = self.attn1(norm_hidden_states,
240
- encoder_hidden_states=torch.cat([norm_hidden_states] + self.bank, dim=1),
241
- attention_mask=attention_mask) + hidden_states
242
- hidden_states_c = hidden_states_uc.clone()
243
- _uc_mask = uc_mask.clone()
244
- if do_classifier_free_guidance:
245
- if hidden_states.shape[0] != _uc_mask.shape[0]:
246
- _uc_mask = (
247
- torch.Tensor([1] * (hidden_states.shape[0]//2) + [0] * (hidden_states.shape[0]//2))
248
- .to(device)
249
- .bool()
250
- )
251
- hidden_states_c[_uc_mask] = self.attn1(
252
- norm_hidden_states[_uc_mask],
253
- encoder_hidden_states=norm_hidden_states[_uc_mask],
254
- attention_mask=attention_mask,
255
- ) + hidden_states[_uc_mask]
256
- hidden_states = hidden_states_c.clone()
257
-
258
- self.bank.clear()
259
- if self.attn2 is not None:
260
- # Cross-Attention
261
- norm_hidden_states = (
262
- self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
263
- )
264
- hidden_states = (
265
- self.attn2(
266
- norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
267
- )
268
- + hidden_states
269
- )
270
-
271
- # Feed-forward
272
- hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
273
-
274
- # Temporal-Attention
275
- if self.unet_use_temporal_attention:
276
- d = hidden_states.shape[1]
277
- hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
278
- norm_hidden_states = (
279
- self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
280
- )
281
- hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
282
- hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
283
-
284
- return hidden_states
285
-
286
- if self.use_ada_layer_norm_zero:
287
- attn_output = gate_msa.unsqueeze(1) * attn_output
288
- hidden_states = attn_output + hidden_states
289
-
290
- if self.attn2 is not None:
291
- norm_hidden_states = (
292
- self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
293
- )
294
-
295
- # 2. Cross-Attention
296
- attn_output = self.attn2(
297
- norm_hidden_states,
298
- encoder_hidden_states=encoder_hidden_states,
299
- attention_mask=encoder_attention_mask,
300
- **cross_attention_kwargs,
301
- )
302
- hidden_states = attn_output + hidden_states
303
-
304
- # 3. Feed-forward
305
- norm_hidden_states = self.norm3(hidden_states)
306
-
307
- if self.use_ada_layer_norm_zero:
308
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
309
-
310
- ff_output = self.ff(norm_hidden_states)
311
-
312
- if self.use_ada_layer_norm_zero:
313
- ff_output = gate_mlp.unsqueeze(1) * ff_output
314
-
315
- hidden_states = ff_output + hidden_states
316
-
317
- return hidden_states
318
-
319
- def hacked_mid_forward(self, *args, **kwargs):
320
- eps = 1e-6
321
- x = self.original_forward(*args, **kwargs)
322
- if MODE == "write":
323
- if gn_auto_machine_weight >= self.gn_weight:
324
- var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
325
- self.mean_bank.append(mean)
326
- self.var_bank.append(var)
327
- if MODE == "read":
328
- if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
329
- var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
330
- std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
331
- mean_acc = sum(self.mean_bank) / float(len(self.mean_bank))
332
- var_acc = sum(self.var_bank) / float(len(self.var_bank))
333
- std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
334
- x_uc = (((x - mean) / std) * std_acc) + mean_acc
335
- x_c = x_uc.clone()
336
- if do_classifier_free_guidance and style_fidelity > 0:
337
- x_c[uc_mask] = x[uc_mask]
338
- x = style_fidelity * x_c + (1.0 - style_fidelity) * x_uc
339
- self.mean_bank = []
340
- self.var_bank = []
341
- return x
342
-
343
- def hack_CrossAttnDownBlock2D_forward(
344
- self,
345
- hidden_states: torch.FloatTensor,
346
- temb: Optional[torch.FloatTensor] = None,
347
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
348
- attention_mask: Optional[torch.FloatTensor] = None,
349
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
350
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
351
- ):
352
- eps = 1e-6
353
-
354
- # TODO(Patrick, William) - attention mask is not used
355
- output_states = ()
356
-
357
- for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
358
- hidden_states = resnet(hidden_states, temb)
359
- hidden_states = attn(
360
- hidden_states,
361
- encoder_hidden_states=encoder_hidden_states,
362
- cross_attention_kwargs=cross_attention_kwargs,
363
- attention_mask=attention_mask,
364
- encoder_attention_mask=encoder_attention_mask,
365
- return_dict=False,
366
- )[0]
367
- if MODE == "write":
368
- if gn_auto_machine_weight >= self.gn_weight:
369
- var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
370
- self.mean_bank.append([mean])
371
- self.var_bank.append([var])
372
- if MODE == "read":
373
- if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
374
- var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
375
- std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
376
- mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
377
- var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
378
- std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
379
- hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
380
- hidden_states_c = hidden_states_uc.clone()
381
- if do_classifier_free_guidance and style_fidelity > 0:
382
- hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype)
383
- hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
384
-
385
- output_states = output_states + (hidden_states,)
386
-
387
- if MODE == "read":
388
- self.mean_bank = []
389
- self.var_bank = []
390
-
391
- if self.downsamplers is not None:
392
- for downsampler in self.downsamplers:
393
- hidden_states = downsampler(hidden_states)
394
-
395
- output_states = output_states + (hidden_states,)
396
-
397
- return hidden_states, output_states
398
-
399
- def hacked_DownBlock2D_forward(self, hidden_states, temb=None):
400
- eps = 1e-6
401
-
402
- output_states = ()
403
-
404
- for i, resnet in enumerate(self.resnets):
405
- hidden_states = resnet(hidden_states, temb)
406
-
407
- if MODE == "write":
408
- if gn_auto_machine_weight >= self.gn_weight:
409
- var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
410
- self.mean_bank.append([mean])
411
- self.var_bank.append([var])
412
- if MODE == "read":
413
- if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
414
- var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
415
- std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
416
- mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
417
- var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
418
- std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
419
- hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
420
- hidden_states_c = hidden_states_uc.clone()
421
- if do_classifier_free_guidance and style_fidelity > 0:
422
- hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype)
423
- hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
424
-
425
- output_states = output_states + (hidden_states,)
426
-
427
- if MODE == "read":
428
- self.mean_bank = []
429
- self.var_bank = []
430
-
431
- if self.downsamplers is not None:
432
- for downsampler in self.downsamplers:
433
- hidden_states = downsampler(hidden_states)
434
-
435
- output_states = output_states + (hidden_states,)
436
-
437
- return hidden_states, output_states
438
-
439
- def hacked_CrossAttnUpBlock2D_forward(
440
- self,
441
- hidden_states: torch.FloatTensor,
442
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
443
- temb: Optional[torch.FloatTensor] = None,
444
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
445
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
446
- upsample_size: Optional[int] = None,
447
- attention_mask: Optional[torch.FloatTensor] = None,
448
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
449
- ):
450
- eps = 1e-6
451
- # TODO(Patrick, William) - attention mask is not used
452
- for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
453
- # pop res hidden states
454
- res_hidden_states = res_hidden_states_tuple[-1]
455
- res_hidden_states_tuple = res_hidden_states_tuple[:-1]
456
- hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
457
- hidden_states = resnet(hidden_states, temb)
458
- hidden_states = attn(
459
- hidden_states,
460
- encoder_hidden_states=encoder_hidden_states,
461
- cross_attention_kwargs=cross_attention_kwargs,
462
- attention_mask=attention_mask,
463
- encoder_attention_mask=encoder_attention_mask,
464
- return_dict=False,
465
- )[0]
466
-
467
- if MODE == "write":
468
- if gn_auto_machine_weight >= self.gn_weight:
469
- var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
470
- self.mean_bank.append([mean])
471
- self.var_bank.append([var])
472
- if MODE == "read":
473
- if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
474
- var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
475
- std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
476
- mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
477
- var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
478
- std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
479
- hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
480
- hidden_states_c = hidden_states_uc.clone()
481
- if do_classifier_free_guidance and style_fidelity > 0:
482
- hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype)
483
- hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
484
-
485
- if MODE == "read":
486
- self.mean_bank = []
487
- self.var_bank = []
488
-
489
- if self.upsamplers is not None:
490
- for upsampler in self.upsamplers:
491
- hidden_states = upsampler(hidden_states, upsample_size)
492
-
493
- return hidden_states
494
-
495
- def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
496
- eps = 1e-6
497
- for i, resnet in enumerate(self.resnets):
498
- # pop res hidden states
499
- res_hidden_states = res_hidden_states_tuple[-1]
500
- res_hidden_states_tuple = res_hidden_states_tuple[:-1]
501
- hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
502
- hidden_states = resnet(hidden_states, temb)
503
-
504
- if MODE == "write":
505
- if gn_auto_machine_weight >= self.gn_weight:
506
- var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
507
- self.mean_bank.append([mean])
508
- self.var_bank.append([var])
509
- if MODE == "read":
510
- if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
511
- var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
512
- std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
513
- mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
514
- var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
515
- std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
516
- hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
517
- hidden_states_c = hidden_states_uc.clone()
518
- if do_classifier_free_guidance and style_fidelity > 0:
519
- hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype)
520
- hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
521
-
522
- if MODE == "read":
523
- self.mean_bank = []
524
- self.var_bank = []
525
-
526
- if self.upsamplers is not None:
527
- for upsampler in self.upsamplers:
528
- hidden_states = upsampler(hidden_states, upsample_size)
529
-
530
- return hidden_states
531
-
532
- if self.reference_attn:
533
- if self.fusion_blocks == "midup":
534
- attn_modules = [module for module in (torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)]
535
- elif self.fusion_blocks == "full":
536
- attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)]
537
- attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
538
-
539
- for i, module in enumerate(attn_modules):
540
- module._original_inner_forward = module.forward
541
- module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)
542
- module.bank = []
543
- module.attn_weight = float(i) / float(len(attn_modules))
544
-
545
- if self.reference_adain:
546
- gn_modules = [self.unet.mid_block]
547
- self.unet.mid_block.gn_weight = 0
548
-
549
- down_blocks = self.unet.down_blocks
550
- for w, module in enumerate(down_blocks):
551
- module.gn_weight = 1.0 - float(w) / float(len(down_blocks))
552
- gn_modules.append(module)
553
-
554
- up_blocks = self.unet.up_blocks
555
- for w, module in enumerate(up_blocks):
556
- module.gn_weight = float(w) / float(len(up_blocks))
557
- gn_modules.append(module)
558
-
559
- for i, module in enumerate(gn_modules):
560
- if getattr(module, "original_forward", None) is None:
561
- module.original_forward = module.forward
562
- if i == 0:
563
- # mid_block
564
- module.forward = hacked_mid_forward.__get__(module, torch.nn.Module)
565
- elif isinstance(module, CrossAttnDownBlock2D):
566
- module.forward = hack_CrossAttnDownBlock2D_forward.__get__(module, CrossAttnDownBlock2D)
567
- elif isinstance(module, DownBlock2D):
568
- module.forward = hacked_DownBlock2D_forward.__get__(module, DownBlock2D)
569
- elif isinstance(module, CrossAttnUpBlock2D):
570
- module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D)
571
- elif isinstance(module, UpBlock2D):
572
- module.forward = hacked_UpBlock2D_forward.__get__(module, UpBlock2D)
573
- module.mean_bank = []
574
- module.var_bank = []
575
- module.gn_weight *= 2
576
-
577
- def update(self, writer, dtype=torch.float16):
578
- if self.reference_attn:
579
- if self.fusion_blocks == "midup":
580
- reader_attn_modules = [module for module in (torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, _BasicTransformerBlock)]
581
- writer_attn_modules = [module for module in (torch_dfs(writer.unet.mid_block)+torch_dfs(writer.unet.up_blocks)) if isinstance(module, BasicTransformerBlock)]
582
- elif self.fusion_blocks == "full":
583
- reader_attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, _BasicTransformerBlock)]
584
- writer_attn_modules = [module for module in torch_dfs(writer.unet) if isinstance(module, BasicTransformerBlock)]
585
- reader_attn_modules = sorted(reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
586
- writer_attn_modules = sorted(writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
587
- for r, w in zip(reader_attn_modules, writer_attn_modules):
588
- r.bank = [v.clone().to(dtype) for v in w.bank]
589
- # w.bank.clear()
590
- if self.reference_adain:
591
- reader_gn_modules = [self.unet.mid_block]
592
-
593
- down_blocks = self.unet.down_blocks
594
- for w, module in enumerate(down_blocks):
595
- reader_gn_modules.append(module)
596
-
597
- up_blocks = self.unet.up_blocks
598
- for w, module in enumerate(up_blocks):
599
- reader_gn_modules.append(module)
600
-
601
- writer_gn_modules = [writer.unet.mid_block]
602
-
603
- down_blocks = writer.unet.down_blocks
604
- for w, module in enumerate(down_blocks):
605
- writer_gn_modules.append(module)
606
-
607
- up_blocks = writer.unet.up_blocks
608
- for w, module in enumerate(up_blocks):
609
- writer_gn_modules.append(module)
610
-
611
- for r, w in zip(reader_gn_modules, writer_gn_modules):
612
- if len(w.mean_bank) > 0 and isinstance(w.mean_bank[0], list):
613
- r.mean_bank = [[v.clone().to(dtype) for v in vl] for vl in w.mean_bank]
614
- r.var_bank = [[v.clone().to(dtype) for v in vl] for vl in w.var_bank]
615
- else:
616
- r.mean_bank = [v.clone().to(dtype) for v in w.mean_bank]
617
- r.var_bank = [v.clone().to(dtype) for v in w.var_bank]
618
-
619
- def clear(self):
620
- if self.reference_attn:
621
- if self.fusion_blocks == "midup":
622
- reader_attn_modules = [module for module in (torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)]
623
- elif self.fusion_blocks == "full":
624
- reader_attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)]
625
- reader_attn_modules = sorted(reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
626
- for r in reader_attn_modules:
627
- r.bank.clear()
628
- if self.reference_adain:
629
- reader_gn_modules = [self.unet.mid_block]
630
-
631
- down_blocks = self.unet.down_blocks
632
- for w, module in enumerate(down_blocks):
633
- reader_gn_modules.append(module)
634
-
635
- up_blocks = self.unet.up_blocks
636
- for w, module in enumerate(up_blocks):
637
- reader_gn_modules.append(module)
638
-
639
- for r in reader_gn_modules:
640
- r.mean_bank.clear()
641
- r.var_bank.clear()
642
 
 
1
+ # Copyright 2023 ByteDance and/or its affiliates.
2
+ #
3
+ # Copyright (2023) MagicAnimate Authors
4
+ #
5
+ # ByteDance, its affiliates and licensors retain all intellectual
6
+ # property and proprietary rights in and to this material, related
7
+ # documentation and any modifications thereto. Any use, reproduction,
8
+ # disclosure or distribution of this material and related documentation
9
+ # without an express license agreement from ByteDance or
10
+ # its affiliates is strictly prohibited.
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+
15
+ from einops import rearrange
16
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
+
18
+ from diffusers.models.attention import BasicTransformerBlock
19
+ from magicanimate.models.attention import BasicTransformerBlock as _BasicTransformerBlock
20
+ from diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
21
+ from .stable_diffusion_controlnet_reference import torch_dfs
22
+
23
+
24
+ class AttentionBase:
25
+ def __init__(self):
26
+ self.cur_step = 0
27
+ self.num_att_layers = -1
28
+ self.cur_att_layer = 0
29
+
30
+ def after_step(self):
31
+ pass
32
+
33
+ def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
34
+ out = self.forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
35
+ self.cur_att_layer += 1
36
+ if self.cur_att_layer == self.num_att_layers:
37
+ self.cur_att_layer = 0
38
+ self.cur_step += 1
39
+ # after step
40
+ self.after_step()
41
+ return out
42
+
43
+ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
44
+ out = torch.einsum('b i j, b j d -> b i d', attn, v)
45
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)
46
+ return out
47
+
48
+ def reset(self):
49
+ self.cur_step = 0
50
+ self.cur_att_layer = 0
51
+
52
+
53
+ class MutualSelfAttentionControl(AttentionBase):
54
+
55
+ def __init__(self, total_steps=50, hijack_init_state=True, with_negative_guidance=False, appearance_control_alpha=0.5, mode='enqueue'):
56
+ """
57
+ Mutual self-attention control for Stable-Diffusion MODEl
58
+ Args:
59
+ total_steps: the total number of steps
60
+ """
61
+ super().__init__()
62
+ self.total_steps = total_steps
63
+ self.hijack = hijack_init_state
64
+ self.with_negative_guidance = with_negative_guidance
65
+
66
+ # alpha: mutual self attention intensity
67
+ # TODO: make alpha learnable
68
+ self.alpha = appearance_control_alpha
69
+ self.GLOBAL_ATTN_QUEUE = []
70
+ assert mode in ['enqueue', 'dequeue']
71
+ MODE = mode
72
+
73
+ def attn_batch(self, q, k, v, num_heads, **kwargs):
74
+ """
75
+ Performing attention for a batch of queries, keys, and values
76
+ """
77
+ b = q.shape[0] // num_heads
78
+ q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads)
79
+ k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads)
80
+ v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads)
81
+
82
+ sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale")
83
+ attn = sim.softmax(-1)
84
+ out = torch.einsum("h i j, h j d -> h i d", attn, v)
85
+ out = rearrange(out, "h (b n) d -> b n (h d)", b=b)
86
+ return out
87
+
88
+ def mutual_self_attn(self, q, k, v, num_heads, **kwargs):
89
+ q_tgt, q_src = q.chunk(2)
90
+ k_tgt, k_src = k.chunk(2)
91
+ v_tgt, v_src = v.chunk(2)
92
+
93
+ # out_tgt = self.attn_batch(q_tgt, k_src, v_src, num_heads, **kwargs) * self.alpha + \
94
+ # self.attn_batch(q_tgt, k_tgt, v_tgt, num_heads, **kwargs) * (1 - self.alpha)
95
+ out_tgt = self.attn_batch(q_tgt, torch.cat([k_tgt, k_src], dim=1), torch.cat([v_tgt, v_src], dim=1), num_heads, **kwargs)
96
+ out_src = self.attn_batch(q_src, k_src, v_src, num_heads, **kwargs)
97
+ out = torch.cat([out_tgt, out_src], dim=0)
98
+ return out
99
+
100
+ def mutual_self_attn_wq(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
101
+ if self.MODE == 'dequeue' and len(self.kv_queue) > 0:
102
+ k_src, v_src = self.kv_queue.pop(0)
103
+ out = self.attn_batch(q, torch.cat([k, k_src], dim=1), torch.cat([v, v_src], dim=1), num_heads, **kwargs)
104
+ return out
105
+ else:
106
+ self.kv_queue.append([k.clone(), v.clone()])
107
+ return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
108
+
109
+ def get_queue(self):
110
+ return self.GLOBAL_ATTN_QUEUE
111
+
112
+ def set_queue(self, attn_queue):
113
+ self.GLOBAL_ATTN_QUEUE = attn_queue
114
+
115
+ def clear_queue(self):
116
+ self.GLOBAL_ATTN_QUEUE = []
117
+
118
+ def to(self, dtype):
119
+ self.GLOBAL_ATTN_QUEUE = [p.to(dtype) for p in self.GLOBAL_ATTN_QUEUE]
120
+
121
+ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
122
+ """
123
+ Attention forward function
124
+ """
125
+ return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
126
+
127
+
128
+ class ReferenceAttentionControl():
129
+
130
+ def __init__(self,
131
+ unet,
132
+ mode="write",
133
+ do_classifier_free_guidance=False,
134
+ attention_auto_machine_weight = float('inf'),
135
+ gn_auto_machine_weight = 1.0,
136
+ style_fidelity = 1.0,
137
+ reference_attn=True,
138
+ reference_adain=False,
139
+ fusion_blocks="midup",
140
+ batch_size=1,
141
+ ) -> None:
142
+ # 10. Modify self attention and group norm
143
+ self.unet = unet
144
+ assert mode in ["read", "write"]
145
+ assert fusion_blocks in ["midup", "full"]
146
+ self.reference_attn = reference_attn
147
+ self.reference_adain = reference_adain
148
+ self.fusion_blocks = fusion_blocks
149
+ self.register_reference_hooks(
150
+ mode,
151
+ do_classifier_free_guidance,
152
+ attention_auto_machine_weight,
153
+ gn_auto_machine_weight,
154
+ style_fidelity,
155
+ reference_attn,
156
+ reference_adain,
157
+ fusion_blocks,
158
+ batch_size=batch_size,
159
+ )
160
+
161
+ def register_reference_hooks(
162
+ self,
163
+ mode,
164
+ do_classifier_free_guidance,
165
+ attention_auto_machine_weight,
166
+ gn_auto_machine_weight,
167
+ style_fidelity,
168
+ reference_attn,
169
+ reference_adain,
170
+ dtype=torch.float16,
171
+ batch_size=1,
172
+ num_images_per_prompt=1,
173
+ device=torch.device("cpu"),
174
+ fusion_blocks='midup',
175
+ ):
176
+ MODE = mode
177
+ do_classifier_free_guidance = do_classifier_free_guidance
178
+ attention_auto_machine_weight = attention_auto_machine_weight
179
+ gn_auto_machine_weight = gn_auto_machine_weight
180
+ style_fidelity = style_fidelity
181
+ reference_attn = reference_attn
182
+ reference_adain = reference_adain
183
+ fusion_blocks = fusion_blocks
184
+ num_images_per_prompt = num_images_per_prompt
185
+ dtype=dtype
186
+ if do_classifier_free_guidance:
187
+ uc_mask = (
188
+ torch.Tensor([1] * batch_size * num_images_per_prompt * 16 + [0] * batch_size * num_images_per_prompt * 16)
189
+ .to(device)
190
+ .bool()
191
+ )
192
+ else:
193
+ uc_mask = (
194
+ torch.Tensor([0] * batch_size * num_images_per_prompt * 2)
195
+ .to(device)
196
+ .bool()
197
+ )
198
+
199
+ def hacked_basic_transformer_inner_forward(
200
+ self,
201
+ hidden_states: torch.FloatTensor,
202
+ attention_mask: Optional[torch.FloatTensor] = None,
203
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
204
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
205
+ timestep: Optional[torch.LongTensor] = None,
206
+ cross_attention_kwargs: Dict[str, Any] = None,
207
+ class_labels: Optional[torch.LongTensor] = None,
208
+ video_length=None,
209
+ ):
210
+ if self.use_ada_layer_norm:
211
+ norm_hidden_states = self.norm1(hidden_states, timestep)
212
+ elif self.use_ada_layer_norm_zero:
213
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
214
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
215
+ )
216
+ else:
217
+ norm_hidden_states = self.norm1(hidden_states)
218
+
219
+ # 1. Self-Attention
220
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
221
+ if self.only_cross_attention:
222
+ attn_output = self.attn1(
223
+ norm_hidden_states,
224
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
225
+ attention_mask=attention_mask,
226
+ **cross_attention_kwargs,
227
+ )
228
+ else:
229
+ if MODE == "write":
230
+ self.bank.append(norm_hidden_states.clone())
231
+ attn_output = self.attn1(
232
+ norm_hidden_states,
233
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
234
+ attention_mask=attention_mask,
235
+ **cross_attention_kwargs,
236
+ )
237
+ if MODE == "read":
238
+ self.bank = [rearrange(d.unsqueeze(1).repeat(1, video_length, 1, 1), "b t l c -> (b t) l c")[:hidden_states.shape[0]] for d in self.bank]
239
+ hidden_states_uc = self.attn1(norm_hidden_states,
240
+ encoder_hidden_states=torch.cat([norm_hidden_states] + self.bank, dim=1),
241
+ attention_mask=attention_mask) + hidden_states
242
+ hidden_states_c = hidden_states_uc.clone()
243
+ _uc_mask = uc_mask.clone()
244
+ if do_classifier_free_guidance:
245
+ if hidden_states.shape[0] != _uc_mask.shape[0]:
246
+ _uc_mask = (
247
+ torch.Tensor([1] * (hidden_states.shape[0]//2) + [0] * (hidden_states.shape[0]//2))
248
+ .to(device)
249
+ .bool()
250
+ )
251
+ hidden_states_c[_uc_mask] = self.attn1(
252
+ norm_hidden_states[_uc_mask],
253
+ encoder_hidden_states=norm_hidden_states[_uc_mask],
254
+ attention_mask=attention_mask,
255
+ ) + hidden_states[_uc_mask]
256
+ hidden_states = hidden_states_c.clone()
257
+
258
+ self.bank.clear()
259
+ if self.attn2 is not None:
260
+ # Cross-Attention
261
+ norm_hidden_states = (
262
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
263
+ )
264
+ hidden_states = (
265
+ self.attn2(
266
+ norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
267
+ )
268
+ + hidden_states
269
+ )
270
+
271
+ # Feed-forward
272
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
273
+
274
+ # Temporal-Attention
275
+ if self.unet_use_temporal_attention:
276
+ d = hidden_states.shape[1]
277
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
278
+ norm_hidden_states = (
279
+ self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
280
+ )
281
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
282
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
283
+
284
+ return hidden_states
285
+
286
+ if self.use_ada_layer_norm_zero:
287
+ attn_output = gate_msa.unsqueeze(1) * attn_output
288
+ hidden_states = attn_output + hidden_states
289
+
290
+ if self.attn2 is not None:
291
+ norm_hidden_states = (
292
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
293
+ )
294
+
295
+ # 2. Cross-Attention
296
+ attn_output = self.attn2(
297
+ norm_hidden_states,
298
+ encoder_hidden_states=encoder_hidden_states,
299
+ attention_mask=encoder_attention_mask,
300
+ **cross_attention_kwargs,
301
+ )
302
+ hidden_states = attn_output + hidden_states
303
+
304
+ # 3. Feed-forward
305
+ norm_hidden_states = self.norm3(hidden_states)
306
+
307
+ if self.use_ada_layer_norm_zero:
308
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
309
+
310
+ ff_output = self.ff(norm_hidden_states)
311
+
312
+ if self.use_ada_layer_norm_zero:
313
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
314
+
315
+ hidden_states = ff_output + hidden_states
316
+
317
+ return hidden_states
318
+
319
+ def hacked_mid_forward(self, *args, **kwargs):
320
+ eps = 1e-6
321
+ x = self.original_forward(*args, **kwargs)
322
+ if MODE == "write":
323
+ if gn_auto_machine_weight >= self.gn_weight:
324
+ var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
325
+ self.mean_bank.append(mean)
326
+ self.var_bank.append(var)
327
+ if MODE == "read":
328
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
329
+ var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
330
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
331
+ mean_acc = sum(self.mean_bank) / float(len(self.mean_bank))
332
+ var_acc = sum(self.var_bank) / float(len(self.var_bank))
333
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
334
+ x_uc = (((x - mean) / std) * std_acc) + mean_acc
335
+ x_c = x_uc.clone()
336
+ if do_classifier_free_guidance and style_fidelity > 0:
337
+ x_c[uc_mask] = x[uc_mask]
338
+ x = style_fidelity * x_c + (1.0 - style_fidelity) * x_uc
339
+ self.mean_bank = []
340
+ self.var_bank = []
341
+ return x
342
+
343
+ def hack_CrossAttnDownBlock2D_forward(
344
+ self,
345
+ hidden_states: torch.FloatTensor,
346
+ temb: Optional[torch.FloatTensor] = None,
347
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
348
+ attention_mask: Optional[torch.FloatTensor] = None,
349
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
350
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
351
+ ):
352
+ eps = 1e-6
353
+
354
+ # TODO(Patrick, William) - attention mask is not used
355
+ output_states = ()
356
+
357
+ for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
358
+ hidden_states = resnet(hidden_states, temb)
359
+ hidden_states = attn(
360
+ hidden_states,
361
+ encoder_hidden_states=encoder_hidden_states,
362
+ cross_attention_kwargs=cross_attention_kwargs,
363
+ attention_mask=attention_mask,
364
+ encoder_attention_mask=encoder_attention_mask,
365
+ return_dict=False,
366
+ )[0]
367
+ if MODE == "write":
368
+ if gn_auto_machine_weight >= self.gn_weight:
369
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
370
+ self.mean_bank.append([mean])
371
+ self.var_bank.append([var])
372
+ if MODE == "read":
373
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
374
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
375
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
376
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
377
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
378
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
379
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
380
+ hidden_states_c = hidden_states_uc.clone()
381
+ if do_classifier_free_guidance and style_fidelity > 0:
382
+ hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype)
383
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
384
+
385
+ output_states = output_states + (hidden_states,)
386
+
387
+ if MODE == "read":
388
+ self.mean_bank = []
389
+ self.var_bank = []
390
+
391
+ if self.downsamplers is not None:
392
+ for downsampler in self.downsamplers:
393
+ hidden_states = downsampler(hidden_states)
394
+
395
+ output_states = output_states + (hidden_states,)
396
+
397
+ return hidden_states, output_states
398
+
399
+ def hacked_DownBlock2D_forward(self, hidden_states, temb=None):
400
+ eps = 1e-6
401
+
402
+ output_states = ()
403
+
404
+ for i, resnet in enumerate(self.resnets):
405
+ hidden_states = resnet(hidden_states, temb)
406
+
407
+ if MODE == "write":
408
+ if gn_auto_machine_weight >= self.gn_weight:
409
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
410
+ self.mean_bank.append([mean])
411
+ self.var_bank.append([var])
412
+ if MODE == "read":
413
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
414
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
415
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
416
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
417
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
418
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
419
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
420
+ hidden_states_c = hidden_states_uc.clone()
421
+ if do_classifier_free_guidance and style_fidelity > 0:
422
+ hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype)
423
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
424
+
425
+ output_states = output_states + (hidden_states,)
426
+
427
+ if MODE == "read":
428
+ self.mean_bank = []
429
+ self.var_bank = []
430
+
431
+ if self.downsamplers is not None:
432
+ for downsampler in self.downsamplers:
433
+ hidden_states = downsampler(hidden_states)
434
+
435
+ output_states = output_states + (hidden_states,)
436
+
437
+ return hidden_states, output_states
438
+
439
+ def hacked_CrossAttnUpBlock2D_forward(
440
+ self,
441
+ hidden_states: torch.FloatTensor,
442
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
443
+ temb: Optional[torch.FloatTensor] = None,
444
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
445
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
446
+ upsample_size: Optional[int] = None,
447
+ attention_mask: Optional[torch.FloatTensor] = None,
448
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
449
+ ):
450
+ eps = 1e-6
451
+ # TODO(Patrick, William) - attention mask is not used
452
+ for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
453
+ # pop res hidden states
454
+ res_hidden_states = res_hidden_states_tuple[-1]
455
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
456
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
457
+ hidden_states = resnet(hidden_states, temb)
458
+ hidden_states = attn(
459
+ hidden_states,
460
+ encoder_hidden_states=encoder_hidden_states,
461
+ cross_attention_kwargs=cross_attention_kwargs,
462
+ attention_mask=attention_mask,
463
+ encoder_attention_mask=encoder_attention_mask,
464
+ return_dict=False,
465
+ )[0]
466
+
467
+ if MODE == "write":
468
+ if gn_auto_machine_weight >= self.gn_weight:
469
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
470
+ self.mean_bank.append([mean])
471
+ self.var_bank.append([var])
472
+ if MODE == "read":
473
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
474
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
475
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
476
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
477
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
478
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
479
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
480
+ hidden_states_c = hidden_states_uc.clone()
481
+ if do_classifier_free_guidance and style_fidelity > 0:
482
+ hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype)
483
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
484
+
485
+ if MODE == "read":
486
+ self.mean_bank = []
487
+ self.var_bank = []
488
+
489
+ if self.upsamplers is not None:
490
+ for upsampler in self.upsamplers:
491
+ hidden_states = upsampler(hidden_states, upsample_size)
492
+
493
+ return hidden_states
494
+
495
+ def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
496
+ eps = 1e-6
497
+ for i, resnet in enumerate(self.resnets):
498
+ # pop res hidden states
499
+ res_hidden_states = res_hidden_states_tuple[-1]
500
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
501
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
502
+ hidden_states = resnet(hidden_states, temb)
503
+
504
+ if MODE == "write":
505
+ if gn_auto_machine_weight >= self.gn_weight:
506
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
507
+ self.mean_bank.append([mean])
508
+ self.var_bank.append([var])
509
+ if MODE == "read":
510
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
511
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
512
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
513
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
514
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
515
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
516
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
517
+ hidden_states_c = hidden_states_uc.clone()
518
+ if do_classifier_free_guidance and style_fidelity > 0:
519
+ hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype)
520
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
521
+
522
+ if MODE == "read":
523
+ self.mean_bank = []
524
+ self.var_bank = []
525
+
526
+ if self.upsamplers is not None:
527
+ for upsampler in self.upsamplers:
528
+ hidden_states = upsampler(hidden_states, upsample_size)
529
+
530
+ return hidden_states
531
+
532
+ if self.reference_attn:
533
+ if self.fusion_blocks == "midup":
534
+ attn_modules = [module for module in (torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)]
535
+ elif self.fusion_blocks == "full":
536
+ attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)]
537
+ attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
538
+
539
+ for i, module in enumerate(attn_modules):
540
+ module._original_inner_forward = module.forward
541
+ module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)
542
+ module.bank = []
543
+ module.attn_weight = float(i) / float(len(attn_modules))
544
+
545
+ if self.reference_adain:
546
+ gn_modules = [self.unet.mid_block]
547
+ self.unet.mid_block.gn_weight = 0
548
+
549
+ down_blocks = self.unet.down_blocks
550
+ for w, module in enumerate(down_blocks):
551
+ module.gn_weight = 1.0 - float(w) / float(len(down_blocks))
552
+ gn_modules.append(module)
553
+
554
+ up_blocks = self.unet.up_blocks
555
+ for w, module in enumerate(up_blocks):
556
+ module.gn_weight = float(w) / float(len(up_blocks))
557
+ gn_modules.append(module)
558
+
559
+ for i, module in enumerate(gn_modules):
560
+ if getattr(module, "original_forward", None) is None:
561
+ module.original_forward = module.forward
562
+ if i == 0:
563
+ # mid_block
564
+ module.forward = hacked_mid_forward.__get__(module, torch.nn.Module)
565
+ elif isinstance(module, CrossAttnDownBlock2D):
566
+ module.forward = hack_CrossAttnDownBlock2D_forward.__get__(module, CrossAttnDownBlock2D)
567
+ elif isinstance(module, DownBlock2D):
568
+ module.forward = hacked_DownBlock2D_forward.__get__(module, DownBlock2D)
569
+ elif isinstance(module, CrossAttnUpBlock2D):
570
+ module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D)
571
+ elif isinstance(module, UpBlock2D):
572
+ module.forward = hacked_UpBlock2D_forward.__get__(module, UpBlock2D)
573
+ module.mean_bank = []
574
+ module.var_bank = []
575
+ module.gn_weight *= 2
576
+
577
+ def update(self, writer, dtype=torch.float16):
578
+ if self.reference_attn:
579
+ if self.fusion_blocks == "midup":
580
+ reader_attn_modules = [module for module in (torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, _BasicTransformerBlock)]
581
+ writer_attn_modules = [module for module in (torch_dfs(writer.unet.mid_block)+torch_dfs(writer.unet.up_blocks)) if isinstance(module, BasicTransformerBlock)]
582
+ elif self.fusion_blocks == "full":
583
+ reader_attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, _BasicTransformerBlock)]
584
+ writer_attn_modules = [module for module in torch_dfs(writer.unet) if isinstance(module, BasicTransformerBlock)]
585
+ reader_attn_modules = sorted(reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
586
+ writer_attn_modules = sorted(writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
587
+ for r, w in zip(reader_attn_modules, writer_attn_modules):
588
+ r.bank = [v.clone().to(dtype) for v in w.bank]
589
+ # w.bank.clear()
590
+ if self.reference_adain:
591
+ reader_gn_modules = [self.unet.mid_block]
592
+
593
+ down_blocks = self.unet.down_blocks
594
+ for w, module in enumerate(down_blocks):
595
+ reader_gn_modules.append(module)
596
+
597
+ up_blocks = self.unet.up_blocks
598
+ for w, module in enumerate(up_blocks):
599
+ reader_gn_modules.append(module)
600
+
601
+ writer_gn_modules = [writer.unet.mid_block]
602
+
603
+ down_blocks = writer.unet.down_blocks
604
+ for w, module in enumerate(down_blocks):
605
+ writer_gn_modules.append(module)
606
+
607
+ up_blocks = writer.unet.up_blocks
608
+ for w, module in enumerate(up_blocks):
609
+ writer_gn_modules.append(module)
610
+
611
+ for r, w in zip(reader_gn_modules, writer_gn_modules):
612
+ if len(w.mean_bank) > 0 and isinstance(w.mean_bank[0], list):
613
+ r.mean_bank = [[v.clone().to(dtype) for v in vl] for vl in w.mean_bank]
614
+ r.var_bank = [[v.clone().to(dtype) for v in vl] for vl in w.var_bank]
615
+ else:
616
+ r.mean_bank = [v.clone().to(dtype) for v in w.mean_bank]
617
+ r.var_bank = [v.clone().to(dtype) for v in w.var_bank]
618
+
619
+ def clear(self):
620
+ if self.reference_attn:
621
+ if self.fusion_blocks == "midup":
622
+ reader_attn_modules = [module for module in (torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)]
623
+ elif self.fusion_blocks == "full":
624
+ reader_attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)]
625
+ reader_attn_modules = sorted(reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
626
+ for r in reader_attn_modules:
627
+ r.bank.clear()
628
+ if self.reference_adain:
629
+ reader_gn_modules = [self.unet.mid_block]
630
+
631
+ down_blocks = self.unet.down_blocks
632
+ for w, module in enumerate(down_blocks):
633
+ reader_gn_modules.append(module)
634
+
635
+ up_blocks = self.unet.up_blocks
636
+ for w, module in enumerate(up_blocks):
637
+ reader_gn_modules.append(module)
638
+
639
+ for r in reader_gn_modules:
640
+ r.mean_bank.clear()
641
+ r.var_bank.clear()
642
 
magicanimate/models/orig_attention.py CHANGED
@@ -1,988 +1,988 @@
1
- # *************************************************************************
2
- # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
- # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
- # ytedance Inc..
5
- # *************************************************************************
6
-
7
- # Copyright 2022 The HuggingFace Team. All rights reserved.
8
- #
9
- # Licensed under the Apache License, Version 2.0 (the "License");
10
- # you may not use this file except in compliance with the License.
11
- # You may obtain a copy of the License at
12
- #
13
- # http://www.apache.org/licenses/LICENSE-2.0
14
- #
15
- # Unless required by applicable law or agreed to in writing, software
16
- # distributed under the License is distributed on an "AS IS" BASIS,
17
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
- # See the License for the specific language governing permissions and
19
- # limitations under the License.
20
- import math
21
- from dataclasses import dataclass
22
- from typing import Optional
23
-
24
- import torch
25
- import torch.nn.functional as F
26
- from torch import nn
27
-
28
- from diffusers.configuration_utils import ConfigMixin, register_to_config
29
- from diffusers.models.modeling_utils import ModelMixin
30
- from diffusers.models.embeddings import ImagePositionalEmbeddings
31
- from diffusers.utils import BaseOutput
32
- from diffusers.utils.import_utils import is_xformers_available
33
-
34
-
35
- @dataclass
36
- class Transformer2DModelOutput(BaseOutput):
37
- """
38
- Args:
39
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
40
- Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions
41
- for the unnoised latent pixels.
42
- """
43
-
44
- sample: torch.FloatTensor
45
-
46
-
47
- if is_xformers_available():
48
- import xformers
49
- import xformers.ops
50
- else:
51
- xformers = None
52
-
53
-
54
- class Transformer2DModel(ModelMixin, ConfigMixin):
55
- """
56
- Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual
57
- embeddings) inputs.
58
-
59
- When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
60
- transformer action. Finally, reshape to image.
61
-
62
- When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
63
- embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
64
- classes of unnoised image.
65
-
66
- Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
67
- image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
68
-
69
- Parameters:
70
- num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
71
- attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
72
- in_channels (`int`, *optional*):
73
- Pass if the input is continuous. The number of channels in the input and output.
74
- num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
75
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
76
- cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
77
- sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
78
- Note that this is fixed at training time as it is used for learning a number of position embeddings. See
79
- `ImagePositionalEmbeddings`.
80
- num_vector_embeds (`int`, *optional*):
81
- Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
82
- Includes the class for the masked latent pixel.
83
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
84
- num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
85
- The number of diffusion steps used during training. Note that this is fixed at training time as it is used
86
- to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
87
- up to but not more than steps than `num_embeds_ada_norm`.
88
- attention_bias (`bool`, *optional*):
89
- Configure if the TransformerBlocks' attention should contain a bias parameter.
90
- """
91
-
92
- @register_to_config
93
- def __init__(
94
- self,
95
- num_attention_heads: int = 16,
96
- attention_head_dim: int = 88,
97
- in_channels: Optional[int] = None,
98
- num_layers: int = 1,
99
- dropout: float = 0.0,
100
- norm_num_groups: int = 32,
101
- cross_attention_dim: Optional[int] = None,
102
- attention_bias: bool = False,
103
- sample_size: Optional[int] = None,
104
- num_vector_embeds: Optional[int] = None,
105
- activation_fn: str = "geglu",
106
- num_embeds_ada_norm: Optional[int] = None,
107
- use_linear_projection: bool = False,
108
- only_cross_attention: bool = False,
109
- upcast_attention: bool = False,
110
- ):
111
- super().__init__()
112
- self.use_linear_projection = use_linear_projection
113
- self.num_attention_heads = num_attention_heads
114
- self.attention_head_dim = attention_head_dim
115
- inner_dim = num_attention_heads * attention_head_dim
116
-
117
- # 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
118
- # Define whether input is continuous or discrete depending on configuration
119
- self.is_input_continuous = in_channels is not None
120
- self.is_input_vectorized = num_vector_embeds is not None
121
-
122
- if self.is_input_continuous and self.is_input_vectorized:
123
- raise ValueError(
124
- f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
125
- " sure that either `in_channels` or `num_vector_embeds` is None."
126
- )
127
- elif not self.is_input_continuous and not self.is_input_vectorized:
128
- raise ValueError(
129
- f"Has to define either `in_channels`: {in_channels} or `num_vector_embeds`: {num_vector_embeds}. Make"
130
- " sure that either `in_channels` or `num_vector_embeds` is not None."
131
- )
132
-
133
- # 2. Define input layers
134
- if self.is_input_continuous:
135
- self.in_channels = in_channels
136
-
137
- self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
138
- if use_linear_projection:
139
- self.proj_in = nn.Linear(in_channels, inner_dim)
140
- else:
141
- self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
142
- elif self.is_input_vectorized:
143
- assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
144
- assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
145
-
146
- self.height = sample_size
147
- self.width = sample_size
148
- self.num_vector_embeds = num_vector_embeds
149
- self.num_latent_pixels = self.height * self.width
150
-
151
- self.latent_image_embedding = ImagePositionalEmbeddings(
152
- num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
153
- )
154
-
155
- # 3. Define transformers blocks
156
- self.transformer_blocks = nn.ModuleList(
157
- [
158
- BasicTransformerBlock(
159
- inner_dim,
160
- num_attention_heads,
161
- attention_head_dim,
162
- dropout=dropout,
163
- cross_attention_dim=cross_attention_dim,
164
- activation_fn=activation_fn,
165
- num_embeds_ada_norm=num_embeds_ada_norm,
166
- attention_bias=attention_bias,
167
- only_cross_attention=only_cross_attention,
168
- upcast_attention=upcast_attention,
169
- )
170
- for d in range(num_layers)
171
- ]
172
- )
173
-
174
- # 4. Define output layers
175
- if self.is_input_continuous:
176
- if use_linear_projection:
177
- self.proj_out = nn.Linear(in_channels, inner_dim)
178
- else:
179
- self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
180
- elif self.is_input_vectorized:
181
- self.norm_out = nn.LayerNorm(inner_dim)
182
- self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
183
-
184
- def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
185
- """
186
- Args:
187
- hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
188
- When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
189
- hidden_states
190
- encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
191
- Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
192
- self-attention.
193
- timestep ( `torch.long`, *optional*):
194
- Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
195
- return_dict (`bool`, *optional*, defaults to `True`):
196
- Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
197
-
198
- Returns:
199
- [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
200
- if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
201
- tensor.
202
- """
203
- # 1. Input
204
- if self.is_input_continuous:
205
- batch, channel, height, weight = hidden_states.shape
206
- residual = hidden_states
207
-
208
- hidden_states = self.norm(hidden_states)
209
- if not self.use_linear_projection:
210
- hidden_states = self.proj_in(hidden_states)
211
- inner_dim = hidden_states.shape[1]
212
- hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
213
- else:
214
- inner_dim = hidden_states.shape[1]
215
- hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
216
- hidden_states = self.proj_in(hidden_states)
217
- elif self.is_input_vectorized:
218
- hidden_states = self.latent_image_embedding(hidden_states)
219
-
220
- # 2. Blocks
221
- for block in self.transformer_blocks:
222
- hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep)
223
-
224
- # 3. Output
225
- if self.is_input_continuous:
226
- if not self.use_linear_projection:
227
- hidden_states = (
228
- hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
229
- )
230
- hidden_states = self.proj_out(hidden_states)
231
- else:
232
- hidden_states = self.proj_out(hidden_states)
233
- hidden_states = (
234
- hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
235
- )
236
-
237
- output = hidden_states + residual
238
- elif self.is_input_vectorized:
239
- hidden_states = self.norm_out(hidden_states)
240
- logits = self.out(hidden_states)
241
- # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
242
- logits = logits.permute(0, 2, 1)
243
-
244
- # log(p(x_0))
245
- output = F.log_softmax(logits.double(), dim=1).float()
246
-
247
- if not return_dict:
248
- return (output,)
249
-
250
- return Transformer2DModelOutput(sample=output)
251
-
252
-
253
- class AttentionBlock(nn.Module):
254
- """
255
- An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
256
- to the N-d case.
257
- https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
258
- Uses three q, k, v linear layers to compute attention.
259
-
260
- Parameters:
261
- channels (`int`): The number of channels in the input and output.
262
- num_head_channels (`int`, *optional*):
263
- The number of channels in each head. If None, then `num_heads` = 1.
264
- norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm.
265
- rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
266
- eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
267
- """
268
-
269
- # IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore
270
-
271
- def __init__(
272
- self,
273
- channels: int,
274
- num_head_channels: Optional[int] = None,
275
- norm_num_groups: int = 32,
276
- rescale_output_factor: float = 1.0,
277
- eps: float = 1e-5,
278
- ):
279
- super().__init__()
280
- self.channels = channels
281
-
282
- self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
283
- self.num_head_size = num_head_channels
284
- self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)
285
-
286
- # define q,k,v as linear layers
287
- self.query = nn.Linear(channels, channels)
288
- self.key = nn.Linear(channels, channels)
289
- self.value = nn.Linear(channels, channels)
290
-
291
- self.rescale_output_factor = rescale_output_factor
292
- self.proj_attn = nn.Linear(channels, channels, 1)
293
-
294
- self._use_memory_efficient_attention_xformers = False
295
-
296
- def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, *args, **kwargs):
297
- if not is_xformers_available():
298
- raise ModuleNotFoundError(
299
- "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
300
- " xformers",
301
- name="xformers",
302
- )
303
- elif not torch.cuda.is_available():
304
- raise ValueError(
305
- "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
306
- " available for GPU "
307
- )
308
- else:
309
- try:
310
- # Make sure we can run the memory efficient attention
311
- _ = xformers.ops.memory_efficient_attention(
312
- torch.randn((1, 2, 40), device="cuda"),
313
- torch.randn((1, 2, 40), device="cuda"),
314
- torch.randn((1, 2, 40), device="cuda"),
315
- )
316
- except Exception as e:
317
- raise e
318
- self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
319
-
320
- def reshape_heads_to_batch_dim(self, tensor):
321
- batch_size, seq_len, dim = tensor.shape
322
- head_size = self.num_heads
323
- tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
324
- tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
325
- return tensor
326
-
327
- def reshape_batch_dim_to_heads(self, tensor):
328
- batch_size, seq_len, dim = tensor.shape
329
- head_size = self.num_heads
330
- tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
331
- tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
332
- return tensor
333
-
334
- def forward(self, hidden_states):
335
- residual = hidden_states
336
- batch, channel, height, width = hidden_states.shape
337
-
338
- # norm
339
- hidden_states = self.group_norm(hidden_states)
340
-
341
- hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
342
-
343
- # proj to q, k, v
344
- query_proj = self.query(hidden_states)
345
- key_proj = self.key(hidden_states)
346
- value_proj = self.value(hidden_states)
347
-
348
- scale = 1 / math.sqrt(self.channels / self.num_heads)
349
-
350
- query_proj = self.reshape_heads_to_batch_dim(query_proj)
351
- key_proj = self.reshape_heads_to_batch_dim(key_proj)
352
- value_proj = self.reshape_heads_to_batch_dim(value_proj)
353
-
354
- if self._use_memory_efficient_attention_xformers:
355
- # Memory efficient attention
356
- hidden_states = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None)
357
- hidden_states = hidden_states.to(query_proj.dtype)
358
- else:
359
- attention_scores = torch.baddbmm(
360
- torch.empty(
361
- query_proj.shape[0],
362
- query_proj.shape[1],
363
- key_proj.shape[1],
364
- dtype=query_proj.dtype,
365
- device=query_proj.device,
366
- ),
367
- query_proj,
368
- key_proj.transpose(-1, -2),
369
- beta=0,
370
- alpha=scale,
371
- )
372
- attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
373
- hidden_states = torch.bmm(attention_probs, value_proj)
374
-
375
- # reshape hidden_states
376
- hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
377
-
378
- # compute next hidden_states
379
- hidden_states = self.proj_attn(hidden_states)
380
-
381
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
382
-
383
- # res connect and rescale
384
- hidden_states = (hidden_states + residual) / self.rescale_output_factor
385
- return hidden_states
386
-
387
-
388
- class BasicTransformerBlock(nn.Module):
389
- r"""
390
- A basic Transformer block.
391
-
392
- Parameters:
393
- dim (`int`): The number of channels in the input and output.
394
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
395
- attention_head_dim (`int`): The number of channels in each head.
396
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
397
- cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
398
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
399
- num_embeds_ada_norm (:
400
- obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
401
- attention_bias (:
402
- obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
403
- """
404
-
405
- def __init__(
406
- self,
407
- dim: int,
408
- num_attention_heads: int,
409
- attention_head_dim: int,
410
- dropout=0.0,
411
- cross_attention_dim: Optional[int] = None,
412
- activation_fn: str = "geglu",
413
- num_embeds_ada_norm: Optional[int] = None,
414
- attention_bias: bool = False,
415
- only_cross_attention: bool = False,
416
- upcast_attention: bool = False,
417
- ):
418
- super().__init__()
419
- self.only_cross_attention = only_cross_attention
420
- self.use_ada_layer_norm = num_embeds_ada_norm is not None
421
-
422
- # 1. Self-Attn
423
- self.attn1 = CrossAttention(
424
- query_dim=dim,
425
- heads=num_attention_heads,
426
- dim_head=attention_head_dim,
427
- dropout=dropout,
428
- bias=attention_bias,
429
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
430
- upcast_attention=upcast_attention,
431
- ) # is a self-attention
432
- self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
433
-
434
- # 2. Cross-Attn
435
- if cross_attention_dim is not None:
436
- self.attn2 = CrossAttention(
437
- query_dim=dim,
438
- cross_attention_dim=cross_attention_dim,
439
- heads=num_attention_heads,
440
- dim_head=attention_head_dim,
441
- dropout=dropout,
442
- bias=attention_bias,
443
- upcast_attention=upcast_attention,
444
- ) # is self-attn if encoder_hidden_states is none
445
- else:
446
- self.attn2 = None
447
-
448
- self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
449
-
450
- if cross_attention_dim is not None:
451
- self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
452
- else:
453
- self.norm2 = None
454
-
455
- # 3. Feed-forward
456
- self.norm3 = nn.LayerNorm(dim)
457
-
458
- def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, *args, **kwargs):
459
- if not is_xformers_available():
460
- print("Here is how to install it")
461
- raise ModuleNotFoundError(
462
- "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
463
- " xformers",
464
- name="xformers",
465
- )
466
- elif not torch.cuda.is_available():
467
- raise ValueError(
468
- "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
469
- " available for GPU "
470
- )
471
- else:
472
- try:
473
- # Make sure we can run the memory efficient attention
474
- _ = xformers.ops.memory_efficient_attention(
475
- torch.randn((1, 2, 40), device="cuda"),
476
- torch.randn((1, 2, 40), device="cuda"),
477
- torch.randn((1, 2, 40), device="cuda"),
478
- )
479
- except Exception as e:
480
- raise e
481
- self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
482
- if self.attn2 is not None:
483
- self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
484
-
485
- def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None):
486
- # 1. Self-Attention
487
- norm_hidden_states = (
488
- self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
489
- )
490
-
491
- if self.only_cross_attention:
492
- hidden_states = (
493
- self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
494
- )
495
- else:
496
- hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
497
-
498
- if self.attn2 is not None:
499
- # 2. Cross-Attention
500
- norm_hidden_states = (
501
- self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
502
- )
503
- hidden_states = (
504
- self.attn2(
505
- norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
506
- )
507
- + hidden_states
508
- )
509
-
510
- # 3. Feed-forward
511
- hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
512
-
513
- return hidden_states
514
-
515
-
516
- class CrossAttention(nn.Module):
517
- r"""
518
- A cross attention layer.
519
-
520
- Parameters:
521
- query_dim (`int`): The number of channels in the query.
522
- cross_attention_dim (`int`, *optional*):
523
- The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
524
- heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
525
- dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
526
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
527
- bias (`bool`, *optional*, defaults to False):
528
- Set to `True` for the query, key, and value linear layers to contain a bias parameter.
529
- """
530
-
531
- def __init__(
532
- self,
533
- query_dim: int,
534
- cross_attention_dim: Optional[int] = None,
535
- heads: int = 8,
536
- dim_head: int = 64,
537
- dropout: float = 0.0,
538
- bias=False,
539
- upcast_attention: bool = False,
540
- upcast_softmax: bool = False,
541
- added_kv_proj_dim: Optional[int] = None,
542
- norm_num_groups: Optional[int] = None,
543
- ):
544
- super().__init__()
545
- inner_dim = dim_head * heads
546
- cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
547
- self.upcast_attention = upcast_attention
548
- self.upcast_softmax = upcast_softmax
549
-
550
- self.scale = dim_head**-0.5
551
-
552
- self.heads = heads
553
- # for slice_size > 0 the attention score computation
554
- # is split across the batch axis to save memory
555
- # You can set slice_size with `set_attention_slice`
556
- self.sliceable_head_dim = heads
557
- self._slice_size = None
558
- self._use_memory_efficient_attention_xformers = False
559
- self.added_kv_proj_dim = added_kv_proj_dim
560
-
561
- if norm_num_groups is not None:
562
- self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
563
- else:
564
- self.group_norm = None
565
-
566
- self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
567
- self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
568
- self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
569
-
570
- if self.added_kv_proj_dim is not None:
571
- self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
572
- self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
573
-
574
- self.to_out = nn.ModuleList([])
575
- self.to_out.append(nn.Linear(inner_dim, query_dim))
576
- self.to_out.append(nn.Dropout(dropout))
577
-
578
- def reshape_heads_to_batch_dim(self, tensor):
579
- batch_size, seq_len, dim = tensor.shape
580
- head_size = self.heads
581
- tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
582
- tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
583
- return tensor
584
-
585
- def reshape_batch_dim_to_heads(self, tensor):
586
- batch_size, seq_len, dim = tensor.shape
587
- head_size = self.heads
588
- tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
589
- tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
590
- return tensor
591
-
592
- def set_attention_slice(self, slice_size):
593
- if slice_size is not None and slice_size > self.sliceable_head_dim:
594
- raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
595
-
596
- self._slice_size = slice_size
597
-
598
- def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
599
- batch_size, sequence_length, _ = hidden_states.shape
600
-
601
- encoder_hidden_states = encoder_hidden_states
602
-
603
- if self.group_norm is not None:
604
- hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
605
-
606
- query = self.to_q(hidden_states)
607
- dim = query.shape[-1]
608
- query = self.reshape_heads_to_batch_dim(query)
609
-
610
- if self.added_kv_proj_dim is not None:
611
- key = self.to_k(hidden_states)
612
- value = self.to_v(hidden_states)
613
- encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
614
- encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
615
-
616
- key = self.reshape_heads_to_batch_dim(key)
617
- value = self.reshape_heads_to_batch_dim(value)
618
- encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
619
- encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
620
-
621
- key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
622
- value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
623
- else:
624
- encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
625
- key = self.to_k(encoder_hidden_states)
626
- value = self.to_v(encoder_hidden_states)
627
-
628
- key = self.reshape_heads_to_batch_dim(key)
629
- value = self.reshape_heads_to_batch_dim(value)
630
-
631
- if attention_mask is not None:
632
- if attention_mask.shape[-1] != query.shape[1]:
633
- target_length = query.shape[1]
634
- attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
635
- attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
636
-
637
- # attention, what we cannot get enough of
638
- if self._use_memory_efficient_attention_xformers:
639
- hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
640
- # Some versions of xformers return output in fp32, cast it back to the dtype of the input
641
- hidden_states = hidden_states.to(query.dtype)
642
- else:
643
- if self._slice_size is None or query.shape[0] // self._slice_size == 1:
644
- hidden_states = self._attention(query, key, value, attention_mask)
645
- else:
646
- hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
647
-
648
- # linear proj
649
- hidden_states = self.to_out[0](hidden_states)
650
-
651
- # dropout
652
- hidden_states = self.to_out[1](hidden_states)
653
- return hidden_states
654
-
655
- def _attention(self, query, key, value, attention_mask=None):
656
- if self.upcast_attention:
657
- query = query.float()
658
- key = key.float()
659
-
660
- attention_scores = torch.baddbmm(
661
- torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
662
- query,
663
- key.transpose(-1, -2),
664
- beta=0,
665
- alpha=self.scale,
666
- )
667
-
668
- if attention_mask is not None:
669
- attention_scores = attention_scores + attention_mask
670
-
671
- if self.upcast_softmax:
672
- attention_scores = attention_scores.float()
673
-
674
- attention_probs = attention_scores.softmax(dim=-1)
675
-
676
- # cast back to the original dtype
677
- attention_probs = attention_probs.to(value.dtype)
678
-
679
- # compute attention output
680
- hidden_states = torch.bmm(attention_probs, value)
681
-
682
- # reshape hidden_states
683
- hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
684
- return hidden_states
685
-
686
- def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
687
- batch_size_attention = query.shape[0]
688
- hidden_states = torch.zeros(
689
- (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
690
- )
691
- slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
692
- for i in range(hidden_states.shape[0] // slice_size):
693
- start_idx = i * slice_size
694
- end_idx = (i + 1) * slice_size
695
-
696
- query_slice = query[start_idx:end_idx]
697
- key_slice = key[start_idx:end_idx]
698
-
699
- if self.upcast_attention:
700
- query_slice = query_slice.float()
701
- key_slice = key_slice.float()
702
-
703
- attn_slice = torch.baddbmm(
704
- torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
705
- query_slice,
706
- key_slice.transpose(-1, -2),
707
- beta=0,
708
- alpha=self.scale,
709
- )
710
-
711
- if attention_mask is not None:
712
- attn_slice = attn_slice + attention_mask[start_idx:end_idx]
713
-
714
- if self.upcast_softmax:
715
- attn_slice = attn_slice.float()
716
-
717
- attn_slice = attn_slice.softmax(dim=-1)
718
-
719
- # cast back to the original dtype
720
- attn_slice = attn_slice.to(value.dtype)
721
- attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
722
-
723
- hidden_states[start_idx:end_idx] = attn_slice
724
-
725
- # reshape hidden_states
726
- hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
727
- return hidden_states
728
-
729
- def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
730
- # TODO attention_mask
731
- query = query.contiguous()
732
- key = key.contiguous()
733
- value = value.contiguous()
734
- hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
735
- hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
736
- return hidden_states
737
-
738
-
739
- class FeedForward(nn.Module):
740
- r"""
741
- A feed-forward layer.
742
-
743
- Parameters:
744
- dim (`int`): The number of channels in the input.
745
- dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
746
- mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
747
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
748
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
749
- """
750
-
751
- def __init__(
752
- self,
753
- dim: int,
754
- dim_out: Optional[int] = None,
755
- mult: int = 4,
756
- dropout: float = 0.0,
757
- activation_fn: str = "geglu",
758
- ):
759
- super().__init__()
760
- inner_dim = int(dim * mult)
761
- dim_out = dim_out if dim_out is not None else dim
762
-
763
- if activation_fn == "gelu":
764
- act_fn = GELU(dim, inner_dim)
765
- elif activation_fn == "geglu":
766
- act_fn = GEGLU(dim, inner_dim)
767
- elif activation_fn == "geglu-approximate":
768
- act_fn = ApproximateGELU(dim, inner_dim)
769
-
770
- self.net = nn.ModuleList([])
771
- # project in
772
- self.net.append(act_fn)
773
- # project dropout
774
- self.net.append(nn.Dropout(dropout))
775
- # project out
776
- self.net.append(nn.Linear(inner_dim, dim_out))
777
-
778
- def forward(self, hidden_states):
779
- for module in self.net:
780
- hidden_states = module(hidden_states)
781
- return hidden_states
782
-
783
-
784
- class GELU(nn.Module):
785
- r"""
786
- GELU activation function
787
- """
788
-
789
- def __init__(self, dim_in: int, dim_out: int):
790
- super().__init__()
791
- self.proj = nn.Linear(dim_in, dim_out)
792
-
793
- def gelu(self, gate):
794
- if gate.device.type != "mps":
795
- return F.gelu(gate)
796
- # mps: gelu is not implemented for float16
797
- return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
798
-
799
- def forward(self, hidden_states):
800
- hidden_states = self.proj(hidden_states)
801
- hidden_states = self.gelu(hidden_states)
802
- return hidden_states
803
-
804
-
805
- # feedforward
806
- class GEGLU(nn.Module):
807
- r"""
808
- A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
809
-
810
- Parameters:
811
- dim_in (`int`): The number of channels in the input.
812
- dim_out (`int`): The number of channels in the output.
813
- """
814
-
815
- def __init__(self, dim_in: int, dim_out: int):
816
- super().__init__()
817
- self.proj = nn.Linear(dim_in, dim_out * 2)
818
-
819
- def gelu(self, gate):
820
- if gate.device.type != "mps":
821
- return F.gelu(gate)
822
- # mps: gelu is not implemented for float16
823
- return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
824
-
825
- def forward(self, hidden_states):
826
- hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
827
- return hidden_states * self.gelu(gate)
828
-
829
-
830
- class ApproximateGELU(nn.Module):
831
- """
832
- The approximate form of Gaussian Error Linear Unit (GELU)
833
-
834
- For more details, see section 2: https://arxiv.org/abs/1606.08415
835
- """
836
-
837
- def __init__(self, dim_in: int, dim_out: int):
838
- super().__init__()
839
- self.proj = nn.Linear(dim_in, dim_out)
840
-
841
- def forward(self, x):
842
- x = self.proj(x)
843
- return x * torch.sigmoid(1.702 * x)
844
-
845
-
846
- class AdaLayerNorm(nn.Module):
847
- """
848
- Norm layer modified to incorporate timestep embeddings.
849
- """
850
-
851
- def __init__(self, embedding_dim, num_embeddings):
852
- super().__init__()
853
- self.emb = nn.Embedding(num_embeddings, embedding_dim)
854
- self.silu = nn.SiLU()
855
- self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
856
- self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
857
-
858
- def forward(self, x, timestep):
859
- emb = self.linear(self.silu(self.emb(timestep)))
860
- scale, shift = torch.chunk(emb, 2)
861
- x = self.norm(x) * (1 + scale) + shift
862
- return x
863
-
864
-
865
- class DualTransformer2DModel(nn.Module):
866
- """
867
- Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
868
-
869
- Parameters:
870
- num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
871
- attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
872
- in_channels (`int`, *optional*):
873
- Pass if the input is continuous. The number of channels in the input and output.
874
- num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
875
- dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
876
- cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
877
- sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
878
- Note that this is fixed at training time as it is used for learning a number of position embeddings. See
879
- `ImagePositionalEmbeddings`.
880
- num_vector_embeds (`int`, *optional*):
881
- Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
882
- Includes the class for the masked latent pixel.
883
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
884
- num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
885
- The number of diffusion steps used during training. Note that this is fixed at training time as it is used
886
- to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
887
- up to but not more than steps than `num_embeds_ada_norm`.
888
- attention_bias (`bool`, *optional*):
889
- Configure if the TransformerBlocks' attention should contain a bias parameter.
890
- """
891
-
892
- def __init__(
893
- self,
894
- num_attention_heads: int = 16,
895
- attention_head_dim: int = 88,
896
- in_channels: Optional[int] = None,
897
- num_layers: int = 1,
898
- dropout: float = 0.0,
899
- norm_num_groups: int = 32,
900
- cross_attention_dim: Optional[int] = None,
901
- attention_bias: bool = False,
902
- sample_size: Optional[int] = None,
903
- num_vector_embeds: Optional[int] = None,
904
- activation_fn: str = "geglu",
905
- num_embeds_ada_norm: Optional[int] = None,
906
- ):
907
- super().__init__()
908
- self.transformers = nn.ModuleList(
909
- [
910
- Transformer2DModel(
911
- num_attention_heads=num_attention_heads,
912
- attention_head_dim=attention_head_dim,
913
- in_channels=in_channels,
914
- num_layers=num_layers,
915
- dropout=dropout,
916
- norm_num_groups=norm_num_groups,
917
- cross_attention_dim=cross_attention_dim,
918
- attention_bias=attention_bias,
919
- sample_size=sample_size,
920
- num_vector_embeds=num_vector_embeds,
921
- activation_fn=activation_fn,
922
- num_embeds_ada_norm=num_embeds_ada_norm,
923
- )
924
- for _ in range(2)
925
- ]
926
- )
927
-
928
- # Variables that can be set by a pipeline:
929
-
930
- # The ratio of transformer1 to transformer2's output states to be combined during inference
931
- self.mix_ratio = 0.5
932
-
933
- # The shape of `encoder_hidden_states` is expected to be
934
- # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
935
- self.condition_lengths = [77, 257]
936
-
937
- # Which transformer to use to encode which condition.
938
- # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
939
- self.transformer_index_for_condition = [1, 0]
940
-
941
- def forward(
942
- self, hidden_states, encoder_hidden_states, timestep=None, attention_mask=None, return_dict: bool = True
943
- ):
944
- """
945
- Args:
946
- hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
947
- When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
948
- hidden_states
949
- encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
950
- Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
951
- self-attention.
952
- timestep ( `torch.long`, *optional*):
953
- Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
954
- attention_mask (`torch.FloatTensor`, *optional*):
955
- Optional attention mask to be applied in CrossAttention
956
- return_dict (`bool`, *optional*, defaults to `True`):
957
- Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
958
-
959
- Returns:
960
- [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
961
- if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
962
- tensor.
963
- """
964
- input_states = hidden_states
965
-
966
- encoded_states = []
967
- tokens_start = 0
968
- # attention_mask is not used yet
969
- for i in range(2):
970
- # for each of the two transformers, pass the corresponding condition tokens
971
- condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
972
- transformer_index = self.transformer_index_for_condition[i]
973
- encoded_state = self.transformers[transformer_index](
974
- input_states,
975
- encoder_hidden_states=condition_state,
976
- timestep=timestep,
977
- return_dict=False,
978
- )[0]
979
- encoded_states.append(encoded_state - input_states)
980
- tokens_start += self.condition_lengths[i]
981
-
982
- output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
983
- output_states = output_states + input_states
984
-
985
- if not return_dict:
986
- return (output_states,)
987
-
988
  return Transformer2DModelOutput(sample=output_states)
 
1
+ # *************************************************************************
2
+ # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
+ # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
+ # ytedance Inc..
5
+ # *************************************************************************
6
+
7
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ import math
21
+ from dataclasses import dataclass
22
+ from typing import Optional
23
+
24
+ import torch
25
+ import torch.nn.functional as F
26
+ from torch import nn
27
+
28
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
29
+ from diffusers.models.modeling_utils import ModelMixin
30
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
31
+ from diffusers.utils import BaseOutput
32
+ from diffusers.utils.import_utils import is_xformers_available
33
+
34
+
35
+ @dataclass
36
+ class Transformer2DModelOutput(BaseOutput):
37
+ """
38
+ Args:
39
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
40
+ Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions
41
+ for the unnoised latent pixels.
42
+ """
43
+
44
+ sample: torch.FloatTensor
45
+
46
+
47
+ if is_xformers_available():
48
+ import xformers
49
+ import xformers.ops
50
+ else:
51
+ xformers = None
52
+
53
+
54
+ class Transformer2DModel(ModelMixin, ConfigMixin):
55
+ """
56
+ Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual
57
+ embeddings) inputs.
58
+
59
+ When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
60
+ transformer action. Finally, reshape to image.
61
+
62
+ When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
63
+ embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
64
+ classes of unnoised image.
65
+
66
+ Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
67
+ image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
68
+
69
+ Parameters:
70
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
71
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
72
+ in_channels (`int`, *optional*):
73
+ Pass if the input is continuous. The number of channels in the input and output.
74
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
75
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
76
+ cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
77
+ sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
78
+ Note that this is fixed at training time as it is used for learning a number of position embeddings. See
79
+ `ImagePositionalEmbeddings`.
80
+ num_vector_embeds (`int`, *optional*):
81
+ Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
82
+ Includes the class for the masked latent pixel.
83
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
84
+ num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
85
+ The number of diffusion steps used during training. Note that this is fixed at training time as it is used
86
+ to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
87
+ up to but not more than steps than `num_embeds_ada_norm`.
88
+ attention_bias (`bool`, *optional*):
89
+ Configure if the TransformerBlocks' attention should contain a bias parameter.
90
+ """
91
+
92
+ @register_to_config
93
+ def __init__(
94
+ self,
95
+ num_attention_heads: int = 16,
96
+ attention_head_dim: int = 88,
97
+ in_channels: Optional[int] = None,
98
+ num_layers: int = 1,
99
+ dropout: float = 0.0,
100
+ norm_num_groups: int = 32,
101
+ cross_attention_dim: Optional[int] = None,
102
+ attention_bias: bool = False,
103
+ sample_size: Optional[int] = None,
104
+ num_vector_embeds: Optional[int] = None,
105
+ activation_fn: str = "geglu",
106
+ num_embeds_ada_norm: Optional[int] = None,
107
+ use_linear_projection: bool = False,
108
+ only_cross_attention: bool = False,
109
+ upcast_attention: bool = False,
110
+ ):
111
+ super().__init__()
112
+ self.use_linear_projection = use_linear_projection
113
+ self.num_attention_heads = num_attention_heads
114
+ self.attention_head_dim = attention_head_dim
115
+ inner_dim = num_attention_heads * attention_head_dim
116
+
117
+ # 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
118
+ # Define whether input is continuous or discrete depending on configuration
119
+ self.is_input_continuous = in_channels is not None
120
+ self.is_input_vectorized = num_vector_embeds is not None
121
+
122
+ if self.is_input_continuous and self.is_input_vectorized:
123
+ raise ValueError(
124
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
125
+ " sure that either `in_channels` or `num_vector_embeds` is None."
126
+ )
127
+ elif not self.is_input_continuous and not self.is_input_vectorized:
128
+ raise ValueError(
129
+ f"Has to define either `in_channels`: {in_channels} or `num_vector_embeds`: {num_vector_embeds}. Make"
130
+ " sure that either `in_channels` or `num_vector_embeds` is not None."
131
+ )
132
+
133
+ # 2. Define input layers
134
+ if self.is_input_continuous:
135
+ self.in_channels = in_channels
136
+
137
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
138
+ if use_linear_projection:
139
+ self.proj_in = nn.Linear(in_channels, inner_dim)
140
+ else:
141
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
142
+ elif self.is_input_vectorized:
143
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
144
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
145
+
146
+ self.height = sample_size
147
+ self.width = sample_size
148
+ self.num_vector_embeds = num_vector_embeds
149
+ self.num_latent_pixels = self.height * self.width
150
+
151
+ self.latent_image_embedding = ImagePositionalEmbeddings(
152
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
153
+ )
154
+
155
+ # 3. Define transformers blocks
156
+ self.transformer_blocks = nn.ModuleList(
157
+ [
158
+ BasicTransformerBlock(
159
+ inner_dim,
160
+ num_attention_heads,
161
+ attention_head_dim,
162
+ dropout=dropout,
163
+ cross_attention_dim=cross_attention_dim,
164
+ activation_fn=activation_fn,
165
+ num_embeds_ada_norm=num_embeds_ada_norm,
166
+ attention_bias=attention_bias,
167
+ only_cross_attention=only_cross_attention,
168
+ upcast_attention=upcast_attention,
169
+ )
170
+ for d in range(num_layers)
171
+ ]
172
+ )
173
+
174
+ # 4. Define output layers
175
+ if self.is_input_continuous:
176
+ if use_linear_projection:
177
+ self.proj_out = nn.Linear(in_channels, inner_dim)
178
+ else:
179
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
180
+ elif self.is_input_vectorized:
181
+ self.norm_out = nn.LayerNorm(inner_dim)
182
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
183
+
184
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
185
+ """
186
+ Args:
187
+ hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
188
+ When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
189
+ hidden_states
190
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
191
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
192
+ self-attention.
193
+ timestep ( `torch.long`, *optional*):
194
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
195
+ return_dict (`bool`, *optional*, defaults to `True`):
196
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
197
+
198
+ Returns:
199
+ [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
200
+ if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
201
+ tensor.
202
+ """
203
+ # 1. Input
204
+ if self.is_input_continuous:
205
+ batch, channel, height, weight = hidden_states.shape
206
+ residual = hidden_states
207
+
208
+ hidden_states = self.norm(hidden_states)
209
+ if not self.use_linear_projection:
210
+ hidden_states = self.proj_in(hidden_states)
211
+ inner_dim = hidden_states.shape[1]
212
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
213
+ else:
214
+ inner_dim = hidden_states.shape[1]
215
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
216
+ hidden_states = self.proj_in(hidden_states)
217
+ elif self.is_input_vectorized:
218
+ hidden_states = self.latent_image_embedding(hidden_states)
219
+
220
+ # 2. Blocks
221
+ for block in self.transformer_blocks:
222
+ hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep)
223
+
224
+ # 3. Output
225
+ if self.is_input_continuous:
226
+ if not self.use_linear_projection:
227
+ hidden_states = (
228
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
229
+ )
230
+ hidden_states = self.proj_out(hidden_states)
231
+ else:
232
+ hidden_states = self.proj_out(hidden_states)
233
+ hidden_states = (
234
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
235
+ )
236
+
237
+ output = hidden_states + residual
238
+ elif self.is_input_vectorized:
239
+ hidden_states = self.norm_out(hidden_states)
240
+ logits = self.out(hidden_states)
241
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
242
+ logits = logits.permute(0, 2, 1)
243
+
244
+ # log(p(x_0))
245
+ output = F.log_softmax(logits.double(), dim=1).float()
246
+
247
+ if not return_dict:
248
+ return (output,)
249
+
250
+ return Transformer2DModelOutput(sample=output)
251
+
252
+
253
+ class AttentionBlock(nn.Module):
254
+ """
255
+ An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
256
+ to the N-d case.
257
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
258
+ Uses three q, k, v linear layers to compute attention.
259
+
260
+ Parameters:
261
+ channels (`int`): The number of channels in the input and output.
262
+ num_head_channels (`int`, *optional*):
263
+ The number of channels in each head. If None, then `num_heads` = 1.
264
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm.
265
+ rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
266
+ eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
267
+ """
268
+
269
+ # IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore
270
+
271
+ def __init__(
272
+ self,
273
+ channels: int,
274
+ num_head_channels: Optional[int] = None,
275
+ norm_num_groups: int = 32,
276
+ rescale_output_factor: float = 1.0,
277
+ eps: float = 1e-5,
278
+ ):
279
+ super().__init__()
280
+ self.channels = channels
281
+
282
+ self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
283
+ self.num_head_size = num_head_channels
284
+ self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)
285
+
286
+ # define q,k,v as linear layers
287
+ self.query = nn.Linear(channels, channels)
288
+ self.key = nn.Linear(channels, channels)
289
+ self.value = nn.Linear(channels, channels)
290
+
291
+ self.rescale_output_factor = rescale_output_factor
292
+ self.proj_attn = nn.Linear(channels, channels, 1)
293
+
294
+ self._use_memory_efficient_attention_xformers = False
295
+
296
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, *args, **kwargs):
297
+ if not is_xformers_available():
298
+ raise ModuleNotFoundError(
299
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
300
+ " xformers",
301
+ name="xformers",
302
+ )
303
+ elif not torch.cuda.is_available():
304
+ raise ValueError(
305
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
306
+ " available for GPU "
307
+ )
308
+ else:
309
+ try:
310
+ # Make sure we can run the memory efficient attention
311
+ _ = xformers.ops.memory_efficient_attention(
312
+ torch.randn((1, 2, 40), device="cuda"),
313
+ torch.randn((1, 2, 40), device="cuda"),
314
+ torch.randn((1, 2, 40), device="cuda"),
315
+ )
316
+ except Exception as e:
317
+ raise e
318
+ self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
319
+
320
+ def reshape_heads_to_batch_dim(self, tensor):
321
+ batch_size, seq_len, dim = tensor.shape
322
+ head_size = self.num_heads
323
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
324
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
325
+ return tensor
326
+
327
+ def reshape_batch_dim_to_heads(self, tensor):
328
+ batch_size, seq_len, dim = tensor.shape
329
+ head_size = self.num_heads
330
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
331
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
332
+ return tensor
333
+
334
+ def forward(self, hidden_states):
335
+ residual = hidden_states
336
+ batch, channel, height, width = hidden_states.shape
337
+
338
+ # norm
339
+ hidden_states = self.group_norm(hidden_states)
340
+
341
+ hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
342
+
343
+ # proj to q, k, v
344
+ query_proj = self.query(hidden_states)
345
+ key_proj = self.key(hidden_states)
346
+ value_proj = self.value(hidden_states)
347
+
348
+ scale = 1 / math.sqrt(self.channels / self.num_heads)
349
+
350
+ query_proj = self.reshape_heads_to_batch_dim(query_proj)
351
+ key_proj = self.reshape_heads_to_batch_dim(key_proj)
352
+ value_proj = self.reshape_heads_to_batch_dim(value_proj)
353
+
354
+ if self._use_memory_efficient_attention_xformers:
355
+ # Memory efficient attention
356
+ hidden_states = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None)
357
+ hidden_states = hidden_states.to(query_proj.dtype)
358
+ else:
359
+ attention_scores = torch.baddbmm(
360
+ torch.empty(
361
+ query_proj.shape[0],
362
+ query_proj.shape[1],
363
+ key_proj.shape[1],
364
+ dtype=query_proj.dtype,
365
+ device=query_proj.device,
366
+ ),
367
+ query_proj,
368
+ key_proj.transpose(-1, -2),
369
+ beta=0,
370
+ alpha=scale,
371
+ )
372
+ attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
373
+ hidden_states = torch.bmm(attention_probs, value_proj)
374
+
375
+ # reshape hidden_states
376
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
377
+
378
+ # compute next hidden_states
379
+ hidden_states = self.proj_attn(hidden_states)
380
+
381
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
382
+
383
+ # res connect and rescale
384
+ hidden_states = (hidden_states + residual) / self.rescale_output_factor
385
+ return hidden_states
386
+
387
+
388
+ class BasicTransformerBlock(nn.Module):
389
+ r"""
390
+ A basic Transformer block.
391
+
392
+ Parameters:
393
+ dim (`int`): The number of channels in the input and output.
394
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
395
+ attention_head_dim (`int`): The number of channels in each head.
396
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
397
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
398
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
399
+ num_embeds_ada_norm (:
400
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
401
+ attention_bias (:
402
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
403
+ """
404
+
405
+ def __init__(
406
+ self,
407
+ dim: int,
408
+ num_attention_heads: int,
409
+ attention_head_dim: int,
410
+ dropout=0.0,
411
+ cross_attention_dim: Optional[int] = None,
412
+ activation_fn: str = "geglu",
413
+ num_embeds_ada_norm: Optional[int] = None,
414
+ attention_bias: bool = False,
415
+ only_cross_attention: bool = False,
416
+ upcast_attention: bool = False,
417
+ ):
418
+ super().__init__()
419
+ self.only_cross_attention = only_cross_attention
420
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
421
+
422
+ # 1. Self-Attn
423
+ self.attn1 = CrossAttention(
424
+ query_dim=dim,
425
+ heads=num_attention_heads,
426
+ dim_head=attention_head_dim,
427
+ dropout=dropout,
428
+ bias=attention_bias,
429
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
430
+ upcast_attention=upcast_attention,
431
+ ) # is a self-attention
432
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
433
+
434
+ # 2. Cross-Attn
435
+ if cross_attention_dim is not None:
436
+ self.attn2 = CrossAttention(
437
+ query_dim=dim,
438
+ cross_attention_dim=cross_attention_dim,
439
+ heads=num_attention_heads,
440
+ dim_head=attention_head_dim,
441
+ dropout=dropout,
442
+ bias=attention_bias,
443
+ upcast_attention=upcast_attention,
444
+ ) # is self-attn if encoder_hidden_states is none
445
+ else:
446
+ self.attn2 = None
447
+
448
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
449
+
450
+ if cross_attention_dim is not None:
451
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
452
+ else:
453
+ self.norm2 = None
454
+
455
+ # 3. Feed-forward
456
+ self.norm3 = nn.LayerNorm(dim)
457
+
458
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, *args, **kwargs):
459
+ if not is_xformers_available():
460
+ print("Here is how to install it")
461
+ raise ModuleNotFoundError(
462
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
463
+ " xformers",
464
+ name="xformers",
465
+ )
466
+ elif not torch.cuda.is_available():
467
+ raise ValueError(
468
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
469
+ " available for GPU "
470
+ )
471
+ else:
472
+ try:
473
+ # Make sure we can run the memory efficient attention
474
+ _ = xformers.ops.memory_efficient_attention(
475
+ torch.randn((1, 2, 40), device="cuda"),
476
+ torch.randn((1, 2, 40), device="cuda"),
477
+ torch.randn((1, 2, 40), device="cuda"),
478
+ )
479
+ except Exception as e:
480
+ raise e
481
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
482
+ if self.attn2 is not None:
483
+ self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
484
+
485
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None):
486
+ # 1. Self-Attention
487
+ norm_hidden_states = (
488
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
489
+ )
490
+
491
+ if self.only_cross_attention:
492
+ hidden_states = (
493
+ self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
494
+ )
495
+ else:
496
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
497
+
498
+ if self.attn2 is not None:
499
+ # 2. Cross-Attention
500
+ norm_hidden_states = (
501
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
502
+ )
503
+ hidden_states = (
504
+ self.attn2(
505
+ norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
506
+ )
507
+ + hidden_states
508
+ )
509
+
510
+ # 3. Feed-forward
511
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
512
+
513
+ return hidden_states
514
+
515
+
516
+ class CrossAttention(nn.Module):
517
+ r"""
518
+ A cross attention layer.
519
+
520
+ Parameters:
521
+ query_dim (`int`): The number of channels in the query.
522
+ cross_attention_dim (`int`, *optional*):
523
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
524
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
525
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
526
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
527
+ bias (`bool`, *optional*, defaults to False):
528
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
529
+ """
530
+
531
+ def __init__(
532
+ self,
533
+ query_dim: int,
534
+ cross_attention_dim: Optional[int] = None,
535
+ heads: int = 8,
536
+ dim_head: int = 64,
537
+ dropout: float = 0.0,
538
+ bias=False,
539
+ upcast_attention: bool = False,
540
+ upcast_softmax: bool = False,
541
+ added_kv_proj_dim: Optional[int] = None,
542
+ norm_num_groups: Optional[int] = None,
543
+ ):
544
+ super().__init__()
545
+ inner_dim = dim_head * heads
546
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
547
+ self.upcast_attention = upcast_attention
548
+ self.upcast_softmax = upcast_softmax
549
+
550
+ self.scale = dim_head**-0.5
551
+
552
+ self.heads = heads
553
+ # for slice_size > 0 the attention score computation
554
+ # is split across the batch axis to save memory
555
+ # You can set slice_size with `set_attention_slice`
556
+ self.sliceable_head_dim = heads
557
+ self._slice_size = None
558
+ self._use_memory_efficient_attention_xformers = False
559
+ self.added_kv_proj_dim = added_kv_proj_dim
560
+
561
+ if norm_num_groups is not None:
562
+ self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
563
+ else:
564
+ self.group_norm = None
565
+
566
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
567
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
568
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
569
+
570
+ if self.added_kv_proj_dim is not None:
571
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
572
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
573
+
574
+ self.to_out = nn.ModuleList([])
575
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
576
+ self.to_out.append(nn.Dropout(dropout))
577
+
578
+ def reshape_heads_to_batch_dim(self, tensor):
579
+ batch_size, seq_len, dim = tensor.shape
580
+ head_size = self.heads
581
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
582
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
583
+ return tensor
584
+
585
+ def reshape_batch_dim_to_heads(self, tensor):
586
+ batch_size, seq_len, dim = tensor.shape
587
+ head_size = self.heads
588
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
589
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
590
+ return tensor
591
+
592
+ def set_attention_slice(self, slice_size):
593
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
594
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
595
+
596
+ self._slice_size = slice_size
597
+
598
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
599
+ batch_size, sequence_length, _ = hidden_states.shape
600
+
601
+ encoder_hidden_states = encoder_hidden_states
602
+
603
+ if self.group_norm is not None:
604
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
605
+
606
+ query = self.to_q(hidden_states)
607
+ dim = query.shape[-1]
608
+ query = self.reshape_heads_to_batch_dim(query)
609
+
610
+ if self.added_kv_proj_dim is not None:
611
+ key = self.to_k(hidden_states)
612
+ value = self.to_v(hidden_states)
613
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
614
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
615
+
616
+ key = self.reshape_heads_to_batch_dim(key)
617
+ value = self.reshape_heads_to_batch_dim(value)
618
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
619
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
620
+
621
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
622
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
623
+ else:
624
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
625
+ key = self.to_k(encoder_hidden_states)
626
+ value = self.to_v(encoder_hidden_states)
627
+
628
+ key = self.reshape_heads_to_batch_dim(key)
629
+ value = self.reshape_heads_to_batch_dim(value)
630
+
631
+ if attention_mask is not None:
632
+ if attention_mask.shape[-1] != query.shape[1]:
633
+ target_length = query.shape[1]
634
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
635
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
636
+
637
+ # attention, what we cannot get enough of
638
+ if self._use_memory_efficient_attention_xformers:
639
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
640
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
641
+ hidden_states = hidden_states.to(query.dtype)
642
+ else:
643
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
644
+ hidden_states = self._attention(query, key, value, attention_mask)
645
+ else:
646
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
647
+
648
+ # linear proj
649
+ hidden_states = self.to_out[0](hidden_states)
650
+
651
+ # dropout
652
+ hidden_states = self.to_out[1](hidden_states)
653
+ return hidden_states
654
+
655
+ def _attention(self, query, key, value, attention_mask=None):
656
+ if self.upcast_attention:
657
+ query = query.float()
658
+ key = key.float()
659
+
660
+ attention_scores = torch.baddbmm(
661
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
662
+ query,
663
+ key.transpose(-1, -2),
664
+ beta=0,
665
+ alpha=self.scale,
666
+ )
667
+
668
+ if attention_mask is not None:
669
+ attention_scores = attention_scores + attention_mask
670
+
671
+ if self.upcast_softmax:
672
+ attention_scores = attention_scores.float()
673
+
674
+ attention_probs = attention_scores.softmax(dim=-1)
675
+
676
+ # cast back to the original dtype
677
+ attention_probs = attention_probs.to(value.dtype)
678
+
679
+ # compute attention output
680
+ hidden_states = torch.bmm(attention_probs, value)
681
+
682
+ # reshape hidden_states
683
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
684
+ return hidden_states
685
+
686
+ def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
687
+ batch_size_attention = query.shape[0]
688
+ hidden_states = torch.zeros(
689
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
690
+ )
691
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
692
+ for i in range(hidden_states.shape[0] // slice_size):
693
+ start_idx = i * slice_size
694
+ end_idx = (i + 1) * slice_size
695
+
696
+ query_slice = query[start_idx:end_idx]
697
+ key_slice = key[start_idx:end_idx]
698
+
699
+ if self.upcast_attention:
700
+ query_slice = query_slice.float()
701
+ key_slice = key_slice.float()
702
+
703
+ attn_slice = torch.baddbmm(
704
+ torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
705
+ query_slice,
706
+ key_slice.transpose(-1, -2),
707
+ beta=0,
708
+ alpha=self.scale,
709
+ )
710
+
711
+ if attention_mask is not None:
712
+ attn_slice = attn_slice + attention_mask[start_idx:end_idx]
713
+
714
+ if self.upcast_softmax:
715
+ attn_slice = attn_slice.float()
716
+
717
+ attn_slice = attn_slice.softmax(dim=-1)
718
+
719
+ # cast back to the original dtype
720
+ attn_slice = attn_slice.to(value.dtype)
721
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
722
+
723
+ hidden_states[start_idx:end_idx] = attn_slice
724
+
725
+ # reshape hidden_states
726
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
727
+ return hidden_states
728
+
729
+ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
730
+ # TODO attention_mask
731
+ query = query.contiguous()
732
+ key = key.contiguous()
733
+ value = value.contiguous()
734
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
735
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
736
+ return hidden_states
737
+
738
+
739
+ class FeedForward(nn.Module):
740
+ r"""
741
+ A feed-forward layer.
742
+
743
+ Parameters:
744
+ dim (`int`): The number of channels in the input.
745
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
746
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
747
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
748
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
749
+ """
750
+
751
+ def __init__(
752
+ self,
753
+ dim: int,
754
+ dim_out: Optional[int] = None,
755
+ mult: int = 4,
756
+ dropout: float = 0.0,
757
+ activation_fn: str = "geglu",
758
+ ):
759
+ super().__init__()
760
+ inner_dim = int(dim * mult)
761
+ dim_out = dim_out if dim_out is not None else dim
762
+
763
+ if activation_fn == "gelu":
764
+ act_fn = GELU(dim, inner_dim)
765
+ elif activation_fn == "geglu":
766
+ act_fn = GEGLU(dim, inner_dim)
767
+ elif activation_fn == "geglu-approximate":
768
+ act_fn = ApproximateGELU(dim, inner_dim)
769
+
770
+ self.net = nn.ModuleList([])
771
+ # project in
772
+ self.net.append(act_fn)
773
+ # project dropout
774
+ self.net.append(nn.Dropout(dropout))
775
+ # project out
776
+ self.net.append(nn.Linear(inner_dim, dim_out))
777
+
778
+ def forward(self, hidden_states):
779
+ for module in self.net:
780
+ hidden_states = module(hidden_states)
781
+ return hidden_states
782
+
783
+
784
+ class GELU(nn.Module):
785
+ r"""
786
+ GELU activation function
787
+ """
788
+
789
+ def __init__(self, dim_in: int, dim_out: int):
790
+ super().__init__()
791
+ self.proj = nn.Linear(dim_in, dim_out)
792
+
793
+ def gelu(self, gate):
794
+ if gate.device.type != "mps":
795
+ return F.gelu(gate)
796
+ # mps: gelu is not implemented for float16
797
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
798
+
799
+ def forward(self, hidden_states):
800
+ hidden_states = self.proj(hidden_states)
801
+ hidden_states = self.gelu(hidden_states)
802
+ return hidden_states
803
+
804
+
805
+ # feedforward
806
+ class GEGLU(nn.Module):
807
+ r"""
808
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
809
+
810
+ Parameters:
811
+ dim_in (`int`): The number of channels in the input.
812
+ dim_out (`int`): The number of channels in the output.
813
+ """
814
+
815
+ def __init__(self, dim_in: int, dim_out: int):
816
+ super().__init__()
817
+ self.proj = nn.Linear(dim_in, dim_out * 2)
818
+
819
+ def gelu(self, gate):
820
+ if gate.device.type != "mps":
821
+ return F.gelu(gate)
822
+ # mps: gelu is not implemented for float16
823
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
824
+
825
+ def forward(self, hidden_states):
826
+ hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
827
+ return hidden_states * self.gelu(gate)
828
+
829
+
830
+ class ApproximateGELU(nn.Module):
831
+ """
832
+ The approximate form of Gaussian Error Linear Unit (GELU)
833
+
834
+ For more details, see section 2: https://arxiv.org/abs/1606.08415
835
+ """
836
+
837
+ def __init__(self, dim_in: int, dim_out: int):
838
+ super().__init__()
839
+ self.proj = nn.Linear(dim_in, dim_out)
840
+
841
+ def forward(self, x):
842
+ x = self.proj(x)
843
+ return x * torch.sigmoid(1.702 * x)
844
+
845
+
846
+ class AdaLayerNorm(nn.Module):
847
+ """
848
+ Norm layer modified to incorporate timestep embeddings.
849
+ """
850
+
851
+ def __init__(self, embedding_dim, num_embeddings):
852
+ super().__init__()
853
+ self.emb = nn.Embedding(num_embeddings, embedding_dim)
854
+ self.silu = nn.SiLU()
855
+ self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
856
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
857
+
858
+ def forward(self, x, timestep):
859
+ emb = self.linear(self.silu(self.emb(timestep)))
860
+ scale, shift = torch.chunk(emb, 2)
861
+ x = self.norm(x) * (1 + scale) + shift
862
+ return x
863
+
864
+
865
+ class DualTransformer2DModel(nn.Module):
866
+ """
867
+ Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
868
+
869
+ Parameters:
870
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
871
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
872
+ in_channels (`int`, *optional*):
873
+ Pass if the input is continuous. The number of channels in the input and output.
874
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
875
+ dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
876
+ cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
877
+ sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
878
+ Note that this is fixed at training time as it is used for learning a number of position embeddings. See
879
+ `ImagePositionalEmbeddings`.
880
+ num_vector_embeds (`int`, *optional*):
881
+ Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
882
+ Includes the class for the masked latent pixel.
883
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
884
+ num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
885
+ The number of diffusion steps used during training. Note that this is fixed at training time as it is used
886
+ to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
887
+ up to but not more than steps than `num_embeds_ada_norm`.
888
+ attention_bias (`bool`, *optional*):
889
+ Configure if the TransformerBlocks' attention should contain a bias parameter.
890
+ """
891
+
892
+ def __init__(
893
+ self,
894
+ num_attention_heads: int = 16,
895
+ attention_head_dim: int = 88,
896
+ in_channels: Optional[int] = None,
897
+ num_layers: int = 1,
898
+ dropout: float = 0.0,
899
+ norm_num_groups: int = 32,
900
+ cross_attention_dim: Optional[int] = None,
901
+ attention_bias: bool = False,
902
+ sample_size: Optional[int] = None,
903
+ num_vector_embeds: Optional[int] = None,
904
+ activation_fn: str = "geglu",
905
+ num_embeds_ada_norm: Optional[int] = None,
906
+ ):
907
+ super().__init__()
908
+ self.transformers = nn.ModuleList(
909
+ [
910
+ Transformer2DModel(
911
+ num_attention_heads=num_attention_heads,
912
+ attention_head_dim=attention_head_dim,
913
+ in_channels=in_channels,
914
+ num_layers=num_layers,
915
+ dropout=dropout,
916
+ norm_num_groups=norm_num_groups,
917
+ cross_attention_dim=cross_attention_dim,
918
+ attention_bias=attention_bias,
919
+ sample_size=sample_size,
920
+ num_vector_embeds=num_vector_embeds,
921
+ activation_fn=activation_fn,
922
+ num_embeds_ada_norm=num_embeds_ada_norm,
923
+ )
924
+ for _ in range(2)
925
+ ]
926
+ )
927
+
928
+ # Variables that can be set by a pipeline:
929
+
930
+ # The ratio of transformer1 to transformer2's output states to be combined during inference
931
+ self.mix_ratio = 0.5
932
+
933
+ # The shape of `encoder_hidden_states` is expected to be
934
+ # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
935
+ self.condition_lengths = [77, 257]
936
+
937
+ # Which transformer to use to encode which condition.
938
+ # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
939
+ self.transformer_index_for_condition = [1, 0]
940
+
941
+ def forward(
942
+ self, hidden_states, encoder_hidden_states, timestep=None, attention_mask=None, return_dict: bool = True
943
+ ):
944
+ """
945
+ Args:
946
+ hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
947
+ When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
948
+ hidden_states
949
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
950
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
951
+ self-attention.
952
+ timestep ( `torch.long`, *optional*):
953
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
954
+ attention_mask (`torch.FloatTensor`, *optional*):
955
+ Optional attention mask to be applied in CrossAttention
956
+ return_dict (`bool`, *optional*, defaults to `True`):
957
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
958
+
959
+ Returns:
960
+ [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
961
+ if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
962
+ tensor.
963
+ """
964
+ input_states = hidden_states
965
+
966
+ encoded_states = []
967
+ tokens_start = 0
968
+ # attention_mask is not used yet
969
+ for i in range(2):
970
+ # for each of the two transformers, pass the corresponding condition tokens
971
+ condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
972
+ transformer_index = self.transformer_index_for_condition[i]
973
+ encoded_state = self.transformers[transformer_index](
974
+ input_states,
975
+ encoder_hidden_states=condition_state,
976
+ timestep=timestep,
977
+ return_dict=False,
978
+ )[0]
979
+ encoded_states.append(encoded_state - input_states)
980
+ tokens_start += self.condition_lengths[i]
981
+
982
+ output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
983
+ output_states = output_states + input_states
984
+
985
+ if not return_dict:
986
+ return (output_states,)
987
+
988
  return Transformer2DModelOutput(sample=output_states)
magicanimate/models/resnet.py CHANGED
@@ -1,212 +1,212 @@
1
- # *************************************************************************
2
- # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
- # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
- # ytedance Inc..
5
- # *************************************************************************
6
-
7
- # Adapted from https://github.com/guoyww/AnimateDiff
8
-
9
- # Copyright 2023 The HuggingFace Team. All rights reserved.
10
- # `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
11
- #
12
- # Licensed under the Apache License, Version 2.0 (the "License");
13
- # you may not use this file except in compliance with the License.
14
- # You may obtain a copy of the License at
15
- #
16
- # http://www.apache.org/licenses/LICENSE-2.0
17
- #
18
- # Unless required by applicable law or agreed to in writing, software
19
- # distributed under the License is distributed on an "AS IS" BASIS,
20
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21
- # See the License for the specific language governing permissions and
22
- # limitations under the License.
23
- import torch
24
- import torch.nn as nn
25
- import torch.nn.functional as F
26
-
27
- from einops import rearrange
28
-
29
-
30
- class InflatedConv3d(nn.Conv2d):
31
- def forward(self, x):
32
- video_length = x.shape[2]
33
-
34
- x = rearrange(x, "b c f h w -> (b f) c h w")
35
- x = super().forward(x)
36
- x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
37
-
38
- return x
39
-
40
-
41
- class Upsample3D(nn.Module):
42
- def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
43
- super().__init__()
44
- self.channels = channels
45
- self.out_channels = out_channels or channels
46
- self.use_conv = use_conv
47
- self.use_conv_transpose = use_conv_transpose
48
- self.name = name
49
-
50
- conv = None
51
- if use_conv_transpose:
52
- raise NotImplementedError
53
- elif use_conv:
54
- self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
55
-
56
- def forward(self, hidden_states, output_size=None):
57
- assert hidden_states.shape[1] == self.channels
58
-
59
- if self.use_conv_transpose:
60
- raise NotImplementedError
61
-
62
- # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
63
- dtype = hidden_states.dtype
64
- if dtype == torch.bfloat16:
65
- hidden_states = hidden_states.to(torch.float32)
66
-
67
- # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
68
- if hidden_states.shape[0] >= 64:
69
- hidden_states = hidden_states.contiguous()
70
-
71
- # if `output_size` is passed we force the interpolation output
72
- # size and do not make use of `scale_factor=2`
73
- if output_size is None:
74
- hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
75
- else:
76
- hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
77
-
78
- # If the input is bfloat16, we cast back to bfloat16
79
- if dtype == torch.bfloat16:
80
- hidden_states = hidden_states.to(dtype)
81
-
82
- hidden_states = self.conv(hidden_states)
83
-
84
- return hidden_states
85
-
86
-
87
- class Downsample3D(nn.Module):
88
- def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
89
- super().__init__()
90
- self.channels = channels
91
- self.out_channels = out_channels or channels
92
- self.use_conv = use_conv
93
- self.padding = padding
94
- stride = 2
95
- self.name = name
96
-
97
- if use_conv:
98
- self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
99
- else:
100
- raise NotImplementedError
101
-
102
- def forward(self, hidden_states):
103
- assert hidden_states.shape[1] == self.channels
104
- if self.use_conv and self.padding == 0:
105
- raise NotImplementedError
106
-
107
- assert hidden_states.shape[1] == self.channels
108
- hidden_states = self.conv(hidden_states)
109
-
110
- return hidden_states
111
-
112
-
113
- class ResnetBlock3D(nn.Module):
114
- def __init__(
115
- self,
116
- *,
117
- in_channels,
118
- out_channels=None,
119
- conv_shortcut=False,
120
- dropout=0.0,
121
- temb_channels=512,
122
- groups=32,
123
- groups_out=None,
124
- pre_norm=True,
125
- eps=1e-6,
126
- non_linearity="swish",
127
- time_embedding_norm="default",
128
- output_scale_factor=1.0,
129
- use_in_shortcut=None,
130
- ):
131
- super().__init__()
132
- self.pre_norm = pre_norm
133
- self.pre_norm = True
134
- self.in_channels = in_channels
135
- out_channels = in_channels if out_channels is None else out_channels
136
- self.out_channels = out_channels
137
- self.use_conv_shortcut = conv_shortcut
138
- self.time_embedding_norm = time_embedding_norm
139
- self.output_scale_factor = output_scale_factor
140
-
141
- if groups_out is None:
142
- groups_out = groups
143
-
144
- self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
145
-
146
- self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
147
-
148
- if temb_channels is not None:
149
- if self.time_embedding_norm == "default":
150
- time_emb_proj_out_channels = out_channels
151
- elif self.time_embedding_norm == "scale_shift":
152
- time_emb_proj_out_channels = out_channels * 2
153
- else:
154
- raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
155
-
156
- self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
157
- else:
158
- self.time_emb_proj = None
159
-
160
- self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
161
- self.dropout = torch.nn.Dropout(dropout)
162
- self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
163
-
164
- if non_linearity == "swish":
165
- self.nonlinearity = lambda x: F.silu(x)
166
- elif non_linearity == "mish":
167
- self.nonlinearity = Mish()
168
- elif non_linearity == "silu":
169
- self.nonlinearity = nn.SiLU()
170
-
171
- self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
172
-
173
- self.conv_shortcut = None
174
- if self.use_in_shortcut:
175
- self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
176
-
177
- def forward(self, input_tensor, temb):
178
- hidden_states = input_tensor
179
-
180
- hidden_states = self.norm1(hidden_states)
181
- hidden_states = self.nonlinearity(hidden_states)
182
-
183
- hidden_states = self.conv1(hidden_states)
184
-
185
- if temb is not None:
186
- temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
187
-
188
- if temb is not None and self.time_embedding_norm == "default":
189
- hidden_states = hidden_states + temb
190
-
191
- hidden_states = self.norm2(hidden_states)
192
-
193
- if temb is not None and self.time_embedding_norm == "scale_shift":
194
- scale, shift = torch.chunk(temb, 2, dim=1)
195
- hidden_states = hidden_states * (1 + scale) + shift
196
-
197
- hidden_states = self.nonlinearity(hidden_states)
198
-
199
- hidden_states = self.dropout(hidden_states)
200
- hidden_states = self.conv2(hidden_states)
201
-
202
- if self.conv_shortcut is not None:
203
- input_tensor = self.conv_shortcut(input_tensor)
204
-
205
- output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
206
-
207
- return output_tensor
208
-
209
-
210
- class Mish(torch.nn.Module):
211
- def forward(self, hidden_states):
212
  return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
 
1
+ # *************************************************************************
2
+ # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
+ # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
+ # ytedance Inc..
5
+ # *************************************************************************
6
+
7
+ # Adapted from https://github.com/guoyww/AnimateDiff
8
+
9
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
10
+ # `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
11
+ #
12
+ # Licensed under the Apache License, Version 2.0 (the "License");
13
+ # you may not use this file except in compliance with the License.
14
+ # You may obtain a copy of the License at
15
+ #
16
+ # http://www.apache.org/licenses/LICENSE-2.0
17
+ #
18
+ # Unless required by applicable law or agreed to in writing, software
19
+ # distributed under the License is distributed on an "AS IS" BASIS,
20
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21
+ # See the License for the specific language governing permissions and
22
+ # limitations under the License.
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+
27
+ from einops import rearrange
28
+
29
+
30
+ class InflatedConv3d(nn.Conv2d):
31
+ def forward(self, x):
32
+ video_length = x.shape[2]
33
+
34
+ x = rearrange(x, "b c f h w -> (b f) c h w")
35
+ x = super().forward(x)
36
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
37
+
38
+ return x
39
+
40
+
41
+ class Upsample3D(nn.Module):
42
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
43
+ super().__init__()
44
+ self.channels = channels
45
+ self.out_channels = out_channels or channels
46
+ self.use_conv = use_conv
47
+ self.use_conv_transpose = use_conv_transpose
48
+ self.name = name
49
+
50
+ conv = None
51
+ if use_conv_transpose:
52
+ raise NotImplementedError
53
+ elif use_conv:
54
+ self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
55
+
56
+ def forward(self, hidden_states, output_size=None):
57
+ assert hidden_states.shape[1] == self.channels
58
+
59
+ if self.use_conv_transpose:
60
+ raise NotImplementedError
61
+
62
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
63
+ dtype = hidden_states.dtype
64
+ if dtype == torch.bfloat16:
65
+ hidden_states = hidden_states.to(torch.float32)
66
+
67
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
68
+ if hidden_states.shape[0] >= 64:
69
+ hidden_states = hidden_states.contiguous()
70
+
71
+ # if `output_size` is passed we force the interpolation output
72
+ # size and do not make use of `scale_factor=2`
73
+ if output_size is None:
74
+ hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
75
+ else:
76
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
77
+
78
+ # If the input is bfloat16, we cast back to bfloat16
79
+ if dtype == torch.bfloat16:
80
+ hidden_states = hidden_states.to(dtype)
81
+
82
+ hidden_states = self.conv(hidden_states)
83
+
84
+ return hidden_states
85
+
86
+
87
+ class Downsample3D(nn.Module):
88
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
89
+ super().__init__()
90
+ self.channels = channels
91
+ self.out_channels = out_channels or channels
92
+ self.use_conv = use_conv
93
+ self.padding = padding
94
+ stride = 2
95
+ self.name = name
96
+
97
+ if use_conv:
98
+ self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
99
+ else:
100
+ raise NotImplementedError
101
+
102
+ def forward(self, hidden_states):
103
+ assert hidden_states.shape[1] == self.channels
104
+ if self.use_conv and self.padding == 0:
105
+ raise NotImplementedError
106
+
107
+ assert hidden_states.shape[1] == self.channels
108
+ hidden_states = self.conv(hidden_states)
109
+
110
+ return hidden_states
111
+
112
+
113
+ class ResnetBlock3D(nn.Module):
114
+ def __init__(
115
+ self,
116
+ *,
117
+ in_channels,
118
+ out_channels=None,
119
+ conv_shortcut=False,
120
+ dropout=0.0,
121
+ temb_channels=512,
122
+ groups=32,
123
+ groups_out=None,
124
+ pre_norm=True,
125
+ eps=1e-6,
126
+ non_linearity="swish",
127
+ time_embedding_norm="default",
128
+ output_scale_factor=1.0,
129
+ use_in_shortcut=None,
130
+ ):
131
+ super().__init__()
132
+ self.pre_norm = pre_norm
133
+ self.pre_norm = True
134
+ self.in_channels = in_channels
135
+ out_channels = in_channels if out_channels is None else out_channels
136
+ self.out_channels = out_channels
137
+ self.use_conv_shortcut = conv_shortcut
138
+ self.time_embedding_norm = time_embedding_norm
139
+ self.output_scale_factor = output_scale_factor
140
+
141
+ if groups_out is None:
142
+ groups_out = groups
143
+
144
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
145
+
146
+ self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
147
+
148
+ if temb_channels is not None:
149
+ if self.time_embedding_norm == "default":
150
+ time_emb_proj_out_channels = out_channels
151
+ elif self.time_embedding_norm == "scale_shift":
152
+ time_emb_proj_out_channels = out_channels * 2
153
+ else:
154
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
155
+
156
+ self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
157
+ else:
158
+ self.time_emb_proj = None
159
+
160
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
161
+ self.dropout = torch.nn.Dropout(dropout)
162
+ self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
163
+
164
+ if non_linearity == "swish":
165
+ self.nonlinearity = lambda x: F.silu(x)
166
+ elif non_linearity == "mish":
167
+ self.nonlinearity = Mish()
168
+ elif non_linearity == "silu":
169
+ self.nonlinearity = nn.SiLU()
170
+
171
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
172
+
173
+ self.conv_shortcut = None
174
+ if self.use_in_shortcut:
175
+ self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
176
+
177
+ def forward(self, input_tensor, temb):
178
+ hidden_states = input_tensor
179
+
180
+ hidden_states = self.norm1(hidden_states)
181
+ hidden_states = self.nonlinearity(hidden_states)
182
+
183
+ hidden_states = self.conv1(hidden_states)
184
+
185
+ if temb is not None:
186
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
187
+
188
+ if temb is not None and self.time_embedding_norm == "default":
189
+ hidden_states = hidden_states + temb
190
+
191
+ hidden_states = self.norm2(hidden_states)
192
+
193
+ if temb is not None and self.time_embedding_norm == "scale_shift":
194
+ scale, shift = torch.chunk(temb, 2, dim=1)
195
+ hidden_states = hidden_states * (1 + scale) + shift
196
+
197
+ hidden_states = self.nonlinearity(hidden_states)
198
+
199
+ hidden_states = self.dropout(hidden_states)
200
+ hidden_states = self.conv2(hidden_states)
201
+
202
+ if self.conv_shortcut is not None:
203
+ input_tensor = self.conv_shortcut(input_tensor)
204
+
205
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
206
+
207
+ return output_tensor
208
+
209
+
210
+ class Mish(torch.nn.Module):
211
+ def forward(self, hidden_states):
212
  return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
magicanimate/models/stable_diffusion_controlnet_reference.py CHANGED
@@ -1,840 +1,840 @@
1
- # *************************************************************************
2
- # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
- # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
- # ytedance Inc..
5
- # *************************************************************************
6
-
7
- # Inspired by: https://github.com/Mikubill/sd-webui-controlnet/discussions/1236 and https://github.com/Mikubill/sd-webui-controlnet/discussions/1280
8
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
9
-
10
- import numpy as np
11
- import PIL.Image
12
- import torch
13
-
14
- from diffusers import StableDiffusionControlNetPipeline
15
- from diffusers.models import ControlNetModel
16
- from diffusers.models.attention import BasicTransformerBlock
17
- from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
18
- from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
19
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
20
- from diffusers.utils import logging
21
- from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
22
-
23
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
24
-
25
- EXAMPLE_DOC_STRING = """
26
- Examples:
27
- ```py
28
- >>> import cv2
29
- >>> import torch
30
- >>> import numpy as np
31
- >>> from PIL import Image
32
- >>> from diffusers import UniPCMultistepScheduler
33
- >>> from diffusers.utils import load_image
34
-
35
- >>> input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")
36
-
37
- >>> # get canny image
38
- >>> image = cv2.Canny(np.array(input_image), 100, 200)
39
- >>> image = image[:, :, None]
40
- >>> image = np.concatenate([image, image, image], axis=2)
41
- >>> canny_image = Image.fromarray(image)
42
-
43
- >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
44
- >>> pipe = StableDiffusionControlNetReferencePipeline.from_pretrained(
45
- "runwayml/stable-diffusion-v1-5",
46
- controlnet=controlnet,
47
- safety_checker=None,
48
- torch_dtype=torch.float16
49
- ).to('cuda:0')
50
-
51
- >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe_controlnet.scheduler.config)
52
-
53
- >>> result_img = pipe(ref_image=input_image,
54
- prompt="1girl",
55
- image=canny_image,
56
- num_inference_steps=20,
57
- reference_attn=True,
58
- reference_adain=True).images[0]
59
-
60
- >>> result_img.show()
61
- ```
62
- """
63
-
64
-
65
- def torch_dfs(model: torch.nn.Module):
66
- result = [model]
67
- for child in model.children():
68
- result += torch_dfs(child)
69
- return result
70
-
71
-
72
- class StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeline):
73
- def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance):
74
- refimage = refimage.to(device=device, dtype=dtype)
75
-
76
- # encode the mask image into latents space so we can concatenate it to the latents
77
- if isinstance(generator, list):
78
- ref_image_latents = [
79
- self.vae.encode(refimage[i : i + 1]).latent_dist.sample(generator=generator[i])
80
- for i in range(batch_size)
81
- ]
82
- ref_image_latents = torch.cat(ref_image_latents, dim=0)
83
- else:
84
- ref_image_latents = self.vae.encode(refimage).latent_dist.sample(generator=generator)
85
- ref_image_latents = self.vae.config.scaling_factor * ref_image_latents
86
-
87
- # duplicate mask and ref_image_latents for each generation per prompt, using mps friendly method
88
- if ref_image_latents.shape[0] < batch_size:
89
- if not batch_size % ref_image_latents.shape[0] == 0:
90
- raise ValueError(
91
- "The passed images and the required batch size don't match. Images are supposed to be duplicated"
92
- f" to a total batch size of {batch_size}, but {ref_image_latents.shape[0]} images were passed."
93
- " Make sure the number of images that you pass is divisible by the total requested batch size."
94
- )
95
- ref_image_latents = ref_image_latents.repeat(batch_size // ref_image_latents.shape[0], 1, 1, 1)
96
-
97
- ref_image_latents = torch.cat([ref_image_latents] * 2) if do_classifier_free_guidance else ref_image_latents
98
-
99
- # aligning device to prevent device errors when concating it with the latent model input
100
- ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
101
- return ref_image_latents
102
-
103
- @torch.no_grad()
104
- def __call__(
105
- self,
106
- prompt: Union[str, List[str]] = None,
107
- image: Union[
108
- torch.FloatTensor,
109
- PIL.Image.Image,
110
- np.ndarray,
111
- List[torch.FloatTensor],
112
- List[PIL.Image.Image],
113
- List[np.ndarray],
114
- ] = None,
115
- ref_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
116
- height: Optional[int] = None,
117
- width: Optional[int] = None,
118
- num_inference_steps: int = 50,
119
- guidance_scale: float = 7.5,
120
- negative_prompt: Optional[Union[str, List[str]]] = None,
121
- num_images_per_prompt: Optional[int] = 1,
122
- eta: float = 0.0,
123
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
124
- latents: Optional[torch.FloatTensor] = None,
125
- prompt_embeds: Optional[torch.FloatTensor] = None,
126
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
127
- output_type: Optional[str] = "pil",
128
- return_dict: bool = True,
129
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
130
- callback_steps: int = 1,
131
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
132
- controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
133
- guess_mode: bool = False,
134
- attention_auto_machine_weight: float = 1.0,
135
- gn_auto_machine_weight: float = 1.0,
136
- style_fidelity: float = 0.5,
137
- reference_attn: bool = True,
138
- reference_adain: bool = True,
139
- ):
140
- r"""
141
- Function invoked when calling the pipeline for generation.
142
-
143
- Args:
144
- prompt (`str` or `List[str]`, *optional*):
145
- The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
146
- instead.
147
- image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
148
- `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
149
- The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
150
- the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
151
- also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
152
- height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
153
- specified in init, images must be passed as a list such that each element of the list can be correctly
154
- batched for input to a single controlnet.
155
- ref_image (`torch.FloatTensor`, `PIL.Image.Image`):
156
- The Reference Control input condition. Reference Control uses this input condition to generate guidance to Unet. If
157
- the type is specified as `Torch.FloatTensor`, it is passed to Reference Control as is. `PIL.Image.Image` can
158
- also be accepted as an image.
159
- height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
160
- The height in pixels of the generated image.
161
- width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
162
- The width in pixels of the generated image.
163
- num_inference_steps (`int`, *optional*, defaults to 50):
164
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
165
- expense of slower inference.
166
- guidance_scale (`float`, *optional*, defaults to 7.5):
167
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
168
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
169
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
170
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
171
- usually at the expense of lower image quality.
172
- negative_prompt (`str` or `List[str]`, *optional*):
173
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
174
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
175
- less than `1`).
176
- num_images_per_prompt (`int`, *optional*, defaults to 1):
177
- The number of images to generate per prompt.
178
- eta (`float`, *optional*, defaults to 0.0):
179
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
180
- [`schedulers.DDIMScheduler`], will be ignored for others.
181
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
182
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
183
- to make generation deterministic.
184
- latents (`torch.FloatTensor`, *optional*):
185
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
186
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
187
- tensor will ge generated by sampling using the supplied random `generator`.
188
- prompt_embeds (`torch.FloatTensor`, *optional*):
189
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
190
- provided, text embeddings will be generated from `prompt` input argument.
191
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
192
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
193
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
194
- argument.
195
- output_type (`str`, *optional*, defaults to `"pil"`):
196
- The output format of the generate image. Choose between
197
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
198
- return_dict (`bool`, *optional*, defaults to `True`):
199
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
200
- plain tuple.
201
- callback (`Callable`, *optional*):
202
- A function that will be called every `callback_steps` steps during inference. The function will be
203
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
204
- callback_steps (`int`, *optional*, defaults to 1):
205
- The frequency at which the `callback` function will be called. If not specified, the callback will be
206
- called at every step.
207
- cross_attention_kwargs (`dict`, *optional*):
208
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
209
- `self.processor` in
210
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
211
- controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
212
- The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
213
- to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
214
- corresponding scale as a list.
215
- guess_mode (`bool`, *optional*, defaults to `False`):
216
- In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
217
- you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
218
- attention_auto_machine_weight (`float`):
219
- Weight of using reference query for self attention's context.
220
- If attention_auto_machine_weight=1.0, use reference query for all self attention's context.
221
- gn_auto_machine_weight (`float`):
222
- Weight of using reference adain. If gn_auto_machine_weight=2.0, use all reference adain plugins.
223
- style_fidelity (`float`):
224
- style fidelity of ref_uncond_xt. If style_fidelity=1.0, control more important,
225
- elif style_fidelity=0.0, prompt more important, else balanced.
226
- reference_attn (`bool`):
227
- Whether to use reference query for self attention's context.
228
- reference_adain (`bool`):
229
- Whether to use reference adain.
230
-
231
- Examples:
232
-
233
- Returns:
234
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
235
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
236
- When returning a tuple, the first element is a list with the generated images, and the second element is a
237
- list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
238
- (nsfw) content, according to the `safety_checker`.
239
- """
240
- assert reference_attn or reference_adain, "`reference_attn` or `reference_adain` must be True."
241
-
242
- # 1. Check inputs. Raise error if not correct
243
- self.check_inputs(
244
- prompt,
245
- image,
246
- callback_steps,
247
- negative_prompt,
248
- prompt_embeds,
249
- negative_prompt_embeds,
250
- controlnet_conditioning_scale,
251
- )
252
-
253
- # 2. Define call parameters
254
- if prompt is not None and isinstance(prompt, str):
255
- batch_size = 1
256
- elif prompt is not None and isinstance(prompt, list):
257
- batch_size = len(prompt)
258
- else:
259
- batch_size = prompt_embeds.shape[0]
260
-
261
- device = self._execution_device
262
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
263
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
264
- # corresponds to doing no classifier free guidance.
265
- do_classifier_free_guidance = guidance_scale > 1.0
266
-
267
- controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
268
-
269
- if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
270
- controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
271
-
272
- global_pool_conditions = (
273
- controlnet.config.global_pool_conditions
274
- if isinstance(controlnet, ControlNetModel)
275
- else controlnet.nets[0].config.global_pool_conditions
276
- )
277
- guess_mode = guess_mode or global_pool_conditions
278
-
279
- # 3. Encode input prompt
280
- text_encoder_lora_scale = (
281
- cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
282
- )
283
- prompt_embeds = self._encode_prompt(
284
- prompt,
285
- device,
286
- num_images_per_prompt,
287
- do_classifier_free_guidance,
288
- negative_prompt,
289
- prompt_embeds=prompt_embeds,
290
- negative_prompt_embeds=negative_prompt_embeds,
291
- lora_scale=text_encoder_lora_scale,
292
- )
293
-
294
- # 4. Prepare image
295
- if isinstance(controlnet, ControlNetModel):
296
- image = self.prepare_image(
297
- image=image,
298
- width=width,
299
- height=height,
300
- batch_size=batch_size * num_images_per_prompt,
301
- num_images_per_prompt=num_images_per_prompt,
302
- device=device,
303
- dtype=controlnet.dtype,
304
- do_classifier_free_guidance=do_classifier_free_guidance,
305
- guess_mode=guess_mode,
306
- )
307
- height, width = image.shape[-2:]
308
- elif isinstance(controlnet, MultiControlNetModel):
309
- images = []
310
-
311
- for image_ in image:
312
- image_ = self.prepare_image(
313
- image=image_,
314
- width=width,
315
- height=height,
316
- batch_size=batch_size * num_images_per_prompt,
317
- num_images_per_prompt=num_images_per_prompt,
318
- device=device,
319
- dtype=controlnet.dtype,
320
- do_classifier_free_guidance=do_classifier_free_guidance,
321
- guess_mode=guess_mode,
322
- )
323
-
324
- images.append(image_)
325
-
326
- image = images
327
- height, width = image[0].shape[-2:]
328
- else:
329
- assert False
330
-
331
- # 5. Preprocess reference image
332
- ref_image = self.prepare_image(
333
- image=ref_image,
334
- width=width,
335
- height=height,
336
- batch_size=batch_size * num_images_per_prompt,
337
- num_images_per_prompt=num_images_per_prompt,
338
- device=device,
339
- dtype=prompt_embeds.dtype,
340
- )
341
-
342
- # 6. Prepare timesteps
343
- self.scheduler.set_timesteps(num_inference_steps, device=device)
344
- timesteps = self.scheduler.timesteps
345
-
346
- # 7. Prepare latent variables
347
- num_channels_latents = self.unet.config.in_channels
348
- latents = self.prepare_latents(
349
- batch_size * num_images_per_prompt,
350
- num_channels_latents,
351
- height,
352
- width,
353
- prompt_embeds.dtype,
354
- device,
355
- generator,
356
- latents,
357
- )
358
-
359
- # 8. Prepare reference latent variables
360
- ref_image_latents = self.prepare_ref_latents(
361
- ref_image,
362
- batch_size * num_images_per_prompt,
363
- prompt_embeds.dtype,
364
- device,
365
- generator,
366
- do_classifier_free_guidance,
367
- )
368
-
369
- # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
370
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
371
-
372
- # 10. Modify self attention and group norm
373
- MODE = "write"
374
- uc_mask = (
375
- torch.Tensor([1] * batch_size * num_images_per_prompt + [0] * batch_size * num_images_per_prompt)
376
- .type_as(ref_image_latents)
377
- .bool()
378
- )
379
-
380
- def hacked_basic_transformer_inner_forward(
381
- self,
382
- hidden_states: torch.FloatTensor,
383
- attention_mask: Optional[torch.FloatTensor] = None,
384
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
385
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
386
- timestep: Optional[torch.LongTensor] = None,
387
- cross_attention_kwargs: Dict[str, Any] = None,
388
- class_labels: Optional[torch.LongTensor] = None,
389
- ):
390
- if self.use_ada_layer_norm:
391
- norm_hidden_states = self.norm1(hidden_states, timestep)
392
- elif self.use_ada_layer_norm_zero:
393
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
394
- hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
395
- )
396
- else:
397
- norm_hidden_states = self.norm1(hidden_states)
398
-
399
- # 1. Self-Attention
400
- cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
401
- if self.only_cross_attention:
402
- attn_output = self.attn1(
403
- norm_hidden_states,
404
- encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
405
- attention_mask=attention_mask,
406
- **cross_attention_kwargs,
407
- )
408
- else:
409
- if MODE == "write":
410
- self.bank.append(norm_hidden_states.detach().clone())
411
- attn_output = self.attn1(
412
- norm_hidden_states,
413
- encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
414
- attention_mask=attention_mask,
415
- **cross_attention_kwargs,
416
- )
417
- if MODE == "read":
418
- if attention_auto_machine_weight > self.attn_weight:
419
- attn_output_uc = self.attn1(
420
- norm_hidden_states,
421
- encoder_hidden_states=torch.cat([norm_hidden_states] + self.bank, dim=1),
422
- # attention_mask=attention_mask,
423
- **cross_attention_kwargs,
424
- )
425
- attn_output_c = attn_output_uc.clone()
426
- if do_classifier_free_guidance and style_fidelity > 0:
427
- attn_output_c[uc_mask] = self.attn1(
428
- norm_hidden_states[uc_mask],
429
- encoder_hidden_states=norm_hidden_states[uc_mask],
430
- **cross_attention_kwargs,
431
- )
432
- attn_output = style_fidelity * attn_output_c + (1.0 - style_fidelity) * attn_output_uc
433
- self.bank.clear()
434
- else:
435
- attn_output = self.attn1(
436
- norm_hidden_states,
437
- encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
438
- attention_mask=attention_mask,
439
- **cross_attention_kwargs,
440
- )
441
- if self.use_ada_layer_norm_zero:
442
- attn_output = gate_msa.unsqueeze(1) * attn_output
443
- hidden_states = attn_output + hidden_states
444
-
445
- if self.attn2 is not None:
446
- norm_hidden_states = (
447
- self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
448
- )
449
-
450
- # 2. Cross-Attention
451
- attn_output = self.attn2(
452
- norm_hidden_states,
453
- encoder_hidden_states=encoder_hidden_states,
454
- attention_mask=encoder_attention_mask,
455
- **cross_attention_kwargs,
456
- )
457
- hidden_states = attn_output + hidden_states
458
-
459
- # 3. Feed-forward
460
- norm_hidden_states = self.norm3(hidden_states)
461
-
462
- if self.use_ada_layer_norm_zero:
463
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
464
-
465
- ff_output = self.ff(norm_hidden_states)
466
-
467
- if self.use_ada_layer_norm_zero:
468
- ff_output = gate_mlp.unsqueeze(1) * ff_output
469
-
470
- hidden_states = ff_output + hidden_states
471
-
472
- return hidden_states
473
-
474
- def hacked_mid_forward(self, *args, **kwargs):
475
- eps = 1e-6
476
- x = self.original_forward(*args, **kwargs)
477
- if MODE == "write":
478
- if gn_auto_machine_weight >= self.gn_weight:
479
- var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
480
- self.mean_bank.append(mean)
481
- self.var_bank.append(var)
482
- if MODE == "read":
483
- if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
484
- var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
485
- std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
486
- mean_acc = sum(self.mean_bank) / float(len(self.mean_bank))
487
- var_acc = sum(self.var_bank) / float(len(self.var_bank))
488
- std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
489
- x_uc = (((x - mean) / std) * std_acc) + mean_acc
490
- x_c = x_uc.clone()
491
- if do_classifier_free_guidance and style_fidelity > 0:
492
- x_c[uc_mask] = x[uc_mask]
493
- x = style_fidelity * x_c + (1.0 - style_fidelity) * x_uc
494
- self.mean_bank = []
495
- self.var_bank = []
496
- return x
497
-
498
- def hack_CrossAttnDownBlock2D_forward(
499
- self,
500
- hidden_states: torch.FloatTensor,
501
- temb: Optional[torch.FloatTensor] = None,
502
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
503
- attention_mask: Optional[torch.FloatTensor] = None,
504
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
505
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
506
- ):
507
- eps = 1e-6
508
-
509
- # TODO(Patrick, William) - attention mask is not used
510
- output_states = ()
511
-
512
- for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
513
- hidden_states = resnet(hidden_states, temb)
514
- hidden_states = attn(
515
- hidden_states,
516
- encoder_hidden_states=encoder_hidden_states,
517
- cross_attention_kwargs=cross_attention_kwargs,
518
- attention_mask=attention_mask,
519
- encoder_attention_mask=encoder_attention_mask,
520
- return_dict=False,
521
- )[0]
522
- if MODE == "write":
523
- if gn_auto_machine_weight >= self.gn_weight:
524
- var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
525
- self.mean_bank.append([mean])
526
- self.var_bank.append([var])
527
- if MODE == "read":
528
- if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
529
- var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
530
- std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
531
- mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
532
- var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
533
- std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
534
- hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
535
- hidden_states_c = hidden_states_uc.clone()
536
- if do_classifier_free_guidance and style_fidelity > 0:
537
- hidden_states_c[uc_mask] = hidden_states[uc_mask]
538
- hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
539
-
540
- output_states = output_states + (hidden_states,)
541
-
542
- if MODE == "read":
543
- self.mean_bank = []
544
- self.var_bank = []
545
-
546
- if self.downsamplers is not None:
547
- for downsampler in self.downsamplers:
548
- hidden_states = downsampler(hidden_states)
549
-
550
- output_states = output_states + (hidden_states,)
551
-
552
- return hidden_states, output_states
553
-
554
- def hacked_DownBlock2D_forward(self, hidden_states, temb=None):
555
- eps = 1e-6
556
-
557
- output_states = ()
558
-
559
- for i, resnet in enumerate(self.resnets):
560
- hidden_states = resnet(hidden_states, temb)
561
-
562
- if MODE == "write":
563
- if gn_auto_machine_weight >= self.gn_weight:
564
- var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
565
- self.mean_bank.append([mean])
566
- self.var_bank.append([var])
567
- if MODE == "read":
568
- if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
569
- var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
570
- std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
571
- mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
572
- var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
573
- std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
574
- hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
575
- hidden_states_c = hidden_states_uc.clone()
576
- if do_classifier_free_guidance and style_fidelity > 0:
577
- hidden_states_c[uc_mask] = hidden_states[uc_mask]
578
- hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
579
-
580
- output_states = output_states + (hidden_states,)
581
-
582
- if MODE == "read":
583
- self.mean_bank = []
584
- self.var_bank = []
585
-
586
- if self.downsamplers is not None:
587
- for downsampler in self.downsamplers:
588
- hidden_states = downsampler(hidden_states)
589
-
590
- output_states = output_states + (hidden_states,)
591
-
592
- return hidden_states, output_states
593
-
594
- def hacked_CrossAttnUpBlock2D_forward(
595
- self,
596
- hidden_states: torch.FloatTensor,
597
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
598
- temb: Optional[torch.FloatTensor] = None,
599
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
600
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
601
- upsample_size: Optional[int] = None,
602
- attention_mask: Optional[torch.FloatTensor] = None,
603
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
604
- ):
605
- eps = 1e-6
606
- # TODO(Patrick, William) - attention mask is not used
607
- for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
608
- # pop res hidden states
609
- res_hidden_states = res_hidden_states_tuple[-1]
610
- res_hidden_states_tuple = res_hidden_states_tuple[:-1]
611
- hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
612
- hidden_states = resnet(hidden_states, temb)
613
- hidden_states = attn(
614
- hidden_states,
615
- encoder_hidden_states=encoder_hidden_states,
616
- cross_attention_kwargs=cross_attention_kwargs,
617
- attention_mask=attention_mask,
618
- encoder_attention_mask=encoder_attention_mask,
619
- return_dict=False,
620
- )[0]
621
-
622
- if MODE == "write":
623
- if gn_auto_machine_weight >= self.gn_weight:
624
- var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
625
- self.mean_bank.append([mean])
626
- self.var_bank.append([var])
627
- if MODE == "read":
628
- if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
629
- var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
630
- std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
631
- mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
632
- var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
633
- std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
634
- hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
635
- hidden_states_c = hidden_states_uc.clone()
636
- if do_classifier_free_guidance and style_fidelity > 0:
637
- hidden_states_c[uc_mask] = hidden_states[uc_mask]
638
- hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
639
-
640
- if MODE == "read":
641
- self.mean_bank = []
642
- self.var_bank = []
643
-
644
- if self.upsamplers is not None:
645
- for upsampler in self.upsamplers:
646
- hidden_states = upsampler(hidden_states, upsample_size)
647
-
648
- return hidden_states
649
-
650
- def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
651
- eps = 1e-6
652
- for i, resnet in enumerate(self.resnets):
653
- # pop res hidden states
654
- res_hidden_states = res_hidden_states_tuple[-1]
655
- res_hidden_states_tuple = res_hidden_states_tuple[:-1]
656
- hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
657
- hidden_states = resnet(hidden_states, temb)
658
-
659
- if MODE == "write":
660
- if gn_auto_machine_weight >= self.gn_weight:
661
- var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
662
- self.mean_bank.append([mean])
663
- self.var_bank.append([var])
664
- if MODE == "read":
665
- if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
666
- var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
667
- std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
668
- mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
669
- var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
670
- std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
671
- hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
672
- hidden_states_c = hidden_states_uc.clone()
673
- if do_classifier_free_guidance and style_fidelity > 0:
674
- hidden_states_c[uc_mask] = hidden_states[uc_mask]
675
- hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
676
-
677
- if MODE == "read":
678
- self.mean_bank = []
679
- self.var_bank = []
680
-
681
- if self.upsamplers is not None:
682
- for upsampler in self.upsamplers:
683
- hidden_states = upsampler(hidden_states, upsample_size)
684
-
685
- return hidden_states
686
-
687
- if reference_attn:
688
- attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock)]
689
- attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
690
-
691
- for i, module in enumerate(attn_modules):
692
- module._original_inner_forward = module.forward
693
- module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)
694
- module.bank = []
695
- module.attn_weight = float(i) / float(len(attn_modules))
696
-
697
- if reference_adain:
698
- gn_modules = [self.unet.mid_block]
699
- self.unet.mid_block.gn_weight = 0
700
-
701
- down_blocks = self.unet.down_blocks
702
- for w, module in enumerate(down_blocks):
703
- module.gn_weight = 1.0 - float(w) / float(len(down_blocks))
704
- gn_modules.append(module)
705
-
706
- up_blocks = self.unet.up_blocks
707
- for w, module in enumerate(up_blocks):
708
- module.gn_weight = float(w) / float(len(up_blocks))
709
- gn_modules.append(module)
710
-
711
- for i, module in enumerate(gn_modules):
712
- if getattr(module, "original_forward", None) is None:
713
- module.original_forward = module.forward
714
- if i == 0:
715
- # mid_block
716
- module.forward = hacked_mid_forward.__get__(module, torch.nn.Module)
717
- elif isinstance(module, CrossAttnDownBlock2D):
718
- module.forward = hack_CrossAttnDownBlock2D_forward.__get__(module, CrossAttnDownBlock2D)
719
- elif isinstance(module, DownBlock2D):
720
- module.forward = hacked_DownBlock2D_forward.__get__(module, DownBlock2D)
721
- elif isinstance(module, CrossAttnUpBlock2D):
722
- module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D)
723
- elif isinstance(module, UpBlock2D):
724
- module.forward = hacked_UpBlock2D_forward.__get__(module, UpBlock2D)
725
- module.mean_bank = []
726
- module.var_bank = []
727
- module.gn_weight *= 2
728
-
729
- # 11. Denoising loop
730
- num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
731
- with self.progress_bar(total=num_inference_steps) as progress_bar:
732
- for i, t in enumerate(timesteps):
733
- # expand the latents if we are doing classifier free guidance
734
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
735
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
736
-
737
- # controlnet(s) inference
738
- if guess_mode and do_classifier_free_guidance:
739
- # Infer ControlNet only for the conditional batch.
740
- control_model_input = latents
741
- control_model_input = self.scheduler.scale_model_input(control_model_input, t)
742
- controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
743
- else:
744
- control_model_input = latent_model_input
745
- controlnet_prompt_embeds = prompt_embeds
746
-
747
- down_block_res_samples, mid_block_res_sample = self.controlnet(
748
- control_model_input,
749
- t,
750
- encoder_hidden_states=controlnet_prompt_embeds,
751
- controlnet_cond=image,
752
- conditioning_scale=controlnet_conditioning_scale,
753
- guess_mode=guess_mode,
754
- return_dict=False,
755
- )
756
-
757
- if guess_mode and do_classifier_free_guidance:
758
- # Infered ControlNet only for the conditional batch.
759
- # To apply the output of ControlNet to both the unconditional and conditional batches,
760
- # add 0 to the unconditional batch to keep it unchanged.
761
- down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
762
- mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
763
-
764
- # ref only part
765
- noise = randn_tensor(
766
- ref_image_latents.shape, generator=generator, device=device, dtype=ref_image_latents.dtype
767
- )
768
- ref_xt = self.scheduler.add_noise(
769
- ref_image_latents,
770
- noise,
771
- t.reshape(
772
- 1,
773
- ),
774
- )
775
- ref_xt = self.scheduler.scale_model_input(ref_xt, t)
776
-
777
- MODE = "write"
778
- self.unet(
779
- ref_xt,
780
- t,
781
- encoder_hidden_states=prompt_embeds,
782
- cross_attention_kwargs=cross_attention_kwargs,
783
- return_dict=False,
784
- )
785
-
786
- # predict the noise residual
787
- MODE = "read"
788
- noise_pred = self.unet(
789
- latent_model_input,
790
- t,
791
- encoder_hidden_states=prompt_embeds,
792
- cross_attention_kwargs=cross_attention_kwargs,
793
- down_block_additional_residuals=down_block_res_samples,
794
- mid_block_additional_residual=mid_block_res_sample,
795
- return_dict=False,
796
- )[0]
797
-
798
- # perform guidance
799
- if do_classifier_free_guidance:
800
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
801
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
802
-
803
- # compute the previous noisy sample x_t -> x_t-1
804
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
805
-
806
- # call the callback, if provided
807
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
808
- progress_bar.update()
809
- if callback is not None and i % callback_steps == 0:
810
- callback(i, t, latents)
811
-
812
- # If we do sequential model offloading, let's offload unet and controlnet
813
- # manually for max memory savings
814
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
815
- self.unet.to("cpu")
816
- self.controlnet.to("cpu")
817
- torch.cuda.empty_cache()
818
-
819
- if not output_type == "latent":
820
- image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
821
- image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
822
- else:
823
- image = latents
824
- has_nsfw_concept = None
825
-
826
- if has_nsfw_concept is None:
827
- do_denormalize = [True] * image.shape[0]
828
- else:
829
- do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
830
-
831
- image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
832
-
833
- # Offload last model to CPU
834
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
835
- self.final_offload_hook.offload()
836
-
837
- if not return_dict:
838
- return (image, has_nsfw_concept)
839
-
840
- return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
 
1
+ # *************************************************************************
2
+ # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
+ # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
+ # ytedance Inc..
5
+ # *************************************************************************
6
+
7
+ # Inspired by: https://github.com/Mikubill/sd-webui-controlnet/discussions/1236 and https://github.com/Mikubill/sd-webui-controlnet/discussions/1280
8
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
9
+
10
+ import numpy as np
11
+ import PIL.Image
12
+ import torch
13
+
14
+ from diffusers import StableDiffusionControlNetPipeline
15
+ from diffusers.models import ControlNetModel
16
+ from diffusers.models.attention import BasicTransformerBlock
17
+ from diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
18
+ from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
19
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
20
+ from diffusers.utils import logging
21
+ from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
22
+
23
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
24
+
25
+ EXAMPLE_DOC_STRING = """
26
+ Examples:
27
+ ```py
28
+ >>> import cv2
29
+ >>> import torch
30
+ >>> import numpy as np
31
+ >>> from PIL import Image
32
+ >>> from diffusers import UniPCMultistepScheduler
33
+ >>> from diffusers.utils import load_image
34
+
35
+ >>> input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")
36
+
37
+ >>> # get canny image
38
+ >>> image = cv2.Canny(np.array(input_image), 100, 200)
39
+ >>> image = image[:, :, None]
40
+ >>> image = np.concatenate([image, image, image], axis=2)
41
+ >>> canny_image = Image.fromarray(image)
42
+
43
+ >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
44
+ >>> pipe = StableDiffusionControlNetReferencePipeline.from_pretrained(
45
+ "runwayml/stable-diffusion-v1-5",
46
+ controlnet=controlnet,
47
+ safety_checker=None,
48
+ torch_dtype=torch.float16
49
+ ).to('cuda:0')
50
+
51
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe_controlnet.scheduler.config)
52
+
53
+ >>> result_img = pipe(ref_image=input_image,
54
+ prompt="1girl",
55
+ image=canny_image,
56
+ num_inference_steps=20,
57
+ reference_attn=True,
58
+ reference_adain=True).images[0]
59
+
60
+ >>> result_img.show()
61
+ ```
62
+ """
63
+
64
+
65
+ def torch_dfs(model: torch.nn.Module):
66
+ result = [model]
67
+ for child in model.children():
68
+ result += torch_dfs(child)
69
+ return result
70
+
71
+
72
+ class StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeline):
73
+ def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance):
74
+ refimage = refimage.to(device=device, dtype=dtype)
75
+
76
+ # encode the mask image into latents space so we can concatenate it to the latents
77
+ if isinstance(generator, list):
78
+ ref_image_latents = [
79
+ self.vae.encode(refimage[i : i + 1]).latent_dist.sample(generator=generator[i])
80
+ for i in range(batch_size)
81
+ ]
82
+ ref_image_latents = torch.cat(ref_image_latents, dim=0)
83
+ else:
84
+ ref_image_latents = self.vae.encode(refimage).latent_dist.sample(generator=generator)
85
+ ref_image_latents = self.vae.config.scaling_factor * ref_image_latents
86
+
87
+ # duplicate mask and ref_image_latents for each generation per prompt, using mps friendly method
88
+ if ref_image_latents.shape[0] < batch_size:
89
+ if not batch_size % ref_image_latents.shape[0] == 0:
90
+ raise ValueError(
91
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
92
+ f" to a total batch size of {batch_size}, but {ref_image_latents.shape[0]} images were passed."
93
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
94
+ )
95
+ ref_image_latents = ref_image_latents.repeat(batch_size // ref_image_latents.shape[0], 1, 1, 1)
96
+
97
+ ref_image_latents = torch.cat([ref_image_latents] * 2) if do_classifier_free_guidance else ref_image_latents
98
+
99
+ # aligning device to prevent device errors when concating it with the latent model input
100
+ ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
101
+ return ref_image_latents
102
+
103
+ @torch.no_grad()
104
+ def __call__(
105
+ self,
106
+ prompt: Union[str, List[str]] = None,
107
+ image: Union[
108
+ torch.FloatTensor,
109
+ PIL.Image.Image,
110
+ np.ndarray,
111
+ List[torch.FloatTensor],
112
+ List[PIL.Image.Image],
113
+ List[np.ndarray],
114
+ ] = None,
115
+ ref_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
116
+ height: Optional[int] = None,
117
+ width: Optional[int] = None,
118
+ num_inference_steps: int = 50,
119
+ guidance_scale: float = 7.5,
120
+ negative_prompt: Optional[Union[str, List[str]]] = None,
121
+ num_images_per_prompt: Optional[int] = 1,
122
+ eta: float = 0.0,
123
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
124
+ latents: Optional[torch.FloatTensor] = None,
125
+ prompt_embeds: Optional[torch.FloatTensor] = None,
126
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
127
+ output_type: Optional[str] = "pil",
128
+ return_dict: bool = True,
129
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
130
+ callback_steps: int = 1,
131
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
132
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
133
+ guess_mode: bool = False,
134
+ attention_auto_machine_weight: float = 1.0,
135
+ gn_auto_machine_weight: float = 1.0,
136
+ style_fidelity: float = 0.5,
137
+ reference_attn: bool = True,
138
+ reference_adain: bool = True,
139
+ ):
140
+ r"""
141
+ Function invoked when calling the pipeline for generation.
142
+
143
+ Args:
144
+ prompt (`str` or `List[str]`, *optional*):
145
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
146
+ instead.
147
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
148
+ `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
149
+ The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
150
+ the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
151
+ also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
152
+ height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
153
+ specified in init, images must be passed as a list such that each element of the list can be correctly
154
+ batched for input to a single controlnet.
155
+ ref_image (`torch.FloatTensor`, `PIL.Image.Image`):
156
+ The Reference Control input condition. Reference Control uses this input condition to generate guidance to Unet. If
157
+ the type is specified as `Torch.FloatTensor`, it is passed to Reference Control as is. `PIL.Image.Image` can
158
+ also be accepted as an image.
159
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
160
+ The height in pixels of the generated image.
161
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
162
+ The width in pixels of the generated image.
163
+ num_inference_steps (`int`, *optional*, defaults to 50):
164
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
165
+ expense of slower inference.
166
+ guidance_scale (`float`, *optional*, defaults to 7.5):
167
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
168
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
169
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
170
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
171
+ usually at the expense of lower image quality.
172
+ negative_prompt (`str` or `List[str]`, *optional*):
173
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
174
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
175
+ less than `1`).
176
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
177
+ The number of images to generate per prompt.
178
+ eta (`float`, *optional*, defaults to 0.0):
179
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
180
+ [`schedulers.DDIMScheduler`], will be ignored for others.
181
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
182
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
183
+ to make generation deterministic.
184
+ latents (`torch.FloatTensor`, *optional*):
185
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
186
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
187
+ tensor will ge generated by sampling using the supplied random `generator`.
188
+ prompt_embeds (`torch.FloatTensor`, *optional*):
189
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
190
+ provided, text embeddings will be generated from `prompt` input argument.
191
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
192
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
193
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
194
+ argument.
195
+ output_type (`str`, *optional*, defaults to `"pil"`):
196
+ The output format of the generate image. Choose between
197
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
198
+ return_dict (`bool`, *optional*, defaults to `True`):
199
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
200
+ plain tuple.
201
+ callback (`Callable`, *optional*):
202
+ A function that will be called every `callback_steps` steps during inference. The function will be
203
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
204
+ callback_steps (`int`, *optional*, defaults to 1):
205
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
206
+ called at every step.
207
+ cross_attention_kwargs (`dict`, *optional*):
208
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
209
+ `self.processor` in
210
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
211
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
212
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
213
+ to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
214
+ corresponding scale as a list.
215
+ guess_mode (`bool`, *optional*, defaults to `False`):
216
+ In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
217
+ you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
218
+ attention_auto_machine_weight (`float`):
219
+ Weight of using reference query for self attention's context.
220
+ If attention_auto_machine_weight=1.0, use reference query for all self attention's context.
221
+ gn_auto_machine_weight (`float`):
222
+ Weight of using reference adain. If gn_auto_machine_weight=2.0, use all reference adain plugins.
223
+ style_fidelity (`float`):
224
+ style fidelity of ref_uncond_xt. If style_fidelity=1.0, control more important,
225
+ elif style_fidelity=0.0, prompt more important, else balanced.
226
+ reference_attn (`bool`):
227
+ Whether to use reference query for self attention's context.
228
+ reference_adain (`bool`):
229
+ Whether to use reference adain.
230
+
231
+ Examples:
232
+
233
+ Returns:
234
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
235
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
236
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
237
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
238
+ (nsfw) content, according to the `safety_checker`.
239
+ """
240
+ assert reference_attn or reference_adain, "`reference_attn` or `reference_adain` must be True."
241
+
242
+ # 1. Check inputs. Raise error if not correct
243
+ self.check_inputs(
244
+ prompt,
245
+ image,
246
+ callback_steps,
247
+ negative_prompt,
248
+ prompt_embeds,
249
+ negative_prompt_embeds,
250
+ controlnet_conditioning_scale,
251
+ )
252
+
253
+ # 2. Define call parameters
254
+ if prompt is not None and isinstance(prompt, str):
255
+ batch_size = 1
256
+ elif prompt is not None and isinstance(prompt, list):
257
+ batch_size = len(prompt)
258
+ else:
259
+ batch_size = prompt_embeds.shape[0]
260
+
261
+ device = self._execution_device
262
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
263
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
264
+ # corresponds to doing no classifier free guidance.
265
+ do_classifier_free_guidance = guidance_scale > 1.0
266
+
267
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
268
+
269
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
270
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
271
+
272
+ global_pool_conditions = (
273
+ controlnet.config.global_pool_conditions
274
+ if isinstance(controlnet, ControlNetModel)
275
+ else controlnet.nets[0].config.global_pool_conditions
276
+ )
277
+ guess_mode = guess_mode or global_pool_conditions
278
+
279
+ # 3. Encode input prompt
280
+ text_encoder_lora_scale = (
281
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
282
+ )
283
+ prompt_embeds = self._encode_prompt(
284
+ prompt,
285
+ device,
286
+ num_images_per_prompt,
287
+ do_classifier_free_guidance,
288
+ negative_prompt,
289
+ prompt_embeds=prompt_embeds,
290
+ negative_prompt_embeds=negative_prompt_embeds,
291
+ lora_scale=text_encoder_lora_scale,
292
+ )
293
+
294
+ # 4. Prepare image
295
+ if isinstance(controlnet, ControlNetModel):
296
+ image = self.prepare_image(
297
+ image=image,
298
+ width=width,
299
+ height=height,
300
+ batch_size=batch_size * num_images_per_prompt,
301
+ num_images_per_prompt=num_images_per_prompt,
302
+ device=device,
303
+ dtype=controlnet.dtype,
304
+ do_classifier_free_guidance=do_classifier_free_guidance,
305
+ guess_mode=guess_mode,
306
+ )
307
+ height, width = image.shape[-2:]
308
+ elif isinstance(controlnet, MultiControlNetModel):
309
+ images = []
310
+
311
+ for image_ in image:
312
+ image_ = self.prepare_image(
313
+ image=image_,
314
+ width=width,
315
+ height=height,
316
+ batch_size=batch_size * num_images_per_prompt,
317
+ num_images_per_prompt=num_images_per_prompt,
318
+ device=device,
319
+ dtype=controlnet.dtype,
320
+ do_classifier_free_guidance=do_classifier_free_guidance,
321
+ guess_mode=guess_mode,
322
+ )
323
+
324
+ images.append(image_)
325
+
326
+ image = images
327
+ height, width = image[0].shape[-2:]
328
+ else:
329
+ assert False
330
+
331
+ # 5. Preprocess reference image
332
+ ref_image = self.prepare_image(
333
+ image=ref_image,
334
+ width=width,
335
+ height=height,
336
+ batch_size=batch_size * num_images_per_prompt,
337
+ num_images_per_prompt=num_images_per_prompt,
338
+ device=device,
339
+ dtype=prompt_embeds.dtype,
340
+ )
341
+
342
+ # 6. Prepare timesteps
343
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
344
+ timesteps = self.scheduler.timesteps
345
+
346
+ # 7. Prepare latent variables
347
+ num_channels_latents = self.unet.config.in_channels
348
+ latents = self.prepare_latents(
349
+ batch_size * num_images_per_prompt,
350
+ num_channels_latents,
351
+ height,
352
+ width,
353
+ prompt_embeds.dtype,
354
+ device,
355
+ generator,
356
+ latents,
357
+ )
358
+
359
+ # 8. Prepare reference latent variables
360
+ ref_image_latents = self.prepare_ref_latents(
361
+ ref_image,
362
+ batch_size * num_images_per_prompt,
363
+ prompt_embeds.dtype,
364
+ device,
365
+ generator,
366
+ do_classifier_free_guidance,
367
+ )
368
+
369
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
370
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
371
+
372
+ # 10. Modify self attention and group norm
373
+ MODE = "write"
374
+ uc_mask = (
375
+ torch.Tensor([1] * batch_size * num_images_per_prompt + [0] * batch_size * num_images_per_prompt)
376
+ .type_as(ref_image_latents)
377
+ .bool()
378
+ )
379
+
380
+ def hacked_basic_transformer_inner_forward(
381
+ self,
382
+ hidden_states: torch.FloatTensor,
383
+ attention_mask: Optional[torch.FloatTensor] = None,
384
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
385
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
386
+ timestep: Optional[torch.LongTensor] = None,
387
+ cross_attention_kwargs: Dict[str, Any] = None,
388
+ class_labels: Optional[torch.LongTensor] = None,
389
+ ):
390
+ if self.use_ada_layer_norm:
391
+ norm_hidden_states = self.norm1(hidden_states, timestep)
392
+ elif self.use_ada_layer_norm_zero:
393
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
394
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
395
+ )
396
+ else:
397
+ norm_hidden_states = self.norm1(hidden_states)
398
+
399
+ # 1. Self-Attention
400
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
401
+ if self.only_cross_attention:
402
+ attn_output = self.attn1(
403
+ norm_hidden_states,
404
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
405
+ attention_mask=attention_mask,
406
+ **cross_attention_kwargs,
407
+ )
408
+ else:
409
+ if MODE == "write":
410
+ self.bank.append(norm_hidden_states.detach().clone())
411
+ attn_output = self.attn1(
412
+ norm_hidden_states,
413
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
414
+ attention_mask=attention_mask,
415
+ **cross_attention_kwargs,
416
+ )
417
+ if MODE == "read":
418
+ if attention_auto_machine_weight > self.attn_weight:
419
+ attn_output_uc = self.attn1(
420
+ norm_hidden_states,
421
+ encoder_hidden_states=torch.cat([norm_hidden_states] + self.bank, dim=1),
422
+ # attention_mask=attention_mask,
423
+ **cross_attention_kwargs,
424
+ )
425
+ attn_output_c = attn_output_uc.clone()
426
+ if do_classifier_free_guidance and style_fidelity > 0:
427
+ attn_output_c[uc_mask] = self.attn1(
428
+ norm_hidden_states[uc_mask],
429
+ encoder_hidden_states=norm_hidden_states[uc_mask],
430
+ **cross_attention_kwargs,
431
+ )
432
+ attn_output = style_fidelity * attn_output_c + (1.0 - style_fidelity) * attn_output_uc
433
+ self.bank.clear()
434
+ else:
435
+ attn_output = self.attn1(
436
+ norm_hidden_states,
437
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
438
+ attention_mask=attention_mask,
439
+ **cross_attention_kwargs,
440
+ )
441
+ if self.use_ada_layer_norm_zero:
442
+ attn_output = gate_msa.unsqueeze(1) * attn_output
443
+ hidden_states = attn_output + hidden_states
444
+
445
+ if self.attn2 is not None:
446
+ norm_hidden_states = (
447
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
448
+ )
449
+
450
+ # 2. Cross-Attention
451
+ attn_output = self.attn2(
452
+ norm_hidden_states,
453
+ encoder_hidden_states=encoder_hidden_states,
454
+ attention_mask=encoder_attention_mask,
455
+ **cross_attention_kwargs,
456
+ )
457
+ hidden_states = attn_output + hidden_states
458
+
459
+ # 3. Feed-forward
460
+ norm_hidden_states = self.norm3(hidden_states)
461
+
462
+ if self.use_ada_layer_norm_zero:
463
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
464
+
465
+ ff_output = self.ff(norm_hidden_states)
466
+
467
+ if self.use_ada_layer_norm_zero:
468
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
469
+
470
+ hidden_states = ff_output + hidden_states
471
+
472
+ return hidden_states
473
+
474
+ def hacked_mid_forward(self, *args, **kwargs):
475
+ eps = 1e-6
476
+ x = self.original_forward(*args, **kwargs)
477
+ if MODE == "write":
478
+ if gn_auto_machine_weight >= self.gn_weight:
479
+ var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
480
+ self.mean_bank.append(mean)
481
+ self.var_bank.append(var)
482
+ if MODE == "read":
483
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
484
+ var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
485
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
486
+ mean_acc = sum(self.mean_bank) / float(len(self.mean_bank))
487
+ var_acc = sum(self.var_bank) / float(len(self.var_bank))
488
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
489
+ x_uc = (((x - mean) / std) * std_acc) + mean_acc
490
+ x_c = x_uc.clone()
491
+ if do_classifier_free_guidance and style_fidelity > 0:
492
+ x_c[uc_mask] = x[uc_mask]
493
+ x = style_fidelity * x_c + (1.0 - style_fidelity) * x_uc
494
+ self.mean_bank = []
495
+ self.var_bank = []
496
+ return x
497
+
498
+ def hack_CrossAttnDownBlock2D_forward(
499
+ self,
500
+ hidden_states: torch.FloatTensor,
501
+ temb: Optional[torch.FloatTensor] = None,
502
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
503
+ attention_mask: Optional[torch.FloatTensor] = None,
504
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
505
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
506
+ ):
507
+ eps = 1e-6
508
+
509
+ # TODO(Patrick, William) - attention mask is not used
510
+ output_states = ()
511
+
512
+ for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
513
+ hidden_states = resnet(hidden_states, temb)
514
+ hidden_states = attn(
515
+ hidden_states,
516
+ encoder_hidden_states=encoder_hidden_states,
517
+ cross_attention_kwargs=cross_attention_kwargs,
518
+ attention_mask=attention_mask,
519
+ encoder_attention_mask=encoder_attention_mask,
520
+ return_dict=False,
521
+ )[0]
522
+ if MODE == "write":
523
+ if gn_auto_machine_weight >= self.gn_weight:
524
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
525
+ self.mean_bank.append([mean])
526
+ self.var_bank.append([var])
527
+ if MODE == "read":
528
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
529
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
530
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
531
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
532
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
533
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
534
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
535
+ hidden_states_c = hidden_states_uc.clone()
536
+ if do_classifier_free_guidance and style_fidelity > 0:
537
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
538
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
539
+
540
+ output_states = output_states + (hidden_states,)
541
+
542
+ if MODE == "read":
543
+ self.mean_bank = []
544
+ self.var_bank = []
545
+
546
+ if self.downsamplers is not None:
547
+ for downsampler in self.downsamplers:
548
+ hidden_states = downsampler(hidden_states)
549
+
550
+ output_states = output_states + (hidden_states,)
551
+
552
+ return hidden_states, output_states
553
+
554
+ def hacked_DownBlock2D_forward(self, hidden_states, temb=None):
555
+ eps = 1e-6
556
+
557
+ output_states = ()
558
+
559
+ for i, resnet in enumerate(self.resnets):
560
+ hidden_states = resnet(hidden_states, temb)
561
+
562
+ if MODE == "write":
563
+ if gn_auto_machine_weight >= self.gn_weight:
564
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
565
+ self.mean_bank.append([mean])
566
+ self.var_bank.append([var])
567
+ if MODE == "read":
568
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
569
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
570
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
571
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
572
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
573
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
574
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
575
+ hidden_states_c = hidden_states_uc.clone()
576
+ if do_classifier_free_guidance and style_fidelity > 0:
577
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
578
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
579
+
580
+ output_states = output_states + (hidden_states,)
581
+
582
+ if MODE == "read":
583
+ self.mean_bank = []
584
+ self.var_bank = []
585
+
586
+ if self.downsamplers is not None:
587
+ for downsampler in self.downsamplers:
588
+ hidden_states = downsampler(hidden_states)
589
+
590
+ output_states = output_states + (hidden_states,)
591
+
592
+ return hidden_states, output_states
593
+
594
+ def hacked_CrossAttnUpBlock2D_forward(
595
+ self,
596
+ hidden_states: torch.FloatTensor,
597
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
598
+ temb: Optional[torch.FloatTensor] = None,
599
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
600
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
601
+ upsample_size: Optional[int] = None,
602
+ attention_mask: Optional[torch.FloatTensor] = None,
603
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
604
+ ):
605
+ eps = 1e-6
606
+ # TODO(Patrick, William) - attention mask is not used
607
+ for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
608
+ # pop res hidden states
609
+ res_hidden_states = res_hidden_states_tuple[-1]
610
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
611
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
612
+ hidden_states = resnet(hidden_states, temb)
613
+ hidden_states = attn(
614
+ hidden_states,
615
+ encoder_hidden_states=encoder_hidden_states,
616
+ cross_attention_kwargs=cross_attention_kwargs,
617
+ attention_mask=attention_mask,
618
+ encoder_attention_mask=encoder_attention_mask,
619
+ return_dict=False,
620
+ )[0]
621
+
622
+ if MODE == "write":
623
+ if gn_auto_machine_weight >= self.gn_weight:
624
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
625
+ self.mean_bank.append([mean])
626
+ self.var_bank.append([var])
627
+ if MODE == "read":
628
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
629
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
630
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
631
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
632
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
633
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
634
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
635
+ hidden_states_c = hidden_states_uc.clone()
636
+ if do_classifier_free_guidance and style_fidelity > 0:
637
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
638
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
639
+
640
+ if MODE == "read":
641
+ self.mean_bank = []
642
+ self.var_bank = []
643
+
644
+ if self.upsamplers is not None:
645
+ for upsampler in self.upsamplers:
646
+ hidden_states = upsampler(hidden_states, upsample_size)
647
+
648
+ return hidden_states
649
+
650
+ def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
651
+ eps = 1e-6
652
+ for i, resnet in enumerate(self.resnets):
653
+ # pop res hidden states
654
+ res_hidden_states = res_hidden_states_tuple[-1]
655
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
656
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
657
+ hidden_states = resnet(hidden_states, temb)
658
+
659
+ if MODE == "write":
660
+ if gn_auto_machine_weight >= self.gn_weight:
661
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
662
+ self.mean_bank.append([mean])
663
+ self.var_bank.append([var])
664
+ if MODE == "read":
665
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
666
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
667
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
668
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
669
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
670
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
671
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
672
+ hidden_states_c = hidden_states_uc.clone()
673
+ if do_classifier_free_guidance and style_fidelity > 0:
674
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
675
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
676
+
677
+ if MODE == "read":
678
+ self.mean_bank = []
679
+ self.var_bank = []
680
+
681
+ if self.upsamplers is not None:
682
+ for upsampler in self.upsamplers:
683
+ hidden_states = upsampler(hidden_states, upsample_size)
684
+
685
+ return hidden_states
686
+
687
+ if reference_attn:
688
+ attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock)]
689
+ attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
690
+
691
+ for i, module in enumerate(attn_modules):
692
+ module._original_inner_forward = module.forward
693
+ module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)
694
+ module.bank = []
695
+ module.attn_weight = float(i) / float(len(attn_modules))
696
+
697
+ if reference_adain:
698
+ gn_modules = [self.unet.mid_block]
699
+ self.unet.mid_block.gn_weight = 0
700
+
701
+ down_blocks = self.unet.down_blocks
702
+ for w, module in enumerate(down_blocks):
703
+ module.gn_weight = 1.0 - float(w) / float(len(down_blocks))
704
+ gn_modules.append(module)
705
+
706
+ up_blocks = self.unet.up_blocks
707
+ for w, module in enumerate(up_blocks):
708
+ module.gn_weight = float(w) / float(len(up_blocks))
709
+ gn_modules.append(module)
710
+
711
+ for i, module in enumerate(gn_modules):
712
+ if getattr(module, "original_forward", None) is None:
713
+ module.original_forward = module.forward
714
+ if i == 0:
715
+ # mid_block
716
+ module.forward = hacked_mid_forward.__get__(module, torch.nn.Module)
717
+ elif isinstance(module, CrossAttnDownBlock2D):
718
+ module.forward = hack_CrossAttnDownBlock2D_forward.__get__(module, CrossAttnDownBlock2D)
719
+ elif isinstance(module, DownBlock2D):
720
+ module.forward = hacked_DownBlock2D_forward.__get__(module, DownBlock2D)
721
+ elif isinstance(module, CrossAttnUpBlock2D):
722
+ module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D)
723
+ elif isinstance(module, UpBlock2D):
724
+ module.forward = hacked_UpBlock2D_forward.__get__(module, UpBlock2D)
725
+ module.mean_bank = []
726
+ module.var_bank = []
727
+ module.gn_weight *= 2
728
+
729
+ # 11. Denoising loop
730
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
731
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
732
+ for i, t in enumerate(timesteps):
733
+ # expand the latents if we are doing classifier free guidance
734
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
735
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
736
+
737
+ # controlnet(s) inference
738
+ if guess_mode and do_classifier_free_guidance:
739
+ # Infer ControlNet only for the conditional batch.
740
+ control_model_input = latents
741
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
742
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
743
+ else:
744
+ control_model_input = latent_model_input
745
+ controlnet_prompt_embeds = prompt_embeds
746
+
747
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
748
+ control_model_input,
749
+ t,
750
+ encoder_hidden_states=controlnet_prompt_embeds,
751
+ controlnet_cond=image,
752
+ conditioning_scale=controlnet_conditioning_scale,
753
+ guess_mode=guess_mode,
754
+ return_dict=False,
755
+ )
756
+
757
+ if guess_mode and do_classifier_free_guidance:
758
+ # Infered ControlNet only for the conditional batch.
759
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
760
+ # add 0 to the unconditional batch to keep it unchanged.
761
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
762
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
763
+
764
+ # ref only part
765
+ noise = randn_tensor(
766
+ ref_image_latents.shape, generator=generator, device=device, dtype=ref_image_latents.dtype
767
+ )
768
+ ref_xt = self.scheduler.add_noise(
769
+ ref_image_latents,
770
+ noise,
771
+ t.reshape(
772
+ 1,
773
+ ),
774
+ )
775
+ ref_xt = self.scheduler.scale_model_input(ref_xt, t)
776
+
777
+ MODE = "write"
778
+ self.unet(
779
+ ref_xt,
780
+ t,
781
+ encoder_hidden_states=prompt_embeds,
782
+ cross_attention_kwargs=cross_attention_kwargs,
783
+ return_dict=False,
784
+ )
785
+
786
+ # predict the noise residual
787
+ MODE = "read"
788
+ noise_pred = self.unet(
789
+ latent_model_input,
790
+ t,
791
+ encoder_hidden_states=prompt_embeds,
792
+ cross_attention_kwargs=cross_attention_kwargs,
793
+ down_block_additional_residuals=down_block_res_samples,
794
+ mid_block_additional_residual=mid_block_res_sample,
795
+ return_dict=False,
796
+ )[0]
797
+
798
+ # perform guidance
799
+ if do_classifier_free_guidance:
800
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
801
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
802
+
803
+ # compute the previous noisy sample x_t -> x_t-1
804
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
805
+
806
+ # call the callback, if provided
807
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
808
+ progress_bar.update()
809
+ if callback is not None and i % callback_steps == 0:
810
+ callback(i, t, latents)
811
+
812
+ # If we do sequential model offloading, let's offload unet and controlnet
813
+ # manually for max memory savings
814
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
815
+ self.unet.to("cpu")
816
+ self.controlnet.to("cpu")
817
+ torch.cuda.empty_cache()
818
+
819
+ if not output_type == "latent":
820
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
821
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
822
+ else:
823
+ image = latents
824
+ has_nsfw_concept = None
825
+
826
+ if has_nsfw_concept is None:
827
+ do_denormalize = [True] * image.shape[0]
828
+ else:
829
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
830
+
831
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
832
+
833
+ # Offload last model to CPU
834
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
835
+ self.final_offload_hook.offload()
836
+
837
+ if not return_dict:
838
+ return (image, has_nsfw_concept)
839
+
840
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
magicanimate/models/unet.py CHANGED
@@ -1,508 +1,508 @@
1
- # *************************************************************************
2
- # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
- # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
- # ytedance Inc..
5
- # *************************************************************************
6
-
7
- # Adapted from https://github.com/guoyww/AnimateDiff
8
-
9
- # Copyright 2023 The HuggingFace Team. All rights reserved.
10
- #
11
- # Licensed under the Apache License, Version 2.0 (the "License");
12
- # you may not use this file except in compliance with the License.
13
- # You may obtain a copy of the License at
14
- #
15
- # http://www.apache.org/licenses/LICENSE-2.0
16
- #
17
- # Unless required by applicable law or agreed to in writing, software
18
- # distributed under the License is distributed on an "AS IS" BASIS,
19
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
- # See the License for the specific language governing permissions and
21
- # limitations under the License.
22
- from dataclasses import dataclass
23
- from typing import List, Optional, Tuple, Union
24
-
25
- import os
26
- import json
27
- import pdb
28
-
29
- import torch
30
- import torch.nn as nn
31
- import torch.utils.checkpoint
32
-
33
- from diffusers.configuration_utils import ConfigMixin, register_to_config
34
- from diffusers.models.modeling_utils import ModelMixin
35
- from diffusers.utils import BaseOutput, logging
36
- from diffusers.models.embeddings import TimestepEmbedding, Timesteps
37
- from .unet_3d_blocks import (
38
- CrossAttnDownBlock3D,
39
- CrossAttnUpBlock3D,
40
- DownBlock3D,
41
- UNetMidBlock3DCrossAttn,
42
- UpBlock3D,
43
- get_down_block,
44
- get_up_block,
45
- )
46
- from .resnet import InflatedConv3d
47
-
48
-
49
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
50
-
51
-
52
- @dataclass
53
- class UNet3DConditionOutput(BaseOutput):
54
- sample: torch.FloatTensor
55
-
56
-
57
- class UNet3DConditionModel(ModelMixin, ConfigMixin):
58
- _supports_gradient_checkpointing = True
59
-
60
- @register_to_config
61
- def __init__(
62
- self,
63
- sample_size: Optional[int] = None,
64
- in_channels: int = 4,
65
- out_channels: int = 4,
66
- center_input_sample: bool = False,
67
- flip_sin_to_cos: bool = True,
68
- freq_shift: int = 0,
69
- down_block_types: Tuple[str] = (
70
- "CrossAttnDownBlock3D",
71
- "CrossAttnDownBlock3D",
72
- "CrossAttnDownBlock3D",
73
- "DownBlock3D",
74
- ),
75
- mid_block_type: str = "UNetMidBlock3DCrossAttn",
76
- up_block_types: Tuple[str] = (
77
- "UpBlock3D",
78
- "CrossAttnUpBlock3D",
79
- "CrossAttnUpBlock3D",
80
- "CrossAttnUpBlock3D"
81
- ),
82
- only_cross_attention: Union[bool, Tuple[bool]] = False,
83
- block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
84
- layers_per_block: int = 2,
85
- downsample_padding: int = 1,
86
- mid_block_scale_factor: float = 1,
87
- act_fn: str = "silu",
88
- norm_num_groups: int = 32,
89
- norm_eps: float = 1e-5,
90
- cross_attention_dim: int = 1280,
91
- attention_head_dim: Union[int, Tuple[int]] = 8,
92
- dual_cross_attention: bool = False,
93
- use_linear_projection: bool = False,
94
- class_embed_type: Optional[str] = None,
95
- num_class_embeds: Optional[int] = None,
96
- upcast_attention: bool = False,
97
- resnet_time_scale_shift: str = "default",
98
-
99
- # Additional
100
- use_motion_module = False,
101
- motion_module_resolutions = ( 1,2,4,8 ),
102
- motion_module_mid_block = False,
103
- motion_module_decoder_only = False,
104
- motion_module_type = None,
105
- motion_module_kwargs = {},
106
- unet_use_cross_frame_attention = None,
107
- unet_use_temporal_attention = None,
108
- ):
109
- super().__init__()
110
-
111
- self.sample_size = sample_size
112
- time_embed_dim = block_out_channels[0] * 4
113
-
114
- # input
115
- self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
116
-
117
- # time
118
- self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
119
- timestep_input_dim = block_out_channels[0]
120
-
121
- self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
122
-
123
- # class embedding
124
- if class_embed_type is None and num_class_embeds is not None:
125
- self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
126
- elif class_embed_type == "timestep":
127
- self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
128
- elif class_embed_type == "identity":
129
- self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
130
- else:
131
- self.class_embedding = None
132
-
133
- self.down_blocks = nn.ModuleList([])
134
- self.mid_block = None
135
- self.up_blocks = nn.ModuleList([])
136
-
137
- if isinstance(only_cross_attention, bool):
138
- only_cross_attention = [only_cross_attention] * len(down_block_types)
139
-
140
- if isinstance(attention_head_dim, int):
141
- attention_head_dim = (attention_head_dim,) * len(down_block_types)
142
-
143
- # down
144
- output_channel = block_out_channels[0]
145
- for i, down_block_type in enumerate(down_block_types):
146
- res = 2 ** i
147
- input_channel = output_channel
148
- output_channel = block_out_channels[i]
149
- is_final_block = i == len(block_out_channels) - 1
150
-
151
- down_block = get_down_block(
152
- down_block_type,
153
- num_layers=layers_per_block,
154
- in_channels=input_channel,
155
- out_channels=output_channel,
156
- temb_channels=time_embed_dim,
157
- add_downsample=not is_final_block,
158
- resnet_eps=norm_eps,
159
- resnet_act_fn=act_fn,
160
- resnet_groups=norm_num_groups,
161
- cross_attention_dim=cross_attention_dim,
162
- attn_num_head_channels=attention_head_dim[i],
163
- downsample_padding=downsample_padding,
164
- dual_cross_attention=dual_cross_attention,
165
- use_linear_projection=use_linear_projection,
166
- only_cross_attention=only_cross_attention[i],
167
- upcast_attention=upcast_attention,
168
- resnet_time_scale_shift=resnet_time_scale_shift,
169
-
170
- unet_use_cross_frame_attention=unet_use_cross_frame_attention,
171
- unet_use_temporal_attention=unet_use_temporal_attention,
172
-
173
- use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only),
174
- motion_module_type=motion_module_type,
175
- motion_module_kwargs=motion_module_kwargs,
176
- )
177
- self.down_blocks.append(down_block)
178
-
179
- # mid
180
- if mid_block_type == "UNetMidBlock3DCrossAttn":
181
- self.mid_block = UNetMidBlock3DCrossAttn(
182
- in_channels=block_out_channels[-1],
183
- temb_channels=time_embed_dim,
184
- resnet_eps=norm_eps,
185
- resnet_act_fn=act_fn,
186
- output_scale_factor=mid_block_scale_factor,
187
- resnet_time_scale_shift=resnet_time_scale_shift,
188
- cross_attention_dim=cross_attention_dim,
189
- attn_num_head_channels=attention_head_dim[-1],
190
- resnet_groups=norm_num_groups,
191
- dual_cross_attention=dual_cross_attention,
192
- use_linear_projection=use_linear_projection,
193
- upcast_attention=upcast_attention,
194
-
195
- unet_use_cross_frame_attention=unet_use_cross_frame_attention,
196
- unet_use_temporal_attention=unet_use_temporal_attention,
197
-
198
- use_motion_module=use_motion_module and motion_module_mid_block,
199
- motion_module_type=motion_module_type,
200
- motion_module_kwargs=motion_module_kwargs,
201
- )
202
- else:
203
- raise ValueError(f"unknown mid_block_type : {mid_block_type}")
204
-
205
- # count how many layers upsample the videos
206
- self.num_upsamplers = 0
207
-
208
- # up
209
- reversed_block_out_channels = list(reversed(block_out_channels))
210
- reversed_attention_head_dim = list(reversed(attention_head_dim))
211
- only_cross_attention = list(reversed(only_cross_attention))
212
- output_channel = reversed_block_out_channels[0]
213
- for i, up_block_type in enumerate(up_block_types):
214
- res = 2 ** (3 - i)
215
- is_final_block = i == len(block_out_channels) - 1
216
-
217
- prev_output_channel = output_channel
218
- output_channel = reversed_block_out_channels[i]
219
- input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
220
-
221
- # add upsample block for all BUT final layer
222
- if not is_final_block:
223
- add_upsample = True
224
- self.num_upsamplers += 1
225
- else:
226
- add_upsample = False
227
-
228
- up_block = get_up_block(
229
- up_block_type,
230
- num_layers=layers_per_block + 1,
231
- in_channels=input_channel,
232
- out_channels=output_channel,
233
- prev_output_channel=prev_output_channel,
234
- temb_channels=time_embed_dim,
235
- add_upsample=add_upsample,
236
- resnet_eps=norm_eps,
237
- resnet_act_fn=act_fn,
238
- resnet_groups=norm_num_groups,
239
- cross_attention_dim=cross_attention_dim,
240
- attn_num_head_channels=reversed_attention_head_dim[i],
241
- dual_cross_attention=dual_cross_attention,
242
- use_linear_projection=use_linear_projection,
243
- only_cross_attention=only_cross_attention[i],
244
- upcast_attention=upcast_attention,
245
- resnet_time_scale_shift=resnet_time_scale_shift,
246
-
247
- unet_use_cross_frame_attention=unet_use_cross_frame_attention,
248
- unet_use_temporal_attention=unet_use_temporal_attention,
249
-
250
- use_motion_module=use_motion_module and (res in motion_module_resolutions),
251
- motion_module_type=motion_module_type,
252
- motion_module_kwargs=motion_module_kwargs,
253
- )
254
- self.up_blocks.append(up_block)
255
- prev_output_channel = output_channel
256
-
257
- # out
258
- self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
259
- self.conv_act = nn.SiLU()
260
- self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
261
-
262
- def set_attention_slice(self, slice_size):
263
- r"""
264
- Enable sliced attention computation.
265
-
266
- When this option is enabled, the attention module will split the input tensor in slices, to compute attention
267
- in several steps. This is useful to save some memory in exchange for a small speed decrease.
268
-
269
- Args:
270
- slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
271
- When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
272
- `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
273
- provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
274
- must be a multiple of `slice_size`.
275
- """
276
- sliceable_head_dims = []
277
-
278
- def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
279
- if hasattr(module, "set_attention_slice"):
280
- sliceable_head_dims.append(module.sliceable_head_dim)
281
-
282
- for child in module.children():
283
- fn_recursive_retrieve_slicable_dims(child)
284
-
285
- # retrieve number of attention layers
286
- for module in self.children():
287
- fn_recursive_retrieve_slicable_dims(module)
288
-
289
- num_slicable_layers = len(sliceable_head_dims)
290
-
291
- if slice_size == "auto":
292
- # half the attention head size is usually a good trade-off between
293
- # speed and memory
294
- slice_size = [dim // 2 for dim in sliceable_head_dims]
295
- elif slice_size == "max":
296
- # make smallest slice possible
297
- slice_size = num_slicable_layers * [1]
298
-
299
- slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
300
-
301
- if len(slice_size) != len(sliceable_head_dims):
302
- raise ValueError(
303
- f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
304
- f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
305
- )
306
-
307
- for i in range(len(slice_size)):
308
- size = slice_size[i]
309
- dim = sliceable_head_dims[i]
310
- if size is not None and size > dim:
311
- raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
312
-
313
- # Recursively walk through all the children.
314
- # Any children which exposes the set_attention_slice method
315
- # gets the message
316
- def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
317
- if hasattr(module, "set_attention_slice"):
318
- module.set_attention_slice(slice_size.pop())
319
-
320
- for child in module.children():
321
- fn_recursive_set_attention_slice(child, slice_size)
322
-
323
- reversed_slice_size = list(reversed(slice_size))
324
- for module in self.children():
325
- fn_recursive_set_attention_slice(module, reversed_slice_size)
326
-
327
- def _set_gradient_checkpointing(self, module, value=False):
328
- if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
329
- module.gradient_checkpointing = value
330
-
331
- def forward(
332
- self,
333
- sample: torch.FloatTensor,
334
- timestep: Union[torch.Tensor, float, int],
335
- encoder_hidden_states: torch.Tensor,
336
- class_labels: Optional[torch.Tensor] = None,
337
- attention_mask: Optional[torch.Tensor] = None,
338
- return_dict: bool = True,
339
- ) -> Union[UNet3DConditionOutput, Tuple]:
340
- r"""
341
- Args:
342
- sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
343
- timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
344
- encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
345
- return_dict (`bool`, *optional*, defaults to `True`):
346
- Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
347
-
348
- Returns:
349
- [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
350
- [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
351
- returning a tuple, the first element is the sample tensor.
352
- """
353
- # By default samples have to be AT least a multiple of the overall upsampling factor.
354
- # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
355
- # However, the upsampling interpolation output size can be forced to fit any upsampling size
356
- # on the fly if necessary.
357
- default_overall_up_factor = 2**self.num_upsamplers
358
-
359
- # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
360
- forward_upsample_size = False
361
- upsample_size = None
362
-
363
- if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
364
- logger.info("Forward upsample size to force interpolation output size.")
365
- forward_upsample_size = True
366
-
367
- # prepare attention_mask
368
- if attention_mask is not None:
369
- attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
370
- attention_mask = attention_mask.unsqueeze(1)
371
-
372
- # center input if necessary
373
- if self.config.center_input_sample:
374
- sample = 2 * sample - 1.0
375
-
376
- # time
377
- timesteps = timestep
378
- if not torch.is_tensor(timesteps):
379
- # This would be a good case for the `match` statement (Python 3.10+)
380
- is_mps = sample.device.type == "mps"
381
- if isinstance(timestep, float):
382
- dtype = torch.float32 if is_mps else torch.float64
383
- else:
384
- dtype = torch.int32 if is_mps else torch.int64
385
- timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
386
- elif len(timesteps.shape) == 0:
387
- timesteps = timesteps[None].to(sample.device)
388
-
389
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
390
- timesteps = timesteps.expand(sample.shape[0])
391
-
392
- t_emb = self.time_proj(timesteps)
393
-
394
- # timesteps does not contain any weights and will always return f32 tensors
395
- # but time_embedding might actually be running in fp16. so we need to cast here.
396
- # there might be better ways to encapsulate this.
397
- t_emb = t_emb.to(dtype=self.dtype)
398
- emb = self.time_embedding(t_emb)
399
-
400
- if self.class_embedding is not None:
401
- if class_labels is None:
402
- raise ValueError("class_labels should be provided when num_class_embeds > 0")
403
-
404
- if self.config.class_embed_type == "timestep":
405
- class_labels = self.time_proj(class_labels)
406
-
407
- class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
408
- emb = emb + class_emb
409
-
410
- # pre-process
411
- sample = self.conv_in(sample)
412
-
413
- # down
414
- down_block_res_samples = (sample,)
415
- for downsample_block in self.down_blocks:
416
- if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
417
- sample, res_samples = downsample_block(
418
- hidden_states=sample,
419
- temb=emb,
420
- encoder_hidden_states=encoder_hidden_states,
421
- attention_mask=attention_mask,
422
- )
423
- else:
424
- sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states)
425
-
426
- down_block_res_samples += res_samples
427
-
428
- # mid
429
- sample = self.mid_block(
430
- sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
431
- )
432
-
433
- # up
434
- for i, upsample_block in enumerate(self.up_blocks):
435
- is_final_block = i == len(self.up_blocks) - 1
436
-
437
- res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
438
- down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
439
-
440
- # if we have not reached the final block and need to forward the
441
- # upsample size, we do it here
442
- if not is_final_block and forward_upsample_size:
443
- upsample_size = down_block_res_samples[-1].shape[2:]
444
-
445
- if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
446
- sample = upsample_block(
447
- hidden_states=sample,
448
- temb=emb,
449
- res_hidden_states_tuple=res_samples,
450
- encoder_hidden_states=encoder_hidden_states,
451
- upsample_size=upsample_size,
452
- attention_mask=attention_mask,
453
- )
454
- else:
455
- sample = upsample_block(
456
- hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states,
457
- )
458
-
459
- # post-process
460
- sample = self.conv_norm_out(sample)
461
- sample = self.conv_act(sample)
462
- sample = self.conv_out(sample)
463
-
464
- if not return_dict:
465
- return (sample,)
466
-
467
- return UNet3DConditionOutput(sample=sample)
468
-
469
- @classmethod
470
- def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None):
471
- if subfolder is not None:
472
- pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
473
- print(f"loaded temporal unet's pretrained weights from {pretrained_model_path} ...")
474
-
475
- config_file = os.path.join(pretrained_model_path, 'config.json')
476
- if not os.path.isfile(config_file):
477
- raise RuntimeError(f"{config_file} does not exist")
478
- with open(config_file, "r") as f:
479
- config = json.load(f)
480
- config["_class_name"] = cls.__name__
481
- config["down_block_types"] = [
482
- "CrossAttnDownBlock3D",
483
- "CrossAttnDownBlock3D",
484
- "CrossAttnDownBlock3D",
485
- "DownBlock3D"
486
- ]
487
- config["up_block_types"] = [
488
- "UpBlock3D",
489
- "CrossAttnUpBlock3D",
490
- "CrossAttnUpBlock3D",
491
- "CrossAttnUpBlock3D"
492
- ]
493
-
494
- from diffusers.utils import WEIGHTS_NAME
495
- model = cls.from_config(config, **unet_additional_kwargs)
496
- model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
497
- if not os.path.isfile(model_file):
498
- raise RuntimeError(f"{model_file} does not exist")
499
- state_dict = torch.load(model_file, map_location="cpu")
500
-
501
- m, u = model.load_state_dict(state_dict, strict=False)
502
- print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
503
- # print(f"### missing keys:\n{m}\n### unexpected keys:\n{u}\n")
504
-
505
- params = [p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()]
506
- print(f"### Temporal Module Parameters: {sum(params) / 1e6} M")
507
-
508
- return model
 
1
+ # *************************************************************************
2
+ # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
+ # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
+ # ytedance Inc..
5
+ # *************************************************************************
6
+
7
+ # Adapted from https://github.com/guoyww/AnimateDiff
8
+
9
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+ from dataclasses import dataclass
23
+ from typing import List, Optional, Tuple, Union
24
+
25
+ import os
26
+ import json
27
+ import pdb
28
+
29
+ import torch
30
+ import torch.nn as nn
31
+ import torch.utils.checkpoint
32
+
33
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
34
+ from diffusers.models.modeling_utils import ModelMixin
35
+ from diffusers.utils import BaseOutput, logging
36
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
37
+ from .unet_3d_blocks import (
38
+ CrossAttnDownBlock3D,
39
+ CrossAttnUpBlock3D,
40
+ DownBlock3D,
41
+ UNetMidBlock3DCrossAttn,
42
+ UpBlock3D,
43
+ get_down_block,
44
+ get_up_block,
45
+ )
46
+ from .resnet import InflatedConv3d
47
+
48
+
49
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
50
+
51
+
52
+ @dataclass
53
+ class UNet3DConditionOutput(BaseOutput):
54
+ sample: torch.FloatTensor
55
+
56
+
57
+ class UNet3DConditionModel(ModelMixin, ConfigMixin):
58
+ _supports_gradient_checkpointing = True
59
+
60
+ @register_to_config
61
+ def __init__(
62
+ self,
63
+ sample_size: Optional[int] = None,
64
+ in_channels: int = 4,
65
+ out_channels: int = 4,
66
+ center_input_sample: bool = False,
67
+ flip_sin_to_cos: bool = True,
68
+ freq_shift: int = 0,
69
+ down_block_types: Tuple[str] = (
70
+ "CrossAttnDownBlock3D",
71
+ "CrossAttnDownBlock3D",
72
+ "CrossAttnDownBlock3D",
73
+ "DownBlock3D",
74
+ ),
75
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
76
+ up_block_types: Tuple[str] = (
77
+ "UpBlock3D",
78
+ "CrossAttnUpBlock3D",
79
+ "CrossAttnUpBlock3D",
80
+ "CrossAttnUpBlock3D"
81
+ ),
82
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
83
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
84
+ layers_per_block: int = 2,
85
+ downsample_padding: int = 1,
86
+ mid_block_scale_factor: float = 1,
87
+ act_fn: str = "silu",
88
+ norm_num_groups: int = 32,
89
+ norm_eps: float = 1e-5,
90
+ cross_attention_dim: int = 1280,
91
+ attention_head_dim: Union[int, Tuple[int]] = 8,
92
+ dual_cross_attention: bool = False,
93
+ use_linear_projection: bool = False,
94
+ class_embed_type: Optional[str] = None,
95
+ num_class_embeds: Optional[int] = None,
96
+ upcast_attention: bool = False,
97
+ resnet_time_scale_shift: str = "default",
98
+
99
+ # Additional
100
+ use_motion_module = False,
101
+ motion_module_resolutions = ( 1,2,4,8 ),
102
+ motion_module_mid_block = False,
103
+ motion_module_decoder_only = False,
104
+ motion_module_type = None,
105
+ motion_module_kwargs = {},
106
+ unet_use_cross_frame_attention = None,
107
+ unet_use_temporal_attention = None,
108
+ ):
109
+ super().__init__()
110
+
111
+ self.sample_size = sample_size
112
+ time_embed_dim = block_out_channels[0] * 4
113
+
114
+ # input
115
+ self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
116
+
117
+ # time
118
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
119
+ timestep_input_dim = block_out_channels[0]
120
+
121
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
122
+
123
+ # class embedding
124
+ if class_embed_type is None and num_class_embeds is not None:
125
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
126
+ elif class_embed_type == "timestep":
127
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
128
+ elif class_embed_type == "identity":
129
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
130
+ else:
131
+ self.class_embedding = None
132
+
133
+ self.down_blocks = nn.ModuleList([])
134
+ self.mid_block = None
135
+ self.up_blocks = nn.ModuleList([])
136
+
137
+ if isinstance(only_cross_attention, bool):
138
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
139
+
140
+ if isinstance(attention_head_dim, int):
141
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
142
+
143
+ # down
144
+ output_channel = block_out_channels[0]
145
+ for i, down_block_type in enumerate(down_block_types):
146
+ res = 2 ** i
147
+ input_channel = output_channel
148
+ output_channel = block_out_channels[i]
149
+ is_final_block = i == len(block_out_channels) - 1
150
+
151
+ down_block = get_down_block(
152
+ down_block_type,
153
+ num_layers=layers_per_block,
154
+ in_channels=input_channel,
155
+ out_channels=output_channel,
156
+ temb_channels=time_embed_dim,
157
+ add_downsample=not is_final_block,
158
+ resnet_eps=norm_eps,
159
+ resnet_act_fn=act_fn,
160
+ resnet_groups=norm_num_groups,
161
+ cross_attention_dim=cross_attention_dim,
162
+ attn_num_head_channels=attention_head_dim[i],
163
+ downsample_padding=downsample_padding,
164
+ dual_cross_attention=dual_cross_attention,
165
+ use_linear_projection=use_linear_projection,
166
+ only_cross_attention=only_cross_attention[i],
167
+ upcast_attention=upcast_attention,
168
+ resnet_time_scale_shift=resnet_time_scale_shift,
169
+
170
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
171
+ unet_use_temporal_attention=unet_use_temporal_attention,
172
+
173
+ use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only),
174
+ motion_module_type=motion_module_type,
175
+ motion_module_kwargs=motion_module_kwargs,
176
+ )
177
+ self.down_blocks.append(down_block)
178
+
179
+ # mid
180
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
181
+ self.mid_block = UNetMidBlock3DCrossAttn(
182
+ in_channels=block_out_channels[-1],
183
+ temb_channels=time_embed_dim,
184
+ resnet_eps=norm_eps,
185
+ resnet_act_fn=act_fn,
186
+ output_scale_factor=mid_block_scale_factor,
187
+ resnet_time_scale_shift=resnet_time_scale_shift,
188
+ cross_attention_dim=cross_attention_dim,
189
+ attn_num_head_channels=attention_head_dim[-1],
190
+ resnet_groups=norm_num_groups,
191
+ dual_cross_attention=dual_cross_attention,
192
+ use_linear_projection=use_linear_projection,
193
+ upcast_attention=upcast_attention,
194
+
195
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
196
+ unet_use_temporal_attention=unet_use_temporal_attention,
197
+
198
+ use_motion_module=use_motion_module and motion_module_mid_block,
199
+ motion_module_type=motion_module_type,
200
+ motion_module_kwargs=motion_module_kwargs,
201
+ )
202
+ else:
203
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
204
+
205
+ # count how many layers upsample the videos
206
+ self.num_upsamplers = 0
207
+
208
+ # up
209
+ reversed_block_out_channels = list(reversed(block_out_channels))
210
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
211
+ only_cross_attention = list(reversed(only_cross_attention))
212
+ output_channel = reversed_block_out_channels[0]
213
+ for i, up_block_type in enumerate(up_block_types):
214
+ res = 2 ** (3 - i)
215
+ is_final_block = i == len(block_out_channels) - 1
216
+
217
+ prev_output_channel = output_channel
218
+ output_channel = reversed_block_out_channels[i]
219
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
220
+
221
+ # add upsample block for all BUT final layer
222
+ if not is_final_block:
223
+ add_upsample = True
224
+ self.num_upsamplers += 1
225
+ else:
226
+ add_upsample = False
227
+
228
+ up_block = get_up_block(
229
+ up_block_type,
230
+ num_layers=layers_per_block + 1,
231
+ in_channels=input_channel,
232
+ out_channels=output_channel,
233
+ prev_output_channel=prev_output_channel,
234
+ temb_channels=time_embed_dim,
235
+ add_upsample=add_upsample,
236
+ resnet_eps=norm_eps,
237
+ resnet_act_fn=act_fn,
238
+ resnet_groups=norm_num_groups,
239
+ cross_attention_dim=cross_attention_dim,
240
+ attn_num_head_channels=reversed_attention_head_dim[i],
241
+ dual_cross_attention=dual_cross_attention,
242
+ use_linear_projection=use_linear_projection,
243
+ only_cross_attention=only_cross_attention[i],
244
+ upcast_attention=upcast_attention,
245
+ resnet_time_scale_shift=resnet_time_scale_shift,
246
+
247
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
248
+ unet_use_temporal_attention=unet_use_temporal_attention,
249
+
250
+ use_motion_module=use_motion_module and (res in motion_module_resolutions),
251
+ motion_module_type=motion_module_type,
252
+ motion_module_kwargs=motion_module_kwargs,
253
+ )
254
+ self.up_blocks.append(up_block)
255
+ prev_output_channel = output_channel
256
+
257
+ # out
258
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
259
+ self.conv_act = nn.SiLU()
260
+ self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
261
+
262
+ def set_attention_slice(self, slice_size):
263
+ r"""
264
+ Enable sliced attention computation.
265
+
266
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
267
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
268
+
269
+ Args:
270
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
271
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
272
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
273
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
274
+ must be a multiple of `slice_size`.
275
+ """
276
+ sliceable_head_dims = []
277
+
278
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
279
+ if hasattr(module, "set_attention_slice"):
280
+ sliceable_head_dims.append(module.sliceable_head_dim)
281
+
282
+ for child in module.children():
283
+ fn_recursive_retrieve_slicable_dims(child)
284
+
285
+ # retrieve number of attention layers
286
+ for module in self.children():
287
+ fn_recursive_retrieve_slicable_dims(module)
288
+
289
+ num_slicable_layers = len(sliceable_head_dims)
290
+
291
+ if slice_size == "auto":
292
+ # half the attention head size is usually a good trade-off between
293
+ # speed and memory
294
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
295
+ elif slice_size == "max":
296
+ # make smallest slice possible
297
+ slice_size = num_slicable_layers * [1]
298
+
299
+ slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
300
+
301
+ if len(slice_size) != len(sliceable_head_dims):
302
+ raise ValueError(
303
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
304
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
305
+ )
306
+
307
+ for i in range(len(slice_size)):
308
+ size = slice_size[i]
309
+ dim = sliceable_head_dims[i]
310
+ if size is not None and size > dim:
311
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
312
+
313
+ # Recursively walk through all the children.
314
+ # Any children which exposes the set_attention_slice method
315
+ # gets the message
316
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
317
+ if hasattr(module, "set_attention_slice"):
318
+ module.set_attention_slice(slice_size.pop())
319
+
320
+ for child in module.children():
321
+ fn_recursive_set_attention_slice(child, slice_size)
322
+
323
+ reversed_slice_size = list(reversed(slice_size))
324
+ for module in self.children():
325
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
326
+
327
+ def _set_gradient_checkpointing(self, module, value=False):
328
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
329
+ module.gradient_checkpointing = value
330
+
331
+ def forward(
332
+ self,
333
+ sample: torch.FloatTensor,
334
+ timestep: Union[torch.Tensor, float, int],
335
+ encoder_hidden_states: torch.Tensor,
336
+ class_labels: Optional[torch.Tensor] = None,
337
+ attention_mask: Optional[torch.Tensor] = None,
338
+ return_dict: bool = True,
339
+ ) -> Union[UNet3DConditionOutput, Tuple]:
340
+ r"""
341
+ Args:
342
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
343
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
344
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
345
+ return_dict (`bool`, *optional*, defaults to `True`):
346
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
347
+
348
+ Returns:
349
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
350
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
351
+ returning a tuple, the first element is the sample tensor.
352
+ """
353
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
354
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
355
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
356
+ # on the fly if necessary.
357
+ default_overall_up_factor = 2**self.num_upsamplers
358
+
359
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
360
+ forward_upsample_size = False
361
+ upsample_size = None
362
+
363
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
364
+ logger.info("Forward upsample size to force interpolation output size.")
365
+ forward_upsample_size = True
366
+
367
+ # prepare attention_mask
368
+ if attention_mask is not None:
369
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
370
+ attention_mask = attention_mask.unsqueeze(1)
371
+
372
+ # center input if necessary
373
+ if self.config.center_input_sample:
374
+ sample = 2 * sample - 1.0
375
+
376
+ # time
377
+ timesteps = timestep
378
+ if not torch.is_tensor(timesteps):
379
+ # This would be a good case for the `match` statement (Python 3.10+)
380
+ is_mps = sample.device.type == "mps"
381
+ if isinstance(timestep, float):
382
+ dtype = torch.float32 if is_mps else torch.float64
383
+ else:
384
+ dtype = torch.int32 if is_mps else torch.int64
385
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
386
+ elif len(timesteps.shape) == 0:
387
+ timesteps = timesteps[None].to(sample.device)
388
+
389
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
390
+ timesteps = timesteps.expand(sample.shape[0])
391
+
392
+ t_emb = self.time_proj(timesteps)
393
+
394
+ # timesteps does not contain any weights and will always return f32 tensors
395
+ # but time_embedding might actually be running in fp16. so we need to cast here.
396
+ # there might be better ways to encapsulate this.
397
+ t_emb = t_emb.to(dtype=self.dtype)
398
+ emb = self.time_embedding(t_emb)
399
+
400
+ if self.class_embedding is not None:
401
+ if class_labels is None:
402
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
403
+
404
+ if self.config.class_embed_type == "timestep":
405
+ class_labels = self.time_proj(class_labels)
406
+
407
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
408
+ emb = emb + class_emb
409
+
410
+ # pre-process
411
+ sample = self.conv_in(sample)
412
+
413
+ # down
414
+ down_block_res_samples = (sample,)
415
+ for downsample_block in self.down_blocks:
416
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
417
+ sample, res_samples = downsample_block(
418
+ hidden_states=sample,
419
+ temb=emb,
420
+ encoder_hidden_states=encoder_hidden_states,
421
+ attention_mask=attention_mask,
422
+ )
423
+ else:
424
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states)
425
+
426
+ down_block_res_samples += res_samples
427
+
428
+ # mid
429
+ sample = self.mid_block(
430
+ sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
431
+ )
432
+
433
+ # up
434
+ for i, upsample_block in enumerate(self.up_blocks):
435
+ is_final_block = i == len(self.up_blocks) - 1
436
+
437
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
438
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
439
+
440
+ # if we have not reached the final block and need to forward the
441
+ # upsample size, we do it here
442
+ if not is_final_block and forward_upsample_size:
443
+ upsample_size = down_block_res_samples[-1].shape[2:]
444
+
445
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
446
+ sample = upsample_block(
447
+ hidden_states=sample,
448
+ temb=emb,
449
+ res_hidden_states_tuple=res_samples,
450
+ encoder_hidden_states=encoder_hidden_states,
451
+ upsample_size=upsample_size,
452
+ attention_mask=attention_mask,
453
+ )
454
+ else:
455
+ sample = upsample_block(
456
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states,
457
+ )
458
+
459
+ # post-process
460
+ sample = self.conv_norm_out(sample)
461
+ sample = self.conv_act(sample)
462
+ sample = self.conv_out(sample)
463
+
464
+ if not return_dict:
465
+ return (sample,)
466
+
467
+ return UNet3DConditionOutput(sample=sample)
468
+
469
+ @classmethod
470
+ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None):
471
+ if subfolder is not None:
472
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
473
+ print(f"loaded temporal unet's pretrained weights from {pretrained_model_path} ...")
474
+
475
+ config_file = os.path.join(pretrained_model_path, 'config.json')
476
+ if not os.path.isfile(config_file):
477
+ raise RuntimeError(f"{config_file} does not exist")
478
+ with open(config_file, "r") as f:
479
+ config = json.load(f)
480
+ config["_class_name"] = cls.__name__
481
+ config["down_block_types"] = [
482
+ "CrossAttnDownBlock3D",
483
+ "CrossAttnDownBlock3D",
484
+ "CrossAttnDownBlock3D",
485
+ "DownBlock3D"
486
+ ]
487
+ config["up_block_types"] = [
488
+ "UpBlock3D",
489
+ "CrossAttnUpBlock3D",
490
+ "CrossAttnUpBlock3D",
491
+ "CrossAttnUpBlock3D"
492
+ ]
493
+
494
+ from diffusers.utils import WEIGHTS_NAME
495
+ model = cls.from_config(config, **unet_additional_kwargs)
496
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
497
+ if not os.path.isfile(model_file):
498
+ raise RuntimeError(f"{model_file} does not exist")
499
+ state_dict = torch.load(model_file, map_location="cpu")
500
+
501
+ m, u = model.load_state_dict(state_dict, strict=False)
502
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
503
+ # print(f"### missing keys:\n{m}\n### unexpected keys:\n{u}\n")
504
+
505
+ params = [p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()]
506
+ print(f"### Temporal Module Parameters: {sum(params) / 1e6} M")
507
+
508
+ return model
magicanimate/models/unet_3d_blocks.py CHANGED
@@ -1,751 +1,751 @@
1
- # *************************************************************************
2
- # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
- # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
- # ytedance Inc..
5
- # *************************************************************************
6
-
7
- # Adapted from https://github.com/guoyww/AnimateDiff
8
-
9
- # Copyright 2023 The HuggingFace Team. All rights reserved.
10
- #
11
- # Licensed under the Apache License, Version 2.0 (the "License");
12
- # you may not use this file except in compliance with the License.
13
- # You may obtain a copy of the License at
14
- #
15
- # http://www.apache.org/licenses/LICENSE-2.0
16
- #
17
- # Unless required by applicable law or agreed to in writing, software
18
- # distributed under the License is distributed on an "AS IS" BASIS,
19
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
- # See the License for the specific language governing permissions and
21
- # limitations under the License.
22
- import torch
23
- from torch import nn
24
-
25
- from .attention import Transformer3DModel
26
- from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
27
- from .motion_module import get_motion_module
28
-
29
-
30
- def get_down_block(
31
- down_block_type,
32
- num_layers,
33
- in_channels,
34
- out_channels,
35
- temb_channels,
36
- add_downsample,
37
- resnet_eps,
38
- resnet_act_fn,
39
- attn_num_head_channels,
40
- resnet_groups=None,
41
- cross_attention_dim=None,
42
- downsample_padding=None,
43
- dual_cross_attention=False,
44
- use_linear_projection=False,
45
- only_cross_attention=False,
46
- upcast_attention=False,
47
- resnet_time_scale_shift="default",
48
-
49
- unet_use_cross_frame_attention=None,
50
- unet_use_temporal_attention=None,
51
-
52
- use_motion_module=None,
53
-
54
- motion_module_type=None,
55
- motion_module_kwargs=None,
56
- ):
57
- down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
58
- if down_block_type == "DownBlock3D":
59
- return DownBlock3D(
60
- num_layers=num_layers,
61
- in_channels=in_channels,
62
- out_channels=out_channels,
63
- temb_channels=temb_channels,
64
- add_downsample=add_downsample,
65
- resnet_eps=resnet_eps,
66
- resnet_act_fn=resnet_act_fn,
67
- resnet_groups=resnet_groups,
68
- downsample_padding=downsample_padding,
69
- resnet_time_scale_shift=resnet_time_scale_shift,
70
-
71
- use_motion_module=use_motion_module,
72
- motion_module_type=motion_module_type,
73
- motion_module_kwargs=motion_module_kwargs,
74
- )
75
- elif down_block_type == "CrossAttnDownBlock3D":
76
- if cross_attention_dim is None:
77
- raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
78
- return CrossAttnDownBlock3D(
79
- num_layers=num_layers,
80
- in_channels=in_channels,
81
- out_channels=out_channels,
82
- temb_channels=temb_channels,
83
- add_downsample=add_downsample,
84
- resnet_eps=resnet_eps,
85
- resnet_act_fn=resnet_act_fn,
86
- resnet_groups=resnet_groups,
87
- downsample_padding=downsample_padding,
88
- cross_attention_dim=cross_attention_dim,
89
- attn_num_head_channels=attn_num_head_channels,
90
- dual_cross_attention=dual_cross_attention,
91
- use_linear_projection=use_linear_projection,
92
- only_cross_attention=only_cross_attention,
93
- upcast_attention=upcast_attention,
94
- resnet_time_scale_shift=resnet_time_scale_shift,
95
-
96
- unet_use_cross_frame_attention=unet_use_cross_frame_attention,
97
- unet_use_temporal_attention=unet_use_temporal_attention,
98
-
99
- use_motion_module=use_motion_module,
100
- motion_module_type=motion_module_type,
101
- motion_module_kwargs=motion_module_kwargs,
102
- )
103
- raise ValueError(f"{down_block_type} does not exist.")
104
-
105
-
106
- def get_up_block(
107
- up_block_type,
108
- num_layers,
109
- in_channels,
110
- out_channels,
111
- prev_output_channel,
112
- temb_channels,
113
- add_upsample,
114
- resnet_eps,
115
- resnet_act_fn,
116
- attn_num_head_channels,
117
- resnet_groups=None,
118
- cross_attention_dim=None,
119
- dual_cross_attention=False,
120
- use_linear_projection=False,
121
- only_cross_attention=False,
122
- upcast_attention=False,
123
- resnet_time_scale_shift="default",
124
-
125
- unet_use_cross_frame_attention=None,
126
- unet_use_temporal_attention=None,
127
-
128
- use_motion_module=None,
129
- motion_module_type=None,
130
- motion_module_kwargs=None,
131
- ):
132
- up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
133
- if up_block_type == "UpBlock3D":
134
- return UpBlock3D(
135
- num_layers=num_layers,
136
- in_channels=in_channels,
137
- out_channels=out_channels,
138
- prev_output_channel=prev_output_channel,
139
- temb_channels=temb_channels,
140
- add_upsample=add_upsample,
141
- resnet_eps=resnet_eps,
142
- resnet_act_fn=resnet_act_fn,
143
- resnet_groups=resnet_groups,
144
- resnet_time_scale_shift=resnet_time_scale_shift,
145
-
146
- use_motion_module=use_motion_module,
147
- motion_module_type=motion_module_type,
148
- motion_module_kwargs=motion_module_kwargs,
149
- )
150
- elif up_block_type == "CrossAttnUpBlock3D":
151
- if cross_attention_dim is None:
152
- raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
153
- return CrossAttnUpBlock3D(
154
- num_layers=num_layers,
155
- in_channels=in_channels,
156
- out_channels=out_channels,
157
- prev_output_channel=prev_output_channel,
158
- temb_channels=temb_channels,
159
- add_upsample=add_upsample,
160
- resnet_eps=resnet_eps,
161
- resnet_act_fn=resnet_act_fn,
162
- resnet_groups=resnet_groups,
163
- cross_attention_dim=cross_attention_dim,
164
- attn_num_head_channels=attn_num_head_channels,
165
- dual_cross_attention=dual_cross_attention,
166
- use_linear_projection=use_linear_projection,
167
- only_cross_attention=only_cross_attention,
168
- upcast_attention=upcast_attention,
169
- resnet_time_scale_shift=resnet_time_scale_shift,
170
-
171
- unet_use_cross_frame_attention=unet_use_cross_frame_attention,
172
- unet_use_temporal_attention=unet_use_temporal_attention,
173
-
174
- use_motion_module=use_motion_module,
175
- motion_module_type=motion_module_type,
176
- motion_module_kwargs=motion_module_kwargs,
177
- )
178
- raise ValueError(f"{up_block_type} does not exist.")
179
-
180
-
181
- class UNetMidBlock3DCrossAttn(nn.Module):
182
- def __init__(
183
- self,
184
- in_channels: int,
185
- temb_channels: int,
186
- dropout: float = 0.0,
187
- num_layers: int = 1,
188
- resnet_eps: float = 1e-6,
189
- resnet_time_scale_shift: str = "default",
190
- resnet_act_fn: str = "swish",
191
- resnet_groups: int = 32,
192
- resnet_pre_norm: bool = True,
193
- attn_num_head_channels=1,
194
- output_scale_factor=1.0,
195
- cross_attention_dim=1280,
196
- dual_cross_attention=False,
197
- use_linear_projection=False,
198
- upcast_attention=False,
199
-
200
- unet_use_cross_frame_attention=None,
201
- unet_use_temporal_attention=None,
202
-
203
- use_motion_module=None,
204
-
205
- motion_module_type=None,
206
- motion_module_kwargs=None,
207
- ):
208
- super().__init__()
209
-
210
- self.has_cross_attention = True
211
- self.attn_num_head_channels = attn_num_head_channels
212
- resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
213
-
214
- # there is always at least one resnet
215
- resnets = [
216
- ResnetBlock3D(
217
- in_channels=in_channels,
218
- out_channels=in_channels,
219
- temb_channels=temb_channels,
220
- eps=resnet_eps,
221
- groups=resnet_groups,
222
- dropout=dropout,
223
- time_embedding_norm=resnet_time_scale_shift,
224
- non_linearity=resnet_act_fn,
225
- output_scale_factor=output_scale_factor,
226
- pre_norm=resnet_pre_norm,
227
- )
228
- ]
229
- attentions = []
230
- motion_modules = []
231
-
232
- for _ in range(num_layers):
233
- if dual_cross_attention:
234
- raise NotImplementedError
235
- attentions.append(
236
- Transformer3DModel(
237
- attn_num_head_channels,
238
- in_channels // attn_num_head_channels,
239
- in_channels=in_channels,
240
- num_layers=1,
241
- cross_attention_dim=cross_attention_dim,
242
- norm_num_groups=resnet_groups,
243
- use_linear_projection=use_linear_projection,
244
- upcast_attention=upcast_attention,
245
-
246
- unet_use_cross_frame_attention=unet_use_cross_frame_attention,
247
- unet_use_temporal_attention=unet_use_temporal_attention,
248
- )
249
- )
250
- motion_modules.append(
251
- get_motion_module(
252
- in_channels=in_channels,
253
- motion_module_type=motion_module_type,
254
- motion_module_kwargs=motion_module_kwargs,
255
- ) if use_motion_module else None
256
- )
257
- resnets.append(
258
- ResnetBlock3D(
259
- in_channels=in_channels,
260
- out_channels=in_channels,
261
- temb_channels=temb_channels,
262
- eps=resnet_eps,
263
- groups=resnet_groups,
264
- dropout=dropout,
265
- time_embedding_norm=resnet_time_scale_shift,
266
- non_linearity=resnet_act_fn,
267
- output_scale_factor=output_scale_factor,
268
- pre_norm=resnet_pre_norm,
269
- )
270
- )
271
-
272
- self.attentions = nn.ModuleList(attentions)
273
- self.resnets = nn.ModuleList(resnets)
274
- self.motion_modules = nn.ModuleList(motion_modules)
275
-
276
- def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
277
- hidden_states = self.resnets[0](hidden_states, temb)
278
- for attn, resnet, motion_module in zip(self.attentions, self.resnets[1:], self.motion_modules):
279
- hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
280
- hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
281
- hidden_states = resnet(hidden_states, temb)
282
-
283
- return hidden_states
284
-
285
-
286
- class CrossAttnDownBlock3D(nn.Module):
287
- def __init__(
288
- self,
289
- in_channels: int,
290
- out_channels: int,
291
- temb_channels: int,
292
- dropout: float = 0.0,
293
- num_layers: int = 1,
294
- resnet_eps: float = 1e-6,
295
- resnet_time_scale_shift: str = "default",
296
- resnet_act_fn: str = "swish",
297
- resnet_groups: int = 32,
298
- resnet_pre_norm: bool = True,
299
- attn_num_head_channels=1,
300
- cross_attention_dim=1280,
301
- output_scale_factor=1.0,
302
- downsample_padding=1,
303
- add_downsample=True,
304
- dual_cross_attention=False,
305
- use_linear_projection=False,
306
- only_cross_attention=False,
307
- upcast_attention=False,
308
-
309
- unet_use_cross_frame_attention=None,
310
- unet_use_temporal_attention=None,
311
-
312
- use_motion_module=None,
313
-
314
- motion_module_type=None,
315
- motion_module_kwargs=None,
316
- ):
317
- super().__init__()
318
- resnets = []
319
- attentions = []
320
- motion_modules = []
321
-
322
- self.has_cross_attention = True
323
- self.attn_num_head_channels = attn_num_head_channels
324
-
325
- for i in range(num_layers):
326
- in_channels = in_channels if i == 0 else out_channels
327
- resnets.append(
328
- ResnetBlock3D(
329
- in_channels=in_channels,
330
- out_channels=out_channels,
331
- temb_channels=temb_channels,
332
- eps=resnet_eps,
333
- groups=resnet_groups,
334
- dropout=dropout,
335
- time_embedding_norm=resnet_time_scale_shift,
336
- non_linearity=resnet_act_fn,
337
- output_scale_factor=output_scale_factor,
338
- pre_norm=resnet_pre_norm,
339
- )
340
- )
341
- if dual_cross_attention:
342
- raise NotImplementedError
343
- attentions.append(
344
- Transformer3DModel(
345
- attn_num_head_channels,
346
- out_channels // attn_num_head_channels,
347
- in_channels=out_channels,
348
- num_layers=1,
349
- cross_attention_dim=cross_attention_dim,
350
- norm_num_groups=resnet_groups,
351
- use_linear_projection=use_linear_projection,
352
- only_cross_attention=only_cross_attention,
353
- upcast_attention=upcast_attention,
354
-
355
- unet_use_cross_frame_attention=unet_use_cross_frame_attention,
356
- unet_use_temporal_attention=unet_use_temporal_attention,
357
- )
358
- )
359
- motion_modules.append(
360
- get_motion_module(
361
- in_channels=out_channels,
362
- motion_module_type=motion_module_type,
363
- motion_module_kwargs=motion_module_kwargs,
364
- ) if use_motion_module else None
365
- )
366
-
367
- self.attentions = nn.ModuleList(attentions)
368
- self.resnets = nn.ModuleList(resnets)
369
- self.motion_modules = nn.ModuleList(motion_modules)
370
-
371
- if add_downsample:
372
- self.downsamplers = nn.ModuleList(
373
- [
374
- Downsample3D(
375
- out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
376
- )
377
- ]
378
- )
379
- else:
380
- self.downsamplers = None
381
-
382
- self.gradient_checkpointing = False
383
-
384
- def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
385
- output_states = ()
386
-
387
- for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
388
- if self.training and self.gradient_checkpointing:
389
-
390
- def create_custom_forward(module, return_dict=None):
391
- def custom_forward(*inputs):
392
- if return_dict is not None:
393
- return module(*inputs, return_dict=return_dict)
394
- else:
395
- return module(*inputs)
396
-
397
- return custom_forward
398
-
399
- hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
400
- hidden_states = torch.utils.checkpoint.checkpoint(
401
- create_custom_forward(attn, return_dict=False),
402
- hidden_states,
403
- encoder_hidden_states,
404
- )[0]
405
- if motion_module is not None:
406
- hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
407
-
408
- else:
409
- hidden_states = resnet(hidden_states, temb)
410
- hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
411
-
412
- # add motion module
413
- hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
414
-
415
- output_states += (hidden_states,)
416
-
417
- if self.downsamplers is not None:
418
- for downsampler in self.downsamplers:
419
- hidden_states = downsampler(hidden_states)
420
-
421
- output_states += (hidden_states,)
422
-
423
- return hidden_states, output_states
424
-
425
-
426
- class DownBlock3D(nn.Module):
427
- def __init__(
428
- self,
429
- in_channels: int,
430
- out_channels: int,
431
- temb_channels: int,
432
- dropout: float = 0.0,
433
- num_layers: int = 1,
434
- resnet_eps: float = 1e-6,
435
- resnet_time_scale_shift: str = "default",
436
- resnet_act_fn: str = "swish",
437
- resnet_groups: int = 32,
438
- resnet_pre_norm: bool = True,
439
- output_scale_factor=1.0,
440
- add_downsample=True,
441
- downsample_padding=1,
442
-
443
- use_motion_module=None,
444
- motion_module_type=None,
445
- motion_module_kwargs=None,
446
- ):
447
- super().__init__()
448
- resnets = []
449
- motion_modules = []
450
-
451
- for i in range(num_layers):
452
- in_channels = in_channels if i == 0 else out_channels
453
- resnets.append(
454
- ResnetBlock3D(
455
- in_channels=in_channels,
456
- out_channels=out_channels,
457
- temb_channels=temb_channels,
458
- eps=resnet_eps,
459
- groups=resnet_groups,
460
- dropout=dropout,
461
- time_embedding_norm=resnet_time_scale_shift,
462
- non_linearity=resnet_act_fn,
463
- output_scale_factor=output_scale_factor,
464
- pre_norm=resnet_pre_norm,
465
- )
466
- )
467
- motion_modules.append(
468
- get_motion_module(
469
- in_channels=out_channels,
470
- motion_module_type=motion_module_type,
471
- motion_module_kwargs=motion_module_kwargs,
472
- ) if use_motion_module else None
473
- )
474
-
475
- self.resnets = nn.ModuleList(resnets)
476
- self.motion_modules = nn.ModuleList(motion_modules)
477
-
478
- if add_downsample:
479
- self.downsamplers = nn.ModuleList(
480
- [
481
- Downsample3D(
482
- out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
483
- )
484
- ]
485
- )
486
- else:
487
- self.downsamplers = None
488
-
489
- self.gradient_checkpointing = False
490
-
491
- def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
492
- output_states = ()
493
-
494
- for resnet, motion_module in zip(self.resnets, self.motion_modules):
495
- if self.training and self.gradient_checkpointing:
496
- def create_custom_forward(module):
497
- def custom_forward(*inputs):
498
- return module(*inputs)
499
-
500
- return custom_forward
501
-
502
- hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
503
- if motion_module is not None:
504
- hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
505
- else:
506
- hidden_states = resnet(hidden_states, temb)
507
-
508
- # add motion module
509
- hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
510
-
511
- output_states += (hidden_states,)
512
-
513
- if self.downsamplers is not None:
514
- for downsampler in self.downsamplers:
515
- hidden_states = downsampler(hidden_states)
516
-
517
- output_states += (hidden_states,)
518
-
519
- return hidden_states, output_states
520
-
521
-
522
- class CrossAttnUpBlock3D(nn.Module):
523
- def __init__(
524
- self,
525
- in_channels: int,
526
- out_channels: int,
527
- prev_output_channel: int,
528
- temb_channels: int,
529
- dropout: float = 0.0,
530
- num_layers: int = 1,
531
- resnet_eps: float = 1e-6,
532
- resnet_time_scale_shift: str = "default",
533
- resnet_act_fn: str = "swish",
534
- resnet_groups: int = 32,
535
- resnet_pre_norm: bool = True,
536
- attn_num_head_channels=1,
537
- cross_attention_dim=1280,
538
- output_scale_factor=1.0,
539
- add_upsample=True,
540
- dual_cross_attention=False,
541
- use_linear_projection=False,
542
- only_cross_attention=False,
543
- upcast_attention=False,
544
-
545
- unet_use_cross_frame_attention=None,
546
- unet_use_temporal_attention=None,
547
-
548
- use_motion_module=None,
549
-
550
- motion_module_type=None,
551
- motion_module_kwargs=None,
552
- ):
553
- super().__init__()
554
- resnets = []
555
- attentions = []
556
- motion_modules = []
557
-
558
- self.has_cross_attention = True
559
- self.attn_num_head_channels = attn_num_head_channels
560
-
561
- for i in range(num_layers):
562
- res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
563
- resnet_in_channels = prev_output_channel if i == 0 else out_channels
564
-
565
- resnets.append(
566
- ResnetBlock3D(
567
- in_channels=resnet_in_channels + res_skip_channels,
568
- out_channels=out_channels,
569
- temb_channels=temb_channels,
570
- eps=resnet_eps,
571
- groups=resnet_groups,
572
- dropout=dropout,
573
- time_embedding_norm=resnet_time_scale_shift,
574
- non_linearity=resnet_act_fn,
575
- output_scale_factor=output_scale_factor,
576
- pre_norm=resnet_pre_norm,
577
- )
578
- )
579
- if dual_cross_attention:
580
- raise NotImplementedError
581
- attentions.append(
582
- Transformer3DModel(
583
- attn_num_head_channels,
584
- out_channels // attn_num_head_channels,
585
- in_channels=out_channels,
586
- num_layers=1,
587
- cross_attention_dim=cross_attention_dim,
588
- norm_num_groups=resnet_groups,
589
- use_linear_projection=use_linear_projection,
590
- only_cross_attention=only_cross_attention,
591
- upcast_attention=upcast_attention,
592
-
593
- unet_use_cross_frame_attention=unet_use_cross_frame_attention,
594
- unet_use_temporal_attention=unet_use_temporal_attention,
595
- )
596
- )
597
- motion_modules.append(
598
- get_motion_module(
599
- in_channels=out_channels,
600
- motion_module_type=motion_module_type,
601
- motion_module_kwargs=motion_module_kwargs,
602
- ) if use_motion_module else None
603
- )
604
-
605
- self.attentions = nn.ModuleList(attentions)
606
- self.resnets = nn.ModuleList(resnets)
607
- self.motion_modules = nn.ModuleList(motion_modules)
608
-
609
- if add_upsample:
610
- self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
611
- else:
612
- self.upsamplers = None
613
-
614
- self.gradient_checkpointing = False
615
-
616
- def forward(
617
- self,
618
- hidden_states,
619
- res_hidden_states_tuple,
620
- temb=None,
621
- encoder_hidden_states=None,
622
- upsample_size=None,
623
- attention_mask=None,
624
- ):
625
- for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
626
- # pop res hidden states
627
- res_hidden_states = res_hidden_states_tuple[-1]
628
- res_hidden_states_tuple = res_hidden_states_tuple[:-1]
629
- hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
630
-
631
- if self.training and self.gradient_checkpointing:
632
-
633
- def create_custom_forward(module, return_dict=None):
634
- def custom_forward(*inputs):
635
- if return_dict is not None:
636
- return module(*inputs, return_dict=return_dict)
637
- else:
638
- return module(*inputs)
639
-
640
- return custom_forward
641
-
642
- hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
643
- hidden_states = torch.utils.checkpoint.checkpoint(
644
- create_custom_forward(attn, return_dict=False),
645
- hidden_states,
646
- encoder_hidden_states,
647
- )[0]
648
- if motion_module is not None:
649
- hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
650
-
651
- else:
652
- hidden_states = resnet(hidden_states, temb)
653
- hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
654
-
655
- # add motion module
656
- hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
657
-
658
- if self.upsamplers is not None:
659
- for upsampler in self.upsamplers:
660
- hidden_states = upsampler(hidden_states, upsample_size)
661
-
662
- return hidden_states
663
-
664
-
665
- class UpBlock3D(nn.Module):
666
- def __init__(
667
- self,
668
- in_channels: int,
669
- prev_output_channel: int,
670
- out_channels: int,
671
- temb_channels: int,
672
- dropout: float = 0.0,
673
- num_layers: int = 1,
674
- resnet_eps: float = 1e-6,
675
- resnet_time_scale_shift: str = "default",
676
- resnet_act_fn: str = "swish",
677
- resnet_groups: int = 32,
678
- resnet_pre_norm: bool = True,
679
- output_scale_factor=1.0,
680
- add_upsample=True,
681
-
682
- use_motion_module=None,
683
- motion_module_type=None,
684
- motion_module_kwargs=None,
685
- ):
686
- super().__init__()
687
- resnets = []
688
- motion_modules = []
689
-
690
- for i in range(num_layers):
691
- res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
692
- resnet_in_channels = prev_output_channel if i == 0 else out_channels
693
-
694
- resnets.append(
695
- ResnetBlock3D(
696
- in_channels=resnet_in_channels + res_skip_channels,
697
- out_channels=out_channels,
698
- temb_channels=temb_channels,
699
- eps=resnet_eps,
700
- groups=resnet_groups,
701
- dropout=dropout,
702
- time_embedding_norm=resnet_time_scale_shift,
703
- non_linearity=resnet_act_fn,
704
- output_scale_factor=output_scale_factor,
705
- pre_norm=resnet_pre_norm,
706
- )
707
- )
708
- motion_modules.append(
709
- get_motion_module(
710
- in_channels=out_channels,
711
- motion_module_type=motion_module_type,
712
- motion_module_kwargs=motion_module_kwargs,
713
- ) if use_motion_module else None
714
- )
715
-
716
- self.resnets = nn.ModuleList(resnets)
717
- self.motion_modules = nn.ModuleList(motion_modules)
718
-
719
- if add_upsample:
720
- self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
721
- else:
722
- self.upsamplers = None
723
-
724
- self.gradient_checkpointing = False
725
-
726
- def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, encoder_hidden_states=None,):
727
- for resnet, motion_module in zip(self.resnets, self.motion_modules):
728
- # pop res hidden states
729
- res_hidden_states = res_hidden_states_tuple[-1]
730
- res_hidden_states_tuple = res_hidden_states_tuple[:-1]
731
- hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
732
-
733
- if self.training and self.gradient_checkpointing:
734
- def create_custom_forward(module):
735
- def custom_forward(*inputs):
736
- return module(*inputs)
737
-
738
- return custom_forward
739
-
740
- hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
741
- if motion_module is not None:
742
- hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
743
- else:
744
- hidden_states = resnet(hidden_states, temb)
745
- hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
746
-
747
- if self.upsamplers is not None:
748
- for upsampler in self.upsamplers:
749
- hidden_states = upsampler(hidden_states, upsample_size)
750
-
751
  return hidden_states
 
1
+ # *************************************************************************
2
+ # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
+ # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
+ # ytedance Inc..
5
+ # *************************************************************************
6
+
7
+ # Adapted from https://github.com/guoyww/AnimateDiff
8
+
9
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+ import torch
23
+ from torch import nn
24
+
25
+ from .attention import Transformer3DModel
26
+ from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
27
+ from .motion_module import get_motion_module
28
+
29
+
30
+ def get_down_block(
31
+ down_block_type,
32
+ num_layers,
33
+ in_channels,
34
+ out_channels,
35
+ temb_channels,
36
+ add_downsample,
37
+ resnet_eps,
38
+ resnet_act_fn,
39
+ attn_num_head_channels,
40
+ resnet_groups=None,
41
+ cross_attention_dim=None,
42
+ downsample_padding=None,
43
+ dual_cross_attention=False,
44
+ use_linear_projection=False,
45
+ only_cross_attention=False,
46
+ upcast_attention=False,
47
+ resnet_time_scale_shift="default",
48
+
49
+ unet_use_cross_frame_attention=None,
50
+ unet_use_temporal_attention=None,
51
+
52
+ use_motion_module=None,
53
+
54
+ motion_module_type=None,
55
+ motion_module_kwargs=None,
56
+ ):
57
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
58
+ if down_block_type == "DownBlock3D":
59
+ return DownBlock3D(
60
+ num_layers=num_layers,
61
+ in_channels=in_channels,
62
+ out_channels=out_channels,
63
+ temb_channels=temb_channels,
64
+ add_downsample=add_downsample,
65
+ resnet_eps=resnet_eps,
66
+ resnet_act_fn=resnet_act_fn,
67
+ resnet_groups=resnet_groups,
68
+ downsample_padding=downsample_padding,
69
+ resnet_time_scale_shift=resnet_time_scale_shift,
70
+
71
+ use_motion_module=use_motion_module,
72
+ motion_module_type=motion_module_type,
73
+ motion_module_kwargs=motion_module_kwargs,
74
+ )
75
+ elif down_block_type == "CrossAttnDownBlock3D":
76
+ if cross_attention_dim is None:
77
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
78
+ return CrossAttnDownBlock3D(
79
+ num_layers=num_layers,
80
+ in_channels=in_channels,
81
+ out_channels=out_channels,
82
+ temb_channels=temb_channels,
83
+ add_downsample=add_downsample,
84
+ resnet_eps=resnet_eps,
85
+ resnet_act_fn=resnet_act_fn,
86
+ resnet_groups=resnet_groups,
87
+ downsample_padding=downsample_padding,
88
+ cross_attention_dim=cross_attention_dim,
89
+ attn_num_head_channels=attn_num_head_channels,
90
+ dual_cross_attention=dual_cross_attention,
91
+ use_linear_projection=use_linear_projection,
92
+ only_cross_attention=only_cross_attention,
93
+ upcast_attention=upcast_attention,
94
+ resnet_time_scale_shift=resnet_time_scale_shift,
95
+
96
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
97
+ unet_use_temporal_attention=unet_use_temporal_attention,
98
+
99
+ use_motion_module=use_motion_module,
100
+ motion_module_type=motion_module_type,
101
+ motion_module_kwargs=motion_module_kwargs,
102
+ )
103
+ raise ValueError(f"{down_block_type} does not exist.")
104
+
105
+
106
+ def get_up_block(
107
+ up_block_type,
108
+ num_layers,
109
+ in_channels,
110
+ out_channels,
111
+ prev_output_channel,
112
+ temb_channels,
113
+ add_upsample,
114
+ resnet_eps,
115
+ resnet_act_fn,
116
+ attn_num_head_channels,
117
+ resnet_groups=None,
118
+ cross_attention_dim=None,
119
+ dual_cross_attention=False,
120
+ use_linear_projection=False,
121
+ only_cross_attention=False,
122
+ upcast_attention=False,
123
+ resnet_time_scale_shift="default",
124
+
125
+ unet_use_cross_frame_attention=None,
126
+ unet_use_temporal_attention=None,
127
+
128
+ use_motion_module=None,
129
+ motion_module_type=None,
130
+ motion_module_kwargs=None,
131
+ ):
132
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
133
+ if up_block_type == "UpBlock3D":
134
+ return UpBlock3D(
135
+ num_layers=num_layers,
136
+ in_channels=in_channels,
137
+ out_channels=out_channels,
138
+ prev_output_channel=prev_output_channel,
139
+ temb_channels=temb_channels,
140
+ add_upsample=add_upsample,
141
+ resnet_eps=resnet_eps,
142
+ resnet_act_fn=resnet_act_fn,
143
+ resnet_groups=resnet_groups,
144
+ resnet_time_scale_shift=resnet_time_scale_shift,
145
+
146
+ use_motion_module=use_motion_module,
147
+ motion_module_type=motion_module_type,
148
+ motion_module_kwargs=motion_module_kwargs,
149
+ )
150
+ elif up_block_type == "CrossAttnUpBlock3D":
151
+ if cross_attention_dim is None:
152
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
153
+ return CrossAttnUpBlock3D(
154
+ num_layers=num_layers,
155
+ in_channels=in_channels,
156
+ out_channels=out_channels,
157
+ prev_output_channel=prev_output_channel,
158
+ temb_channels=temb_channels,
159
+ add_upsample=add_upsample,
160
+ resnet_eps=resnet_eps,
161
+ resnet_act_fn=resnet_act_fn,
162
+ resnet_groups=resnet_groups,
163
+ cross_attention_dim=cross_attention_dim,
164
+ attn_num_head_channels=attn_num_head_channels,
165
+ dual_cross_attention=dual_cross_attention,
166
+ use_linear_projection=use_linear_projection,
167
+ only_cross_attention=only_cross_attention,
168
+ upcast_attention=upcast_attention,
169
+ resnet_time_scale_shift=resnet_time_scale_shift,
170
+
171
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
172
+ unet_use_temporal_attention=unet_use_temporal_attention,
173
+
174
+ use_motion_module=use_motion_module,
175
+ motion_module_type=motion_module_type,
176
+ motion_module_kwargs=motion_module_kwargs,
177
+ )
178
+ raise ValueError(f"{up_block_type} does not exist.")
179
+
180
+
181
+ class UNetMidBlock3DCrossAttn(nn.Module):
182
+ def __init__(
183
+ self,
184
+ in_channels: int,
185
+ temb_channels: int,
186
+ dropout: float = 0.0,
187
+ num_layers: int = 1,
188
+ resnet_eps: float = 1e-6,
189
+ resnet_time_scale_shift: str = "default",
190
+ resnet_act_fn: str = "swish",
191
+ resnet_groups: int = 32,
192
+ resnet_pre_norm: bool = True,
193
+ attn_num_head_channels=1,
194
+ output_scale_factor=1.0,
195
+ cross_attention_dim=1280,
196
+ dual_cross_attention=False,
197
+ use_linear_projection=False,
198
+ upcast_attention=False,
199
+
200
+ unet_use_cross_frame_attention=None,
201
+ unet_use_temporal_attention=None,
202
+
203
+ use_motion_module=None,
204
+
205
+ motion_module_type=None,
206
+ motion_module_kwargs=None,
207
+ ):
208
+ super().__init__()
209
+
210
+ self.has_cross_attention = True
211
+ self.attn_num_head_channels = attn_num_head_channels
212
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
213
+
214
+ # there is always at least one resnet
215
+ resnets = [
216
+ ResnetBlock3D(
217
+ in_channels=in_channels,
218
+ out_channels=in_channels,
219
+ temb_channels=temb_channels,
220
+ eps=resnet_eps,
221
+ groups=resnet_groups,
222
+ dropout=dropout,
223
+ time_embedding_norm=resnet_time_scale_shift,
224
+ non_linearity=resnet_act_fn,
225
+ output_scale_factor=output_scale_factor,
226
+ pre_norm=resnet_pre_norm,
227
+ )
228
+ ]
229
+ attentions = []
230
+ motion_modules = []
231
+
232
+ for _ in range(num_layers):
233
+ if dual_cross_attention:
234
+ raise NotImplementedError
235
+ attentions.append(
236
+ Transformer3DModel(
237
+ attn_num_head_channels,
238
+ in_channels // attn_num_head_channels,
239
+ in_channels=in_channels,
240
+ num_layers=1,
241
+ cross_attention_dim=cross_attention_dim,
242
+ norm_num_groups=resnet_groups,
243
+ use_linear_projection=use_linear_projection,
244
+ upcast_attention=upcast_attention,
245
+
246
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
247
+ unet_use_temporal_attention=unet_use_temporal_attention,
248
+ )
249
+ )
250
+ motion_modules.append(
251
+ get_motion_module(
252
+ in_channels=in_channels,
253
+ motion_module_type=motion_module_type,
254
+ motion_module_kwargs=motion_module_kwargs,
255
+ ) if use_motion_module else None
256
+ )
257
+ resnets.append(
258
+ ResnetBlock3D(
259
+ in_channels=in_channels,
260
+ out_channels=in_channels,
261
+ temb_channels=temb_channels,
262
+ eps=resnet_eps,
263
+ groups=resnet_groups,
264
+ dropout=dropout,
265
+ time_embedding_norm=resnet_time_scale_shift,
266
+ non_linearity=resnet_act_fn,
267
+ output_scale_factor=output_scale_factor,
268
+ pre_norm=resnet_pre_norm,
269
+ )
270
+ )
271
+
272
+ self.attentions = nn.ModuleList(attentions)
273
+ self.resnets = nn.ModuleList(resnets)
274
+ self.motion_modules = nn.ModuleList(motion_modules)
275
+
276
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
277
+ hidden_states = self.resnets[0](hidden_states, temb)
278
+ for attn, resnet, motion_module in zip(self.attentions, self.resnets[1:], self.motion_modules):
279
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
280
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
281
+ hidden_states = resnet(hidden_states, temb)
282
+
283
+ return hidden_states
284
+
285
+
286
+ class CrossAttnDownBlock3D(nn.Module):
287
+ def __init__(
288
+ self,
289
+ in_channels: int,
290
+ out_channels: int,
291
+ temb_channels: int,
292
+ dropout: float = 0.0,
293
+ num_layers: int = 1,
294
+ resnet_eps: float = 1e-6,
295
+ resnet_time_scale_shift: str = "default",
296
+ resnet_act_fn: str = "swish",
297
+ resnet_groups: int = 32,
298
+ resnet_pre_norm: bool = True,
299
+ attn_num_head_channels=1,
300
+ cross_attention_dim=1280,
301
+ output_scale_factor=1.0,
302
+ downsample_padding=1,
303
+ add_downsample=True,
304
+ dual_cross_attention=False,
305
+ use_linear_projection=False,
306
+ only_cross_attention=False,
307
+ upcast_attention=False,
308
+
309
+ unet_use_cross_frame_attention=None,
310
+ unet_use_temporal_attention=None,
311
+
312
+ use_motion_module=None,
313
+
314
+ motion_module_type=None,
315
+ motion_module_kwargs=None,
316
+ ):
317
+ super().__init__()
318
+ resnets = []
319
+ attentions = []
320
+ motion_modules = []
321
+
322
+ self.has_cross_attention = True
323
+ self.attn_num_head_channels = attn_num_head_channels
324
+
325
+ for i in range(num_layers):
326
+ in_channels = in_channels if i == 0 else out_channels
327
+ resnets.append(
328
+ ResnetBlock3D(
329
+ in_channels=in_channels,
330
+ out_channels=out_channels,
331
+ temb_channels=temb_channels,
332
+ eps=resnet_eps,
333
+ groups=resnet_groups,
334
+ dropout=dropout,
335
+ time_embedding_norm=resnet_time_scale_shift,
336
+ non_linearity=resnet_act_fn,
337
+ output_scale_factor=output_scale_factor,
338
+ pre_norm=resnet_pre_norm,
339
+ )
340
+ )
341
+ if dual_cross_attention:
342
+ raise NotImplementedError
343
+ attentions.append(
344
+ Transformer3DModel(
345
+ attn_num_head_channels,
346
+ out_channels // attn_num_head_channels,
347
+ in_channels=out_channels,
348
+ num_layers=1,
349
+ cross_attention_dim=cross_attention_dim,
350
+ norm_num_groups=resnet_groups,
351
+ use_linear_projection=use_linear_projection,
352
+ only_cross_attention=only_cross_attention,
353
+ upcast_attention=upcast_attention,
354
+
355
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
356
+ unet_use_temporal_attention=unet_use_temporal_attention,
357
+ )
358
+ )
359
+ motion_modules.append(
360
+ get_motion_module(
361
+ in_channels=out_channels,
362
+ motion_module_type=motion_module_type,
363
+ motion_module_kwargs=motion_module_kwargs,
364
+ ) if use_motion_module else None
365
+ )
366
+
367
+ self.attentions = nn.ModuleList(attentions)
368
+ self.resnets = nn.ModuleList(resnets)
369
+ self.motion_modules = nn.ModuleList(motion_modules)
370
+
371
+ if add_downsample:
372
+ self.downsamplers = nn.ModuleList(
373
+ [
374
+ Downsample3D(
375
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
376
+ )
377
+ ]
378
+ )
379
+ else:
380
+ self.downsamplers = None
381
+
382
+ self.gradient_checkpointing = False
383
+
384
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
385
+ output_states = ()
386
+
387
+ for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
388
+ if self.training and self.gradient_checkpointing:
389
+
390
+ def create_custom_forward(module, return_dict=None):
391
+ def custom_forward(*inputs):
392
+ if return_dict is not None:
393
+ return module(*inputs, return_dict=return_dict)
394
+ else:
395
+ return module(*inputs)
396
+
397
+ return custom_forward
398
+
399
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
400
+ hidden_states = torch.utils.checkpoint.checkpoint(
401
+ create_custom_forward(attn, return_dict=False),
402
+ hidden_states,
403
+ encoder_hidden_states,
404
+ )[0]
405
+ if motion_module is not None:
406
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
407
+
408
+ else:
409
+ hidden_states = resnet(hidden_states, temb)
410
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
411
+
412
+ # add motion module
413
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
414
+
415
+ output_states += (hidden_states,)
416
+
417
+ if self.downsamplers is not None:
418
+ for downsampler in self.downsamplers:
419
+ hidden_states = downsampler(hidden_states)
420
+
421
+ output_states += (hidden_states,)
422
+
423
+ return hidden_states, output_states
424
+
425
+
426
+ class DownBlock3D(nn.Module):
427
+ def __init__(
428
+ self,
429
+ in_channels: int,
430
+ out_channels: int,
431
+ temb_channels: int,
432
+ dropout: float = 0.0,
433
+ num_layers: int = 1,
434
+ resnet_eps: float = 1e-6,
435
+ resnet_time_scale_shift: str = "default",
436
+ resnet_act_fn: str = "swish",
437
+ resnet_groups: int = 32,
438
+ resnet_pre_norm: bool = True,
439
+ output_scale_factor=1.0,
440
+ add_downsample=True,
441
+ downsample_padding=1,
442
+
443
+ use_motion_module=None,
444
+ motion_module_type=None,
445
+ motion_module_kwargs=None,
446
+ ):
447
+ super().__init__()
448
+ resnets = []
449
+ motion_modules = []
450
+
451
+ for i in range(num_layers):
452
+ in_channels = in_channels if i == 0 else out_channels
453
+ resnets.append(
454
+ ResnetBlock3D(
455
+ in_channels=in_channels,
456
+ out_channels=out_channels,
457
+ temb_channels=temb_channels,
458
+ eps=resnet_eps,
459
+ groups=resnet_groups,
460
+ dropout=dropout,
461
+ time_embedding_norm=resnet_time_scale_shift,
462
+ non_linearity=resnet_act_fn,
463
+ output_scale_factor=output_scale_factor,
464
+ pre_norm=resnet_pre_norm,
465
+ )
466
+ )
467
+ motion_modules.append(
468
+ get_motion_module(
469
+ in_channels=out_channels,
470
+ motion_module_type=motion_module_type,
471
+ motion_module_kwargs=motion_module_kwargs,
472
+ ) if use_motion_module else None
473
+ )
474
+
475
+ self.resnets = nn.ModuleList(resnets)
476
+ self.motion_modules = nn.ModuleList(motion_modules)
477
+
478
+ if add_downsample:
479
+ self.downsamplers = nn.ModuleList(
480
+ [
481
+ Downsample3D(
482
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
483
+ )
484
+ ]
485
+ )
486
+ else:
487
+ self.downsamplers = None
488
+
489
+ self.gradient_checkpointing = False
490
+
491
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
492
+ output_states = ()
493
+
494
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
495
+ if self.training and self.gradient_checkpointing:
496
+ def create_custom_forward(module):
497
+ def custom_forward(*inputs):
498
+ return module(*inputs)
499
+
500
+ return custom_forward
501
+
502
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
503
+ if motion_module is not None:
504
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
505
+ else:
506
+ hidden_states = resnet(hidden_states, temb)
507
+
508
+ # add motion module
509
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
510
+
511
+ output_states += (hidden_states,)
512
+
513
+ if self.downsamplers is not None:
514
+ for downsampler in self.downsamplers:
515
+ hidden_states = downsampler(hidden_states)
516
+
517
+ output_states += (hidden_states,)
518
+
519
+ return hidden_states, output_states
520
+
521
+
522
+ class CrossAttnUpBlock3D(nn.Module):
523
+ def __init__(
524
+ self,
525
+ in_channels: int,
526
+ out_channels: int,
527
+ prev_output_channel: int,
528
+ temb_channels: int,
529
+ dropout: float = 0.0,
530
+ num_layers: int = 1,
531
+ resnet_eps: float = 1e-6,
532
+ resnet_time_scale_shift: str = "default",
533
+ resnet_act_fn: str = "swish",
534
+ resnet_groups: int = 32,
535
+ resnet_pre_norm: bool = True,
536
+ attn_num_head_channels=1,
537
+ cross_attention_dim=1280,
538
+ output_scale_factor=1.0,
539
+ add_upsample=True,
540
+ dual_cross_attention=False,
541
+ use_linear_projection=False,
542
+ only_cross_attention=False,
543
+ upcast_attention=False,
544
+
545
+ unet_use_cross_frame_attention=None,
546
+ unet_use_temporal_attention=None,
547
+
548
+ use_motion_module=None,
549
+
550
+ motion_module_type=None,
551
+ motion_module_kwargs=None,
552
+ ):
553
+ super().__init__()
554
+ resnets = []
555
+ attentions = []
556
+ motion_modules = []
557
+
558
+ self.has_cross_attention = True
559
+ self.attn_num_head_channels = attn_num_head_channels
560
+
561
+ for i in range(num_layers):
562
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
563
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
564
+
565
+ resnets.append(
566
+ ResnetBlock3D(
567
+ in_channels=resnet_in_channels + res_skip_channels,
568
+ out_channels=out_channels,
569
+ temb_channels=temb_channels,
570
+ eps=resnet_eps,
571
+ groups=resnet_groups,
572
+ dropout=dropout,
573
+ time_embedding_norm=resnet_time_scale_shift,
574
+ non_linearity=resnet_act_fn,
575
+ output_scale_factor=output_scale_factor,
576
+ pre_norm=resnet_pre_norm,
577
+ )
578
+ )
579
+ if dual_cross_attention:
580
+ raise NotImplementedError
581
+ attentions.append(
582
+ Transformer3DModel(
583
+ attn_num_head_channels,
584
+ out_channels // attn_num_head_channels,
585
+ in_channels=out_channels,
586
+ num_layers=1,
587
+ cross_attention_dim=cross_attention_dim,
588
+ norm_num_groups=resnet_groups,
589
+ use_linear_projection=use_linear_projection,
590
+ only_cross_attention=only_cross_attention,
591
+ upcast_attention=upcast_attention,
592
+
593
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
594
+ unet_use_temporal_attention=unet_use_temporal_attention,
595
+ )
596
+ )
597
+ motion_modules.append(
598
+ get_motion_module(
599
+ in_channels=out_channels,
600
+ motion_module_type=motion_module_type,
601
+ motion_module_kwargs=motion_module_kwargs,
602
+ ) if use_motion_module else None
603
+ )
604
+
605
+ self.attentions = nn.ModuleList(attentions)
606
+ self.resnets = nn.ModuleList(resnets)
607
+ self.motion_modules = nn.ModuleList(motion_modules)
608
+
609
+ if add_upsample:
610
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
611
+ else:
612
+ self.upsamplers = None
613
+
614
+ self.gradient_checkpointing = False
615
+
616
+ def forward(
617
+ self,
618
+ hidden_states,
619
+ res_hidden_states_tuple,
620
+ temb=None,
621
+ encoder_hidden_states=None,
622
+ upsample_size=None,
623
+ attention_mask=None,
624
+ ):
625
+ for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
626
+ # pop res hidden states
627
+ res_hidden_states = res_hidden_states_tuple[-1]
628
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
629
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
630
+
631
+ if self.training and self.gradient_checkpointing:
632
+
633
+ def create_custom_forward(module, return_dict=None):
634
+ def custom_forward(*inputs):
635
+ if return_dict is not None:
636
+ return module(*inputs, return_dict=return_dict)
637
+ else:
638
+ return module(*inputs)
639
+
640
+ return custom_forward
641
+
642
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
643
+ hidden_states = torch.utils.checkpoint.checkpoint(
644
+ create_custom_forward(attn, return_dict=False),
645
+ hidden_states,
646
+ encoder_hidden_states,
647
+ )[0]
648
+ if motion_module is not None:
649
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
650
+
651
+ else:
652
+ hidden_states = resnet(hidden_states, temb)
653
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
654
+
655
+ # add motion module
656
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
657
+
658
+ if self.upsamplers is not None:
659
+ for upsampler in self.upsamplers:
660
+ hidden_states = upsampler(hidden_states, upsample_size)
661
+
662
+ return hidden_states
663
+
664
+
665
+ class UpBlock3D(nn.Module):
666
+ def __init__(
667
+ self,
668
+ in_channels: int,
669
+ prev_output_channel: int,
670
+ out_channels: int,
671
+ temb_channels: int,
672
+ dropout: float = 0.0,
673
+ num_layers: int = 1,
674
+ resnet_eps: float = 1e-6,
675
+ resnet_time_scale_shift: str = "default",
676
+ resnet_act_fn: str = "swish",
677
+ resnet_groups: int = 32,
678
+ resnet_pre_norm: bool = True,
679
+ output_scale_factor=1.0,
680
+ add_upsample=True,
681
+
682
+ use_motion_module=None,
683
+ motion_module_type=None,
684
+ motion_module_kwargs=None,
685
+ ):
686
+ super().__init__()
687
+ resnets = []
688
+ motion_modules = []
689
+
690
+ for i in range(num_layers):
691
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
692
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
693
+
694
+ resnets.append(
695
+ ResnetBlock3D(
696
+ in_channels=resnet_in_channels + res_skip_channels,
697
+ out_channels=out_channels,
698
+ temb_channels=temb_channels,
699
+ eps=resnet_eps,
700
+ groups=resnet_groups,
701
+ dropout=dropout,
702
+ time_embedding_norm=resnet_time_scale_shift,
703
+ non_linearity=resnet_act_fn,
704
+ output_scale_factor=output_scale_factor,
705
+ pre_norm=resnet_pre_norm,
706
+ )
707
+ )
708
+ motion_modules.append(
709
+ get_motion_module(
710
+ in_channels=out_channels,
711
+ motion_module_type=motion_module_type,
712
+ motion_module_kwargs=motion_module_kwargs,
713
+ ) if use_motion_module else None
714
+ )
715
+
716
+ self.resnets = nn.ModuleList(resnets)
717
+ self.motion_modules = nn.ModuleList(motion_modules)
718
+
719
+ if add_upsample:
720
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
721
+ else:
722
+ self.upsamplers = None
723
+
724
+ self.gradient_checkpointing = False
725
+
726
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, encoder_hidden_states=None,):
727
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
728
+ # pop res hidden states
729
+ res_hidden_states = res_hidden_states_tuple[-1]
730
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
731
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
732
+
733
+ if self.training and self.gradient_checkpointing:
734
+ def create_custom_forward(module):
735
+ def custom_forward(*inputs):
736
+ return module(*inputs)
737
+
738
+ return custom_forward
739
+
740
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
741
+ if motion_module is not None:
742
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
743
+ else:
744
+ hidden_states = resnet(hidden_states, temb)
745
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
746
+
747
+ if self.upsamplers is not None:
748
+ for upsampler in self.upsamplers:
749
+ hidden_states = upsampler(hidden_states, upsample_size)
750
+
751
  return hidden_states
magicanimate/models/unet_controlnet.py CHANGED
@@ -1,525 +1,525 @@
1
- # *************************************************************************
2
- # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
- # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
- # ytedance Inc..
5
- # *************************************************************************
6
-
7
- # Copyright 2023 The HuggingFace Team. All rights reserved.
8
- #
9
- # Licensed under the Apache License, Version 2.0 (the "License");
10
- # you may not use this file except in compliance with the License.
11
- # You may obtain a copy of the License at
12
- #
13
- # http://www.apache.org/licenses/LICENSE-2.0
14
- #
15
- # Unless required by applicable law or agreed to in writing, software
16
- # distributed under the License is distributed on an "AS IS" BASIS,
17
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
- # See the License for the specific language governing permissions and
19
- # limitations under the License.
20
- from dataclasses import dataclass
21
- from typing import List, Optional, Tuple, Union
22
-
23
- import os
24
- import json
25
-
26
- import torch
27
- import torch.nn as nn
28
- import torch.utils.checkpoint
29
-
30
- from diffusers.configuration_utils import ConfigMixin, register_to_config
31
- from diffusers.models.modeling_utils import ModelMixin
32
- from diffusers.utils import BaseOutput, logging
33
- from diffusers.models.embeddings import TimestepEmbedding, Timesteps
34
- from magicanimate.models.unet_3d_blocks import (
35
- CrossAttnDownBlock3D,
36
- CrossAttnUpBlock3D,
37
- DownBlock3D,
38
- UNetMidBlock3DCrossAttn,
39
- UpBlock3D,
40
- get_down_block,
41
- get_up_block,
42
- )
43
- from .resnet import InflatedConv3d
44
-
45
-
46
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
47
-
48
-
49
- @dataclass
50
- class UNet3DConditionOutput(BaseOutput):
51
- sample: torch.FloatTensor
52
-
53
-
54
- class UNet3DConditionModel(ModelMixin, ConfigMixin):
55
- _supports_gradient_checkpointing = True
56
-
57
- @register_to_config
58
- def __init__(
59
- self,
60
- sample_size: Optional[int] = None,
61
- in_channels: int = 4,
62
- out_channels: int = 4,
63
- center_input_sample: bool = False,
64
- flip_sin_to_cos: bool = True,
65
- freq_shift: int = 0,
66
- down_block_types: Tuple[str] = (
67
- "CrossAttnDownBlock3D",
68
- "CrossAttnDownBlock3D",
69
- "CrossAttnDownBlock3D",
70
- "DownBlock3D",
71
- ),
72
- mid_block_type: str = "UNetMidBlock3DCrossAttn",
73
- up_block_types: Tuple[str] = (
74
- "UpBlock3D",
75
- "CrossAttnUpBlock3D",
76
- "CrossAttnUpBlock3D",
77
- "CrossAttnUpBlock3D"
78
- ),
79
- only_cross_attention: Union[bool, Tuple[bool]] = False,
80
- block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
81
- layers_per_block: int = 2,
82
- downsample_padding: int = 1,
83
- mid_block_scale_factor: float = 1,
84
- act_fn: str = "silu",
85
- norm_num_groups: int = 32,
86
- norm_eps: float = 1e-5,
87
- cross_attention_dim: int = 1280,
88
- attention_head_dim: Union[int, Tuple[int]] = 8,
89
- dual_cross_attention: bool = False,
90
- use_linear_projection: bool = False,
91
- class_embed_type: Optional[str] = None,
92
- num_class_embeds: Optional[int] = None,
93
- upcast_attention: bool = False,
94
- resnet_time_scale_shift: str = "default",
95
-
96
- # Additional
97
- use_motion_module = False,
98
- motion_module_resolutions = ( 1,2,4,8 ),
99
- motion_module_mid_block = False,
100
- motion_module_decoder_only = False,
101
- motion_module_type = None,
102
- motion_module_kwargs = {},
103
- unet_use_cross_frame_attention = None,
104
- unet_use_temporal_attention = None,
105
- ):
106
- super().__init__()
107
-
108
- self.sample_size = sample_size
109
- time_embed_dim = block_out_channels[0] * 4
110
-
111
- # input
112
- self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
113
-
114
- # time
115
- self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
116
- timestep_input_dim = block_out_channels[0]
117
-
118
- self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
119
-
120
- # class embedding
121
- if class_embed_type is None and num_class_embeds is not None:
122
- self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
123
- elif class_embed_type == "timestep":
124
- self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
125
- elif class_embed_type == "identity":
126
- self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
127
- else:
128
- self.class_embedding = None
129
-
130
- self.down_blocks = nn.ModuleList([])
131
- self.mid_block = None
132
- self.up_blocks = nn.ModuleList([])
133
-
134
- if isinstance(only_cross_attention, bool):
135
- only_cross_attention = [only_cross_attention] * len(down_block_types)
136
-
137
- if isinstance(attention_head_dim, int):
138
- attention_head_dim = (attention_head_dim,) * len(down_block_types)
139
-
140
- # down
141
- output_channel = block_out_channels[0]
142
- for i, down_block_type in enumerate(down_block_types):
143
- res = 2 ** i
144
- input_channel = output_channel
145
- output_channel = block_out_channels[i]
146
- is_final_block = i == len(block_out_channels) - 1
147
-
148
- down_block = get_down_block(
149
- down_block_type,
150
- num_layers=layers_per_block,
151
- in_channels=input_channel,
152
- out_channels=output_channel,
153
- temb_channels=time_embed_dim,
154
- add_downsample=not is_final_block,
155
- resnet_eps=norm_eps,
156
- resnet_act_fn=act_fn,
157
- resnet_groups=norm_num_groups,
158
- cross_attention_dim=cross_attention_dim,
159
- attn_num_head_channels=attention_head_dim[i],
160
- downsample_padding=downsample_padding,
161
- dual_cross_attention=dual_cross_attention,
162
- use_linear_projection=use_linear_projection,
163
- only_cross_attention=only_cross_attention[i],
164
- upcast_attention=upcast_attention,
165
- resnet_time_scale_shift=resnet_time_scale_shift,
166
-
167
- unet_use_cross_frame_attention=unet_use_cross_frame_attention,
168
- unet_use_temporal_attention=unet_use_temporal_attention,
169
-
170
- use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only),
171
- motion_module_type=motion_module_type,
172
- motion_module_kwargs=motion_module_kwargs,
173
- )
174
- self.down_blocks.append(down_block)
175
-
176
- # mid
177
- if mid_block_type == "UNetMidBlock3DCrossAttn":
178
- self.mid_block = UNetMidBlock3DCrossAttn(
179
- in_channels=block_out_channels[-1],
180
- temb_channels=time_embed_dim,
181
- resnet_eps=norm_eps,
182
- resnet_act_fn=act_fn,
183
- output_scale_factor=mid_block_scale_factor,
184
- resnet_time_scale_shift=resnet_time_scale_shift,
185
- cross_attention_dim=cross_attention_dim,
186
- attn_num_head_channels=attention_head_dim[-1],
187
- resnet_groups=norm_num_groups,
188
- dual_cross_attention=dual_cross_attention,
189
- use_linear_projection=use_linear_projection,
190
- upcast_attention=upcast_attention,
191
-
192
- unet_use_cross_frame_attention=unet_use_cross_frame_attention,
193
- unet_use_temporal_attention=unet_use_temporal_attention,
194
-
195
- use_motion_module=use_motion_module and motion_module_mid_block,
196
- motion_module_type=motion_module_type,
197
- motion_module_kwargs=motion_module_kwargs,
198
- )
199
- else:
200
- raise ValueError(f"unknown mid_block_type : {mid_block_type}")
201
-
202
- # count how many layers upsample the videos
203
- self.num_upsamplers = 0
204
-
205
- # up
206
- reversed_block_out_channels = list(reversed(block_out_channels))
207
- reversed_attention_head_dim = list(reversed(attention_head_dim))
208
- only_cross_attention = list(reversed(only_cross_attention))
209
- output_channel = reversed_block_out_channels[0]
210
- for i, up_block_type in enumerate(up_block_types):
211
- res = 2 ** (3 - i)
212
- is_final_block = i == len(block_out_channels) - 1
213
-
214
- prev_output_channel = output_channel
215
- output_channel = reversed_block_out_channels[i]
216
- input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
217
-
218
- # add upsample block for all BUT final layer
219
- if not is_final_block:
220
- add_upsample = True
221
- self.num_upsamplers += 1
222
- else:
223
- add_upsample = False
224
-
225
- up_block = get_up_block(
226
- up_block_type,
227
- num_layers=layers_per_block + 1,
228
- in_channels=input_channel,
229
- out_channels=output_channel,
230
- prev_output_channel=prev_output_channel,
231
- temb_channels=time_embed_dim,
232
- add_upsample=add_upsample,
233
- resnet_eps=norm_eps,
234
- resnet_act_fn=act_fn,
235
- resnet_groups=norm_num_groups,
236
- cross_attention_dim=cross_attention_dim,
237
- attn_num_head_channels=reversed_attention_head_dim[i],
238
- dual_cross_attention=dual_cross_attention,
239
- use_linear_projection=use_linear_projection,
240
- only_cross_attention=only_cross_attention[i],
241
- upcast_attention=upcast_attention,
242
- resnet_time_scale_shift=resnet_time_scale_shift,
243
-
244
- unet_use_cross_frame_attention=unet_use_cross_frame_attention,
245
- unet_use_temporal_attention=unet_use_temporal_attention,
246
-
247
- use_motion_module=use_motion_module and (res in motion_module_resolutions),
248
- motion_module_type=motion_module_type,
249
- motion_module_kwargs=motion_module_kwargs,
250
- )
251
- self.up_blocks.append(up_block)
252
- prev_output_channel = output_channel
253
-
254
- # out
255
- self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
256
- self.conv_act = nn.SiLU()
257
- self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
258
-
259
- def set_attention_slice(self, slice_size):
260
- r"""
261
- Enable sliced attention computation.
262
-
263
- When this option is enabled, the attention module will split the input tensor in slices, to compute attention
264
- in several steps. This is useful to save some memory in exchange for a small speed decrease.
265
-
266
- Args:
267
- slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
268
- When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
269
- `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
270
- provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
271
- must be a multiple of `slice_size`.
272
- """
273
- sliceable_head_dims = []
274
-
275
- def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
276
- if hasattr(module, "set_attention_slice"):
277
- sliceable_head_dims.append(module.sliceable_head_dim)
278
-
279
- for child in module.children():
280
- fn_recursive_retrieve_slicable_dims(child)
281
-
282
- # retrieve number of attention layers
283
- for module in self.children():
284
- fn_recursive_retrieve_slicable_dims(module)
285
-
286
- num_slicable_layers = len(sliceable_head_dims)
287
-
288
- if slice_size == "auto":
289
- # half the attention head size is usually a good trade-off between
290
- # speed and memory
291
- slice_size = [dim // 2 for dim in sliceable_head_dims]
292
- elif slice_size == "max":
293
- # make smallest slice possible
294
- slice_size = num_slicable_layers * [1]
295
-
296
- slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
297
-
298
- if len(slice_size) != len(sliceable_head_dims):
299
- raise ValueError(
300
- f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
301
- f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
302
- )
303
-
304
- for i in range(len(slice_size)):
305
- size = slice_size[i]
306
- dim = sliceable_head_dims[i]
307
- if size is not None and size > dim:
308
- raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
309
-
310
- # Recursively walk through all the children.
311
- # Any children which exposes the set_attention_slice method
312
- # gets the message
313
- def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
314
- if hasattr(module, "set_attention_slice"):
315
- module.set_attention_slice(slice_size.pop())
316
-
317
- for child in module.children():
318
- fn_recursive_set_attention_slice(child, slice_size)
319
-
320
- reversed_slice_size = list(reversed(slice_size))
321
- for module in self.children():
322
- fn_recursive_set_attention_slice(module, reversed_slice_size)
323
-
324
- def _set_gradient_checkpointing(self, module, value=False):
325
- if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
326
- module.gradient_checkpointing = value
327
-
328
- def forward(
329
- self,
330
- sample: torch.FloatTensor,
331
- timestep: Union[torch.Tensor, float, int],
332
- encoder_hidden_states: torch.Tensor,
333
- class_labels: Optional[torch.Tensor] = None,
334
- attention_mask: Optional[torch.Tensor] = None,
335
- # for controlnet
336
- down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
337
- mid_block_additional_residual: Optional[torch.Tensor] = None,
338
- return_dict: bool = True,
339
- ) -> Union[UNet3DConditionOutput, Tuple]:
340
- r"""
341
- Args:
342
- sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
343
- timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
344
- encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
345
- return_dict (`bool`, *optional*, defaults to `True`):
346
- Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
347
-
348
- Returns:
349
- [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
350
- [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
351
- returning a tuple, the first element is the sample tensor.
352
- """
353
- # By default samples have to be AT least a multiple of the overall upsampling factor.
354
- # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
355
- # However, the upsampling interpolation output size can be forced to fit any upsampling size
356
- # on the fly if necessary.
357
- default_overall_up_factor = 2**self.num_upsamplers
358
-
359
- # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
360
- forward_upsample_size = False
361
- upsample_size = None
362
-
363
- if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
364
- logger.info("Forward upsample size to force interpolation output size.")
365
- forward_upsample_size = True
366
-
367
- # prepare attention_mask
368
- if attention_mask is not None:
369
- attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
370
- attention_mask = attention_mask.unsqueeze(1)
371
-
372
- # center input if necessary
373
- if self.config.center_input_sample:
374
- sample = 2 * sample - 1.0
375
-
376
- # time
377
- timesteps = timestep
378
- if not torch.is_tensor(timesteps):
379
- # This would be a good case for the `match` statement (Python 3.10+)
380
- is_mps = sample.device.type == "mps"
381
- if isinstance(timestep, float):
382
- dtype = torch.float32 if is_mps else torch.float64
383
- else:
384
- dtype = torch.int32 if is_mps else torch.int64
385
- timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
386
- elif len(timesteps.shape) == 0:
387
- timesteps = timesteps[None].to(sample.device)
388
-
389
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
390
- timesteps = timesteps.expand(sample.shape[0])
391
-
392
- t_emb = self.time_proj(timesteps)
393
-
394
- # timesteps does not contain any weights and will always return f32 tensors
395
- # but time_embedding might actually be running in fp16. so we need to cast here.
396
- # there might be better ways to encapsulate this.
397
- t_emb = t_emb.to(dtype=self.dtype)
398
- emb = self.time_embedding(t_emb)
399
-
400
- if self.class_embedding is not None:
401
- if class_labels is None:
402
- raise ValueError("class_labels should be provided when num_class_embeds > 0")
403
-
404
- if self.config.class_embed_type == "timestep":
405
- class_labels = self.time_proj(class_labels)
406
-
407
- class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
408
- emb = emb + class_emb
409
-
410
- # pre-process
411
- sample = self.conv_in(sample)
412
-
413
- # down
414
- is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
415
-
416
- down_block_res_samples = (sample,)
417
- for downsample_block in self.down_blocks:
418
- if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
419
- sample, res_samples = downsample_block(
420
- hidden_states=sample,
421
- temb=emb,
422
- encoder_hidden_states=encoder_hidden_states,
423
- attention_mask=attention_mask,
424
- )
425
- else:
426
- sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states)
427
-
428
- down_block_res_samples += res_samples
429
-
430
- if is_controlnet:
431
- new_down_block_res_samples = ()
432
-
433
- for down_block_res_sample, down_block_additional_residual in zip(
434
- down_block_res_samples, down_block_additional_residuals
435
- ):
436
- down_block_res_sample = down_block_res_sample + down_block_additional_residual
437
- new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
438
-
439
- down_block_res_samples = new_down_block_res_samples
440
-
441
- # mid
442
- sample = self.mid_block(
443
- sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
444
- )
445
-
446
- if is_controlnet:
447
- sample = sample + mid_block_additional_residual
448
-
449
- # up
450
- for i, upsample_block in enumerate(self.up_blocks):
451
- is_final_block = i == len(self.up_blocks) - 1
452
-
453
- res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
454
- down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
455
-
456
- # if we have not reached the final block and need to forward the
457
- # upsample size, we do it here
458
- if not is_final_block and forward_upsample_size:
459
- upsample_size = down_block_res_samples[-1].shape[2:]
460
-
461
- if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
462
- sample = upsample_block(
463
- hidden_states=sample,
464
- temb=emb,
465
- res_hidden_states_tuple=res_samples,
466
- encoder_hidden_states=encoder_hidden_states,
467
- upsample_size=upsample_size,
468
- attention_mask=attention_mask,
469
- )
470
- else:
471
- sample = upsample_block(
472
- hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states,
473
- )
474
-
475
- # post-process
476
- sample = self.conv_norm_out(sample)
477
- sample = self.conv_act(sample)
478
- sample = self.conv_out(sample)
479
-
480
- if not return_dict:
481
- return (sample,)
482
-
483
- return UNet3DConditionOutput(sample=sample)
484
-
485
- @classmethod
486
- def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None):
487
- if subfolder is not None:
488
- pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
489
- print(f"loaded temporal unet's pretrained weights from {pretrained_model_path} ...")
490
-
491
- config_file = os.path.join(pretrained_model_path, 'config.json')
492
- if not os.path.isfile(config_file):
493
- raise RuntimeError(f"{config_file} does not exist")
494
- with open(config_file, "r") as f:
495
- config = json.load(f)
496
- config["_class_name"] = cls.__name__
497
- config["down_block_types"] = [
498
- "CrossAttnDownBlock3D",
499
- "CrossAttnDownBlock3D",
500
- "CrossAttnDownBlock3D",
501
- "DownBlock3D"
502
- ]
503
- config["up_block_types"] = [
504
- "UpBlock3D",
505
- "CrossAttnUpBlock3D",
506
- "CrossAttnUpBlock3D",
507
- "CrossAttnUpBlock3D"
508
- ]
509
- # config["mid_block_type"] = "UNetMidBlock3DCrossAttn"
510
-
511
- from diffusers.utils import WEIGHTS_NAME
512
- model = cls.from_config(config, **unet_additional_kwargs)
513
- model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
514
- if not os.path.isfile(model_file):
515
- raise RuntimeError(f"{model_file} does not exist")
516
- state_dict = torch.load(model_file, map_location="cpu")
517
-
518
- m, u = model.load_state_dict(state_dict, strict=False)
519
- print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
520
- # print(f"### missing keys:\n{m}\n### unexpected keys:\n{u}\n")
521
-
522
- params = [p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()]
523
- print(f"### Temporal Module Parameters: {sum(params) / 1e6} M")
524
-
525
- return model
 
1
+ # *************************************************************************
2
+ # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
+ # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
+ # ytedance Inc..
5
+ # *************************************************************************
6
+
7
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ from dataclasses import dataclass
21
+ from typing import List, Optional, Tuple, Union
22
+
23
+ import os
24
+ import json
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+ import torch.utils.checkpoint
29
+
30
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
31
+ from diffusers.models.modeling_utils import ModelMixin
32
+ from diffusers.utils import BaseOutput, logging
33
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
34
+ from magicanimate.models.unet_3d_blocks import (
35
+ CrossAttnDownBlock3D,
36
+ CrossAttnUpBlock3D,
37
+ DownBlock3D,
38
+ UNetMidBlock3DCrossAttn,
39
+ UpBlock3D,
40
+ get_down_block,
41
+ get_up_block,
42
+ )
43
+ from .resnet import InflatedConv3d
44
+
45
+
46
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
47
+
48
+
49
+ @dataclass
50
+ class UNet3DConditionOutput(BaseOutput):
51
+ sample: torch.FloatTensor
52
+
53
+
54
+ class UNet3DConditionModel(ModelMixin, ConfigMixin):
55
+ _supports_gradient_checkpointing = True
56
+
57
+ @register_to_config
58
+ def __init__(
59
+ self,
60
+ sample_size: Optional[int] = None,
61
+ in_channels: int = 4,
62
+ out_channels: int = 4,
63
+ center_input_sample: bool = False,
64
+ flip_sin_to_cos: bool = True,
65
+ freq_shift: int = 0,
66
+ down_block_types: Tuple[str] = (
67
+ "CrossAttnDownBlock3D",
68
+ "CrossAttnDownBlock3D",
69
+ "CrossAttnDownBlock3D",
70
+ "DownBlock3D",
71
+ ),
72
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
73
+ up_block_types: Tuple[str] = (
74
+ "UpBlock3D",
75
+ "CrossAttnUpBlock3D",
76
+ "CrossAttnUpBlock3D",
77
+ "CrossAttnUpBlock3D"
78
+ ),
79
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
80
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
81
+ layers_per_block: int = 2,
82
+ downsample_padding: int = 1,
83
+ mid_block_scale_factor: float = 1,
84
+ act_fn: str = "silu",
85
+ norm_num_groups: int = 32,
86
+ norm_eps: float = 1e-5,
87
+ cross_attention_dim: int = 1280,
88
+ attention_head_dim: Union[int, Tuple[int]] = 8,
89
+ dual_cross_attention: bool = False,
90
+ use_linear_projection: bool = False,
91
+ class_embed_type: Optional[str] = None,
92
+ num_class_embeds: Optional[int] = None,
93
+ upcast_attention: bool = False,
94
+ resnet_time_scale_shift: str = "default",
95
+
96
+ # Additional
97
+ use_motion_module = False,
98
+ motion_module_resolutions = ( 1,2,4,8 ),
99
+ motion_module_mid_block = False,
100
+ motion_module_decoder_only = False,
101
+ motion_module_type = None,
102
+ motion_module_kwargs = {},
103
+ unet_use_cross_frame_attention = None,
104
+ unet_use_temporal_attention = None,
105
+ ):
106
+ super().__init__()
107
+
108
+ self.sample_size = sample_size
109
+ time_embed_dim = block_out_channels[0] * 4
110
+
111
+ # input
112
+ self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
113
+
114
+ # time
115
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
116
+ timestep_input_dim = block_out_channels[0]
117
+
118
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
119
+
120
+ # class embedding
121
+ if class_embed_type is None and num_class_embeds is not None:
122
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
123
+ elif class_embed_type == "timestep":
124
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
125
+ elif class_embed_type == "identity":
126
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
127
+ else:
128
+ self.class_embedding = None
129
+
130
+ self.down_blocks = nn.ModuleList([])
131
+ self.mid_block = None
132
+ self.up_blocks = nn.ModuleList([])
133
+
134
+ if isinstance(only_cross_attention, bool):
135
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
136
+
137
+ if isinstance(attention_head_dim, int):
138
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
139
+
140
+ # down
141
+ output_channel = block_out_channels[0]
142
+ for i, down_block_type in enumerate(down_block_types):
143
+ res = 2 ** i
144
+ input_channel = output_channel
145
+ output_channel = block_out_channels[i]
146
+ is_final_block = i == len(block_out_channels) - 1
147
+
148
+ down_block = get_down_block(
149
+ down_block_type,
150
+ num_layers=layers_per_block,
151
+ in_channels=input_channel,
152
+ out_channels=output_channel,
153
+ temb_channels=time_embed_dim,
154
+ add_downsample=not is_final_block,
155
+ resnet_eps=norm_eps,
156
+ resnet_act_fn=act_fn,
157
+ resnet_groups=norm_num_groups,
158
+ cross_attention_dim=cross_attention_dim,
159
+ attn_num_head_channels=attention_head_dim[i],
160
+ downsample_padding=downsample_padding,
161
+ dual_cross_attention=dual_cross_attention,
162
+ use_linear_projection=use_linear_projection,
163
+ only_cross_attention=only_cross_attention[i],
164
+ upcast_attention=upcast_attention,
165
+ resnet_time_scale_shift=resnet_time_scale_shift,
166
+
167
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
168
+ unet_use_temporal_attention=unet_use_temporal_attention,
169
+
170
+ use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only),
171
+ motion_module_type=motion_module_type,
172
+ motion_module_kwargs=motion_module_kwargs,
173
+ )
174
+ self.down_blocks.append(down_block)
175
+
176
+ # mid
177
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
178
+ self.mid_block = UNetMidBlock3DCrossAttn(
179
+ in_channels=block_out_channels[-1],
180
+ temb_channels=time_embed_dim,
181
+ resnet_eps=norm_eps,
182
+ resnet_act_fn=act_fn,
183
+ output_scale_factor=mid_block_scale_factor,
184
+ resnet_time_scale_shift=resnet_time_scale_shift,
185
+ cross_attention_dim=cross_attention_dim,
186
+ attn_num_head_channels=attention_head_dim[-1],
187
+ resnet_groups=norm_num_groups,
188
+ dual_cross_attention=dual_cross_attention,
189
+ use_linear_projection=use_linear_projection,
190
+ upcast_attention=upcast_attention,
191
+
192
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
193
+ unet_use_temporal_attention=unet_use_temporal_attention,
194
+
195
+ use_motion_module=use_motion_module and motion_module_mid_block,
196
+ motion_module_type=motion_module_type,
197
+ motion_module_kwargs=motion_module_kwargs,
198
+ )
199
+ else:
200
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
201
+
202
+ # count how many layers upsample the videos
203
+ self.num_upsamplers = 0
204
+
205
+ # up
206
+ reversed_block_out_channels = list(reversed(block_out_channels))
207
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
208
+ only_cross_attention = list(reversed(only_cross_attention))
209
+ output_channel = reversed_block_out_channels[0]
210
+ for i, up_block_type in enumerate(up_block_types):
211
+ res = 2 ** (3 - i)
212
+ is_final_block = i == len(block_out_channels) - 1
213
+
214
+ prev_output_channel = output_channel
215
+ output_channel = reversed_block_out_channels[i]
216
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
217
+
218
+ # add upsample block for all BUT final layer
219
+ if not is_final_block:
220
+ add_upsample = True
221
+ self.num_upsamplers += 1
222
+ else:
223
+ add_upsample = False
224
+
225
+ up_block = get_up_block(
226
+ up_block_type,
227
+ num_layers=layers_per_block + 1,
228
+ in_channels=input_channel,
229
+ out_channels=output_channel,
230
+ prev_output_channel=prev_output_channel,
231
+ temb_channels=time_embed_dim,
232
+ add_upsample=add_upsample,
233
+ resnet_eps=norm_eps,
234
+ resnet_act_fn=act_fn,
235
+ resnet_groups=norm_num_groups,
236
+ cross_attention_dim=cross_attention_dim,
237
+ attn_num_head_channels=reversed_attention_head_dim[i],
238
+ dual_cross_attention=dual_cross_attention,
239
+ use_linear_projection=use_linear_projection,
240
+ only_cross_attention=only_cross_attention[i],
241
+ upcast_attention=upcast_attention,
242
+ resnet_time_scale_shift=resnet_time_scale_shift,
243
+
244
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
245
+ unet_use_temporal_attention=unet_use_temporal_attention,
246
+
247
+ use_motion_module=use_motion_module and (res in motion_module_resolutions),
248
+ motion_module_type=motion_module_type,
249
+ motion_module_kwargs=motion_module_kwargs,
250
+ )
251
+ self.up_blocks.append(up_block)
252
+ prev_output_channel = output_channel
253
+
254
+ # out
255
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
256
+ self.conv_act = nn.SiLU()
257
+ self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
258
+
259
+ def set_attention_slice(self, slice_size):
260
+ r"""
261
+ Enable sliced attention computation.
262
+
263
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
264
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
265
+
266
+ Args:
267
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
268
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
269
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
270
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
271
+ must be a multiple of `slice_size`.
272
+ """
273
+ sliceable_head_dims = []
274
+
275
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
276
+ if hasattr(module, "set_attention_slice"):
277
+ sliceable_head_dims.append(module.sliceable_head_dim)
278
+
279
+ for child in module.children():
280
+ fn_recursive_retrieve_slicable_dims(child)
281
+
282
+ # retrieve number of attention layers
283
+ for module in self.children():
284
+ fn_recursive_retrieve_slicable_dims(module)
285
+
286
+ num_slicable_layers = len(sliceable_head_dims)
287
+
288
+ if slice_size == "auto":
289
+ # half the attention head size is usually a good trade-off between
290
+ # speed and memory
291
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
292
+ elif slice_size == "max":
293
+ # make smallest slice possible
294
+ slice_size = num_slicable_layers * [1]
295
+
296
+ slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
297
+
298
+ if len(slice_size) != len(sliceable_head_dims):
299
+ raise ValueError(
300
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
301
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
302
+ )
303
+
304
+ for i in range(len(slice_size)):
305
+ size = slice_size[i]
306
+ dim = sliceable_head_dims[i]
307
+ if size is not None and size > dim:
308
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
309
+
310
+ # Recursively walk through all the children.
311
+ # Any children which exposes the set_attention_slice method
312
+ # gets the message
313
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
314
+ if hasattr(module, "set_attention_slice"):
315
+ module.set_attention_slice(slice_size.pop())
316
+
317
+ for child in module.children():
318
+ fn_recursive_set_attention_slice(child, slice_size)
319
+
320
+ reversed_slice_size = list(reversed(slice_size))
321
+ for module in self.children():
322
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
323
+
324
+ def _set_gradient_checkpointing(self, module, value=False):
325
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
326
+ module.gradient_checkpointing = value
327
+
328
+ def forward(
329
+ self,
330
+ sample: torch.FloatTensor,
331
+ timestep: Union[torch.Tensor, float, int],
332
+ encoder_hidden_states: torch.Tensor,
333
+ class_labels: Optional[torch.Tensor] = None,
334
+ attention_mask: Optional[torch.Tensor] = None,
335
+ # for controlnet
336
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
337
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
338
+ return_dict: bool = True,
339
+ ) -> Union[UNet3DConditionOutput, Tuple]:
340
+ r"""
341
+ Args:
342
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
343
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
344
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
345
+ return_dict (`bool`, *optional*, defaults to `True`):
346
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
347
+
348
+ Returns:
349
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
350
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
351
+ returning a tuple, the first element is the sample tensor.
352
+ """
353
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
354
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
355
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
356
+ # on the fly if necessary.
357
+ default_overall_up_factor = 2**self.num_upsamplers
358
+
359
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
360
+ forward_upsample_size = False
361
+ upsample_size = None
362
+
363
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
364
+ logger.info("Forward upsample size to force interpolation output size.")
365
+ forward_upsample_size = True
366
+
367
+ # prepare attention_mask
368
+ if attention_mask is not None:
369
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
370
+ attention_mask = attention_mask.unsqueeze(1)
371
+
372
+ # center input if necessary
373
+ if self.config.center_input_sample:
374
+ sample = 2 * sample - 1.0
375
+
376
+ # time
377
+ timesteps = timestep
378
+ if not torch.is_tensor(timesteps):
379
+ # This would be a good case for the `match` statement (Python 3.10+)
380
+ is_mps = sample.device.type == "mps"
381
+ if isinstance(timestep, float):
382
+ dtype = torch.float32 if is_mps else torch.float64
383
+ else:
384
+ dtype = torch.int32 if is_mps else torch.int64
385
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
386
+ elif len(timesteps.shape) == 0:
387
+ timesteps = timesteps[None].to(sample.device)
388
+
389
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
390
+ timesteps = timesteps.expand(sample.shape[0])
391
+
392
+ t_emb = self.time_proj(timesteps)
393
+
394
+ # timesteps does not contain any weights and will always return f32 tensors
395
+ # but time_embedding might actually be running in fp16. so we need to cast here.
396
+ # there might be better ways to encapsulate this.
397
+ t_emb = t_emb.to(dtype=self.dtype)
398
+ emb = self.time_embedding(t_emb)
399
+
400
+ if self.class_embedding is not None:
401
+ if class_labels is None:
402
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
403
+
404
+ if self.config.class_embed_type == "timestep":
405
+ class_labels = self.time_proj(class_labels)
406
+
407
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
408
+ emb = emb + class_emb
409
+
410
+ # pre-process
411
+ sample = self.conv_in(sample)
412
+
413
+ # down
414
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
415
+
416
+ down_block_res_samples = (sample,)
417
+ for downsample_block in self.down_blocks:
418
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
419
+ sample, res_samples = downsample_block(
420
+ hidden_states=sample,
421
+ temb=emb,
422
+ encoder_hidden_states=encoder_hidden_states,
423
+ attention_mask=attention_mask,
424
+ )
425
+ else:
426
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states)
427
+
428
+ down_block_res_samples += res_samples
429
+
430
+ if is_controlnet:
431
+ new_down_block_res_samples = ()
432
+
433
+ for down_block_res_sample, down_block_additional_residual in zip(
434
+ down_block_res_samples, down_block_additional_residuals
435
+ ):
436
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
437
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
438
+
439
+ down_block_res_samples = new_down_block_res_samples
440
+
441
+ # mid
442
+ sample = self.mid_block(
443
+ sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
444
+ )
445
+
446
+ if is_controlnet:
447
+ sample = sample + mid_block_additional_residual
448
+
449
+ # up
450
+ for i, upsample_block in enumerate(self.up_blocks):
451
+ is_final_block = i == len(self.up_blocks) - 1
452
+
453
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
454
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
455
+
456
+ # if we have not reached the final block and need to forward the
457
+ # upsample size, we do it here
458
+ if not is_final_block and forward_upsample_size:
459
+ upsample_size = down_block_res_samples[-1].shape[2:]
460
+
461
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
462
+ sample = upsample_block(
463
+ hidden_states=sample,
464
+ temb=emb,
465
+ res_hidden_states_tuple=res_samples,
466
+ encoder_hidden_states=encoder_hidden_states,
467
+ upsample_size=upsample_size,
468
+ attention_mask=attention_mask,
469
+ )
470
+ else:
471
+ sample = upsample_block(
472
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states,
473
+ )
474
+
475
+ # post-process
476
+ sample = self.conv_norm_out(sample)
477
+ sample = self.conv_act(sample)
478
+ sample = self.conv_out(sample)
479
+
480
+ if not return_dict:
481
+ return (sample,)
482
+
483
+ return UNet3DConditionOutput(sample=sample)
484
+
485
+ @classmethod
486
+ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None):
487
+ if subfolder is not None:
488
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
489
+ print(f"loaded temporal unet's pretrained weights from {pretrained_model_path} ...")
490
+
491
+ config_file = os.path.join(pretrained_model_path, 'config.json')
492
+ if not os.path.isfile(config_file):
493
+ raise RuntimeError(f"{config_file} does not exist")
494
+ with open(config_file, "r") as f:
495
+ config = json.load(f)
496
+ config["_class_name"] = cls.__name__
497
+ config["down_block_types"] = [
498
+ "CrossAttnDownBlock3D",
499
+ "CrossAttnDownBlock3D",
500
+ "CrossAttnDownBlock3D",
501
+ "DownBlock3D"
502
+ ]
503
+ config["up_block_types"] = [
504
+ "UpBlock3D",
505
+ "CrossAttnUpBlock3D",
506
+ "CrossAttnUpBlock3D",
507
+ "CrossAttnUpBlock3D"
508
+ ]
509
+ # config["mid_block_type"] = "UNetMidBlock3DCrossAttn"
510
+
511
+ from diffusers.utils import WEIGHTS_NAME
512
+ model = cls.from_config(config, **unet_additional_kwargs)
513
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
514
+ if not os.path.isfile(model_file):
515
+ raise RuntimeError(f"{model_file} does not exist")
516
+ state_dict = torch.load(model_file, map_location="cpu", weights_only=True)
517
+
518
+ m, u = model.load_state_dict(state_dict, strict=False)
519
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
520
+ # print(f"### missing keys:\n{m}\n### unexpected keys:\n{u}\n")
521
+
522
+ params = [p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()]
523
+ print(f"### Temporal Module Parameters: {sum(params) / 1e6} M")
524
+
525
+ return model
magicanimate/pipelines/animation.py CHANGED
@@ -1,282 +1,282 @@
1
- # Copyright 2023 ByteDance and/or its affiliates.
2
- #
3
- # Copyright (2023) MagicAnimate Authors
4
- #
5
- # ByteDance, its affiliates and licensors retain all intellectual
6
- # property and proprietary rights in and to this material, related
7
- # documentation and any modifications thereto. Any use, reproduction,
8
- # disclosure or distribution of this material and related documentation
9
- # without an express license agreement from ByteDance or
10
- # its affiliates is strictly prohibited.
11
- import argparse
12
- import datetime
13
- import inspect
14
- import os
15
- import random
16
- import numpy as np
17
-
18
- from PIL import Image
19
- from omegaconf import OmegaConf
20
- from collections import OrderedDict
21
-
22
- import torch
23
- import torch.distributed as dist
24
-
25
- from diffusers import AutoencoderKL, DDIMScheduler, UniPCMultistepScheduler
26
-
27
- from tqdm import tqdm
28
- from transformers import CLIPTextModel, CLIPTokenizer
29
-
30
- from magicanimate.models.unet_controlnet import UNet3DConditionModel
31
- from magicanimate.models.controlnet import ControlNetModel
32
- from magicanimate.models.appearance_encoder import AppearanceEncoderModel
33
- from magicanimate.models.mutual_self_attention import ReferenceAttentionControl
34
- from magicanimate.pipelines.pipeline_animation import AnimationPipeline
35
- from magicanimate.utils.util import save_videos_grid
36
- from magicanimate.utils.dist_tools import distributed_init
37
- from accelerate.utils import set_seed
38
-
39
- from magicanimate.utils.videoreader import VideoReader
40
-
41
- from einops import rearrange
42
-
43
- from pathlib import Path
44
-
45
-
46
- def main(args):
47
-
48
- *_, func_args = inspect.getargvalues(inspect.currentframe())
49
- func_args = dict(func_args)
50
-
51
- config = OmegaConf.load(args.config)
52
-
53
- # Initialize distributed training
54
- device = torch.device(f"cuda:{args.rank}")
55
- dist_kwargs = {"rank":args.rank, "world_size":args.world_size, "dist":args.dist}
56
-
57
- if config.savename is None:
58
- time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
59
- savedir = f"samples/{Path(args.config).stem}-{time_str}"
60
- else:
61
- savedir = f"samples/{config.savename}"
62
-
63
- if args.dist:
64
- dist.broadcast_object_list([savedir], 0)
65
- dist.barrier()
66
-
67
- if args.rank == 0:
68
- os.makedirs(savedir, exist_ok=True)
69
-
70
- inference_config = OmegaConf.load(config.inference_config)
71
-
72
- motion_module = config.motion_module
73
-
74
- ### >>> create animation pipeline >>> ###
75
- tokenizer = CLIPTokenizer.from_pretrained(config.pretrained_model_path, subfolder="tokenizer")
76
- text_encoder = CLIPTextModel.from_pretrained(config.pretrained_model_path, subfolder="text_encoder")
77
- if config.pretrained_unet_path:
78
- unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_unet_path, unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
79
- else:
80
- unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
81
- appearance_encoder = AppearanceEncoderModel.from_pretrained(config.pretrained_appearance_encoder_path, subfolder="appearance_encoder").to(device)
82
- reference_control_writer = ReferenceAttentionControl(appearance_encoder, do_classifier_free_guidance=True, mode='write', fusion_blocks=config.fusion_blocks)
83
- reference_control_reader = ReferenceAttentionControl(unet, do_classifier_free_guidance=True, mode='read', fusion_blocks=config.fusion_blocks)
84
- if config.pretrained_vae_path is not None:
85
- vae = AutoencoderKL.from_pretrained(config.pretrained_vae_path)
86
- else:
87
- vae = AutoencoderKL.from_pretrained(config.pretrained_model_path, subfolder="vae")
88
-
89
- ### Load controlnet
90
- controlnet = ControlNetModel.from_pretrained(config.pretrained_controlnet_path)
91
-
92
- unet.enable_xformers_memory_efficient_attention()
93
- appearance_encoder.enable_xformers_memory_efficient_attention()
94
- controlnet.enable_xformers_memory_efficient_attention()
95
-
96
- vae.to(torch.float16)
97
- unet.to(torch.float16)
98
- text_encoder.to(torch.float16)
99
- appearance_encoder.to(torch.float16)
100
- controlnet.to(torch.float16)
101
-
102
- pipeline = AnimationPipeline(
103
- vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet,
104
- scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
105
- # NOTE: UniPCMultistepScheduler
106
- )
107
-
108
- # 1. unet ckpt
109
- # 1.1 motion module
110
- motion_module_state_dict = torch.load(motion_module, map_location="cpu")
111
- if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]})
112
- motion_module_state_dict = motion_module_state_dict['state_dict'] if 'state_dict' in motion_module_state_dict else motion_module_state_dict
113
- try:
114
- # extra steps for self-trained models
115
- state_dict = OrderedDict()
116
- for key in motion_module_state_dict.keys():
117
- if key.startswith("module."):
118
- _key = key.split("module.")[-1]
119
- state_dict[_key] = motion_module_state_dict[key]
120
- else:
121
- state_dict[key] = motion_module_state_dict[key]
122
- motion_module_state_dict = state_dict
123
- del state_dict
124
- missing, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False)
125
- assert len(unexpected) == 0
126
- except:
127
- _tmp_ = OrderedDict()
128
- for key in motion_module_state_dict.keys():
129
- if "motion_modules" in key:
130
- if key.startswith("unet."):
131
- _key = key.split('unet.')[-1]
132
- _tmp_[_key] = motion_module_state_dict[key]
133
- else:
134
- _tmp_[key] = motion_module_state_dict[key]
135
- missing, unexpected = unet.load_state_dict(_tmp_, strict=False)
136
- assert len(unexpected) == 0
137
- del _tmp_
138
- del motion_module_state_dict
139
-
140
- pipeline.to(device)
141
- ### <<< create validation pipeline <<< ###
142
-
143
- random_seeds = config.get("seed", [-1])
144
- random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds)
145
- random_seeds = random_seeds * len(config.source_image) if len(random_seeds) == 1 else random_seeds
146
-
147
- # input test videos (either source video/ conditions)
148
-
149
- test_videos = config.video_path
150
- source_images = config.source_image
151
- num_actual_inference_steps = config.get("num_actual_inference_steps", config.steps)
152
-
153
- # read size, step from yaml file
154
- sizes = [config.size] * len(test_videos)
155
- steps = [config.S] * len(test_videos)
156
-
157
- config.random_seed = []
158
- prompt = n_prompt = ""
159
- for idx, (source_image, test_video, random_seed, size, step) in tqdm(
160
- enumerate(zip(source_images, test_videos, random_seeds, sizes, steps)),
161
- total=len(test_videos),
162
- disable=(args.rank!=0)
163
- ):
164
- samples_per_video = []
165
- samples_per_clip = []
166
- # manually set random seed for reproduction
167
- if random_seed != -1:
168
- torch.manual_seed(random_seed)
169
- set_seed(random_seed)
170
- else:
171
- torch.seed()
172
- config.random_seed.append(torch.initial_seed())
173
-
174
- if test_video.endswith('.mp4'):
175
- control = VideoReader(test_video).read()
176
- if control[0].shape[0] != size:
177
- control = [np.array(Image.fromarray(c).resize((size, size))) for c in control]
178
- if config.max_length is not None:
179
- control = control[config.offset: (config.offset+config.max_length)]
180
- control = np.array(control)
181
-
182
- if source_image.endswith(".mp4"):
183
- source_image = np.array(Image.fromarray(VideoReader(source_image).read()[0]).resize((size, size)))
184
- else:
185
- source_image = np.array(Image.open(source_image).resize((size, size)))
186
- H, W, C = source_image.shape
187
-
188
- print(f"current seed: {torch.initial_seed()}")
189
- init_latents = None
190
-
191
- # print(f"sampling {prompt} ...")
192
- original_length = control.shape[0]
193
- if control.shape[0] % config.L > 0:
194
- control = np.pad(control, ((0, config.L-control.shape[0] % config.L), (0, 0), (0, 0), (0, 0)), mode='edge')
195
- generator = torch.Generator(device=torch.device("cuda:0"))
196
- generator.manual_seed(torch.initial_seed())
197
- sample = pipeline(
198
- prompt,
199
- negative_prompt = n_prompt,
200
- num_inference_steps = config.steps,
201
- guidance_scale = config.guidance_scale,
202
- width = W,
203
- height = H,
204
- video_length = len(control),
205
- controlnet_condition = control,
206
- init_latents = init_latents,
207
- generator = generator,
208
- num_actual_inference_steps = num_actual_inference_steps,
209
- appearance_encoder = appearance_encoder,
210
- reference_control_writer = reference_control_writer,
211
- reference_control_reader = reference_control_reader,
212
- source_image = source_image,
213
- **dist_kwargs,
214
- ).videos
215
-
216
- if args.rank == 0:
217
- source_images = np.array([source_image] * original_length)
218
- source_images = rearrange(torch.from_numpy(source_images), "t h w c -> 1 c t h w") / 255.0
219
- samples_per_video.append(source_images)
220
-
221
- control = control / 255.0
222
- control = rearrange(control, "t h w c -> 1 c t h w")
223
- control = torch.from_numpy(control)
224
- samples_per_video.append(control[:, :, :original_length])
225
-
226
- samples_per_video.append(sample[:, :, :original_length])
227
-
228
- samples_per_video = torch.cat(samples_per_video)
229
-
230
- video_name = os.path.basename(test_video)[:-4]
231
- source_name = os.path.basename(config.source_image[idx]).split(".")[0]
232
- save_videos_grid(samples_per_video[-1:], f"{savedir}/videos/{source_name}_{video_name}.mp4")
233
- save_videos_grid(samples_per_video, f"{savedir}/videos/{source_name}_{video_name}/grid.mp4")
234
-
235
- if config.save_individual_videos:
236
- save_videos_grid(samples_per_video[1:2], f"{savedir}/videos/{source_name}_{video_name}/ctrl.mp4")
237
- save_videos_grid(samples_per_video[0:1], f"{savedir}/videos/{source_name}_{video_name}/orig.mp4")
238
-
239
- if args.dist:
240
- dist.barrier()
241
-
242
- if args.rank == 0:
243
- OmegaConf.save(config, f"{savedir}/config.yaml")
244
-
245
-
246
- def distributed_main(device_id, args):
247
- args.rank = device_id
248
- args.device_id = device_id
249
- if torch.cuda.is_available():
250
- torch.cuda.set_device(args.device_id)
251
- torch.cuda.init()
252
- distributed_init(args)
253
- main(args)
254
-
255
-
256
- def run(args):
257
-
258
- if args.dist:
259
- args.world_size = max(1, torch.cuda.device_count())
260
- assert args.world_size <= torch.cuda.device_count()
261
-
262
- if args.world_size > 0 and torch.cuda.device_count() > 1:
263
- port = random.randint(10000, 20000)
264
- args.init_method = f"tcp://localhost:{port}"
265
- torch.multiprocessing.spawn(
266
- fn=distributed_main,
267
- args=(args,),
268
- nprocs=args.world_size,
269
- )
270
- else:
271
- main(args)
272
-
273
-
274
- if __name__ == "__main__":
275
- parser = argparse.ArgumentParser()
276
- parser.add_argument("--config", type=str, required=True)
277
- parser.add_argument("--dist", action="store_true", required=False)
278
- parser.add_argument("--rank", type=int, default=0, required=False)
279
- parser.add_argument("--world_size", type=int, default=1, required=False)
280
-
281
- args = parser.parse_args()
282
- run(args)
 
1
+ # Copyright 2023 ByteDance and/or its affiliates.
2
+ #
3
+ # Copyright (2023) MagicAnimate Authors
4
+ #
5
+ # ByteDance, its affiliates and licensors retain all intellectual
6
+ # property and proprietary rights in and to this material, related
7
+ # documentation and any modifications thereto. Any use, reproduction,
8
+ # disclosure or distribution of this material and related documentation
9
+ # without an express license agreement from ByteDance or
10
+ # its affiliates is strictly prohibited.
11
+ import argparse
12
+ import datetime
13
+ import inspect
14
+ import os
15
+ import random
16
+ import numpy as np
17
+
18
+ from PIL import Image
19
+ from omegaconf import OmegaConf
20
+ from collections import OrderedDict
21
+
22
+ import torch
23
+ import torch.distributed as dist
24
+
25
+ from diffusers import AutoencoderKL, DDIMScheduler, UniPCMultistepScheduler
26
+
27
+ from tqdm import tqdm
28
+ from transformers import CLIPTextModel, CLIPTokenizer
29
+
30
+ from magicanimate.models.unet_controlnet import UNet3DConditionModel
31
+ from magicanimate.models.controlnet import ControlNetModel
32
+ from magicanimate.models.appearance_encoder import AppearanceEncoderModel
33
+ from magicanimate.models.mutual_self_attention import ReferenceAttentionControl
34
+ from magicanimate.pipelines.pipeline_animation import AnimationPipeline
35
+ from magicanimate.utils.util import save_videos_grid
36
+ from magicanimate.utils.dist_tools import distributed_init
37
+ from accelerate.utils import set_seed
38
+
39
+ from magicanimate.utils.videoreader import VideoReader
40
+
41
+ from einops import rearrange
42
+
43
+ from pathlib import Path
44
+
45
+
46
+ def main(args):
47
+
48
+ *_, func_args = inspect.getargvalues(inspect.currentframe())
49
+ func_args = dict(func_args)
50
+
51
+ config = OmegaConf.load(args.config)
52
+
53
+ # Initialize distributed training
54
+ device = torch.device(f"cuda:{args.rank}")
55
+ dist_kwargs = {"rank":args.rank, "world_size":args.world_size, "dist":args.dist}
56
+
57
+ if config.savename is None:
58
+ time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
59
+ savedir = f"samples/{Path(args.config).stem}-{time_str}"
60
+ else:
61
+ savedir = f"samples/{config.savename}"
62
+
63
+ if args.dist:
64
+ dist.broadcast_object_list([savedir], 0)
65
+ dist.barrier()
66
+
67
+ if args.rank == 0:
68
+ os.makedirs(savedir, exist_ok=True)
69
+
70
+ inference_config = OmegaConf.load(config.inference_config)
71
+
72
+ motion_module = config.motion_module
73
+
74
+ ### >>> create animation pipeline >>> ###
75
+ tokenizer = CLIPTokenizer.from_pretrained(config.pretrained_model_path, subfolder="tokenizer")
76
+ text_encoder = CLIPTextModel.from_pretrained(config.pretrained_model_path, subfolder="text_encoder")
77
+ if config.pretrained_unet_path:
78
+ unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_unet_path, unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
79
+ else:
80
+ unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
81
+ appearance_encoder = AppearanceEncoderModel.from_pretrained(config.pretrained_appearance_encoder_path, subfolder="appearance_encoder").to(device)
82
+ reference_control_writer = ReferenceAttentionControl(appearance_encoder, do_classifier_free_guidance=True, mode='write', fusion_blocks=config.fusion_blocks)
83
+ reference_control_reader = ReferenceAttentionControl(unet, do_classifier_free_guidance=True, mode='read', fusion_blocks=config.fusion_blocks)
84
+ if config.pretrained_vae_path is not None:
85
+ vae = AutoencoderKL.from_pretrained(config.pretrained_vae_path)
86
+ else:
87
+ vae = AutoencoderKL.from_pretrained(config.pretrained_model_path, subfolder="vae")
88
+
89
+ ### Load controlnet
90
+ controlnet = ControlNetModel.from_pretrained(config.pretrained_controlnet_path)
91
+
92
+ unet.enable_xformers_memory_efficient_attention()
93
+ appearance_encoder.enable_xformers_memory_efficient_attention()
94
+ controlnet.enable_xformers_memory_efficient_attention()
95
+
96
+ vae.to(torch.float16)
97
+ unet.to(torch.float16)
98
+ text_encoder.to(torch.float16)
99
+ appearance_encoder.to(torch.float16)
100
+ controlnet.to(torch.float16)
101
+
102
+ pipeline = AnimationPipeline(
103
+ vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet,
104
+ scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
105
+ # NOTE: UniPCMultistepScheduler
106
+ )
107
+
108
+ # 1. unet ckpt
109
+ # 1.1 motion module
110
+ motion_module_state_dict = torch.load(motion_module, map_location="cpu")
111
+ if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]})
112
+ motion_module_state_dict = motion_module_state_dict['state_dict'] if 'state_dict' in motion_module_state_dict else motion_module_state_dict
113
+ try:
114
+ # extra steps for self-trained models
115
+ state_dict = OrderedDict()
116
+ for key in motion_module_state_dict.keys():
117
+ if key.startswith("module."):
118
+ _key = key.split("module.")[-1]
119
+ state_dict[_key] = motion_module_state_dict[key]
120
+ else:
121
+ state_dict[key] = motion_module_state_dict[key]
122
+ motion_module_state_dict = state_dict
123
+ del state_dict
124
+ missing, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False)
125
+ assert len(unexpected) == 0
126
+ except:
127
+ _tmp_ = OrderedDict()
128
+ for key in motion_module_state_dict.keys():
129
+ if "motion_modules" in key:
130
+ if key.startswith("unet."):
131
+ _key = key.split('unet.')[-1]
132
+ _tmp_[_key] = motion_module_state_dict[key]
133
+ else:
134
+ _tmp_[key] = motion_module_state_dict[key]
135
+ missing, unexpected = unet.load_state_dict(_tmp_, strict=False)
136
+ assert len(unexpected) == 0
137
+ del _tmp_
138
+ del motion_module_state_dict
139
+
140
+ pipeline.to(device)
141
+ ### <<< create validation pipeline <<< ###
142
+
143
+ random_seeds = config.get("seed", [-1])
144
+ random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds)
145
+ random_seeds = random_seeds * len(config.source_image) if len(random_seeds) == 1 else random_seeds
146
+
147
+ # input test videos (either source video/ conditions)
148
+
149
+ test_videos = config.video_path
150
+ source_images = config.source_image
151
+ num_actual_inference_steps = config.get("num_actual_inference_steps", config.steps)
152
+
153
+ # read size, step from yaml file
154
+ sizes = [config.size] * len(test_videos)
155
+ steps = [config.S] * len(test_videos)
156
+
157
+ config.random_seed = []
158
+ prompt = n_prompt = ""
159
+ for idx, (source_image, test_video, random_seed, size, step) in tqdm(
160
+ enumerate(zip(source_images, test_videos, random_seeds, sizes, steps)),
161
+ total=len(test_videos),
162
+ disable=(args.rank!=0)
163
+ ):
164
+ samples_per_video = []
165
+ samples_per_clip = []
166
+ # manually set random seed for reproduction
167
+ if random_seed != -1:
168
+ torch.manual_seed(random_seed)
169
+ set_seed(random_seed)
170
+ else:
171
+ torch.seed()
172
+ config.random_seed.append(torch.initial_seed())
173
+
174
+ if test_video.endswith('.mp4'):
175
+ control = VideoReader(test_video).read()
176
+ if control[0].shape[0] != size:
177
+ control = [np.array(Image.fromarray(c).resize((size, size))) for c in control]
178
+ if config.max_length is not None:
179
+ control = control[config.offset: (config.offset+config.max_length)]
180
+ control = np.array(control)
181
+
182
+ if source_image.endswith(".mp4"):
183
+ source_image = np.array(Image.fromarray(VideoReader(source_image).read()[0]).resize((size, size)))
184
+ else:
185
+ source_image = np.array(Image.open(source_image).resize((size, size)))
186
+ H, W, C = source_image.shape
187
+
188
+ print(f"current seed: {torch.initial_seed()}")
189
+ init_latents = None
190
+
191
+ # print(f"sampling {prompt} ...")
192
+ original_length = control.shape[0]
193
+ if control.shape[0] % config.L > 0:
194
+ control = np.pad(control, ((0, config.L-control.shape[0] % config.L), (0, 0), (0, 0), (0, 0)), mode='edge')
195
+ generator = torch.Generator(device=torch.device("cuda:0"))
196
+ generator.manual_seed(torch.initial_seed())
197
+ sample = pipeline(
198
+ prompt,
199
+ negative_prompt = n_prompt,
200
+ num_inference_steps = config.steps,
201
+ guidance_scale = config.guidance_scale,
202
+ width = W,
203
+ height = H,
204
+ video_length = len(control),
205
+ controlnet_condition = control,
206
+ init_latents = init_latents,
207
+ generator = generator,
208
+ num_actual_inference_steps = num_actual_inference_steps,
209
+ appearance_encoder = appearance_encoder,
210
+ reference_control_writer = reference_control_writer,
211
+ reference_control_reader = reference_control_reader,
212
+ source_image = source_image,
213
+ **dist_kwargs,
214
+ ).videos
215
+
216
+ if args.rank == 0:
217
+ source_images = np.array([source_image] * original_length)
218
+ source_images = rearrange(torch.from_numpy(source_images), "t h w c -> 1 c t h w") / 255.0
219
+ samples_per_video.append(source_images)
220
+
221
+ control = control / 255.0
222
+ control = rearrange(control, "t h w c -> 1 c t h w")
223
+ control = torch.from_numpy(control)
224
+ samples_per_video.append(control[:, :, :original_length])
225
+
226
+ samples_per_video.append(sample[:, :, :original_length])
227
+
228
+ samples_per_video = torch.cat(samples_per_video)
229
+
230
+ video_name = os.path.basename(test_video)[:-4]
231
+ source_name = os.path.basename(config.source_image[idx]).split(".")[0]
232
+ save_videos_grid(samples_per_video[-1:], f"{savedir}/videos/{source_name}_{video_name}.mp4")
233
+ save_videos_grid(samples_per_video, f"{savedir}/videos/{source_name}_{video_name}/grid.mp4")
234
+
235
+ if config.save_individual_videos:
236
+ save_videos_grid(samples_per_video[1:2], f"{savedir}/videos/{source_name}_{video_name}/ctrl.mp4")
237
+ save_videos_grid(samples_per_video[0:1], f"{savedir}/videos/{source_name}_{video_name}/orig.mp4")
238
+
239
+ if args.dist:
240
+ dist.barrier()
241
+
242
+ if args.rank == 0:
243
+ OmegaConf.save(config, f"{savedir}/config.yaml")
244
+
245
+
246
+ def distributed_main(device_id, args):
247
+ args.rank = device_id
248
+ args.device_id = device_id
249
+ if torch.cuda.is_available():
250
+ torch.cuda.set_device(args.device_id)
251
+ torch.cuda.init()
252
+ distributed_init(args)
253
+ main(args)
254
+
255
+
256
+ def run(args):
257
+
258
+ if args.dist:
259
+ args.world_size = max(1, torch.cuda.device_count())
260
+ assert args.world_size <= torch.cuda.device_count()
261
+
262
+ if args.world_size > 0 and torch.cuda.device_count() > 1:
263
+ port = random.randint(10000, 20000)
264
+ args.init_method = f"tcp://localhost:{port}"
265
+ torch.multiprocessing.spawn(
266
+ fn=distributed_main,
267
+ args=(args,),
268
+ nprocs=args.world_size,
269
+ )
270
+ else:
271
+ main(args)
272
+
273
+
274
+ if __name__ == "__main__":
275
+ parser = argparse.ArgumentParser()
276
+ parser.add_argument("--config", type=str, required=True)
277
+ parser.add_argument("--dist", action="store_true", required=False)
278
+ parser.add_argument("--rank", type=int, default=0, required=False)
279
+ parser.add_argument("--world_size", type=int, default=1, required=False)
280
+
281
+ args = parser.parse_args()
282
+ run(args)
magicanimate/pipelines/context.py CHANGED
@@ -1,76 +1,76 @@
1
- # *************************************************************************
2
- # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
- # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
- # ytedance Inc..
5
- # *************************************************************************
6
-
7
- # Adapted from https://github.com/s9roll7/animatediff-cli-prompt-travel/tree/main
8
- import numpy as np
9
- from typing import Callable, Optional, List
10
-
11
-
12
- def ordered_halving(val):
13
- bin_str = f"{val:064b}"
14
- bin_flip = bin_str[::-1]
15
- as_int = int(bin_flip, 2)
16
-
17
- return as_int / (1 << 64)
18
-
19
-
20
- def uniform(
21
- step: int = ...,
22
- num_steps: Optional[int] = None,
23
- num_frames: int = ...,
24
- context_size: Optional[int] = None,
25
- context_stride: int = 3,
26
- context_overlap: int = 4,
27
- closed_loop: bool = True,
28
- ):
29
- if num_frames <= context_size:
30
- yield list(range(num_frames))
31
- return
32
-
33
- context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1)
34
-
35
- for context_step in 1 << np.arange(context_stride):
36
- pad = int(round(num_frames * ordered_halving(step)))
37
- for j in range(
38
- int(ordered_halving(step) * context_step) + pad,
39
- num_frames + pad + (0 if closed_loop else -context_overlap),
40
- (context_size * context_step - context_overlap),
41
- ):
42
- yield [e % num_frames for e in range(j, j + context_size * context_step, context_step)]
43
-
44
-
45
- def get_context_scheduler(name: str) -> Callable:
46
- if name == "uniform":
47
- return uniform
48
- else:
49
- raise ValueError(f"Unknown context_overlap policy {name}")
50
-
51
-
52
- def get_total_steps(
53
- scheduler,
54
- timesteps: List[int],
55
- num_steps: Optional[int] = None,
56
- num_frames: int = ...,
57
- context_size: Optional[int] = None,
58
- context_stride: int = 3,
59
- context_overlap: int = 4,
60
- closed_loop: bool = True,
61
- ):
62
- return sum(
63
- len(
64
- list(
65
- scheduler(
66
- i,
67
- num_steps,
68
- num_frames,
69
- context_size,
70
- context_stride,
71
- context_overlap,
72
- )
73
- )
74
- )
75
- for i in range(len(timesteps))
76
- )
 
1
+ # *************************************************************************
2
+ # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
+ # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
+ # ytedance Inc..
5
+ # *************************************************************************
6
+
7
+ # Adapted from https://github.com/s9roll7/animatediff-cli-prompt-travel/tree/main
8
+ import numpy as np
9
+ from typing import Callable, Optional, List
10
+
11
+
12
+ def ordered_halving(val):
13
+ bin_str = f"{val:064b}"
14
+ bin_flip = bin_str[::-1]
15
+ as_int = int(bin_flip, 2)
16
+
17
+ return as_int / (1 << 64)
18
+
19
+
20
+ def uniform(
21
+ step: int = ...,
22
+ num_steps: Optional[int] = None,
23
+ num_frames: int = ...,
24
+ context_size: Optional[int] = None,
25
+ context_stride: int = 3,
26
+ context_overlap: int = 4,
27
+ closed_loop: bool = True,
28
+ ):
29
+ if num_frames <= context_size:
30
+ yield list(range(num_frames))
31
+ return
32
+
33
+ context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1)
34
+
35
+ for context_step in 1 << np.arange(context_stride):
36
+ pad = int(round(num_frames * ordered_halving(step)))
37
+ for j in range(
38
+ int(ordered_halving(step) * context_step) + pad,
39
+ num_frames + pad + (0 if closed_loop else -context_overlap),
40
+ (context_size * context_step - context_overlap),
41
+ ):
42
+ yield [e % num_frames for e in range(j, j + context_size * context_step, context_step)]
43
+
44
+
45
+ def get_context_scheduler(name: str) -> Callable:
46
+ if name == "uniform":
47
+ return uniform
48
+ else:
49
+ raise ValueError(f"Unknown context_overlap policy {name}")
50
+
51
+
52
+ def get_total_steps(
53
+ scheduler,
54
+ timesteps: List[int],
55
+ num_steps: Optional[int] = None,
56
+ num_frames: int = ...,
57
+ context_size: Optional[int] = None,
58
+ context_stride: int = 3,
59
+ context_overlap: int = 4,
60
+ closed_loop: bool = True,
61
+ ):
62
+ return sum(
63
+ len(
64
+ list(
65
+ scheduler(
66
+ i,
67
+ num_steps,
68
+ num_frames,
69
+ context_size,
70
+ context_stride,
71
+ context_overlap,
72
+ )
73
+ )
74
+ )
75
+ for i in range(len(timesteps))
76
+ )
magicanimate/pipelines/pipeline_animation.py CHANGED
@@ -1,799 +1,800 @@
1
- # *************************************************************************
2
- # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
- # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
- # ytedance Inc..
5
- # *************************************************************************
6
-
7
- # Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py
8
-
9
- # Copyright 2023 The HuggingFace Team. All rights reserved.
10
- #
11
- # Licensed under the Apache License, Version 2.0 (the "License");
12
- # you may not use this file except in compliance with the License.
13
- # You may obtain a copy of the License at
14
- #
15
- # http://www.apache.org/licenses/LICENSE-2.0
16
- #
17
- # Unless required by applicable law or agreed to in writing, software
18
- # distributed under the License is distributed on an "AS IS" BASIS,
19
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
- # See the License for the specific language governing permissions and
21
- # limitations under the License.
22
- """
23
- TODO:
24
- 1. support multi-controlnet
25
- 2. [DONE] support DDIM inversion
26
- 3. support Prompt-to-prompt
27
- """
28
-
29
- import inspect, math
30
- from typing import Callable, List, Optional, Union
31
- from dataclasses import dataclass
32
- from PIL import Image
33
- import numpy as np
34
- import torch
35
- import torch.distributed as dist
36
- from tqdm import tqdm
37
- from diffusers.utils import is_accelerate_available
38
- from packaging import version
39
- from transformers import CLIPTextModel, CLIPTokenizer
40
-
41
- from diffusers.configuration_utils import FrozenDict
42
- from diffusers.models import AutoencoderKL
43
- from diffusers.pipeline_utils import DiffusionPipeline
44
- from diffusers.schedulers import (
45
- DDIMScheduler,
46
- DPMSolverMultistepScheduler,
47
- EulerAncestralDiscreteScheduler,
48
- EulerDiscreteScheduler,
49
- LMSDiscreteScheduler,
50
- PNDMScheduler,
51
- )
52
- from diffusers.utils import deprecate, logging, BaseOutput
53
-
54
- from einops import rearrange
55
-
56
- from magicanimate.models.unet_controlnet import UNet3DConditionModel
57
- from magicanimate.models.controlnet import ControlNetModel
58
- from magicanimate.models.mutual_self_attention import ReferenceAttentionControl
59
- from magicanimate.pipelines.context import (
60
- get_context_scheduler,
61
- get_total_steps
62
- )
63
- from magicanimate.utils.util import get_tensor_interpolation_method
64
-
65
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
66
-
67
-
68
- @dataclass
69
- class AnimationPipelineOutput(BaseOutput):
70
- videos: Union[torch.Tensor, np.ndarray]
71
-
72
-
73
- class AnimationPipeline(DiffusionPipeline):
74
- _optional_components = []
75
-
76
- def __init__(
77
- self,
78
- vae: AutoencoderKL,
79
- text_encoder: CLIPTextModel,
80
- tokenizer: CLIPTokenizer,
81
- unet: UNet3DConditionModel,
82
- controlnet: ControlNetModel,
83
- scheduler: Union[
84
- DDIMScheduler,
85
- PNDMScheduler,
86
- LMSDiscreteScheduler,
87
- EulerDiscreteScheduler,
88
- EulerAncestralDiscreteScheduler,
89
- DPMSolverMultistepScheduler,
90
- ],
91
- ):
92
- super().__init__()
93
-
94
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
95
- deprecation_message = (
96
- f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
97
- f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
98
- "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
99
- " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
100
- " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
101
- " file"
102
- )
103
- deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
104
- new_config = dict(scheduler.config)
105
- new_config["steps_offset"] = 1
106
- scheduler._internal_dict = FrozenDict(new_config)
107
-
108
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
109
- deprecation_message = (
110
- f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
111
- " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
112
- " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
113
- " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
114
- " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
115
- )
116
- deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
117
- new_config = dict(scheduler.config)
118
- new_config["clip_sample"] = False
119
- scheduler._internal_dict = FrozenDict(new_config)
120
-
121
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
122
- version.parse(unet.config._diffusers_version).base_version
123
- ) < version.parse("0.9.0.dev0")
124
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
125
- if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
126
- deprecation_message = (
127
- "The configuration file of the unet has set the default `sample_size` to smaller than"
128
- " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
129
- " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
130
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
131
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
132
- " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
133
- " in the config might lead to incorrect results in future versions. If you have downloaded this"
134
- " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
135
- " the `unet/config.json` file"
136
- )
137
- deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
138
- new_config = dict(unet.config)
139
- new_config["sample_size"] = 64
140
- unet._internal_dict = FrozenDict(new_config)
141
-
142
- self.register_modules(
143
- vae=vae,
144
- text_encoder=text_encoder,
145
- tokenizer=tokenizer,
146
- unet=unet,
147
- controlnet=controlnet,
148
- scheduler=scheduler,
149
- )
150
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
151
-
152
- def enable_vae_slicing(self):
153
- self.vae.enable_slicing()
154
-
155
- def disable_vae_slicing(self):
156
- self.vae.disable_slicing()
157
-
158
- def enable_sequential_cpu_offload(self, gpu_id=0):
159
- if is_accelerate_available():
160
- from accelerate import cpu_offload
161
- else:
162
- raise ImportError("Please install accelerate via `pip install accelerate`")
163
-
164
- device = torch.device(f"cuda:{gpu_id}")
165
-
166
- for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
167
- if cpu_offloaded_model is not None:
168
- cpu_offload(cpu_offloaded_model, device)
169
-
170
-
171
- @property
172
- def _execution_device(self):
173
- if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
174
- return self.device
175
- for module in self.unet.modules():
176
- if (
177
- hasattr(module, "_hf_hook")
178
- and hasattr(module._hf_hook, "execution_device")
179
- and module._hf_hook.execution_device is not None
180
- ):
181
- return torch.device(module._hf_hook.execution_device)
182
- return self.device
183
-
184
- def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt):
185
- batch_size = len(prompt) if isinstance(prompt, list) else 1
186
-
187
- text_inputs = self.tokenizer(
188
- prompt,
189
- padding="max_length",
190
- max_length=self.tokenizer.model_max_length,
191
- truncation=True,
192
- return_tensors="pt",
193
- )
194
- text_input_ids = text_inputs.input_ids
195
- untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
196
-
197
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
198
- removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
199
- logger.warning(
200
- "The following part of your input was truncated because CLIP can only handle sequences up to"
201
- f" {self.tokenizer.model_max_length} tokens: {removed_text}"
202
- )
203
-
204
- if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
205
- attention_mask = text_inputs.attention_mask.to(device)
206
- else:
207
- attention_mask = None
208
-
209
- text_embeddings = self.text_encoder(
210
- text_input_ids.to(device),
211
- attention_mask=attention_mask,
212
- )
213
- text_embeddings = text_embeddings[0]
214
-
215
- # duplicate text embeddings for each generation per prompt, using mps friendly method
216
- bs_embed, seq_len, _ = text_embeddings.shape
217
- text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
218
- text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
219
-
220
- # get unconditional embeddings for classifier free guidance
221
- if do_classifier_free_guidance:
222
- uncond_tokens: List[str]
223
- if negative_prompt is None:
224
- uncond_tokens = [""] * batch_size
225
- elif type(prompt) is not type(negative_prompt):
226
- raise TypeError(
227
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
228
- f" {type(prompt)}."
229
- )
230
- elif isinstance(negative_prompt, str):
231
- uncond_tokens = [negative_prompt]
232
- elif batch_size != len(negative_prompt):
233
- raise ValueError(
234
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
235
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
236
- " the batch size of `prompt`."
237
- )
238
- else:
239
- uncond_tokens = negative_prompt
240
-
241
- max_length = text_input_ids.shape[-1]
242
- uncond_input = self.tokenizer(
243
- uncond_tokens,
244
- padding="max_length",
245
- max_length=max_length,
246
- truncation=True,
247
- return_tensors="pt",
248
- )
249
-
250
- if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
251
- attention_mask = uncond_input.attention_mask.to(device)
252
- else:
253
- attention_mask = None
254
-
255
- uncond_embeddings = self.text_encoder(
256
- uncond_input.input_ids.to(device),
257
- attention_mask=attention_mask,
258
- )
259
- uncond_embeddings = uncond_embeddings[0]
260
-
261
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
262
- seq_len = uncond_embeddings.shape[1]
263
- uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
264
- uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1)
265
-
266
- # For classifier free guidance, we need to do two forward passes.
267
- # Here we concatenate the unconditional and text embeddings into a single batch
268
- # to avoid doing two forward passes
269
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
270
-
271
- return text_embeddings
272
-
273
- def decode_latents(self, latents, rank, decoder_consistency=None):
274
- video_length = latents.shape[2]
275
- latents = 1 / 0.18215 * latents
276
- latents = rearrange(latents, "b c f h w -> (b f) c h w")
277
- # video = self.vae.decode(latents).sample
278
- video = []
279
- for frame_idx in tqdm(range(latents.shape[0]), disable=(rank!=0)):
280
- if decoder_consistency is not None:
281
- video.append(decoder_consistency(latents[frame_idx:frame_idx+1]))
282
- else:
283
- video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
284
- video = torch.cat(video)
285
- video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
286
- video = (video / 2 + 0.5).clamp(0, 1)
287
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
288
- video = video.cpu().float().numpy()
289
- return video
290
-
291
- def prepare_extra_step_kwargs(self, generator, eta):
292
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
293
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
294
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
295
- # and should be between [0, 1]
296
-
297
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
298
- extra_step_kwargs = {}
299
- if accepts_eta:
300
- extra_step_kwargs["eta"] = eta
301
-
302
- # check if the scheduler accepts generator
303
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
304
- if accepts_generator:
305
- extra_step_kwargs["generator"] = generator
306
- return extra_step_kwargs
307
-
308
- def check_inputs(self, prompt, height, width, callback_steps):
309
- if not isinstance(prompt, str) and not isinstance(prompt, list):
310
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
311
-
312
- if height % 8 != 0 or width % 8 != 0:
313
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
314
-
315
- if (callback_steps is None) or (
316
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
317
- ):
318
- raise ValueError(
319
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
320
- f" {type(callback_steps)}."
321
- )
322
-
323
- def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None, clip_length=16):
324
- shape = (batch_size, num_channels_latents, clip_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
325
- if isinstance(generator, list) and len(generator) != batch_size:
326
- raise ValueError(
327
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
328
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
329
- )
330
- if latents is None:
331
- rand_device = "cpu" if device.type == "mps" else device
332
-
333
- if isinstance(generator, list):
334
- latents = [
335
- torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
336
- for i in range(batch_size)
337
- ]
338
- latents = torch.cat(latents, dim=0).to(device)
339
- else:
340
- latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
341
-
342
- latents = latents.repeat(1, 1, video_length//clip_length, 1, 1)
343
- else:
344
- if latents.shape != shape:
345
- raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
346
- latents = latents.to(device)
347
-
348
- # scale the initial noise by the standard deviation required by the scheduler
349
- latents = latents * self.scheduler.init_noise_sigma
350
- return latents
351
-
352
- def prepare_condition(self, condition, num_videos_per_prompt, device, dtype, do_classifier_free_guidance):
353
- # prepare conditions for controlnet
354
- condition = torch.from_numpy(condition.copy()).to(device=device, dtype=dtype) / 255.0
355
- condition = torch.stack([condition for _ in range(num_videos_per_prompt)], dim=0)
356
- condition = rearrange(condition, 'b f h w c -> (b f) c h w').clone()
357
- if do_classifier_free_guidance:
358
- condition = torch.cat([condition] * 2)
359
- return condition
360
-
361
- def next_step(
362
- self,
363
- model_output: torch.FloatTensor,
364
- timestep: int,
365
- x: torch.FloatTensor,
366
- eta=0.,
367
- verbose=False
368
- ):
369
- """
370
- Inverse sampling for DDIM Inversion
371
- """
372
- if verbose:
373
- print("timestep: ", timestep)
374
- next_step = timestep
375
- timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999)
376
- alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
377
- alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step]
378
- beta_prod_t = 1 - alpha_prod_t
379
- pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
380
- pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output
381
- x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir
382
- return x_next, pred_x0
383
-
384
- @torch.no_grad()
385
- def images2latents(self, images, dtype):
386
- """
387
- Convert RGB image to VAE latents
388
- """
389
- device = self._execution_device
390
- images = torch.from_numpy(images).float().to(dtype) / 127.5 - 1
391
- images = rearrange(images, "f h w c -> f c h w").to(device)
392
- latents = []
393
- for frame_idx in range(images.shape[0]):
394
- latents.append(self.vae.encode(images[frame_idx:frame_idx+1])['latent_dist'].mean * 0.18215)
395
- latents = torch.cat(latents)
396
- return latents
397
-
398
- @torch.no_grad()
399
- def invert(
400
- self,
401
- image: torch.Tensor,
402
- prompt,
403
- num_inference_steps=20,
404
- num_actual_inference_steps=10,
405
- eta=0.0,
406
- return_intermediates=False,
407
- **kwargs):
408
- """
409
- Adapted from: https://github.com/Yujun-Shi/DragDiffusion/blob/main/drag_pipeline.py#L440
410
- invert a real image into noise map with determinisc DDIM inversion
411
- """
412
- device = self._execution_device
413
- batch_size = image.shape[0]
414
- if isinstance(prompt, list):
415
- if batch_size == 1:
416
- image = image.expand(len(prompt), -1, -1, -1)
417
- elif isinstance(prompt, str):
418
- if batch_size > 1:
419
- prompt = [prompt] * batch_size
420
-
421
- # text embeddings
422
- text_input = self.tokenizer(
423
- prompt,
424
- padding="max_length",
425
- max_length=77,
426
- return_tensors="pt"
427
- )
428
- text_embeddings = self.text_encoder(text_input.input_ids.to(device))[0]
429
- print("input text embeddings :", text_embeddings.shape)
430
- # define initial latents
431
- latents = self.images2latents(image)
432
-
433
- print("latents shape: ", latents.shape)
434
- # interative sampling
435
- self.scheduler.set_timesteps(num_inference_steps)
436
- print("Valid timesteps: ", reversed(self.scheduler.timesteps))
437
- latents_list = [latents]
438
- pred_x0_list = [latents]
439
- for i, t in enumerate(tqdm(reversed(self.scheduler.timesteps), desc="DDIM Inversion")):
440
-
441
- if num_actual_inference_steps is not None and i >= num_actual_inference_steps:
442
- continue
443
- model_inputs = latents
444
-
445
- # predict the noise
446
- # NOTE: the u-net here is UNet3D, therefore the model_inputs need to be of shape (b c f h w)
447
- model_inputs = rearrange(model_inputs, "f c h w -> 1 c f h w")
448
- noise_pred = self.unet(model_inputs, t, encoder_hidden_states=text_embeddings).sample
449
- noise_pred = rearrange(noise_pred, "b c f h w -> (b f) c h w")
450
-
451
- # compute the previous noise sample x_t-1 -> x_t
452
- latents, pred_x0 = self.next_step(noise_pred, t, latents)
453
- latents_list.append(latents)
454
- pred_x0_list.append(pred_x0)
455
-
456
- if return_intermediates:
457
- # return the intermediate laters during inversion
458
- return latents, latents_list
459
- return latents
460
-
461
- def interpolate_latents(self, latents: torch.Tensor, interpolation_factor:int, device ):
462
- if interpolation_factor < 2:
463
- return latents
464
-
465
- new_latents = torch.zeros(
466
- (latents.shape[0],latents.shape[1],((latents.shape[2]-1) * interpolation_factor)+1, latents.shape[3],latents.shape[4]),
467
- device=latents.device,
468
- dtype=latents.dtype,
469
- )
470
-
471
- org_video_length = latents.shape[2]
472
- rate = [i/interpolation_factor for i in range(interpolation_factor)][1:]
473
-
474
- new_index = 0
475
-
476
- v0 = None
477
- v1 = None
478
-
479
- for i0,i1 in zip( range( org_video_length ),range( org_video_length )[1:] ):
480
- v0 = latents[:,:,i0,:,:]
481
- v1 = latents[:,:,i1,:,:]
482
-
483
- new_latents[:,:,new_index,:,:] = v0
484
- new_index += 1
485
-
486
- for f in rate:
487
- v = get_tensor_interpolation_method()(v0.to(device=device),v1.to(device=device),f)
488
- new_latents[:,:,new_index,:,:] = v.to(latents.device)
489
- new_index += 1
490
-
491
- new_latents[:,:,new_index,:,:] = v1
492
- new_index += 1
493
-
494
- return new_latents
495
-
496
- def select_controlnet_res_samples(self, controlnet_res_samples_cache_dict, context, do_classifier_free_guidance, b, f):
497
- _down_block_res_samples = []
498
- _mid_block_res_sample = []
499
- for i in np.concatenate(np.array(context)):
500
- _down_block_res_samples.append(controlnet_res_samples_cache_dict[i][0])
501
- _mid_block_res_sample.append(controlnet_res_samples_cache_dict[i][1])
502
- down_block_res_samples = [[] for _ in range(len(controlnet_res_samples_cache_dict[i][0]))]
503
- for res_t in _down_block_res_samples:
504
- for i, res in enumerate(res_t):
505
- down_block_res_samples[i].append(res)
506
- down_block_res_samples = [torch.cat(res) for res in down_block_res_samples]
507
- mid_block_res_sample = torch.cat(_mid_block_res_sample)
508
-
509
- # reshape controlnet output to match the unet3d inputs
510
- b = b // 2 if do_classifier_free_guidance else b
511
- _down_block_res_samples = []
512
- for sample in down_block_res_samples:
513
- sample = rearrange(sample, '(b f) c h w -> b c f h w', b=b, f=f)
514
- if do_classifier_free_guidance:
515
- sample = sample.repeat(2, 1, 1, 1, 1)
516
- _down_block_res_samples.append(sample)
517
- down_block_res_samples = _down_block_res_samples
518
- mid_block_res_sample = rearrange(mid_block_res_sample, '(b f) c h w -> b c f h w', b=b, f=f)
519
- if do_classifier_free_guidance:
520
- mid_block_res_sample = mid_block_res_sample.repeat(2, 1, 1, 1, 1)
521
-
522
- return down_block_res_samples, mid_block_res_sample
523
-
524
- @torch.no_grad()
525
- def __call__(
526
- self,
527
- prompt: Union[str, List[str]],
528
- video_length: Optional[int],
529
- height: Optional[int] = None,
530
- width: Optional[int] = None,
531
- num_inference_steps: int = 50,
532
- guidance_scale: float = 7.5,
533
- negative_prompt: Optional[Union[str, List[str]]] = None,
534
- num_videos_per_prompt: Optional[int] = 1,
535
- eta: float = 0.0,
536
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
537
- latents: Optional[torch.FloatTensor] = None,
538
- output_type: Optional[str] = "tensor",
539
- return_dict: bool = True,
540
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
541
- callback_steps: Optional[int] = 1,
542
- controlnet_condition: list = None,
543
- controlnet_conditioning_scale: float = 1.0,
544
- context_frames: int = 16,
545
- context_stride: int = 1,
546
- context_overlap: int = 4,
547
- context_batch_size: int = 1,
548
- context_schedule: str = "uniform",
549
- init_latents: Optional[torch.FloatTensor] = None,
550
- num_actual_inference_steps: Optional[int] = None,
551
- appearance_encoder = None,
552
- reference_control_writer = None,
553
- reference_control_reader = None,
554
- source_image: str = None,
555
- decoder_consistency = None,
556
- **kwargs,
557
- ):
558
- """
559
- New args:
560
- - controlnet_condition : condition map (e.g., depth, canny, keypoints) for controlnet
561
- - controlnet_conditioning_scale : conditioning scale for controlnet
562
- - init_latents : initial latents to begin with (used along with invert())
563
- - num_actual_inference_steps : number of actual inference steps (while total steps is num_inference_steps)
564
- """
565
- controlnet = self.controlnet
566
-
567
- # Default height and width to unet
568
- height = height or self.unet.config.sample_size * self.vae_scale_factor
569
- width = width or self.unet.config.sample_size * self.vae_scale_factor
570
-
571
- # Check inputs. Raise error if not correct
572
- self.check_inputs(prompt, height, width, callback_steps)
573
-
574
- # Define call parameters
575
- # batch_size = 1 if isinstance(prompt, str) else len(prompt)
576
- batch_size = 1
577
- if latents is not None:
578
- batch_size = latents.shape[0]
579
- if isinstance(prompt, list):
580
- batch_size = len(prompt)
581
-
582
- device = self._execution_device
583
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
584
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
585
- # corresponds to doing no classifier free guidance.
586
- do_classifier_free_guidance = guidance_scale > 1.0
587
-
588
- # Encode input prompt
589
- prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size
590
- if negative_prompt is not None:
591
- negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size
592
- text_embeddings = self._encode_prompt(
593
- prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
594
- )
595
- text_embeddings = torch.cat([text_embeddings] * context_batch_size)
596
-
597
- reference_control_writer = ReferenceAttentionControl(appearance_encoder, do_classifier_free_guidance=True, mode='write', batch_size=context_batch_size)
598
- reference_control_reader = ReferenceAttentionControl(self.unet, do_classifier_free_guidance=True, mode='read', batch_size=context_batch_size)
599
-
600
- is_dist_initialized = kwargs.get("dist", False)
601
- rank = kwargs.get("rank", 0)
602
- world_size = kwargs.get("world_size", 1)
603
-
604
- # Prepare video
605
- assert num_videos_per_prompt == 1 # FIXME: verify if num_videos_per_prompt > 1 works
606
- assert batch_size == 1 # FIXME: verify if batch_size > 1 works
607
- control = self.prepare_condition(
608
- condition=controlnet_condition,
609
- device=device,
610
- dtype=controlnet.dtype,
611
- num_videos_per_prompt=num_videos_per_prompt,
612
- do_classifier_free_guidance=do_classifier_free_guidance,
613
- )
614
- controlnet_uncond_images, controlnet_cond_images = control.chunk(2)
615
-
616
- # Prepare timesteps
617
- self.scheduler.set_timesteps(num_inference_steps, device=device)
618
- timesteps = self.scheduler.timesteps
619
-
620
- # Prepare latent variables
621
- if init_latents is not None:
622
- latents = rearrange(init_latents, "(b f) c h w -> b c f h w", f=video_length)
623
- else:
624
- num_channels_latents = self.unet.in_channels
625
- latents = self.prepare_latents(
626
- batch_size * num_videos_per_prompt,
627
- num_channels_latents,
628
- video_length,
629
- height,
630
- width,
631
- text_embeddings.dtype,
632
- device,
633
- generator,
634
- latents,
635
- )
636
- latents_dtype = latents.dtype
637
-
638
- # Prepare extra step kwargs.
639
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
640
-
641
- # Prepare text embeddings for controlnet
642
- controlnet_text_embeddings = text_embeddings.repeat_interleave(video_length, 0)
643
- _, controlnet_text_embeddings_c = controlnet_text_embeddings.chunk(2)
644
-
645
- controlnet_res_samples_cache_dict = {i:None for i in range(video_length)}
646
-
647
- # For img2img setting
648
- if num_actual_inference_steps is None:
649
- num_actual_inference_steps = num_inference_steps
650
-
651
- if isinstance(source_image, str):
652
- ref_image_latents = self.images2latents(np.array(Image.open(source_image).resize((width, height)))[None, :], latents_dtype).cuda()
653
- elif isinstance(source_image, np.ndarray):
654
- ref_image_latents = self.images2latents(source_image[None, :], latents_dtype).cuda()
655
-
656
- context_scheduler = get_context_scheduler(context_schedule)
657
-
658
- # Denoising loop
659
- for i, t in tqdm(enumerate(timesteps), total=len(timesteps), disable=(rank!=0)):
660
- if num_actual_inference_steps is not None and i < num_inference_steps - num_actual_inference_steps:
661
- continue
662
-
663
- noise_pred = torch.zeros(
664
- (latents.shape[0] * (2 if do_classifier_free_guidance else 1), *latents.shape[1:]),
665
- device=latents.device,
666
- dtype=latents.dtype,
667
- )
668
- counter = torch.zeros(
669
- (1, 1, latents.shape[2], 1, 1), device=latents.device, dtype=latents.dtype
670
- )
671
-
672
- appearance_encoder(
673
- ref_image_latents.repeat(context_batch_size * (2 if do_classifier_free_guidance else 1), 1, 1, 1),
674
- t,
675
- encoder_hidden_states=text_embeddings,
676
- return_dict=False,
677
- )
678
-
679
- context_queue = list(context_scheduler(
680
- 0, num_inference_steps, latents.shape[2], context_frames, context_stride, 0
681
- ))
682
- num_context_batches = math.ceil(len(context_queue) / context_batch_size)
683
- for i in range(num_context_batches):
684
- context = context_queue[i*context_batch_size: (i+1)*context_batch_size]
685
- # expand the latents if we are doing classifier free guidance
686
- controlnet_latent_input = (
687
- torch.cat([latents[:, :, c] for c in context])
688
- .to(device)
689
- )
690
- controlnet_latent_input = self.scheduler.scale_model_input(controlnet_latent_input, t)
691
-
692
- # prepare inputs for controlnet
693
- b, c, f, h, w = controlnet_latent_input.shape
694
- controlnet_latent_input = rearrange(controlnet_latent_input, "b c f h w -> (b f) c h w")
695
-
696
- # controlnet inference
697
- down_block_res_samples, mid_block_res_sample = self.controlnet(
698
- controlnet_latent_input,
699
- t,
700
- encoder_hidden_states=torch.cat([controlnet_text_embeddings_c[c] for c in context]),
701
- controlnet_cond=torch.cat([controlnet_cond_images[c] for c in context]),
702
- conditioning_scale=controlnet_conditioning_scale,
703
- return_dict=False,
704
- )
705
-
706
- for j, k in enumerate(np.concatenate(np.array(context))):
707
- controlnet_res_samples_cache_dict[k] = ([sample[j:j+1] for sample in down_block_res_samples], mid_block_res_sample[j:j+1])
708
-
709
- context_queue = list(context_scheduler(
710
- 0, num_inference_steps, latents.shape[2], context_frames, context_stride, context_overlap
711
- ))
712
-
713
- num_context_batches = math.ceil(len(context_queue) / context_batch_size)
714
- global_context = []
715
- for i in range(num_context_batches):
716
- global_context.append(context_queue[i*context_batch_size: (i+1)*context_batch_size])
717
-
718
- for context in global_context[rank::world_size]:
719
- # expand the latents if we are doing classifier free guidance
720
- latent_model_input = (
721
- torch.cat([latents[:, :, c] for c in context])
722
- .to(device)
723
- .repeat(2 if do_classifier_free_guidance else 1, 1, 1, 1, 1)
724
- )
725
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
726
-
727
- b, c, f, h, w = latent_model_input.shape
728
- down_block_res_samples, mid_block_res_sample = self.select_controlnet_res_samples(
729
- controlnet_res_samples_cache_dict,
730
- context,
731
- do_classifier_free_guidance,
732
- b, f
733
- )
734
-
735
- reference_control_reader.update(reference_control_writer)
736
-
737
- # predict the noise residual
738
- pred = self.unet(
739
- latent_model_input,
740
- t,
741
- encoder_hidden_states=text_embeddings[:b],
742
- down_block_additional_residuals=down_block_res_samples,
743
- mid_block_additional_residual=mid_block_res_sample,
744
- return_dict=False,
745
- )[0]
746
-
747
- reference_control_reader.clear()
748
-
749
- pred_uc, pred_c = pred.chunk(2)
750
- pred = torch.cat([pred_uc.unsqueeze(0), pred_c.unsqueeze(0)])
751
- for j, c in enumerate(context):
752
- noise_pred[:, :, c] = noise_pred[:, :, c] + pred[:, j]
753
- counter[:, :, c] = counter[:, :, c] + 1
754
-
755
- if is_dist_initialized:
756
- noise_pred_gathered = [torch.zeros_like(noise_pred) for _ in range(world_size)]
757
- if rank == 0:
758
- dist.gather(tensor=noise_pred, gather_list=noise_pred_gathered, dst=0)
759
- else:
760
- dist.gather(tensor=noise_pred, gather_list=[], dst=0)
761
- dist.barrier()
762
-
763
- if rank == 0:
764
- for k in range(1, world_size):
765
- for context in global_context[k::world_size]:
766
- for j, c in enumerate(context):
767
- noise_pred[:, :, c] = noise_pred[:, :, c] + noise_pred_gathered[k][:, :, c]
768
- counter[:, :, c] = counter[:, :, c] + 1
769
-
770
- # perform guidance
771
- if do_classifier_free_guidance:
772
- noise_pred_uncond, noise_pred_text = (noise_pred / counter).chunk(2)
773
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
774
-
775
- # compute the previous noisy sample x_t -> x_t-1
776
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
777
-
778
- if is_dist_initialized:
779
- dist.broadcast(latents, 0)
780
- dist.barrier()
781
-
782
- reference_control_writer.clear()
783
-
784
- interpolation_factor = 1
785
- latents = self.interpolate_latents(latents, interpolation_factor, device)
786
- # Post-processing
787
- video = self.decode_latents(latents, rank, decoder_consistency=decoder_consistency)
788
-
789
- if is_dist_initialized:
790
- dist.barrier()
791
-
792
- # Convert to tensor
793
- if output_type == "tensor":
794
- video = torch.from_numpy(video)
795
-
796
- if not return_dict:
797
- return video
798
-
799
- return AnimationPipelineOutput(videos=video)
 
 
1
+ # *************************************************************************
2
+ # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
+ # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
+ # ytedance Inc..
5
+ # *************************************************************************
6
+
7
+ # Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py
8
+
9
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+ """
23
+ TODO:
24
+ 1. support multi-controlnet
25
+ 2. [DONE] support DDIM inversion
26
+ 3. support Prompt-to-prompt
27
+ """
28
+
29
+ import inspect, math
30
+ from typing import Callable, List, Optional, Union
31
+ from dataclasses import dataclass
32
+ from PIL import Image
33
+ import numpy as np
34
+ import torch
35
+ import torch.distributed as dist
36
+ from tqdm import tqdm
37
+ from diffusers.utils import is_accelerate_available
38
+ from packaging import version
39
+ from transformers import CLIPTextModel, CLIPTokenizer
40
+
41
+ from diffusers.configuration_utils import FrozenDict
42
+ from diffusers.models import AutoencoderKL
43
+ from diffusers import DiffusionPipeline
44
+ from diffusers.schedulers import (
45
+ DDIMScheduler,
46
+ DPMSolverMultistepScheduler,
47
+ EulerAncestralDiscreteScheduler,
48
+ EulerDiscreteScheduler,
49
+ LMSDiscreteScheduler,
50
+ PNDMScheduler,
51
+ )
52
+ from diffusers.utils import deprecate, logging, BaseOutput
53
+
54
+ from einops import rearrange
55
+
56
+ from magicanimate.models.unet_controlnet import UNet3DConditionModel
57
+ from magicanimate.models.controlnet import ControlNetModel
58
+ from magicanimate.models.mutual_self_attention import ReferenceAttentionControl
59
+ from magicanimate.pipelines.context import (
60
+ get_context_scheduler,
61
+ get_total_steps
62
+ )
63
+ from magicanimate.utils.util import get_tensor_interpolation_method
64
+
65
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
66
+
67
+ device = "cuda" if torch.cuda.is_available() else "cpu"
68
+
69
+ @dataclass
70
+ class AnimationPipelineOutput(BaseOutput):
71
+ videos: Union[torch.Tensor, np.ndarray]
72
+
73
+
74
+ class AnimationPipeline(DiffusionPipeline):
75
+ _optional_components = []
76
+
77
+ def __init__(
78
+ self,
79
+ vae: AutoencoderKL,
80
+ text_encoder: CLIPTextModel,
81
+ tokenizer: CLIPTokenizer,
82
+ unet: UNet3DConditionModel,
83
+ controlnet: ControlNetModel,
84
+ scheduler: Union[
85
+ DDIMScheduler,
86
+ PNDMScheduler,
87
+ LMSDiscreteScheduler,
88
+ EulerDiscreteScheduler,
89
+ EulerAncestralDiscreteScheduler,
90
+ DPMSolverMultistepScheduler,
91
+ ],
92
+ ):
93
+ super().__init__()
94
+
95
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
96
+ deprecation_message = (
97
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
98
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
99
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
100
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
101
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
102
+ " file"
103
+ )
104
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
105
+ new_config = dict(scheduler.config)
106
+ new_config["steps_offset"] = 1
107
+ scheduler._internal_dict = FrozenDict(new_config)
108
+
109
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
110
+ deprecation_message = (
111
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
112
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
113
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
114
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
115
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
116
+ )
117
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
118
+ new_config = dict(scheduler.config)
119
+ new_config["clip_sample"] = False
120
+ scheduler._internal_dict = FrozenDict(new_config)
121
+
122
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
123
+ version.parse(unet.config._diffusers_version).base_version
124
+ ) < version.parse("0.9.0.dev0")
125
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
126
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
127
+ deprecation_message = (
128
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
129
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
130
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
131
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
132
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
133
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
134
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
135
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
136
+ " the `unet/config.json` file"
137
+ )
138
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
139
+ new_config = dict(unet.config)
140
+ new_config["sample_size"] = 64
141
+ unet._internal_dict = FrozenDict(new_config)
142
+
143
+ self.register_modules(
144
+ vae=vae,
145
+ text_encoder=text_encoder,
146
+ tokenizer=tokenizer,
147
+ unet=unet,
148
+ controlnet=controlnet,
149
+ scheduler=scheduler,
150
+ )
151
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
152
+
153
+ def enable_vae_slicing(self):
154
+ self.vae.enable_slicing()
155
+
156
+ def disable_vae_slicing(self):
157
+ self.vae.disable_slicing()
158
+
159
+ def enable_sequential_cpu_offload(self, gpu_id=0):
160
+ if is_accelerate_available():
161
+ from accelerate import cpu_offload
162
+ else:
163
+ raise ImportError("Please install accelerate via `pip install accelerate`")
164
+
165
+ device = torch.device(f"cuda:{gpu_id}")
166
+
167
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
168
+ if cpu_offloaded_model is not None:
169
+ cpu_offload(cpu_offloaded_model, device)
170
+
171
+
172
+ @property
173
+ def _execution_device(self):
174
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
175
+ return self.device
176
+ for module in self.unet.modules():
177
+ if (
178
+ hasattr(module, "_hf_hook")
179
+ and hasattr(module._hf_hook, "execution_device")
180
+ and module._hf_hook.execution_device is not None
181
+ ):
182
+ return torch.device(module._hf_hook.execution_device)
183
+ return self.device
184
+
185
+ def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt):
186
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
187
+
188
+ text_inputs = self.tokenizer(
189
+ prompt,
190
+ padding="max_length",
191
+ max_length=self.tokenizer.model_max_length,
192
+ truncation=True,
193
+ return_tensors="pt",
194
+ )
195
+ text_input_ids = text_inputs.input_ids
196
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
197
+
198
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
199
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
200
+ logger.warning(
201
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
202
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
203
+ )
204
+
205
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
206
+ attention_mask = text_inputs.attention_mask.to(device)
207
+ else:
208
+ attention_mask = None
209
+
210
+ text_embeddings = self.text_encoder(
211
+ text_input_ids.to(device),
212
+ attention_mask=attention_mask,
213
+ )
214
+ text_embeddings = text_embeddings[0]
215
+
216
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
217
+ bs_embed, seq_len, _ = text_embeddings.shape
218
+ text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
219
+ text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
220
+
221
+ # get unconditional embeddings for classifier free guidance
222
+ if do_classifier_free_guidance:
223
+ uncond_tokens: List[str]
224
+ if negative_prompt is None:
225
+ uncond_tokens = [""] * batch_size
226
+ elif type(prompt) is not type(negative_prompt):
227
+ raise TypeError(
228
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
229
+ f" {type(prompt)}."
230
+ )
231
+ elif isinstance(negative_prompt, str):
232
+ uncond_tokens = [negative_prompt]
233
+ elif batch_size != len(negative_prompt):
234
+ raise ValueError(
235
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
236
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
237
+ " the batch size of `prompt`."
238
+ )
239
+ else:
240
+ uncond_tokens = negative_prompt
241
+
242
+ max_length = text_input_ids.shape[-1]
243
+ uncond_input = self.tokenizer(
244
+ uncond_tokens,
245
+ padding="max_length",
246
+ max_length=max_length,
247
+ truncation=True,
248
+ return_tensors="pt",
249
+ )
250
+
251
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
252
+ attention_mask = uncond_input.attention_mask.to(device)
253
+ else:
254
+ attention_mask = None
255
+
256
+ uncond_embeddings = self.text_encoder(
257
+ uncond_input.input_ids.to(device),
258
+ attention_mask=attention_mask,
259
+ )
260
+ uncond_embeddings = uncond_embeddings[0]
261
+
262
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
263
+ seq_len = uncond_embeddings.shape[1]
264
+ uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
265
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1)
266
+
267
+ # For classifier free guidance, we need to do two forward passes.
268
+ # Here we concatenate the unconditional and text embeddings into a single batch
269
+ # to avoid doing two forward passes
270
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
271
+
272
+ return text_embeddings
273
+
274
+ def decode_latents(self, latents, rank, decoder_consistency=None):
275
+ video_length = latents.shape[2]
276
+ latents = 1 / 0.18215 * latents
277
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
278
+ # video = self.vae.decode(latents).sample
279
+ video = []
280
+ for frame_idx in tqdm(range(latents.shape[0]), disable=(rank!=0)):
281
+ if decoder_consistency is not None:
282
+ video.append(decoder_consistency(latents[frame_idx:frame_idx+1]))
283
+ else:
284
+ video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
285
+ video = torch.cat(video)
286
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
287
+ video = (video / 2 + 0.5).clamp(0, 1)
288
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
289
+ video = video.cpu().float().numpy()
290
+ return video
291
+
292
+ def prepare_extra_step_kwargs(self, generator, eta):
293
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
294
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
295
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
296
+ # and should be between [0, 1]
297
+
298
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
299
+ extra_step_kwargs = {}
300
+ if accepts_eta:
301
+ extra_step_kwargs["eta"] = eta
302
+
303
+ # check if the scheduler accepts generator
304
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
305
+ if accepts_generator:
306
+ extra_step_kwargs["generator"] = generator
307
+ return extra_step_kwargs
308
+
309
+ def check_inputs(self, prompt, height, width, callback_steps):
310
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
311
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
312
+
313
+ if height % 8 != 0 or width % 8 != 0:
314
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
315
+
316
+ if (callback_steps is None) or (
317
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
318
+ ):
319
+ raise ValueError(
320
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
321
+ f" {type(callback_steps)}."
322
+ )
323
+
324
+ def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None, clip_length=16):
325
+ shape = (batch_size, num_channels_latents, clip_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
326
+ if isinstance(generator, list) and len(generator) != batch_size:
327
+ raise ValueError(
328
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
329
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
330
+ )
331
+ if latents is None:
332
+ rand_device = "cpu" if device.type == "mps" else device
333
+
334
+ if isinstance(generator, list):
335
+ latents = [
336
+ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
337
+ for i in range(batch_size)
338
+ ]
339
+ latents = torch.cat(latents, dim=0).to(device)
340
+ else:
341
+ latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
342
+
343
+ latents = latents.repeat(1, 1, video_length//clip_length, 1, 1)
344
+ else:
345
+ if latents.shape != shape:
346
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
347
+ latents = latents.to(device)
348
+
349
+ # scale the initial noise by the standard deviation required by the scheduler
350
+ latents = latents * self.scheduler.init_noise_sigma
351
+ return latents
352
+
353
+ def prepare_condition(self, condition, num_videos_per_prompt, device, dtype, do_classifier_free_guidance):
354
+ # prepare conditions for controlnet
355
+ condition = torch.from_numpy(condition.copy()).to(device=device, dtype=dtype) / 255.0
356
+ condition = torch.stack([condition for _ in range(num_videos_per_prompt)], dim=0)
357
+ condition = rearrange(condition, 'b f h w c -> (b f) c h w').clone()
358
+ if do_classifier_free_guidance:
359
+ condition = torch.cat([condition] * 2)
360
+ return condition
361
+
362
+ def next_step(
363
+ self,
364
+ model_output: torch.FloatTensor,
365
+ timestep: int,
366
+ x: torch.FloatTensor,
367
+ eta=0.,
368
+ verbose=False
369
+ ):
370
+ """
371
+ Inverse sampling for DDIM Inversion
372
+ """
373
+ if verbose:
374
+ print("timestep: ", timestep)
375
+ next_step = timestep
376
+ timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999)
377
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
378
+ alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step]
379
+ beta_prod_t = 1 - alpha_prod_t
380
+ pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
381
+ pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output
382
+ x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir
383
+ return x_next, pred_x0
384
+
385
+ @torch.no_grad()
386
+ def images2latents(self, images, dtype):
387
+ """
388
+ Convert RGB image to VAE latents
389
+ """
390
+ device = self._execution_device
391
+ images = torch.from_numpy(images).float().to(dtype) / 127.5 - 1
392
+ images = rearrange(images, "f h w c -> f c h w").to(device)
393
+ latents = []
394
+ for frame_idx in range(images.shape[0]):
395
+ latents.append(self.vae.encode(images[frame_idx:frame_idx+1])['latent_dist'].mean * 0.18215)
396
+ latents = torch.cat(latents)
397
+ return latents
398
+
399
+ @torch.no_grad()
400
+ def invert(
401
+ self,
402
+ image: torch.Tensor,
403
+ prompt,
404
+ num_inference_steps=20,
405
+ num_actual_inference_steps=10,
406
+ eta=0.0,
407
+ return_intermediates=False,
408
+ **kwargs):
409
+ """
410
+ Adapted from: https://github.com/Yujun-Shi/DragDiffusion/blob/main/drag_pipeline.py#L440
411
+ invert a real image into noise map with determinisc DDIM inversion
412
+ """
413
+ device = self._execution_device
414
+ batch_size = image.shape[0]
415
+ if isinstance(prompt, list):
416
+ if batch_size == 1:
417
+ image = image.expand(len(prompt), -1, -1, -1)
418
+ elif isinstance(prompt, str):
419
+ if batch_size > 1:
420
+ prompt = [prompt] * batch_size
421
+
422
+ # text embeddings
423
+ text_input = self.tokenizer(
424
+ prompt,
425
+ padding="max_length",
426
+ max_length=77,
427
+ return_tensors="pt"
428
+ )
429
+ text_embeddings = self.text_encoder(text_input.input_ids.to(device))[0]
430
+ print("input text embeddings :", text_embeddings.shape)
431
+ # define initial latents
432
+ latents = self.images2latents(image)
433
+
434
+ print("latents shape: ", latents.shape)
435
+ # interative sampling
436
+ self.scheduler.set_timesteps(num_inference_steps)
437
+ print("Valid timesteps: ", reversed(self.scheduler.timesteps))
438
+ latents_list = [latents]
439
+ pred_x0_list = [latents]
440
+ for i, t in enumerate(tqdm(reversed(self.scheduler.timesteps), desc="DDIM Inversion")):
441
+
442
+ if num_actual_inference_steps is not None and i >= num_actual_inference_steps:
443
+ continue
444
+ model_inputs = latents
445
+
446
+ # predict the noise
447
+ # NOTE: the u-net here is UNet3D, therefore the model_inputs need to be of shape (b c f h w)
448
+ model_inputs = rearrange(model_inputs, "f c h w -> 1 c f h w")
449
+ noise_pred = self.unet(model_inputs, t, encoder_hidden_states=text_embeddings).sample
450
+ noise_pred = rearrange(noise_pred, "b c f h w -> (b f) c h w")
451
+
452
+ # compute the previous noise sample x_t-1 -> x_t
453
+ latents, pred_x0 = self.next_step(noise_pred, t, latents)
454
+ latents_list.append(latents)
455
+ pred_x0_list.append(pred_x0)
456
+
457
+ if return_intermediates:
458
+ # return the intermediate laters during inversion
459
+ return latents, latents_list
460
+ return latents
461
+
462
+ def interpolate_latents(self, latents: torch.Tensor, interpolation_factor:int, device ):
463
+ if interpolation_factor < 2:
464
+ return latents
465
+
466
+ new_latents = torch.zeros(
467
+ (latents.shape[0],latents.shape[1],((latents.shape[2]-1) * interpolation_factor)+1, latents.shape[3],latents.shape[4]),
468
+ device=latents.device,
469
+ dtype=latents.dtype,
470
+ )
471
+
472
+ org_video_length = latents.shape[2]
473
+ rate = [i/interpolation_factor for i in range(interpolation_factor)][1:]
474
+
475
+ new_index = 0
476
+
477
+ v0 = None
478
+ v1 = None
479
+
480
+ for i0,i1 in zip( range( org_video_length ),range( org_video_length )[1:] ):
481
+ v0 = latents[:,:,i0,:,:]
482
+ v1 = latents[:,:,i1,:,:]
483
+
484
+ new_latents[:,:,new_index,:,:] = v0
485
+ new_index += 1
486
+
487
+ for f in rate:
488
+ v = get_tensor_interpolation_method()(v0.to(device=device),v1.to(device=device),f)
489
+ new_latents[:,:,new_index,:,:] = v.to(latents.device)
490
+ new_index += 1
491
+
492
+ new_latents[:,:,new_index,:,:] = v1
493
+ new_index += 1
494
+
495
+ return new_latents
496
+
497
+ def select_controlnet_res_samples(self, controlnet_res_samples_cache_dict, context, do_classifier_free_guidance, b, f):
498
+ _down_block_res_samples = []
499
+ _mid_block_res_sample = []
500
+ for i in np.concatenate(np.array(context)):
501
+ _down_block_res_samples.append(controlnet_res_samples_cache_dict[i][0])
502
+ _mid_block_res_sample.append(controlnet_res_samples_cache_dict[i][1])
503
+ down_block_res_samples = [[] for _ in range(len(controlnet_res_samples_cache_dict[i][0]))]
504
+ for res_t in _down_block_res_samples:
505
+ for i, res in enumerate(res_t):
506
+ down_block_res_samples[i].append(res)
507
+ down_block_res_samples = [torch.cat(res) for res in down_block_res_samples]
508
+ mid_block_res_sample = torch.cat(_mid_block_res_sample)
509
+
510
+ # reshape controlnet output to match the unet3d inputs
511
+ b = b // 2 if do_classifier_free_guidance else b
512
+ _down_block_res_samples = []
513
+ for sample in down_block_res_samples:
514
+ sample = rearrange(sample, '(b f) c h w -> b c f h w', b=b, f=f)
515
+ if do_classifier_free_guidance:
516
+ sample = sample.repeat(2, 1, 1, 1, 1)
517
+ _down_block_res_samples.append(sample)
518
+ down_block_res_samples = _down_block_res_samples
519
+ mid_block_res_sample = rearrange(mid_block_res_sample, '(b f) c h w -> b c f h w', b=b, f=f)
520
+ if do_classifier_free_guidance:
521
+ mid_block_res_sample = mid_block_res_sample.repeat(2, 1, 1, 1, 1)
522
+
523
+ return down_block_res_samples, mid_block_res_sample
524
+
525
+ @torch.no_grad()
526
+ def __call__(
527
+ self,
528
+ prompt: Union[str, List[str]],
529
+ video_length: Optional[int],
530
+ height: Optional[int] = None,
531
+ width: Optional[int] = None,
532
+ num_inference_steps: int = 50,
533
+ guidance_scale: float = 7.5,
534
+ negative_prompt: Optional[Union[str, List[str]]] = None,
535
+ num_videos_per_prompt: Optional[int] = 1,
536
+ eta: float = 0.0,
537
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
538
+ latents: Optional[torch.FloatTensor] = None,
539
+ output_type: Optional[str] = "tensor",
540
+ return_dict: bool = True,
541
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
542
+ callback_steps: Optional[int] = 1,
543
+ controlnet_condition: list = None,
544
+ controlnet_conditioning_scale: float = 1.0,
545
+ context_frames: int = 16,
546
+ context_stride: int = 1,
547
+ context_overlap: int = 4,
548
+ context_batch_size: int = 1,
549
+ context_schedule: str = "uniform",
550
+ init_latents: Optional[torch.FloatTensor] = None,
551
+ num_actual_inference_steps: Optional[int] = None,
552
+ appearance_encoder = None,
553
+ reference_control_writer = None,
554
+ reference_control_reader = None,
555
+ source_image: str = None,
556
+ decoder_consistency = None,
557
+ **kwargs,
558
+ ):
559
+ """
560
+ New args:
561
+ - controlnet_condition : condition map (e.g., depth, canny, keypoints) for controlnet
562
+ - controlnet_conditioning_scale : conditioning scale for controlnet
563
+ - init_latents : initial latents to begin with (used along with invert())
564
+ - num_actual_inference_steps : number of actual inference steps (while total steps is num_inference_steps)
565
+ """
566
+ controlnet = self.controlnet
567
+
568
+ # Default height and width to unet
569
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
570
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
571
+
572
+ # Check inputs. Raise error if not correct
573
+ self.check_inputs(prompt, height, width, callback_steps)
574
+
575
+ # Define call parameters
576
+ # batch_size = 1 if isinstance(prompt, str) else len(prompt)
577
+ batch_size = 1
578
+ if latents is not None:
579
+ batch_size = latents.shape[0]
580
+ if isinstance(prompt, list):
581
+ batch_size = len(prompt)
582
+
583
+ device = self._execution_device
584
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
585
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
586
+ # corresponds to doing no classifier free guidance.
587
+ do_classifier_free_guidance = guidance_scale > 1.0
588
+
589
+ # Encode input prompt
590
+ prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size
591
+ if negative_prompt is not None:
592
+ negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size
593
+ text_embeddings = self._encode_prompt(
594
+ prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
595
+ )
596
+ text_embeddings = torch.cat([text_embeddings] * context_batch_size)
597
+
598
+ reference_control_writer = ReferenceAttentionControl(appearance_encoder, do_classifier_free_guidance=True, mode='write', batch_size=context_batch_size)
599
+ reference_control_reader = ReferenceAttentionControl(self.unet, do_classifier_free_guidance=True, mode='read', batch_size=context_batch_size)
600
+
601
+ is_dist_initialized = kwargs.get("dist", False)
602
+ rank = kwargs.get("rank", 0)
603
+ world_size = kwargs.get("world_size", 1)
604
+
605
+ # Prepare video
606
+ assert num_videos_per_prompt == 1 # FIXME: verify if num_videos_per_prompt > 1 works
607
+ assert batch_size == 1 # FIXME: verify if batch_size > 1 works
608
+ control = self.prepare_condition(
609
+ condition=controlnet_condition,
610
+ device=device,
611
+ dtype=controlnet.dtype,
612
+ num_videos_per_prompt=num_videos_per_prompt,
613
+ do_classifier_free_guidance=do_classifier_free_guidance,
614
+ )
615
+ controlnet_uncond_images, controlnet_cond_images = control.chunk(2)
616
+
617
+ # Prepare timesteps
618
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
619
+ timesteps = self.scheduler.timesteps
620
+
621
+ # Prepare latent variables
622
+ if init_latents is not None:
623
+ latents = rearrange(init_latents, "(b f) c h w -> b c f h w", f=video_length)
624
+ else:
625
+ num_channels_latents = self.unet.config.in_channels
626
+ latents = self.prepare_latents(
627
+ batch_size * num_videos_per_prompt,
628
+ num_channels_latents,
629
+ video_length,
630
+ height,
631
+ width,
632
+ text_embeddings.dtype,
633
+ device,
634
+ generator,
635
+ latents,
636
+ )
637
+ latents_dtype = latents.dtype
638
+
639
+ # Prepare extra step kwargs.
640
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
641
+
642
+ # Prepare text embeddings for controlnet
643
+ controlnet_text_embeddings = text_embeddings.repeat_interleave(video_length, 0)
644
+ _, controlnet_text_embeddings_c = controlnet_text_embeddings.chunk(2)
645
+
646
+ controlnet_res_samples_cache_dict = {i:None for i in range(video_length)}
647
+
648
+ # For img2img setting
649
+ if num_actual_inference_steps is None:
650
+ num_actual_inference_steps = num_inference_steps
651
+
652
+ if isinstance(source_image, str):
653
+ ref_image_latents = self.images2latents(np.array(Image.open(source_image).resize((width, height)))[None, :], latents_dtype).to(device)
654
+ elif isinstance(source_image, np.ndarray):
655
+ ref_image_latents = self.images2latents(source_image[None, :], latents_dtype).to(device)
656
+
657
+ context_scheduler = get_context_scheduler(context_schedule)
658
+
659
+ # Denoising loop
660
+ for i, t in tqdm(enumerate(timesteps), total=len(timesteps), disable=(rank!=0)):
661
+ if num_actual_inference_steps is not None and i < num_inference_steps - num_actual_inference_steps:
662
+ continue
663
+
664
+ noise_pred = torch.zeros(
665
+ (latents.shape[0] * (2 if do_classifier_free_guidance else 1), *latents.shape[1:]),
666
+ device=latents.device,
667
+ dtype=latents.dtype,
668
+ )
669
+ counter = torch.zeros(
670
+ (1, 1, latents.shape[2], 1, 1), device=latents.device, dtype=latents.dtype
671
+ )
672
+
673
+ appearance_encoder(
674
+ ref_image_latents.repeat(context_batch_size * (2 if do_classifier_free_guidance else 1), 1, 1, 1),
675
+ t,
676
+ encoder_hidden_states=text_embeddings,
677
+ return_dict=False,
678
+ )
679
+
680
+ context_queue = list(context_scheduler(
681
+ 0, num_inference_steps, latents.shape[2], context_frames, context_stride, 0
682
+ ))
683
+ num_context_batches = math.ceil(len(context_queue) / context_batch_size)
684
+ for i in range(num_context_batches):
685
+ context = context_queue[i*context_batch_size: (i+1)*context_batch_size]
686
+ # expand the latents if we are doing classifier free guidance
687
+ controlnet_latent_input = (
688
+ torch.cat([latents[:, :, c] for c in context])
689
+ .to(device)
690
+ )
691
+ controlnet_latent_input = self.scheduler.scale_model_input(controlnet_latent_input, t)
692
+
693
+ # prepare inputs for controlnet
694
+ b, c, f, h, w = controlnet_latent_input.shape
695
+ controlnet_latent_input = rearrange(controlnet_latent_input, "b c f h w -> (b f) c h w")
696
+
697
+ # controlnet inference
698
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
699
+ controlnet_latent_input,
700
+ t,
701
+ encoder_hidden_states=torch.cat([controlnet_text_embeddings_c[c] for c in context]),
702
+ controlnet_cond=torch.cat([controlnet_cond_images[c] for c in context]),
703
+ conditioning_scale=controlnet_conditioning_scale,
704
+ return_dict=False,
705
+ )
706
+
707
+ for j, k in enumerate(np.concatenate(np.array(context))):
708
+ controlnet_res_samples_cache_dict[k] = ([sample[j:j+1] for sample in down_block_res_samples], mid_block_res_sample[j:j+1])
709
+
710
+ context_queue = list(context_scheduler(
711
+ 0, num_inference_steps, latents.shape[2], context_frames, context_stride, context_overlap
712
+ ))
713
+
714
+ num_context_batches = math.ceil(len(context_queue) / context_batch_size)
715
+ global_context = []
716
+ for i in range(num_context_batches):
717
+ global_context.append(context_queue[i*context_batch_size: (i+1)*context_batch_size])
718
+
719
+ for context in global_context[rank::world_size]:
720
+ # expand the latents if we are doing classifier free guidance
721
+ latent_model_input = (
722
+ torch.cat([latents[:, :, c] for c in context])
723
+ .to(device)
724
+ .repeat(2 if do_classifier_free_guidance else 1, 1, 1, 1, 1)
725
+ )
726
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
727
+
728
+ b, c, f, h, w = latent_model_input.shape
729
+ down_block_res_samples, mid_block_res_sample = self.select_controlnet_res_samples(
730
+ controlnet_res_samples_cache_dict,
731
+ context,
732
+ do_classifier_free_guidance,
733
+ b, f
734
+ )
735
+
736
+ reference_control_reader.update(reference_control_writer)
737
+
738
+ # predict the noise residual
739
+ pred = self.unet(
740
+ latent_model_input,
741
+ t,
742
+ encoder_hidden_states=text_embeddings[:b],
743
+ down_block_additional_residuals=down_block_res_samples,
744
+ mid_block_additional_residual=mid_block_res_sample,
745
+ return_dict=False,
746
+ )[0]
747
+
748
+ reference_control_reader.clear()
749
+
750
+ pred_uc, pred_c = pred.chunk(2)
751
+ pred = torch.cat([pred_uc.unsqueeze(0), pred_c.unsqueeze(0)])
752
+ for j, c in enumerate(context):
753
+ noise_pred[:, :, c] = noise_pred[:, :, c] + pred[:, j]
754
+ counter[:, :, c] = counter[:, :, c] + 1
755
+
756
+ if is_dist_initialized:
757
+ noise_pred_gathered = [torch.zeros_like(noise_pred) for _ in range(world_size)]
758
+ if rank == 0:
759
+ dist.gather(tensor=noise_pred, gather_list=noise_pred_gathered, dst=0)
760
+ else:
761
+ dist.gather(tensor=noise_pred, gather_list=[], dst=0)
762
+ dist.barrier()
763
+
764
+ if rank == 0:
765
+ for k in range(1, world_size):
766
+ for context in global_context[k::world_size]:
767
+ for j, c in enumerate(context):
768
+ noise_pred[:, :, c] = noise_pred[:, :, c] + noise_pred_gathered[k][:, :, c]
769
+ counter[:, :, c] = counter[:, :, c] + 1
770
+
771
+ # perform guidance
772
+ if do_classifier_free_guidance:
773
+ noise_pred_uncond, noise_pred_text = (noise_pred / counter).chunk(2)
774
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
775
+
776
+ # compute the previous noisy sample x_t -> x_t-1
777
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
778
+
779
+ if is_dist_initialized:
780
+ dist.broadcast(latents, 0)
781
+ dist.barrier()
782
+
783
+ reference_control_writer.clear()
784
+
785
+ interpolation_factor = 1
786
+ latents = self.interpolate_latents(latents, interpolation_factor, device)
787
+ # Post-processing
788
+ video = self.decode_latents(latents, rank, decoder_consistency=decoder_consistency)
789
+
790
+ if is_dist_initialized:
791
+ dist.barrier()
792
+
793
+ # Convert to tensor
794
+ if output_type == "tensor":
795
+ video = torch.from_numpy(video)
796
+
797
+ if not return_dict:
798
+ return video
799
+
800
+ return AnimationPipelineOutput(videos=video)
magicanimate/utils/dist_tools.py CHANGED
@@ -1,105 +1,105 @@
1
- # Copyright 2023 ByteDance and/or its affiliates.
2
- #
3
- # Copyright (2023) MagicAnimate Authors
4
- #
5
- # ByteDance, its affiliates and licensors retain all intellectual
6
- # property and proprietary rights in and to this material, related
7
- # documentation and any modifications thereto. Any use, reproduction,
8
- # disclosure or distribution of this material and related documentation
9
- # without an express license agreement from ByteDance or
10
- # its affiliates is strictly prohibited.
11
- import os
12
- import socket
13
- import warnings
14
- import torch
15
- from torch import distributed as dist
16
-
17
-
18
- def distributed_init(args):
19
-
20
- if dist.is_initialized():
21
- warnings.warn("Distributed is already initialized, cannot initialize twice!")
22
- args.rank = dist.get_rank()
23
- else:
24
- print(
25
- f"Distributed Init (Rank {args.rank}): "
26
- f"{args.init_method}"
27
- )
28
- dist.init_process_group(
29
- backend='nccl',
30
- init_method=args.init_method,
31
- world_size=args.world_size,
32
- rank=args.rank,
33
- )
34
- print(
35
- f"Initialized Host {socket.gethostname()} as Rank "
36
- f"{args.rank}"
37
- )
38
-
39
- if "MASTER_ADDR" not in os.environ or "MASTER_PORT" not in os.environ:
40
- # Set for onboxdataloader support
41
- split = args.init_method.split("//")
42
- assert len(split) == 2, (
43
- "host url for distributed should be split by '//' "
44
- + "into exactly two elements"
45
- )
46
-
47
- split = split[1].split(":")
48
- assert (
49
- len(split) == 2
50
- ), "host url should be of the form <host_url>:<host_port>"
51
- os.environ["MASTER_ADDR"] = split[0]
52
- os.environ["MASTER_PORT"] = split[1]
53
-
54
- # perform a dummy all-reduce to initialize the NCCL communicator
55
- dist.all_reduce(torch.zeros(1).cuda())
56
-
57
- suppress_output(is_master())
58
- args.rank = dist.get_rank()
59
- return args.rank
60
-
61
-
62
- def get_rank():
63
- if not dist.is_available():
64
- return 0
65
- if not dist.is_nccl_available():
66
- return 0
67
- if not dist.is_initialized():
68
- return 0
69
- return dist.get_rank()
70
-
71
-
72
- def is_master():
73
- return get_rank() == 0
74
-
75
-
76
- def synchronize():
77
- if dist.is_initialized():
78
- dist.barrier()
79
-
80
-
81
- def suppress_output(is_master):
82
- """Suppress printing on the current device. Force printing with `force=True`."""
83
- import builtins as __builtin__
84
-
85
- builtin_print = __builtin__.print
86
-
87
- def print(*args, **kwargs):
88
- force = kwargs.pop("force", False)
89
- if is_master or force:
90
- builtin_print(*args, **kwargs)
91
-
92
- __builtin__.print = print
93
-
94
- import warnings
95
-
96
- builtin_warn = warnings.warn
97
-
98
- def warn(*args, **kwargs):
99
- force = kwargs.pop("force", False)
100
- if is_master or force:
101
- builtin_warn(*args, **kwargs)
102
-
103
- # Log warnings only once
104
- warnings.warn = warn
105
  warnings.simplefilter("once", UserWarning)
 
1
+ # Copyright 2023 ByteDance and/or its affiliates.
2
+ #
3
+ # Copyright (2023) MagicAnimate Authors
4
+ #
5
+ # ByteDance, its affiliates and licensors retain all intellectual
6
+ # property and proprietary rights in and to this material, related
7
+ # documentation and any modifications thereto. Any use, reproduction,
8
+ # disclosure or distribution of this material and related documentation
9
+ # without an express license agreement from ByteDance or
10
+ # its affiliates is strictly prohibited.
11
+ import os
12
+ import socket
13
+ import warnings
14
+ import torch
15
+ from torch import distributed as dist
16
+
17
+
18
+ def distributed_init(args):
19
+
20
+ if dist.is_initialized():
21
+ warnings.warn("Distributed is already initialized, cannot initialize twice!")
22
+ args.rank = dist.get_rank()
23
+ else:
24
+ print(
25
+ f"Distributed Init (Rank {args.rank}): "
26
+ f"{args.init_method}"
27
+ )
28
+ dist.init_process_group(
29
+ backend='nccl',
30
+ init_method=args.init_method,
31
+ world_size=args.world_size,
32
+ rank=args.rank,
33
+ )
34
+ print(
35
+ f"Initialized Host {socket.gethostname()} as Rank "
36
+ f"{args.rank}"
37
+ )
38
+
39
+ if "MASTER_ADDR" not in os.environ or "MASTER_PORT" not in os.environ:
40
+ # Set for onboxdataloader support
41
+ split = args.init_method.split("//")
42
+ assert len(split) == 2, (
43
+ "host url for distributed should be split by '//' "
44
+ + "into exactly two elements"
45
+ )
46
+
47
+ split = split[1].split(":")
48
+ assert (
49
+ len(split) == 2
50
+ ), "host url should be of the form <host_url>:<host_port>"
51
+ os.environ["MASTER_ADDR"] = split[0]
52
+ os.environ["MASTER_PORT"] = split[1]
53
+
54
+ # perform a dummy all-reduce to initialize the NCCL communicator
55
+ dist.all_reduce(torch.zeros(1).cuda())
56
+
57
+ suppress_output(is_master())
58
+ args.rank = dist.get_rank()
59
+ return args.rank
60
+
61
+
62
+ def get_rank():
63
+ if not dist.is_available():
64
+ return 0
65
+ if not dist.is_nccl_available():
66
+ return 0
67
+ if not dist.is_initialized():
68
+ return 0
69
+ return dist.get_rank()
70
+
71
+
72
+ def is_master():
73
+ return get_rank() == 0
74
+
75
+
76
+ def synchronize():
77
+ if dist.is_initialized():
78
+ dist.barrier()
79
+
80
+
81
+ def suppress_output(is_master):
82
+ """Suppress printing on the current device. Force printing with `force=True`."""
83
+ import builtins as __builtin__
84
+
85
+ builtin_print = __builtin__.print
86
+
87
+ def print(*args, **kwargs):
88
+ force = kwargs.pop("force", False)
89
+ if is_master or force:
90
+ builtin_print(*args, **kwargs)
91
+
92
+ __builtin__.print = print
93
+
94
+ import warnings
95
+
96
+ builtin_warn = warnings.warn
97
+
98
+ def warn(*args, **kwargs):
99
+ force = kwargs.pop("force", False)
100
+ if is_master or force:
101
+ builtin_warn(*args, **kwargs)
102
+
103
+ # Log warnings only once
104
+ warnings.warn = warn
105
  warnings.simplefilter("once", UserWarning)
magicanimate/utils/util.py CHANGED
@@ -1,138 +1,138 @@
1
- # *************************************************************************
2
- # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
- # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
- # ytedance Inc..
5
- # *************************************************************************
6
-
7
- # Adapted from https://github.com/guoyww/AnimateDiff
8
- import os
9
- import imageio
10
- import numpy as np
11
-
12
- import torch
13
- import torchvision
14
-
15
- from PIL import Image
16
- from typing import Union
17
- from tqdm import tqdm
18
- from einops import rearrange
19
-
20
-
21
- def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=25):
22
- videos = rearrange(videos, "b c t h w -> t b c h w")
23
- outputs = []
24
- for x in videos:
25
- x = torchvision.utils.make_grid(x, nrow=n_rows)
26
- x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
27
- if rescale:
28
- x = (x + 1.0) / 2.0 # -1,1 -> 0,1
29
- x = (x * 255).numpy().astype(np.uint8)
30
- outputs.append(x)
31
-
32
- os.makedirs(os.path.dirname(path), exist_ok=True)
33
- imageio.mimsave(path, outputs, fps=fps)
34
-
35
- def save_images_grid(images: torch.Tensor, path: str):
36
- assert images.shape[2] == 1 # no time dimension
37
- images = images.squeeze(2)
38
- grid = torchvision.utils.make_grid(images)
39
- grid = (grid * 255).numpy().transpose(1, 2, 0).astype(np.uint8)
40
- os.makedirs(os.path.dirname(path), exist_ok=True)
41
- Image.fromarray(grid).save(path)
42
-
43
- # DDIM Inversion
44
- @torch.no_grad()
45
- def init_prompt(prompt, pipeline):
46
- uncond_input = pipeline.tokenizer(
47
- [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
48
- return_tensors="pt"
49
- )
50
- uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
51
- text_input = pipeline.tokenizer(
52
- [prompt],
53
- padding="max_length",
54
- max_length=pipeline.tokenizer.model_max_length,
55
- truncation=True,
56
- return_tensors="pt",
57
- )
58
- text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
59
- context = torch.cat([uncond_embeddings, text_embeddings])
60
-
61
- return context
62
-
63
-
64
- def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
65
- sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
66
- timestep, next_timestep = min(
67
- timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
68
- alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
69
- alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
70
- beta_prod_t = 1 - alpha_prod_t
71
- next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
72
- next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
73
- next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
74
- return next_sample
75
-
76
-
77
- def get_noise_pred_single(latents, t, context, unet):
78
- noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
79
- return noise_pred
80
-
81
-
82
- @torch.no_grad()
83
- def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
84
- context = init_prompt(prompt, pipeline)
85
- uncond_embeddings, cond_embeddings = context.chunk(2)
86
- all_latent = [latent]
87
- latent = latent.clone().detach()
88
- for i in tqdm(range(num_inv_steps)):
89
- t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
90
- noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
91
- latent = next_step(noise_pred, t, latent, ddim_scheduler)
92
- all_latent.append(latent)
93
- return all_latent
94
-
95
-
96
- @torch.no_grad()
97
- def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
98
- ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
99
- return ddim_latents
100
-
101
-
102
- def video2images(path, step=4, length=16, start=0):
103
- reader = imageio.get_reader(path)
104
- frames = []
105
- for frame in reader:
106
- frames.append(np.array(frame))
107
- frames = frames[start::step][:length]
108
- return frames
109
-
110
-
111
- def images2video(video, path, fps=8):
112
- imageio.mimsave(path, video, fps=fps)
113
- return
114
-
115
-
116
- tensor_interpolation = None
117
-
118
- def get_tensor_interpolation_method():
119
- return tensor_interpolation
120
-
121
- def set_tensor_interpolation_method(is_slerp):
122
- global tensor_interpolation
123
- tensor_interpolation = slerp if is_slerp else linear
124
-
125
- def linear(v1, v2, t):
126
- return (1.0 - t) * v1 + t * v2
127
-
128
- def slerp(
129
- v0: torch.Tensor, v1: torch.Tensor, t: float, DOT_THRESHOLD: float = 0.9995
130
- ) -> torch.Tensor:
131
- u0 = v0 / v0.norm()
132
- u1 = v1 / v1.norm()
133
- dot = (u0 * u1).sum()
134
- if dot.abs() > DOT_THRESHOLD:
135
- #logger.info(f'warning: v0 and v1 close to parallel, using linear interpolation instead.')
136
- return (1.0 - t) * v0 + t * v1
137
- omega = dot.acos()
138
  return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin()
 
1
+ # *************************************************************************
2
+ # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
+ # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
+ # ytedance Inc..
5
+ # *************************************************************************
6
+
7
+ # Adapted from https://github.com/guoyww/AnimateDiff
8
+ import os
9
+ import imageio
10
+ import numpy as np
11
+
12
+ import torch
13
+ import torchvision
14
+
15
+ from PIL import Image
16
+ from typing import Union
17
+ from tqdm import tqdm
18
+ from einops import rearrange
19
+
20
+
21
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=25):
22
+ videos = rearrange(videos, "b c t h w -> t b c h w")
23
+ outputs = []
24
+ for x in videos:
25
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
26
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
27
+ if rescale:
28
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
29
+ x = (x * 255).numpy().astype(np.uint8)
30
+ outputs.append(x)
31
+
32
+ os.makedirs(os.path.dirname(path), exist_ok=True)
33
+ imageio.mimsave(path, outputs, fps=fps)
34
+
35
+ def save_images_grid(images: torch.Tensor, path: str):
36
+ assert images.shape[2] == 1 # no time dimension
37
+ images = images.squeeze(2)
38
+ grid = torchvision.utils.make_grid(images)
39
+ grid = (grid * 255).numpy().transpose(1, 2, 0).astype(np.uint8)
40
+ os.makedirs(os.path.dirname(path), exist_ok=True)
41
+ Image.fromarray(grid).save(path)
42
+
43
+ # DDIM Inversion
44
+ @torch.no_grad()
45
+ def init_prompt(prompt, pipeline):
46
+ uncond_input = pipeline.tokenizer(
47
+ [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
48
+ return_tensors="pt"
49
+ )
50
+ uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
51
+ text_input = pipeline.tokenizer(
52
+ [prompt],
53
+ padding="max_length",
54
+ max_length=pipeline.tokenizer.model_max_length,
55
+ truncation=True,
56
+ return_tensors="pt",
57
+ )
58
+ text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
59
+ context = torch.cat([uncond_embeddings, text_embeddings])
60
+
61
+ return context
62
+
63
+
64
+ def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
65
+ sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
66
+ timestep, next_timestep = min(
67
+ timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
68
+ alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
69
+ alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
70
+ beta_prod_t = 1 - alpha_prod_t
71
+ next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
72
+ next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
73
+ next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
74
+ return next_sample
75
+
76
+
77
+ def get_noise_pred_single(latents, t, context, unet):
78
+ noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
79
+ return noise_pred
80
+
81
+
82
+ @torch.no_grad()
83
+ def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
84
+ context = init_prompt(prompt, pipeline)
85
+ uncond_embeddings, cond_embeddings = context.chunk(2)
86
+ all_latent = [latent]
87
+ latent = latent.clone().detach()
88
+ for i in tqdm(range(num_inv_steps)):
89
+ t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
90
+ noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
91
+ latent = next_step(noise_pred, t, latent, ddim_scheduler)
92
+ all_latent.append(latent)
93
+ return all_latent
94
+
95
+
96
+ @torch.no_grad()
97
+ def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
98
+ ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
99
+ return ddim_latents
100
+
101
+
102
+ def video2images(path, step=4, length=16, start=0):
103
+ reader = imageio.get_reader(path)
104
+ frames = []
105
+ for frame in reader:
106
+ frames.append(np.array(frame))
107
+ frames = frames[start::step][:length]
108
+ return frames
109
+
110
+
111
+ def images2video(video, path, fps=8):
112
+ imageio.mimsave(path, video, fps=fps)
113
+ return
114
+
115
+
116
+ tensor_interpolation = None
117
+
118
+ def get_tensor_interpolation_method():
119
+ return tensor_interpolation
120
+
121
+ def set_tensor_interpolation_method(is_slerp):
122
+ global tensor_interpolation
123
+ tensor_interpolation = slerp if is_slerp else linear
124
+
125
+ def linear(v1, v2, t):
126
+ return (1.0 - t) * v1 + t * v2
127
+
128
+ def slerp(
129
+ v0: torch.Tensor, v1: torch.Tensor, t: float, DOT_THRESHOLD: float = 0.9995
130
+ ) -> torch.Tensor:
131
+ u0 = v0 / v0.norm()
132
+ u1 = v1 / v1.norm()
133
+ dot = (u0 * u1).sum()
134
+ if dot.abs() > DOT_THRESHOLD:
135
+ #logger.info(f'warning: v0 and v1 close to parallel, using linear interpolation instead.')
136
+ return (1.0 - t) * v0 + t * v1
137
+ omega = dot.acos()
138
  return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin()
magicanimate/utils/videoreader.py CHANGED
@@ -1,157 +1,157 @@
1
- # *************************************************************************
2
- # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
- # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
- # ytedance Inc..
5
- # *************************************************************************
6
-
7
- # Copyright 2022 ByteDance and/or its affiliates.
8
- #
9
- # Copyright (2022) PV3D Authors
10
- #
11
- # ByteDance, its affiliates and licensors retain all intellectual
12
- # property and proprietary rights in and to this material, related
13
- # documentation and any modifications thereto. Any use, reproduction,
14
- # disclosure or distribution of this material and related documentation
15
- # without an express license agreement from ByteDance or
16
- # its affiliates is strictly prohibited.
17
- import av, gc
18
- import torch
19
- import warnings
20
- import numpy as np
21
-
22
-
23
- _CALLED_TIMES = 0
24
- _GC_COLLECTION_INTERVAL = 20
25
-
26
-
27
- # remove warnings
28
- av.logging.set_level(av.logging.ERROR)
29
-
30
-
31
- class VideoReader():
32
- """
33
- Simple wrapper around PyAV that exposes a few useful functions for
34
- dealing with video reading. PyAV is a pythonic binding for the ffmpeg libraries.
35
- Acknowledgement: Codes are borrowed from Bruno Korbar
36
- """
37
- def __init__(self, video, num_frames=float("inf"), decode_lossy=False, audio_resample_rate=None, bi_frame=False):
38
- """
39
- Arguments:
40
- video_path (str): path or byte of the video to be loaded
41
- """
42
- self.container = av.open(video)
43
- self.num_frames = num_frames
44
- self.bi_frame = bi_frame
45
-
46
- self.resampler = None
47
- if audio_resample_rate is not None:
48
- self.resampler = av.AudioResampler(rate=audio_resample_rate)
49
-
50
- if self.container.streams.video:
51
- # enable multi-threaded video decoding
52
- if decode_lossy:
53
- warnings.warn('VideoReader| thread_type==AUTO can yield potential frame dropping!', RuntimeWarning)
54
- self.container.streams.video[0].thread_type = 'AUTO'
55
- self.video_stream = self.container.streams.video[0]
56
- else:
57
- self.video_stream = None
58
-
59
- self.fps = self._get_video_frame_rate()
60
-
61
- def seek(self, pts, backward=True, any_frame=False):
62
- stream = self.video_stream
63
- self.container.seek(pts, any_frame=any_frame, backward=backward, stream=stream)
64
-
65
- def _occasional_gc(self):
66
- # there are a lot of reference cycles in PyAV, so need to manually call
67
- # the garbage collector from time to time
68
- global _CALLED_TIMES, _GC_COLLECTION_INTERVAL
69
- _CALLED_TIMES += 1
70
- if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1:
71
- gc.collect()
72
-
73
- def _read_video(self, offset):
74
- self._occasional_gc()
75
-
76
- pts = self.container.duration * offset
77
- time_ = pts / float(av.time_base)
78
- self.container.seek(int(pts))
79
-
80
- video_frames = []
81
- count = 0
82
- for _, frame in enumerate(self._iter_frames()):
83
- if frame.pts * frame.time_base >= time_:
84
- video_frames.append(frame)
85
- if count >= self.num_frames - 1:
86
- break
87
- count += 1
88
- return video_frames
89
-
90
- def _iter_frames(self):
91
- for packet in self.container.demux(self.video_stream):
92
- for frame in packet.decode():
93
- yield frame
94
-
95
- def _compute_video_stats(self):
96
- if self.video_stream is None or self.container is None:
97
- return 0
98
- num_of_frames = self.container.streams.video[0].frames
99
- if num_of_frames == 0:
100
- num_of_frames = self.fps * float(self.container.streams.video[0].duration*self.video_stream.time_base)
101
- self.seek(0, backward=False)
102
- count = 0
103
- time_base = 512
104
- for p in self.container.decode(video=0):
105
- count = count + 1
106
- if count == 1:
107
- start_pts = p.pts
108
- elif count == 2:
109
- time_base = p.pts - start_pts
110
- break
111
- return start_pts, time_base, num_of_frames
112
-
113
- def _get_video_frame_rate(self):
114
- return float(self.container.streams.video[0].guessed_rate)
115
-
116
- def sample(self, debug=False):
117
-
118
- if self.container is None:
119
- raise RuntimeError('video stream not found')
120
- sample = dict()
121
- _, _, total_num_frames = self._compute_video_stats()
122
- offset = torch.randint(max(1, total_num_frames-self.num_frames-1), [1]).item()
123
- video_frames = self._read_video(offset/total_num_frames)
124
- video_frames = np.array([np.uint8(f.to_rgb().to_ndarray()) for f in video_frames])
125
- sample["frames"] = video_frames
126
- sample["frame_idx"] = [offset]
127
-
128
- if self.bi_frame:
129
- frames = [np.random.beta(2, 1, size=1), np.random.beta(1, 2, size=1)]
130
- frames = [int(frames[0] * self.num_frames), int(frames[1] * self.num_frames)]
131
- frames.sort()
132
- video_frames = np.array([video_frames[min(frames)], video_frames[max(frames)]])
133
- Ts= [min(frames) / (self.num_frames - 1), max(frames) / (self.num_frames - 1)]
134
- sample["frames"] = video_frames
135
- sample["real_t"] = torch.tensor(Ts, dtype=torch.float32)
136
- sample["frame_idx"] = [offset+min(frames), offset+max(frames)]
137
- return sample
138
-
139
- return sample
140
-
141
- def read_frames(self, frame_indices):
142
- self.num_frames = frame_indices[1] - frame_indices[0]
143
- video_frames = self._read_video(frame_indices[0]/self.get_num_frames())
144
- video_frames = np.array([
145
- np.uint8(video_frames[0].to_rgb().to_ndarray()),
146
- np.uint8(video_frames[-1].to_rgb().to_ndarray())
147
- ])
148
- return video_frames
149
-
150
- def read(self):
151
- video_frames = self._read_video(0)
152
- video_frames = np.array([np.uint8(f.to_rgb().to_ndarray()) for f in video_frames])
153
- return video_frames
154
-
155
- def get_num_frames(self):
156
- _, _, total_num_frames = self._compute_video_stats()
157
  return total_num_frames
 
1
+ # *************************************************************************
2
+ # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
+ # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
+ # ytedance Inc..
5
+ # *************************************************************************
6
+
7
+ # Copyright 2022 ByteDance and/or its affiliates.
8
+ #
9
+ # Copyright (2022) PV3D Authors
10
+ #
11
+ # ByteDance, its affiliates and licensors retain all intellectual
12
+ # property and proprietary rights in and to this material, related
13
+ # documentation and any modifications thereto. Any use, reproduction,
14
+ # disclosure or distribution of this material and related documentation
15
+ # without an express license agreement from ByteDance or
16
+ # its affiliates is strictly prohibited.
17
+ import av, gc
18
+ import torch
19
+ import warnings
20
+ import numpy as np
21
+
22
+
23
+ _CALLED_TIMES = 0
24
+ _GC_COLLECTION_INTERVAL = 20
25
+
26
+
27
+ # remove warnings
28
+ av.logging.set_level(av.logging.ERROR)
29
+
30
+
31
+ class VideoReader():
32
+ """
33
+ Simple wrapper around PyAV that exposes a few useful functions for
34
+ dealing with video reading. PyAV is a pythonic binding for the ffmpeg libraries.
35
+ Acknowledgement: Codes are borrowed from Bruno Korbar
36
+ """
37
+ def __init__(self, video, num_frames=float("inf"), decode_lossy=False, audio_resample_rate=None, bi_frame=False):
38
+ """
39
+ Arguments:
40
+ video_path (str): path or byte of the video to be loaded
41
+ """
42
+ self.container = av.open(video)
43
+ self.num_frames = num_frames
44
+ self.bi_frame = bi_frame
45
+
46
+ self.resampler = None
47
+ if audio_resample_rate is not None:
48
+ self.resampler = av.AudioResampler(rate=audio_resample_rate)
49
+
50
+ if self.container.streams.video:
51
+ # enable multi-threaded video decoding
52
+ if decode_lossy:
53
+ warnings.warn('VideoReader| thread_type==AUTO can yield potential frame dropping!', RuntimeWarning)
54
+ self.container.streams.video[0].thread_type = 'AUTO'
55
+ self.video_stream = self.container.streams.video[0]
56
+ else:
57
+ self.video_stream = None
58
+
59
+ self.fps = self._get_video_frame_rate()
60
+
61
+ def seek(self, pts, backward=True, any_frame=False):
62
+ stream = self.video_stream
63
+ self.container.seek(pts, any_frame=any_frame, backward=backward, stream=stream)
64
+
65
+ def _occasional_gc(self):
66
+ # there are a lot of reference cycles in PyAV, so need to manually call
67
+ # the garbage collector from time to time
68
+ global _CALLED_TIMES, _GC_COLLECTION_INTERVAL
69
+ _CALLED_TIMES += 1
70
+ if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1:
71
+ gc.collect()
72
+
73
+ def _read_video(self, offset):
74
+ self._occasional_gc()
75
+
76
+ pts = self.container.duration * offset
77
+ time_ = pts / float(av.time_base)
78
+ self.container.seek(int(pts))
79
+
80
+ video_frames = []
81
+ count = 0
82
+ for _, frame in enumerate(self._iter_frames()):
83
+ if frame.pts * frame.time_base >= time_:
84
+ video_frames.append(frame)
85
+ if count >= self.num_frames - 1:
86
+ break
87
+ count += 1
88
+ return video_frames
89
+
90
+ def _iter_frames(self):
91
+ for packet in self.container.demux(self.video_stream):
92
+ for frame in packet.decode():
93
+ yield frame
94
+
95
+ def _compute_video_stats(self):
96
+ if self.video_stream is None or self.container is None:
97
+ return 0
98
+ num_of_frames = self.container.streams.video[0].frames
99
+ if num_of_frames == 0:
100
+ num_of_frames = self.fps * float(self.container.streams.video[0].duration*self.video_stream.time_base)
101
+ self.seek(0, backward=False)
102
+ count = 0
103
+ time_base = 512
104
+ for p in self.container.decode(video=0):
105
+ count = count + 1
106
+ if count == 1:
107
+ start_pts = p.pts
108
+ elif count == 2:
109
+ time_base = p.pts - start_pts
110
+ break
111
+ return start_pts, time_base, num_of_frames
112
+
113
+ def _get_video_frame_rate(self):
114
+ return float(self.container.streams.video[0].guessed_rate)
115
+
116
+ def sample(self, debug=False):
117
+
118
+ if self.container is None:
119
+ raise RuntimeError('video stream not found')
120
+ sample = dict()
121
+ _, _, total_num_frames = self._compute_video_stats()
122
+ offset = torch.randint(max(1, total_num_frames-self.num_frames-1), [1]).item()
123
+ video_frames = self._read_video(offset/total_num_frames)
124
+ video_frames = np.array([np.uint8(f.to_rgb().to_ndarray()) for f in video_frames])
125
+ sample["frames"] = video_frames
126
+ sample["frame_idx"] = [offset]
127
+
128
+ if self.bi_frame:
129
+ frames = [np.random.beta(2, 1, size=1), np.random.beta(1, 2, size=1)]
130
+ frames = [int(frames[0] * self.num_frames), int(frames[1] * self.num_frames)]
131
+ frames.sort()
132
+ video_frames = np.array([video_frames[min(frames)], video_frames[max(frames)]])
133
+ Ts= [min(frames) / (self.num_frames - 1), max(frames) / (self.num_frames - 1)]
134
+ sample["frames"] = video_frames
135
+ sample["real_t"] = torch.tensor(Ts, dtype=torch.float32)
136
+ sample["frame_idx"] = [offset+min(frames), offset+max(frames)]
137
+ return sample
138
+
139
+ return sample
140
+
141
+ def read_frames(self, frame_indices):
142
+ self.num_frames = frame_indices[1] - frame_indices[0]
143
+ video_frames = self._read_video(frame_indices[0]/self.get_num_frames())
144
+ video_frames = np.array([
145
+ np.uint8(video_frames[0].to_rgb().to_ndarray()),
146
+ np.uint8(video_frames[-1].to_rgb().to_ndarray())
147
+ ])
148
+ return video_frames
149
+
150
+ def read(self):
151
+ video_frames = self._read_video(0)
152
+ video_frames = np.array([np.uint8(f.to_rgb().to_ndarray()) for f in video_frames])
153
+ return video_frames
154
+
155
+ def get_num_frames(self):
156
+ _, _, total_num_frames = self._compute_video_stats()
157
  return total_num_frames
pre-requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ pip>=23.0.0
requirements.txt CHANGED
@@ -1,117 +1,117 @@
1
- absl-py==1.4.0
2
- accelerate==0.22.0
3
- aiofiles==23.2.1
4
- aiohttp==3.8.5
5
- aiosignal==1.3.1
6
- altair==5.0.1
7
- annotated-types==0.5.0
8
- antlr4-python3-runtime==4.9.3
9
- anyio==3.7.1
10
- async-timeout==4.0.3
11
- attrs==23.1.0
12
- cachetools==5.3.1
13
- certifi==2023.7.22
14
- charset-normalizer==3.2.0
15
- click==8.1.7
16
- cmake==3.27.2
17
- contourpy==1.1.0
18
- cycler==0.11.0
19
- datasets==2.14.4
20
- dill==0.3.7
21
- einops==0.6.1
22
- exceptiongroup==1.1.3
23
- fastapi==0.103.0
24
- ffmpy==0.3.1
25
- filelock==3.12.2
26
- fonttools==4.42.1
27
- frozenlist==1.4.0
28
- fsspec==2023.6.0
29
- google-auth==2.22.0
30
- google-auth-oauthlib==1.0.0
31
- grpcio==1.57.0
32
- h11==0.14.0
33
  httpcore
34
  httpx
35
  huggingface-hub
36
- idna==3.4
37
- importlib-metadata==6.8.0
38
- importlib-resources==6.0.1
39
- jinja2==3.1.2
40
- joblib==1.3.2
41
- jsonschema==4.19.0
42
- jsonschema-specifications==2023.7.1
43
- kiwisolver==1.4.5
44
- lightning-utilities==0.9.0
45
- lit==16.0.6
46
- markdown==3.4.4
47
- markupsafe==2.1.3
48
- matplotlib==3.7.2
49
- mpmath==1.3.0
50
- multidict==6.0.4
51
- multiprocess==0.70.15
52
- networkx==3.1
53
- numpy==1.24.4
54
- nvidia-cublas-cu11==11.10.3.66
55
- nvidia-cuda-cupti-cu11==11.7.101
56
- nvidia-cuda-nvrtc-cu11==11.7.99
57
- nvidia-cuda-runtime-cu11==11.7.99
58
- nvidia-cudnn-cu11==8.5.0.96
59
- nvidia-cufft-cu11==10.9.0.58
60
- nvidia-curand-cu11==10.2.10.91
61
- nvidia-cusolver-cu11==11.4.0.1
62
- nvidia-cusparse-cu11==11.7.4.91
63
- nvidia-nccl-cu11==2.14.3
64
- nvidia-nvtx-cu11==11.7.91
65
- oauthlib==3.2.2
66
- omegaconf==2.3.0
67
- opencv-python==4.8.0.76
68
- orjson==3.9.5
69
- pandas==2.0.3
70
- pillow==9.5.0
71
- pkgutil-resolve-name==1.3.10
72
- protobuf==4.24.2
73
- psutil==5.9.5
74
- pyarrow==13.0.0
75
- pyasn1==0.5.0
76
- pyasn1-modules==0.3.0
77
- pydantic==2.3.0
78
- pydantic-core==2.6.3
79
- pydub==0.25.1
80
- pyparsing==3.0.9
81
- python-multipart==0.0.6
82
- pytorch-lightning==2.0.7
83
- pytz==2023.3
84
- pyyaml==6.0.1
85
- referencing==0.30.2
86
  regex
87
  requests
88
  requests-oauthlib
89
- rpds-py==0.9.2
90
- rsa==4.9
91
- safetensors==0.3.3
92
- semantic-version==2.10.0
93
- sniffio==1.3.0
94
- starlette==0.27.0
95
- sympy==1.12
96
- tensorboard==2.14.0
97
- tensorboard-data-server==0.7.1
98
- tokenizers==0.13.3
99
- toolz==0.12.0
100
- torchmetrics==1.1.0
101
  tqdm
102
- transformers==4.32.0
103
- triton==2.0.0
104
- tzdata==2023.3
105
- urllib3==1.26.16
106
- uvicorn==0.23.2
107
- websockets==11.0.3
108
- werkzeug==2.3.7
109
- xxhash==3.3.0
110
- yarl==1.9.2
111
- zipp==3.16.2
112
  decord
113
- imageio==2.9.0
114
- imageio-ffmpeg==0.4.3
115
  timm
116
  scipy
117
  scikit-image
@@ -119,6 +119,6 @@ av
119
  imgaug
120
  lpips
121
  ffmpeg-python
122
- torch==2.0.1
123
- torchvision==0.15.2
124
- diffusers==0.21.4
 
1
+ absl-py
2
+ accelerate
3
+ aiofiles
4
+ aiohttp
5
+ aiosignal
6
+ altair
7
+ annotated-types
8
+ antlr4-python3-runtime
9
+ anyio
10
+ async-timeout
11
+ attrs
12
+ cachetools
13
+ certifi
14
+ charset-normalizer
15
+ click
16
+ cmake
17
+ contourpy
18
+ cycler
19
+ datasets
20
+ dill
21
+ einops
22
+ exceptiongroup
23
+ fastapi
24
+ ffmpy
25
+ filelock
26
+ fonttools
27
+ frozenlist
28
+ fsspec
29
+ google-auth
30
+ google-auth-oauthlib
31
+ grpcio
32
+ h11
33
  httpcore
34
  httpx
35
  huggingface-hub
36
+ idna
37
+ importlib-metadata
38
+ importlib-resources
39
+ jinja2
40
+ joblib
41
+ jsonschema
42
+ jsonschema-specifications
43
+ kiwisolver
44
+ lightning-utilities
45
+ lit
46
+ markdown
47
+ markupsafe
48
+ matplotlib
49
+ mpmath
50
+ multidict
51
+ multiprocess
52
+ networkx
53
+ numpy
54
+ nvidia-cublas-cu11
55
+ nvidia-cuda-cupti-cu11
56
+ nvidia-cuda-nvrtc-cu11
57
+ nvidia-cuda-runtime-cu11
58
+ nvidia-cudnn-cu11
59
+ nvidia-cufft-cu11
60
+ nvidia-curand-cu11
61
+ nvidia-cusolver-cu11
62
+ nvidia-cusparse-cu11
63
+ nvidia-nccl-cu11
64
+ nvidia-nvtx-cu11
65
+ oauthlib
66
+ omegaconf
67
+ opencv-python
68
+ orjson
69
+ pandas
70
+ pillow
71
+ pkgutil-resolve-name
72
+ protobuf
73
+ psutil
74
+ pyarrow
75
+ pyasn1
76
+ pyasn1-modules
77
+ pydantic
78
+ pydantic-core
79
+ pydub
80
+ pyparsing
81
+ python-multipart
82
+ pytorch-lightning
83
+ pytz
84
+ pyyaml
85
+ referencing
86
  regex
87
  requests
88
  requests-oauthlib
89
+ rpds-py
90
+ rsa
91
+ safetensors
92
+ semantic-version
93
+ sniffio
94
+ starlette
95
+ sympy
96
+ tensorboard
97
+ tensorboard-data-server
98
+ tokenizers
99
+ toolz
100
+ torchmetrics
101
  tqdm
102
+ transformers
103
+ triton
104
+ tzdata
105
+ urllib3
106
+ uvicorn
107
+ websockets
108
+ werkzeug
109
+ xxhash
110
+ yarl
111
+ zipp
112
  decord
113
+ imageio
114
+ imageio-ffmpeg
115
  timm
116
  scipy
117
  scikit-image
 
119
  imgaug
120
  lpips
121
  ffmpeg-python
122
+ torch
123
+ torchvision
124
+ diffusers