Spaces:
Running
on
A10G
Running
on
A10G
Upload 17 files
Browse files- LICENSE +201 -0
- MotionDirector_inference.py +284 -0
- MotionDirector_inference_batch.py +290 -0
- MotionDirector_train.py +1021 -0
- README.md +364 -13
- app.py +100 -0
- demo/MotionDirector_gradio.py +92 -0
- demo/motiondirector.py +218 -0
- models/unet_3d_blocks.py +842 -0
- models/unet_3d_condition.py +500 -0
- requirements.txt +19 -0
- utils/bucketing.py +32 -0
- utils/convert_diffusers_to_original_ms_text_to_video.py +465 -0
- utils/dataset.py +578 -0
- utils/ddim_utils.py +65 -0
- utils/lora.py +1483 -0
- utils/lora_handler.py +269 -0
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
MotionDirector_inference.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import platform
|
4 |
+
import re
|
5 |
+
import warnings
|
6 |
+
from typing import Optional
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from diffusers import DDIMScheduler, TextToVideoSDPipeline
|
10 |
+
from einops import rearrange
|
11 |
+
from torch import Tensor
|
12 |
+
from torch.nn.functional import interpolate
|
13 |
+
from tqdm import trange
|
14 |
+
import random
|
15 |
+
|
16 |
+
from MotionDirector_train import export_to_video, handle_memory_attention, load_primary_models, unet_and_text_g_c, freeze_models
|
17 |
+
from utils.lora_handler import LoraHandler
|
18 |
+
from utils.ddim_utils import ddim_inversion
|
19 |
+
import imageio
|
20 |
+
|
21 |
+
|
22 |
+
def initialize_pipeline(
|
23 |
+
model: str,
|
24 |
+
device: str = "cuda",
|
25 |
+
xformers: bool = False,
|
26 |
+
sdp: bool = False,
|
27 |
+
lora_path: str = "",
|
28 |
+
lora_rank: int = 64,
|
29 |
+
lora_scale: float = 1.0,
|
30 |
+
):
|
31 |
+
with warnings.catch_warnings():
|
32 |
+
warnings.simplefilter("ignore")
|
33 |
+
|
34 |
+
scheduler, tokenizer, text_encoder, vae, unet = load_primary_models(model)
|
35 |
+
|
36 |
+
# Freeze any necessary models
|
37 |
+
freeze_models([vae, text_encoder, unet])
|
38 |
+
|
39 |
+
# Enable xformers if available
|
40 |
+
handle_memory_attention(xformers, sdp, unet)
|
41 |
+
|
42 |
+
lora_manager_temporal = LoraHandler(
|
43 |
+
version="cloneofsimo",
|
44 |
+
use_unet_lora=True,
|
45 |
+
use_text_lora=False,
|
46 |
+
save_for_webui=False,
|
47 |
+
only_for_webui=False,
|
48 |
+
unet_replace_modules=["TransformerTemporalModel"],
|
49 |
+
text_encoder_replace_modules=None,
|
50 |
+
lora_bias=None
|
51 |
+
)
|
52 |
+
|
53 |
+
unet_lora_params, unet_negation = lora_manager_temporal.add_lora_to_model(
|
54 |
+
True, unet, lora_manager_temporal.unet_replace_modules, 0, lora_path, r=lora_rank, scale=lora_scale)
|
55 |
+
|
56 |
+
unet.eval()
|
57 |
+
text_encoder.eval()
|
58 |
+
unet_and_text_g_c(unet, text_encoder, False, False)
|
59 |
+
|
60 |
+
pipe = TextToVideoSDPipeline.from_pretrained(
|
61 |
+
pretrained_model_name_or_path=model,
|
62 |
+
scheduler=scheduler,
|
63 |
+
tokenizer=tokenizer,
|
64 |
+
text_encoder=text_encoder.to(device=device, dtype=torch.half),
|
65 |
+
vae=vae.to(device=device, dtype=torch.half),
|
66 |
+
unet=unet.to(device=device, dtype=torch.half),
|
67 |
+
)
|
68 |
+
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
69 |
+
|
70 |
+
return pipe
|
71 |
+
|
72 |
+
|
73 |
+
def inverse_video(pipe, latents, num_steps):
|
74 |
+
ddim_inv_scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
75 |
+
ddim_inv_scheduler.set_timesteps(num_steps)
|
76 |
+
|
77 |
+
ddim_inv_latent = ddim_inversion(
|
78 |
+
pipe, ddim_inv_scheduler, video_latent=latents.to(pipe.device),
|
79 |
+
num_inv_steps=num_steps, prompt="")[-1]
|
80 |
+
return ddim_inv_latent
|
81 |
+
|
82 |
+
|
83 |
+
def prepare_input_latents(
|
84 |
+
pipe: TextToVideoSDPipeline,
|
85 |
+
batch_size: int,
|
86 |
+
num_frames: int,
|
87 |
+
height: int,
|
88 |
+
width: int,
|
89 |
+
latents_path:str,
|
90 |
+
noise_prior: float
|
91 |
+
):
|
92 |
+
# initialize with random gaussian noise
|
93 |
+
scale = pipe.vae_scale_factor
|
94 |
+
shape = (batch_size, pipe.unet.config.in_channels, num_frames, height // scale, width // scale)
|
95 |
+
if noise_prior > 0.:
|
96 |
+
cached_latents = torch.load(latents_path)
|
97 |
+
if 'inversion_noise' not in cached_latents:
|
98 |
+
latents = inverse_video(pipe, cached_latents['latents'].unsqueeze(0), 50).squeeze(0)
|
99 |
+
else:
|
100 |
+
latents = torch.load(latents_path)['inversion_noise'].unsqueeze(0)
|
101 |
+
if latents.shape[0] != batch_size:
|
102 |
+
latents = latents.repeat(batch_size, 1, 1, 1, 1)
|
103 |
+
if latents.shape != shape:
|
104 |
+
latents = interpolate(rearrange(latents, "b c f h w -> (b f) c h w", b=batch_size), (height // scale, width // scale), mode='bilinear')
|
105 |
+
latents = rearrange(latents, "(b f) c h w -> b c f h w", b=batch_size)
|
106 |
+
noise = torch.randn_like(latents, dtype=torch.half)
|
107 |
+
latents = (noise_prior) ** 0.5 * latents + (1 - noise_prior) ** 0.5 * noise
|
108 |
+
else:
|
109 |
+
latents = torch.randn(shape, dtype=torch.half)
|
110 |
+
|
111 |
+
return latents
|
112 |
+
|
113 |
+
|
114 |
+
def encode(pipe: TextToVideoSDPipeline, pixels: Tensor, batch_size: int = 8):
|
115 |
+
nf = pixels.shape[2]
|
116 |
+
pixels = rearrange(pixels, "b c f h w -> (b f) c h w")
|
117 |
+
|
118 |
+
latents = []
|
119 |
+
for idx in trange(
|
120 |
+
0, pixels.shape[0], batch_size, desc="Encoding to latents...", unit_scale=batch_size, unit="frame"
|
121 |
+
):
|
122 |
+
pixels_batch = pixels[idx : idx + batch_size].to(pipe.device, dtype=torch.half)
|
123 |
+
latents_batch = pipe.vae.encode(pixels_batch).latent_dist.sample()
|
124 |
+
latents_batch = latents_batch.mul(pipe.vae.config.scaling_factor).cpu()
|
125 |
+
latents.append(latents_batch)
|
126 |
+
latents = torch.cat(latents)
|
127 |
+
|
128 |
+
latents = rearrange(latents, "(b f) c h w -> b c f h w", f=nf)
|
129 |
+
|
130 |
+
return latents
|
131 |
+
|
132 |
+
|
133 |
+
@torch.inference_mode()
|
134 |
+
def inference(
|
135 |
+
model: str,
|
136 |
+
prompt: str,
|
137 |
+
negative_prompt: Optional[str] = None,
|
138 |
+
width: int = 256,
|
139 |
+
height: int = 256,
|
140 |
+
num_frames: int = 24,
|
141 |
+
num_steps: int = 50,
|
142 |
+
guidance_scale: float = 15,
|
143 |
+
device: str = "cuda",
|
144 |
+
xformers: bool = False,
|
145 |
+
sdp: bool = False,
|
146 |
+
lora_path: str = "",
|
147 |
+
lora_rank: int = 64,
|
148 |
+
lora_scale: float = 1.0,
|
149 |
+
seed: Optional[int] = None,
|
150 |
+
latents_path: str="",
|
151 |
+
noise_prior: float = 0.,
|
152 |
+
repeat_num: int = 1,
|
153 |
+
):
|
154 |
+
if seed is not None:
|
155 |
+
random_seed = seed
|
156 |
+
torch.manual_seed(seed)
|
157 |
+
|
158 |
+
with torch.autocast(device, dtype=torch.half):
|
159 |
+
# prepare models
|
160 |
+
pipe = initialize_pipeline(model, device, xformers, sdp, lora_path, lora_rank, lora_scale)
|
161 |
+
|
162 |
+
for i in range(repeat_num):
|
163 |
+
if seed is None:
|
164 |
+
random_seed = random.randint(100, 10000000)
|
165 |
+
torch.manual_seed(random_seed)
|
166 |
+
|
167 |
+
# prepare input latents
|
168 |
+
init_latents = prepare_input_latents(
|
169 |
+
pipe=pipe,
|
170 |
+
batch_size=len(prompt),
|
171 |
+
num_frames=num_frames,
|
172 |
+
height=height,
|
173 |
+
width=width,
|
174 |
+
latents_path=latents_path,
|
175 |
+
noise_prior=noise_prior
|
176 |
+
)
|
177 |
+
|
178 |
+
with torch.no_grad():
|
179 |
+
video_frames = pipe(
|
180 |
+
prompt=prompt,
|
181 |
+
negative_prompt=negative_prompt,
|
182 |
+
width=width,
|
183 |
+
height=height,
|
184 |
+
num_frames=num_frames,
|
185 |
+
num_inference_steps=num_steps,
|
186 |
+
guidance_scale=guidance_scale,
|
187 |
+
latents=init_latents
|
188 |
+
).frames
|
189 |
+
|
190 |
+
# =========================================
|
191 |
+
# ========= write outputs to file =========
|
192 |
+
# =========================================
|
193 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
194 |
+
|
195 |
+
# save to mp4
|
196 |
+
export_to_video(video_frames, f"{out_name}_{random_seed}.mp4", args.fps)
|
197 |
+
|
198 |
+
# # save to gif
|
199 |
+
file_name = f"{out_name}_{random_seed}.gif"
|
200 |
+
imageio.mimsave(file_name, video_frames, 'GIF', duration=1000 * 1 / args.fps, loop=0)
|
201 |
+
|
202 |
+
return video_frames
|
203 |
+
|
204 |
+
|
205 |
+
if __name__ == "__main__":
|
206 |
+
import decord
|
207 |
+
|
208 |
+
decord.bridge.set_bridge("torch")
|
209 |
+
|
210 |
+
# fmt: off
|
211 |
+
parser = argparse.ArgumentParser()
|
212 |
+
parser.add_argument("-m", "--model", type=str, required=True,
|
213 |
+
help="HuggingFace repository or path to model checkpoint directory")
|
214 |
+
parser.add_argument("-p", "--prompt", type=str, required=True, help="Text prompt to condition on")
|
215 |
+
parser.add_argument("-n", "--negative-prompt", type=str, default=None, help="Text prompt to condition against")
|
216 |
+
parser.add_argument("-o", "--output_dir", type=str, default="./outputs/inference", help="Directory to save output video to")
|
217 |
+
parser.add_argument("-B", "--batch-size", type=int, default=1, help="Batch size for inference")
|
218 |
+
parser.add_argument("-W", "--width", type=int, default=384, help="Width of output video")
|
219 |
+
parser.add_argument("-H", "--height", type=int, default=384, help="Height of output video")
|
220 |
+
parser.add_argument("-T", "--num-frames", type=int, default=16, help="Total number of frames to generate")
|
221 |
+
parser.add_argument("-s", "--num-steps", type=int, default=30, help="Number of diffusion steps to run per frame.")
|
222 |
+
parser.add_argument("-g", "--guidance-scale", type=float, default=12, help="Scale for guidance loss (higher values = more guidance, but possibly more artifacts).")
|
223 |
+
parser.add_argument("-f", "--fps", type=int, default=8, help="FPS of output video")
|
224 |
+
parser.add_argument("-d", "--device", type=str, default="cuda", help="Device to run inference on (defaults to cuda).")
|
225 |
+
parser.add_argument("-x", "--xformers", action="store_true", help="Use XFormers attnetion, a memory-efficient attention implementation (requires `pip install xformers`).")
|
226 |
+
parser.add_argument("-S", "--sdp", action="store_true", help="Use SDP attention, PyTorch's built-in memory-efficient attention implementation.")
|
227 |
+
parser.add_argument("-cf", "--checkpoint_folder", type=str, default=None, help="Path to Low Rank Adaptation checkpoint file (defaults to empty string, which uses no LoRA).")
|
228 |
+
parser.add_argument("-lr", "--lora_rank", type=int, default=32, help="Size of the LoRA checkpoint's projection matrix (defaults to 32).")
|
229 |
+
parser.add_argument("-ls", "--lora_scale", type=float, default=1.0, help="Scale of LoRAs.")
|
230 |
+
parser.add_argument("-r", "--seed", type=int, default=None, help="Random seed to make generations reproducible.")
|
231 |
+
parser.add_argument("-np", "--noise_prior", type=float, default=0., help="Scale of the influence of inversion noise.")
|
232 |
+
parser.add_argument("-ci", "--checkpoint_index", type=int, required=True,
|
233 |
+
help="The index of checkpoint, such as 300.")
|
234 |
+
parser.add_argument("-rn", "--repeat_num", type=int, default=1,
|
235 |
+
help="How many results to generate with the same prompt.")
|
236 |
+
|
237 |
+
args = parser.parse_args()
|
238 |
+
# fmt: on
|
239 |
+
|
240 |
+
# =========================================
|
241 |
+
# ====== validate and prepare inputs ======
|
242 |
+
# =========================================
|
243 |
+
|
244 |
+
out_name = f"{args.output_dir}/"
|
245 |
+
prompt = re.sub(r'[<>:"/\\|?*\x00-\x1F]', "_", args.prompt) if platform.system() == "Windows" else args.prompt
|
246 |
+
out_name += f"{prompt}".replace(' ','_').replace(',', '').replace('.', '')
|
247 |
+
|
248 |
+
args.prompt = [prompt] * args.batch_size
|
249 |
+
if args.negative_prompt is not None:
|
250 |
+
args.negative_prompt = [args.negative_prompt] * args.batch_size
|
251 |
+
|
252 |
+
# =========================================
|
253 |
+
# ============= sample videos =============
|
254 |
+
# =========================================
|
255 |
+
if args.checkpoint_index is not None:
|
256 |
+
lora_path = f"{args.checkpoint_folder}/checkpoint-{args.checkpoint_index}/temporal/lora"
|
257 |
+
else:
|
258 |
+
lora_path = f"{args.checkpoint_folder}/checkpoint-default/temporal/lora"
|
259 |
+
latents_folder = f"{args.checkpoint_folder}/cached_latents"
|
260 |
+
latents_path = f"{latents_folder}/{random.choice(os.listdir(latents_folder))}"
|
261 |
+
assert os.path.exists(lora_path)
|
262 |
+
video_frames = inference(
|
263 |
+
model=args.model,
|
264 |
+
prompt=args.prompt,
|
265 |
+
negative_prompt=args.negative_prompt,
|
266 |
+
width=args.width,
|
267 |
+
height=args.height,
|
268 |
+
num_frames=args.num_frames,
|
269 |
+
num_steps=args.num_steps,
|
270 |
+
guidance_scale=args.guidance_scale,
|
271 |
+
device=args.device,
|
272 |
+
xformers=args.xformers,
|
273 |
+
sdp=args.sdp,
|
274 |
+
lora_path=lora_path,
|
275 |
+
lora_rank=args.lora_rank,
|
276 |
+
lora_scale = args.lora_scale,
|
277 |
+
seed=args.seed,
|
278 |
+
latents_path=latents_path,
|
279 |
+
noise_prior=args.noise_prior,
|
280 |
+
repeat_num=args.repeat_num
|
281 |
+
)
|
282 |
+
|
283 |
+
|
284 |
+
|
MotionDirector_inference_batch.py
ADDED
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import platform
|
4 |
+
import re
|
5 |
+
import warnings
|
6 |
+
from typing import Optional
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from diffusers import DDIMScheduler, TextToVideoSDPipeline
|
10 |
+
from einops import rearrange
|
11 |
+
from torch import Tensor
|
12 |
+
from torch.nn.functional import interpolate
|
13 |
+
from tqdm import trange
|
14 |
+
import random
|
15 |
+
|
16 |
+
from MotionDirector_train import export_to_video, handle_memory_attention, load_primary_models, unet_and_text_g_c, freeze_models
|
17 |
+
from utils.lora_handler import LoraHandler
|
18 |
+
from utils.ddim_utils import ddim_inversion
|
19 |
+
import imageio
|
20 |
+
|
21 |
+
|
22 |
+
def initialize_pipeline(
|
23 |
+
model: str,
|
24 |
+
device: str = "cuda",
|
25 |
+
xformers: bool = False,
|
26 |
+
sdp: bool = False,
|
27 |
+
lora_path: str = "",
|
28 |
+
lora_rank: int = 64,
|
29 |
+
lora_scale: float = 1.0,
|
30 |
+
):
|
31 |
+
with warnings.catch_warnings():
|
32 |
+
warnings.simplefilter("ignore")
|
33 |
+
|
34 |
+
scheduler, tokenizer, text_encoder, vae, unet = load_primary_models(model)
|
35 |
+
|
36 |
+
# Freeze any necessary models
|
37 |
+
freeze_models([vae, text_encoder, unet])
|
38 |
+
|
39 |
+
# Enable xformers if available
|
40 |
+
handle_memory_attention(xformers, sdp, unet)
|
41 |
+
|
42 |
+
lora_manager_temporal = LoraHandler(
|
43 |
+
version="cloneofsimo",
|
44 |
+
use_unet_lora=True,
|
45 |
+
use_text_lora=False,
|
46 |
+
save_for_webui=False,
|
47 |
+
only_for_webui=False,
|
48 |
+
unet_replace_modules=["TransformerTemporalModel"],
|
49 |
+
text_encoder_replace_modules=None,
|
50 |
+
lora_bias=None
|
51 |
+
)
|
52 |
+
|
53 |
+
unet_lora_params, unet_negation = lora_manager_temporal.add_lora_to_model(
|
54 |
+
True, unet, lora_manager_temporal.unet_replace_modules, 0, lora_path, r=lora_rank, scale=lora_scale)
|
55 |
+
|
56 |
+
unet.eval()
|
57 |
+
text_encoder.eval()
|
58 |
+
unet_and_text_g_c(unet, text_encoder, False, False)
|
59 |
+
|
60 |
+
pipe = TextToVideoSDPipeline.from_pretrained(
|
61 |
+
pretrained_model_name_or_path=model,
|
62 |
+
scheduler=scheduler,
|
63 |
+
tokenizer=tokenizer,
|
64 |
+
text_encoder=text_encoder.to(device=device, dtype=torch.half),
|
65 |
+
vae=vae.to(device=device, dtype=torch.half),
|
66 |
+
unet=unet.to(device=device, dtype=torch.half),
|
67 |
+
)
|
68 |
+
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
69 |
+
|
70 |
+
return pipe
|
71 |
+
|
72 |
+
|
73 |
+
def inverse_video(pipe, latents, num_steps):
|
74 |
+
ddim_inv_scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
75 |
+
ddim_inv_scheduler.set_timesteps(num_steps)
|
76 |
+
|
77 |
+
ddim_inv_latent = ddim_inversion(
|
78 |
+
pipe, ddim_inv_scheduler, video_latent=latents.to(pipe.device),
|
79 |
+
num_inv_steps=num_steps, prompt="")[-1]
|
80 |
+
return ddim_inv_latent
|
81 |
+
|
82 |
+
|
83 |
+
def prepare_input_latents(
|
84 |
+
pipe: TextToVideoSDPipeline,
|
85 |
+
batch_size: int,
|
86 |
+
num_frames: int,
|
87 |
+
height: int,
|
88 |
+
width: int,
|
89 |
+
latents_path:str,
|
90 |
+
noise_prior: float
|
91 |
+
):
|
92 |
+
# initialize with random gaussian noise
|
93 |
+
scale = pipe.vae_scale_factor
|
94 |
+
shape = (batch_size, pipe.unet.config.in_channels, num_frames, height // scale, width // scale)
|
95 |
+
if noise_prior > 0.:
|
96 |
+
cached_latents = torch.load(latents_path)
|
97 |
+
if 'inversion_noise' not in cached_latents:
|
98 |
+
latents = inverse_video(pipe, cached_latents['latents'].unsqueeze(0), 50).squeeze(0)
|
99 |
+
else:
|
100 |
+
latents = torch.load(latents_path)['inversion_noise'].unsqueeze(0)
|
101 |
+
if latents.shape[0] != batch_size:
|
102 |
+
latents = latents.repeat(batch_size, 1, 1, 1, 1)
|
103 |
+
if latents.shape != shape:
|
104 |
+
latents = interpolate(rearrange(latents, "b c f h w -> (b f) c h w", b=batch_size), (height // scale, width // scale), mode='bilinear')
|
105 |
+
latents = rearrange(latents, "(b f) c h w -> b c f h w", b=batch_size)
|
106 |
+
noise = torch.randn_like(latents, dtype=torch.half)
|
107 |
+
latents = (noise_prior) ** 0.5 * latents + (1 - noise_prior) ** 0.5 * noise
|
108 |
+
else:
|
109 |
+
latents = torch.randn(shape, dtype=torch.half)
|
110 |
+
|
111 |
+
return latents
|
112 |
+
|
113 |
+
|
114 |
+
def encode(pipe: TextToVideoSDPipeline, pixels: Tensor, batch_size: int = 8):
|
115 |
+
nf = pixels.shape[2]
|
116 |
+
pixels = rearrange(pixels, "b c f h w -> (b f) c h w")
|
117 |
+
|
118 |
+
latents = []
|
119 |
+
for idx in trange(
|
120 |
+
0, pixels.shape[0], batch_size, desc="Encoding to latents...", unit_scale=batch_size, unit="frame"
|
121 |
+
):
|
122 |
+
pixels_batch = pixels[idx : idx + batch_size].to(pipe.device, dtype=torch.half)
|
123 |
+
latents_batch = pipe.vae.encode(pixels_batch).latent_dist.sample()
|
124 |
+
latents_batch = latents_batch.mul(pipe.vae.config.scaling_factor).cpu()
|
125 |
+
latents.append(latents_batch)
|
126 |
+
latents = torch.cat(latents)
|
127 |
+
|
128 |
+
latents = rearrange(latents, "(b f) c h w -> b c f h w", f=nf)
|
129 |
+
|
130 |
+
return latents
|
131 |
+
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
+
@torch.inference_mode()
|
136 |
+
def inference(
|
137 |
+
model: str,
|
138 |
+
prompt: str,
|
139 |
+
negative_prompt: Optional[str] = None,
|
140 |
+
width: int = 256,
|
141 |
+
height: int = 256,
|
142 |
+
num_frames: int = 24,
|
143 |
+
num_steps: int = 50,
|
144 |
+
guidance_scale: float = 15,
|
145 |
+
device: str = "cuda",
|
146 |
+
xformers: bool = False,
|
147 |
+
sdp: bool = False,
|
148 |
+
lora_path: str = "",
|
149 |
+
lora_rank: int = 64,
|
150 |
+
lora_scale: float = 1.0,
|
151 |
+
seed: Optional[int] = None,
|
152 |
+
latents_path: str="",
|
153 |
+
noise_prior: float = 0.,
|
154 |
+
repeat_num: int = 1,
|
155 |
+
):
|
156 |
+
|
157 |
+
with torch.autocast(device, dtype=torch.half):
|
158 |
+
# prepare models
|
159 |
+
pipe = initialize_pipeline(model, device, xformers, sdp, lora_path, lora_rank, lora_scale)
|
160 |
+
|
161 |
+
for i in range(repeat_num):
|
162 |
+
if seed is not None:
|
163 |
+
random_seed = seed
|
164 |
+
torch.manual_seed(seed)
|
165 |
+
else:
|
166 |
+
random_seed = random.randint(100, 10000000)
|
167 |
+
torch.manual_seed(random_seed)
|
168 |
+
|
169 |
+
# prepare input latents
|
170 |
+
init_latents = prepare_input_latents(
|
171 |
+
pipe=pipe,
|
172 |
+
batch_size=len(prompt),
|
173 |
+
num_frames=num_frames,
|
174 |
+
height=height,
|
175 |
+
width=width,
|
176 |
+
latents_path=latents_path,
|
177 |
+
noise_prior=noise_prior
|
178 |
+
)
|
179 |
+
|
180 |
+
video_frames = pipe(
|
181 |
+
prompt=prompt,
|
182 |
+
negative_prompt=negative_prompt,
|
183 |
+
width=width,
|
184 |
+
height=height,
|
185 |
+
num_frames=num_frames,
|
186 |
+
num_inference_steps=num_steps,
|
187 |
+
guidance_scale=guidance_scale,
|
188 |
+
latents=init_latents
|
189 |
+
).frames
|
190 |
+
# =========================================
|
191 |
+
# ========= write outputs to file =========
|
192 |
+
# =========================================
|
193 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
194 |
+
|
195 |
+
# save to mp4
|
196 |
+
export_to_video(video_frames, f"{out_name}_{random_seed}.mp4", args.fps)
|
197 |
+
|
198 |
+
# # save to gif
|
199 |
+
file_name = f"{out_name}_{random_seed}.gif"
|
200 |
+
imageio.mimsave(file_name, video_frames, 'GIF', duration=1000 * 1 / args.fps, loop=0)
|
201 |
+
|
202 |
+
return video_frames
|
203 |
+
|
204 |
+
|
205 |
+
if __name__ == "__main__":
|
206 |
+
import decord
|
207 |
+
|
208 |
+
decord.bridge.set_bridge("torch")
|
209 |
+
|
210 |
+
# fmt: off
|
211 |
+
parser = argparse.ArgumentParser()
|
212 |
+
parser.add_argument("-m", "--model", type=str, default='/Users/rui/data/models/zeroscope_v2_576w/',
|
213 |
+
help="HuggingFace repository or path to model checkpoint directory")
|
214 |
+
parser.add_argument("-p", "--prompt", type=str, default=None, help="Text prompt to condition on")
|
215 |
+
parser.add_argument("-n", "--negative-prompt", type=str, default=None, help="Text prompt to condition against")
|
216 |
+
parser.add_argument("-o", "--output_dir", type=str, default="./outputs/inference", help="Directory to save output video to")
|
217 |
+
parser.add_argument("-B", "--batch-size", type=int, default=1, help="Batch size for inference")
|
218 |
+
parser.add_argument("-W", "--width", type=int, default=384, help="Width of output video")
|
219 |
+
parser.add_argument("-H", "--height", type=int, default=384, help="Height of output video")
|
220 |
+
parser.add_argument("-T", "--num-frames", type=int, default=16, help="Total number of frames to generate")
|
221 |
+
parser.add_argument("-s", "--num-steps", type=int, default=30, help="Number of diffusion steps to run per frame.")
|
222 |
+
parser.add_argument("-g", "--guidance-scale", type=float, default=12, help="Scale for guidance loss (higher values = more guidance, but possibly more artifacts).")
|
223 |
+
parser.add_argument("-f", "--fps", type=int, default=8, help="FPS of output video")
|
224 |
+
parser.add_argument("-d", "--device", type=str, default="cuda", help="Device to run inference on (defaults to cuda).")
|
225 |
+
parser.add_argument("-x", "--xformers", action="store_true", help="Use XFormers attnetion, a memory-efficient attention implementation (requires `pip install xformers`).")
|
226 |
+
parser.add_argument("-S", "--sdp", action="store_true", help="Use SDP attention, PyTorch's built-in memory-efficient attention implementation.")
|
227 |
+
parser.add_argument("-cf", "--checkpoint_folder", type=str, default=None, help="Path to Low Rank Adaptation checkpoint file (defaults to empty string, which uses no LoRA).")
|
228 |
+
parser.add_argument("-lr", "--lora_rank", type=int, default=32, help="Size of the LoRA checkpoint's projection matrix (defaults to 32).")
|
229 |
+
parser.add_argument("-ls", "--lora_scale", type=float, default=1.0, help="Scale of LoRAs.")
|
230 |
+
parser.add_argument("-r", "--seed", type=int, default=None, help="Random seed to make generations reproducible.")
|
231 |
+
parser.add_argument("-np", "--noise_prior", type=float, default=0., help="Random seed to make generations reproducible.")
|
232 |
+
parser.add_argument("-ci", "--checkpoint_index", type=int, default=None,
|
233 |
+
help="Random seed to make generations reproducible.")
|
234 |
+
parser.add_argument("-rn", "--repeat_num", type=int, default=None,
|
235 |
+
help="Random seed to make generations reproducible.")
|
236 |
+
|
237 |
+
args = parser.parse_args()
|
238 |
+
# fmt: on
|
239 |
+
|
240 |
+
# =========================================
|
241 |
+
# ====== validate and prepare inputs ======
|
242 |
+
# =========================================
|
243 |
+
|
244 |
+
# args.prompt = ["A firefighter standing in front of a burning forest captured with a dolly zoom.",
|
245 |
+
# "A spaceman standing on the moon with earth behind him captured with a dolly zoom."]
|
246 |
+
args.prompt = "A person is riding a bicycle past the Eiffel Tower."
|
247 |
+
args.checkpoint_folder = './outputs/train/train_2023-12-02T11-45-22/'
|
248 |
+
args.checkpoint_index = 500
|
249 |
+
args.noise_prior = 0.
|
250 |
+
args.repeat_num = 10
|
251 |
+
|
252 |
+
out_name = f"{args.output_dir}/"
|
253 |
+
prompt = re.sub(r'[<>:"/\\|?*\x00-\x1F]', "_", args.prompt) if platform.system() == "Windows" else args.prompt
|
254 |
+
out_name += f"{prompt}".replace(' ','_').replace(',', '').replace('.', '')
|
255 |
+
|
256 |
+
args.prompt = [prompt] * args.batch_size
|
257 |
+
if args.negative_prompt is not None:
|
258 |
+
args.negative_prompt = [args.negative_prompt] * args.batch_size
|
259 |
+
|
260 |
+
# =========================================
|
261 |
+
# ============= sample videos =============
|
262 |
+
# =========================================
|
263 |
+
|
264 |
+
lora_path = f"{args.checkpoint_folder}/checkpoint-{args.checkpoint_index}/temporal/lora"
|
265 |
+
latents_folder = f"{args.checkpoint_folder}/cached_latents"
|
266 |
+
latents_path = f"{latents_folder}/{random.choice(os.listdir(latents_folder))}"
|
267 |
+
# if args.seed is None:
|
268 |
+
# args.seed = random.randint(100, 10000000)
|
269 |
+
assert os.path.exists(lora_path)
|
270 |
+
video_frames = inference(
|
271 |
+
model=args.model,
|
272 |
+
prompt=args.prompt,
|
273 |
+
negative_prompt=args.negative_prompt,
|
274 |
+
width=args.width,
|
275 |
+
height=args.height,
|
276 |
+
num_frames=args.num_frames,
|
277 |
+
num_steps=args.num_steps,
|
278 |
+
guidance_scale=args.guidance_scale,
|
279 |
+
device=args.device,
|
280 |
+
xformers=args.xformers,
|
281 |
+
sdp=args.sdp,
|
282 |
+
lora_path=lora_path,
|
283 |
+
lora_rank=args.lora_rank,
|
284 |
+
lora_scale = args.lora_scale,
|
285 |
+
seed=args.seed,
|
286 |
+
latents_path=latents_path,
|
287 |
+
noise_prior=args.noise_prior,
|
288 |
+
repeat_num=args.repeat_num
|
289 |
+
)
|
290 |
+
|
MotionDirector_train.py
ADDED
@@ -0,0 +1,1021 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import datetime
|
3 |
+
import logging
|
4 |
+
import inspect
|
5 |
+
import math
|
6 |
+
import os
|
7 |
+
import random
|
8 |
+
import gc
|
9 |
+
import copy
|
10 |
+
|
11 |
+
from typing import Dict, Optional, Tuple
|
12 |
+
from omegaconf import OmegaConf
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn.functional as F
|
16 |
+
import torch.utils.checkpoint
|
17 |
+
import diffusers
|
18 |
+
import transformers
|
19 |
+
|
20 |
+
from torchvision import transforms
|
21 |
+
from tqdm.auto import tqdm
|
22 |
+
|
23 |
+
from accelerate import Accelerator
|
24 |
+
from accelerate.logging import get_logger
|
25 |
+
from accelerate.utils import set_seed
|
26 |
+
|
27 |
+
from models.unet_3d_condition import UNet3DConditionModel
|
28 |
+
from diffusers.models import AutoencoderKL
|
29 |
+
from diffusers import DDIMScheduler, TextToVideoSDPipeline
|
30 |
+
from diffusers.optimization import get_scheduler
|
31 |
+
from diffusers.utils.import_utils import is_xformers_available
|
32 |
+
from diffusers.models.attention_processor import AttnProcessor2_0, Attention
|
33 |
+
from diffusers.models.attention import BasicTransformerBlock
|
34 |
+
|
35 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
36 |
+
from transformers.models.clip.modeling_clip import CLIPEncoder
|
37 |
+
from utils.dataset import VideoJsonDataset, SingleVideoDataset, \
|
38 |
+
ImageDataset, VideoFolderDataset, CachedDataset
|
39 |
+
from einops import rearrange, repeat
|
40 |
+
from utils.lora_handler import LoraHandler
|
41 |
+
from utils.lora import extract_lora_child_module
|
42 |
+
from utils.ddim_utils import ddim_inversion
|
43 |
+
import imageio
|
44 |
+
import numpy as np
|
45 |
+
|
46 |
+
|
47 |
+
already_printed_trainables = False
|
48 |
+
|
49 |
+
logger = get_logger(__name__, log_level="INFO")
|
50 |
+
|
51 |
+
|
52 |
+
def create_logging(logging, logger, accelerator):
|
53 |
+
logging.basicConfig(
|
54 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
55 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
56 |
+
level=logging.INFO,
|
57 |
+
)
|
58 |
+
logger.info(accelerator.state, main_process_only=False)
|
59 |
+
|
60 |
+
|
61 |
+
def accelerate_set_verbose(accelerator):
|
62 |
+
if accelerator.is_local_main_process:
|
63 |
+
transformers.utils.logging.set_verbosity_warning()
|
64 |
+
diffusers.utils.logging.set_verbosity_info()
|
65 |
+
else:
|
66 |
+
transformers.utils.logging.set_verbosity_error()
|
67 |
+
diffusers.utils.logging.set_verbosity_error()
|
68 |
+
|
69 |
+
|
70 |
+
def get_train_dataset(dataset_types, train_data, tokenizer):
|
71 |
+
train_datasets = []
|
72 |
+
|
73 |
+
# Loop through all available datasets, get the name, then add to list of data to process.
|
74 |
+
for DataSet in [VideoJsonDataset, SingleVideoDataset, ImageDataset, VideoFolderDataset]:
|
75 |
+
for dataset in dataset_types:
|
76 |
+
if dataset == DataSet.__getname__():
|
77 |
+
train_datasets.append(DataSet(**train_data, tokenizer=tokenizer))
|
78 |
+
|
79 |
+
if len(train_datasets) > 0:
|
80 |
+
return train_datasets
|
81 |
+
else:
|
82 |
+
raise ValueError("Dataset type not found: 'json', 'single_video', 'folder', 'image'")
|
83 |
+
|
84 |
+
|
85 |
+
def extend_datasets(datasets, dataset_items, extend=False):
|
86 |
+
biggest_data_len = max(x.__len__() for x in datasets)
|
87 |
+
extended = []
|
88 |
+
for dataset in datasets:
|
89 |
+
if dataset.__len__() == 0:
|
90 |
+
del dataset
|
91 |
+
continue
|
92 |
+
if dataset.__len__() < biggest_data_len:
|
93 |
+
for item in dataset_items:
|
94 |
+
if extend and item not in extended and hasattr(dataset, item):
|
95 |
+
print(f"Extending {item}")
|
96 |
+
|
97 |
+
value = getattr(dataset, item)
|
98 |
+
value *= biggest_data_len
|
99 |
+
value = value[:biggest_data_len]
|
100 |
+
|
101 |
+
setattr(dataset, item, value)
|
102 |
+
|
103 |
+
print(f"New {item} dataset length: {dataset.__len__()}")
|
104 |
+
extended.append(item)
|
105 |
+
|
106 |
+
|
107 |
+
def export_to_video(video_frames, output_video_path, fps):
|
108 |
+
video_writer = imageio.get_writer(output_video_path, fps=fps)
|
109 |
+
for img in video_frames:
|
110 |
+
video_writer.append_data(np.array(img))
|
111 |
+
video_writer.close()
|
112 |
+
|
113 |
+
|
114 |
+
def create_output_folders(output_dir, config):
|
115 |
+
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
116 |
+
out_dir = os.path.join(output_dir, f"train_{now}")
|
117 |
+
|
118 |
+
os.makedirs(out_dir, exist_ok=True)
|
119 |
+
os.makedirs(f"{out_dir}/samples", exist_ok=True)
|
120 |
+
# OmegaConf.save(config, os.path.join(out_dir, 'config.yaml'))
|
121 |
+
|
122 |
+
return out_dir
|
123 |
+
|
124 |
+
|
125 |
+
def load_primary_models(pretrained_model_path):
|
126 |
+
noise_scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
|
127 |
+
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
|
128 |
+
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
|
129 |
+
vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
|
130 |
+
unet = UNet3DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet")
|
131 |
+
|
132 |
+
return noise_scheduler, tokenizer, text_encoder, vae, unet
|
133 |
+
|
134 |
+
|
135 |
+
def unet_and_text_g_c(unet, text_encoder, unet_enable, text_enable):
|
136 |
+
unet._set_gradient_checkpointing(value=unet_enable)
|
137 |
+
text_encoder._set_gradient_checkpointing(CLIPEncoder, value=text_enable)
|
138 |
+
|
139 |
+
|
140 |
+
def freeze_models(models_to_freeze):
|
141 |
+
for model in models_to_freeze:
|
142 |
+
if model is not None: model.requires_grad_(False)
|
143 |
+
|
144 |
+
|
145 |
+
def is_attn(name):
|
146 |
+
return ('attn1' or 'attn2' == name.split('.')[-1])
|
147 |
+
|
148 |
+
|
149 |
+
def set_processors(attentions):
|
150 |
+
for attn in attentions: attn.set_processor(AttnProcessor2_0())
|
151 |
+
|
152 |
+
|
153 |
+
def set_torch_2_attn(unet):
|
154 |
+
optim_count = 0
|
155 |
+
|
156 |
+
for name, module in unet.named_modules():
|
157 |
+
if is_attn(name):
|
158 |
+
if isinstance(module, torch.nn.ModuleList):
|
159 |
+
for m in module:
|
160 |
+
if isinstance(m, BasicTransformerBlock):
|
161 |
+
set_processors([m.attn1, m.attn2])
|
162 |
+
optim_count += 1
|
163 |
+
if optim_count > 0:
|
164 |
+
print(f"{optim_count} Attention layers using Scaled Dot Product Attention.")
|
165 |
+
|
166 |
+
|
167 |
+
def handle_memory_attention(enable_xformers_memory_efficient_attention, enable_torch_2_attn, unet):
|
168 |
+
try:
|
169 |
+
is_torch_2 = hasattr(F, 'scaled_dot_product_attention')
|
170 |
+
enable_torch_2 = is_torch_2 and enable_torch_2_attn
|
171 |
+
|
172 |
+
if enable_xformers_memory_efficient_attention and not enable_torch_2:
|
173 |
+
if is_xformers_available():
|
174 |
+
from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
|
175 |
+
unet.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
|
176 |
+
else:
|
177 |
+
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
178 |
+
|
179 |
+
if enable_torch_2:
|
180 |
+
set_torch_2_attn(unet)
|
181 |
+
|
182 |
+
except:
|
183 |
+
print("Could not enable memory efficient attention for xformers or Torch 2.0.")
|
184 |
+
|
185 |
+
|
186 |
+
def param_optim(model, condition, extra_params=None, is_lora=False, negation=None):
|
187 |
+
extra_params = extra_params if len(extra_params.keys()) > 0 else None
|
188 |
+
return {
|
189 |
+
"model": model,
|
190 |
+
"condition": condition,
|
191 |
+
'extra_params': extra_params,
|
192 |
+
'is_lora': is_lora,
|
193 |
+
"negation": negation
|
194 |
+
}
|
195 |
+
|
196 |
+
|
197 |
+
def create_optim_params(name='param', params=None, lr=5e-6, extra_params=None):
|
198 |
+
params = {
|
199 |
+
"name": name,
|
200 |
+
"params": params,
|
201 |
+
"lr": lr
|
202 |
+
}
|
203 |
+
if extra_params is not None:
|
204 |
+
for k, v in extra_params.items():
|
205 |
+
params[k] = v
|
206 |
+
|
207 |
+
return params
|
208 |
+
|
209 |
+
|
210 |
+
def negate_params(name, negation):
|
211 |
+
# We have to do this if we are co-training with LoRA.
|
212 |
+
# This ensures that parameter groups aren't duplicated.
|
213 |
+
if negation is None: return False
|
214 |
+
for n in negation:
|
215 |
+
if n in name and 'temp' not in name:
|
216 |
+
return True
|
217 |
+
return False
|
218 |
+
|
219 |
+
|
220 |
+
def create_optimizer_params(model_list, lr):
|
221 |
+
import itertools
|
222 |
+
optimizer_params = []
|
223 |
+
|
224 |
+
for optim in model_list:
|
225 |
+
model, condition, extra_params, is_lora, negation = optim.values()
|
226 |
+
# Check if we are doing LoRA training.
|
227 |
+
if is_lora and condition and isinstance(model, list):
|
228 |
+
params = create_optim_params(
|
229 |
+
params=itertools.chain(*model),
|
230 |
+
extra_params=extra_params
|
231 |
+
)
|
232 |
+
optimizer_params.append(params)
|
233 |
+
continue
|
234 |
+
|
235 |
+
if is_lora and condition and not isinstance(model, list):
|
236 |
+
for n, p in model.named_parameters():
|
237 |
+
if 'lora' in n:
|
238 |
+
params = create_optim_params(n, p, lr, extra_params)
|
239 |
+
optimizer_params.append(params)
|
240 |
+
continue
|
241 |
+
|
242 |
+
# If this is true, we can train it.
|
243 |
+
if condition:
|
244 |
+
for n, p in model.named_parameters():
|
245 |
+
should_negate = 'lora' in n and not is_lora
|
246 |
+
if should_negate: continue
|
247 |
+
|
248 |
+
params = create_optim_params(n, p, lr, extra_params)
|
249 |
+
optimizer_params.append(params)
|
250 |
+
|
251 |
+
return optimizer_params
|
252 |
+
|
253 |
+
|
254 |
+
def get_optimizer(use_8bit_adam):
|
255 |
+
if use_8bit_adam:
|
256 |
+
try:
|
257 |
+
import bitsandbytes as bnb
|
258 |
+
except ImportError:
|
259 |
+
raise ImportError(
|
260 |
+
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
|
261 |
+
)
|
262 |
+
|
263 |
+
return bnb.optim.AdamW8bit
|
264 |
+
else:
|
265 |
+
return torch.optim.AdamW
|
266 |
+
|
267 |
+
|
268 |
+
def is_mixed_precision(accelerator):
|
269 |
+
weight_dtype = torch.float32
|
270 |
+
|
271 |
+
if accelerator.mixed_precision == "fp16":
|
272 |
+
weight_dtype = torch.float16
|
273 |
+
|
274 |
+
elif accelerator.mixed_precision == "bf16":
|
275 |
+
weight_dtype = torch.bfloat16
|
276 |
+
|
277 |
+
return weight_dtype
|
278 |
+
|
279 |
+
|
280 |
+
def cast_to_gpu_and_type(model_list, accelerator, weight_dtype):
|
281 |
+
for model in model_list:
|
282 |
+
if model is not None: model.to(accelerator.device, dtype=weight_dtype)
|
283 |
+
|
284 |
+
|
285 |
+
def inverse_video(pipe, latents, num_steps):
|
286 |
+
ddim_inv_scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
287 |
+
ddim_inv_scheduler.set_timesteps(num_steps)
|
288 |
+
|
289 |
+
ddim_inv_latent = ddim_inversion(
|
290 |
+
pipe, ddim_inv_scheduler, video_latent=latents.to(pipe.device),
|
291 |
+
num_inv_steps=num_steps, prompt="")[-1]
|
292 |
+
return ddim_inv_latent
|
293 |
+
|
294 |
+
|
295 |
+
def handle_cache_latents(
|
296 |
+
should_cache,
|
297 |
+
output_dir,
|
298 |
+
train_dataloader,
|
299 |
+
train_batch_size,
|
300 |
+
vae,
|
301 |
+
unet,
|
302 |
+
pretrained_model_path,
|
303 |
+
noise_prior,
|
304 |
+
cached_latent_dir=None,
|
305 |
+
):
|
306 |
+
# Cache latents by storing them in VRAM.
|
307 |
+
# Speeds up training and saves memory by not encoding during the train loop.
|
308 |
+
if not should_cache: return None
|
309 |
+
vae.to('cuda', dtype=torch.float16)
|
310 |
+
vae.enable_slicing()
|
311 |
+
|
312 |
+
pipe = TextToVideoSDPipeline.from_pretrained(
|
313 |
+
pretrained_model_path,
|
314 |
+
vae=vae,
|
315 |
+
unet=copy.deepcopy(unet).to('cuda', dtype=torch.float16)
|
316 |
+
)
|
317 |
+
pipe.text_encoder.to('cuda', dtype=torch.float16)
|
318 |
+
|
319 |
+
cached_latent_dir = (
|
320 |
+
os.path.abspath(cached_latent_dir) if cached_latent_dir is not None else None
|
321 |
+
)
|
322 |
+
|
323 |
+
if cached_latent_dir is None:
|
324 |
+
cache_save_dir = f"{output_dir}/cached_latents"
|
325 |
+
os.makedirs(cache_save_dir, exist_ok=True)
|
326 |
+
|
327 |
+
for i, batch in enumerate(tqdm(train_dataloader, desc="Caching Latents.")):
|
328 |
+
|
329 |
+
save_name = f"cached_{i}"
|
330 |
+
full_out_path = f"{cache_save_dir}/{save_name}.pt"
|
331 |
+
|
332 |
+
pixel_values = batch['pixel_values'].to('cuda', dtype=torch.float16)
|
333 |
+
batch['latents'] = tensor_to_vae_latent(pixel_values, vae)
|
334 |
+
if noise_prior > 0.:
|
335 |
+
batch['inversion_noise'] = inverse_video(pipe, batch['latents'], 50)
|
336 |
+
for k, v in batch.items(): batch[k] = v[0]
|
337 |
+
|
338 |
+
torch.save(batch, full_out_path)
|
339 |
+
del pixel_values
|
340 |
+
del batch
|
341 |
+
|
342 |
+
# We do this to avoid fragmentation from casting latents between devices.
|
343 |
+
torch.cuda.empty_cache()
|
344 |
+
else:
|
345 |
+
cache_save_dir = cached_latent_dir
|
346 |
+
|
347 |
+
return torch.utils.data.DataLoader(
|
348 |
+
CachedDataset(cache_dir=cache_save_dir),
|
349 |
+
batch_size=train_batch_size,
|
350 |
+
shuffle=True,
|
351 |
+
num_workers=0
|
352 |
+
)
|
353 |
+
|
354 |
+
|
355 |
+
def handle_trainable_modules(model, trainable_modules=None, is_enabled=True, negation=None):
|
356 |
+
global already_printed_trainables
|
357 |
+
|
358 |
+
# This can most definitely be refactored :-)
|
359 |
+
unfrozen_params = 0
|
360 |
+
if trainable_modules is not None:
|
361 |
+
for name, module in model.named_modules():
|
362 |
+
for tm in tuple(trainable_modules):
|
363 |
+
if tm == 'all':
|
364 |
+
model.requires_grad_(is_enabled)
|
365 |
+
unfrozen_params = len(list(model.parameters()))
|
366 |
+
break
|
367 |
+
|
368 |
+
if tm in name and 'lora' not in name:
|
369 |
+
for m in module.parameters():
|
370 |
+
m.requires_grad_(is_enabled)
|
371 |
+
if is_enabled: unfrozen_params += 1
|
372 |
+
|
373 |
+
if unfrozen_params > 0 and not already_printed_trainables:
|
374 |
+
already_printed_trainables = True
|
375 |
+
print(f"{unfrozen_params} params have been unfrozen for training.")
|
376 |
+
|
377 |
+
|
378 |
+
def tensor_to_vae_latent(t, vae):
|
379 |
+
video_length = t.shape[1]
|
380 |
+
|
381 |
+
t = rearrange(t, "b f c h w -> (b f) c h w")
|
382 |
+
latents = vae.encode(t).latent_dist.sample()
|
383 |
+
latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length)
|
384 |
+
latents = latents * 0.18215
|
385 |
+
|
386 |
+
return latents
|
387 |
+
|
388 |
+
|
389 |
+
def sample_noise(latents, noise_strength, use_offset_noise=False):
|
390 |
+
b, c, f, *_ = latents.shape
|
391 |
+
noise_latents = torch.randn_like(latents, device=latents.device)
|
392 |
+
|
393 |
+
if use_offset_noise:
|
394 |
+
offset_noise = torch.randn(b, c, f, 1, 1, device=latents.device)
|
395 |
+
noise_latents = noise_latents + noise_strength * offset_noise
|
396 |
+
|
397 |
+
return noise_latents
|
398 |
+
|
399 |
+
|
400 |
+
def enforce_zero_terminal_snr(betas):
|
401 |
+
"""
|
402 |
+
Corrects noise in diffusion schedulers.
|
403 |
+
From: Common Diffusion Noise Schedules and Sample Steps are Flawed
|
404 |
+
https://arxiv.org/pdf/2305.08891.pdf
|
405 |
+
"""
|
406 |
+
# Convert betas to alphas_bar_sqrt
|
407 |
+
alphas = 1 - betas
|
408 |
+
alphas_bar = alphas.cumprod(0)
|
409 |
+
alphas_bar_sqrt = alphas_bar.sqrt()
|
410 |
+
|
411 |
+
# Store old values.
|
412 |
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
413 |
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
414 |
+
|
415 |
+
# Shift so the last timestep is zero.
|
416 |
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
417 |
+
|
418 |
+
# Scale so the first timestep is back to the old value.
|
419 |
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (
|
420 |
+
alphas_bar_sqrt_0 - alphas_bar_sqrt_T
|
421 |
+
)
|
422 |
+
|
423 |
+
# Convert alphas_bar_sqrt to betas
|
424 |
+
alphas_bar = alphas_bar_sqrt ** 2
|
425 |
+
alphas = alphas_bar[1:] / alphas_bar[:-1]
|
426 |
+
alphas = torch.cat([alphas_bar[0:1], alphas])
|
427 |
+
betas = 1 - alphas
|
428 |
+
|
429 |
+
return betas
|
430 |
+
|
431 |
+
|
432 |
+
def should_sample(global_step, validation_steps, validation_data):
|
433 |
+
return global_step % validation_steps == 0 and validation_data.sample_preview
|
434 |
+
|
435 |
+
|
436 |
+
def save_pipe(
|
437 |
+
path,
|
438 |
+
global_step,
|
439 |
+
accelerator,
|
440 |
+
unet,
|
441 |
+
text_encoder,
|
442 |
+
vae,
|
443 |
+
output_dir,
|
444 |
+
lora_manager_spatial: LoraHandler,
|
445 |
+
lora_manager_temporal: LoraHandler,
|
446 |
+
unet_target_replace_module=None,
|
447 |
+
text_target_replace_module=None,
|
448 |
+
is_checkpoint=False,
|
449 |
+
save_pretrained_model=True
|
450 |
+
):
|
451 |
+
if is_checkpoint:
|
452 |
+
save_path = os.path.join(output_dir, f"checkpoint-{global_step}")
|
453 |
+
os.makedirs(save_path, exist_ok=True)
|
454 |
+
else:
|
455 |
+
save_path = output_dir
|
456 |
+
|
457 |
+
# Save the dtypes so we can continue training at the same precision.
|
458 |
+
u_dtype, t_dtype, v_dtype = unet.dtype, text_encoder.dtype, vae.dtype
|
459 |
+
|
460 |
+
# Copy the model without creating a reference to it. This allows keeping the state of our lora training if enabled.
|
461 |
+
unet_out = copy.deepcopy(accelerator.unwrap_model(unet.cpu(), keep_fp32_wrapper=False))
|
462 |
+
text_encoder_out = copy.deepcopy(accelerator.unwrap_model(text_encoder.cpu(), keep_fp32_wrapper=False))
|
463 |
+
|
464 |
+
pipeline = TextToVideoSDPipeline.from_pretrained(
|
465 |
+
path,
|
466 |
+
unet=unet_out,
|
467 |
+
text_encoder=text_encoder_out,
|
468 |
+
vae=vae,
|
469 |
+
).to(torch_dtype=torch.float32)
|
470 |
+
|
471 |
+
lora_manager_spatial.save_lora_weights(model=copy.deepcopy(pipeline), save_path=save_path+'/spatial', step=global_step)
|
472 |
+
lora_manager_temporal.save_lora_weights(model=copy.deepcopy(pipeline), save_path=save_path+'/temporal', step=global_step)
|
473 |
+
|
474 |
+
if save_pretrained_model:
|
475 |
+
pipeline.save_pretrained(save_path)
|
476 |
+
|
477 |
+
if is_checkpoint:
|
478 |
+
unet, text_encoder = accelerator.prepare(unet, text_encoder)
|
479 |
+
models_to_cast_back = [(unet, u_dtype), (text_encoder, t_dtype), (vae, v_dtype)]
|
480 |
+
[x[0].to(accelerator.device, dtype=x[1]) for x in models_to_cast_back]
|
481 |
+
|
482 |
+
logger.info(f"Saved model at {save_path} on step {global_step}")
|
483 |
+
|
484 |
+
del pipeline
|
485 |
+
del unet_out
|
486 |
+
del text_encoder_out
|
487 |
+
torch.cuda.empty_cache()
|
488 |
+
gc.collect()
|
489 |
+
|
490 |
+
|
491 |
+
def main(
|
492 |
+
pretrained_model_path: str,
|
493 |
+
output_dir: str,
|
494 |
+
train_data: Dict,
|
495 |
+
validation_data: Dict,
|
496 |
+
extra_train_data: list = [],
|
497 |
+
dataset_types: Tuple[str] = ('json'),
|
498 |
+
validation_steps: int = 100,
|
499 |
+
trainable_modules: Tuple[str] = None, # Eg: ("attn1", "attn2")
|
500 |
+
extra_unet_params=None,
|
501 |
+
train_batch_size: int = 1,
|
502 |
+
max_train_steps: int = 500,
|
503 |
+
learning_rate: float = 5e-5,
|
504 |
+
lr_scheduler: str = "constant",
|
505 |
+
lr_warmup_steps: int = 0,
|
506 |
+
adam_beta1: float = 0.9,
|
507 |
+
adam_beta2: float = 0.999,
|
508 |
+
adam_weight_decay: float = 1e-2,
|
509 |
+
adam_epsilon: float = 1e-08,
|
510 |
+
gradient_accumulation_steps: int = 1,
|
511 |
+
gradient_checkpointing: bool = False,
|
512 |
+
text_encoder_gradient_checkpointing: bool = False,
|
513 |
+
checkpointing_steps: int = 500,
|
514 |
+
resume_from_checkpoint: Optional[str] = None,
|
515 |
+
resume_step: Optional[int] = None,
|
516 |
+
mixed_precision: Optional[str] = "fp16",
|
517 |
+
use_8bit_adam: bool = False,
|
518 |
+
enable_xformers_memory_efficient_attention: bool = True,
|
519 |
+
enable_torch_2_attn: bool = False,
|
520 |
+
seed: Optional[int] = None,
|
521 |
+
use_offset_noise: bool = False,
|
522 |
+
rescale_schedule: bool = False,
|
523 |
+
offset_noise_strength: float = 0.1,
|
524 |
+
extend_dataset: bool = False,
|
525 |
+
cache_latents: bool = False,
|
526 |
+
cached_latent_dir=None,
|
527 |
+
use_unet_lora: bool = False,
|
528 |
+
unet_lora_modules: Tuple[str] = [],
|
529 |
+
text_encoder_lora_modules: Tuple[str] = [],
|
530 |
+
save_pretrained_model: bool = True,
|
531 |
+
lora_rank: int = 16,
|
532 |
+
lora_path: str = '',
|
533 |
+
lora_unet_dropout: float = 0.1,
|
534 |
+
logger_type: str = 'tensorboard',
|
535 |
+
**kwargs
|
536 |
+
):
|
537 |
+
*_, config = inspect.getargvalues(inspect.currentframe())
|
538 |
+
|
539 |
+
accelerator = Accelerator(
|
540 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
541 |
+
mixed_precision=mixed_precision,
|
542 |
+
log_with=logger_type,
|
543 |
+
project_dir=output_dir
|
544 |
+
)
|
545 |
+
|
546 |
+
# Make one log on every process with the configuration for debugging.
|
547 |
+
create_logging(logging, logger, accelerator)
|
548 |
+
|
549 |
+
# Initialize accelerate, transformers, and diffusers warnings
|
550 |
+
accelerate_set_verbose(accelerator)
|
551 |
+
|
552 |
+
# Handle the output folder creation
|
553 |
+
if accelerator.is_main_process:
|
554 |
+
output_dir = create_output_folders(output_dir, config)
|
555 |
+
|
556 |
+
# Load scheduler, tokenizer and models.
|
557 |
+
noise_scheduler, tokenizer, text_encoder, vae, unet = load_primary_models(pretrained_model_path)
|
558 |
+
|
559 |
+
# Freeze any necessary models
|
560 |
+
freeze_models([vae, text_encoder, unet])
|
561 |
+
|
562 |
+
# Enable xformers if available
|
563 |
+
handle_memory_attention(enable_xformers_memory_efficient_attention, enable_torch_2_attn, unet)
|
564 |
+
|
565 |
+
# Initialize the optimizer
|
566 |
+
optimizer_cls = get_optimizer(use_8bit_adam)
|
567 |
+
|
568 |
+
# Get the training dataset based on types (json, single_video, image)
|
569 |
+
train_datasets = get_train_dataset(dataset_types, train_data, tokenizer)
|
570 |
+
|
571 |
+
# If you have extra train data, you can add a list of however many you would like.
|
572 |
+
# Eg: extra_train_data: [{: {dataset_types, train_data: {etc...}}}]
|
573 |
+
try:
|
574 |
+
if extra_train_data is not None and len(extra_train_data) > 0:
|
575 |
+
for dataset in extra_train_data:
|
576 |
+
d_t, t_d = dataset['dataset_types'], dataset['train_data']
|
577 |
+
train_datasets += get_train_dataset(d_t, t_d, tokenizer)
|
578 |
+
|
579 |
+
except Exception as e:
|
580 |
+
print(f"Could not process extra train datasets due to an error : {e}")
|
581 |
+
|
582 |
+
# Extend datasets that are less than the greatest one. This allows for more balanced training.
|
583 |
+
attrs = ['train_data', 'frames', 'image_dir', 'video_files']
|
584 |
+
extend_datasets(train_datasets, attrs, extend=extend_dataset)
|
585 |
+
|
586 |
+
# Process one dataset
|
587 |
+
if len(train_datasets) == 1:
|
588 |
+
train_dataset = train_datasets[0]
|
589 |
+
|
590 |
+
# Process many datasets
|
591 |
+
else:
|
592 |
+
train_dataset = torch.utils.data.ConcatDataset(train_datasets)
|
593 |
+
|
594 |
+
# Create parameters to optimize over with a condition (if "condition" is true, optimize it)
|
595 |
+
extra_unet_params = extra_unet_params if extra_unet_params is not None else {}
|
596 |
+
extra_text_encoder_params = extra_unet_params if extra_unet_params is not None else {}
|
597 |
+
|
598 |
+
# Use LoRA if enabled.
|
599 |
+
# one temporal lora
|
600 |
+
lora_manager_temporal = LoraHandler(use_unet_lora=use_unet_lora, unet_replace_modules=["TransformerTemporalModel"])
|
601 |
+
|
602 |
+
unet_lora_params_temporal, unet_negation_temporal = lora_manager_temporal.add_lora_to_model(
|
603 |
+
use_unet_lora, unet, lora_manager_temporal.unet_replace_modules, lora_unet_dropout,
|
604 |
+
lora_path + '/temporal/lora/', r=lora_rank)
|
605 |
+
|
606 |
+
optimizer_temporal = optimizer_cls(
|
607 |
+
create_optimizer_params([param_optim(unet_lora_params_temporal, use_unet_lora, is_lora=True,
|
608 |
+
extra_params={**{"lr": learning_rate}, **extra_text_encoder_params}
|
609 |
+
)], learning_rate),
|
610 |
+
lr=learning_rate,
|
611 |
+
betas=(adam_beta1, adam_beta2),
|
612 |
+
weight_decay=adam_weight_decay,
|
613 |
+
eps=adam_epsilon,
|
614 |
+
)
|
615 |
+
|
616 |
+
lr_scheduler_temporal = get_scheduler(
|
617 |
+
lr_scheduler,
|
618 |
+
optimizer=optimizer_temporal,
|
619 |
+
num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
|
620 |
+
num_training_steps=max_train_steps * gradient_accumulation_steps,
|
621 |
+
)
|
622 |
+
|
623 |
+
# one spatial lora for each video
|
624 |
+
if 'folder' in dataset_types:
|
625 |
+
spatial_lora_num = train_dataset.__len__()
|
626 |
+
else:
|
627 |
+
spatial_lora_num = 1
|
628 |
+
|
629 |
+
lora_manager_spatials = []
|
630 |
+
unet_lora_params_spatial_list = []
|
631 |
+
optimizer_spatial_list = []
|
632 |
+
lr_scheduler_spatial_list = []
|
633 |
+
for i in range(spatial_lora_num):
|
634 |
+
lora_manager_spatial = LoraHandler(use_unet_lora=use_unet_lora, unet_replace_modules=["Transformer2DModel"])
|
635 |
+
lora_manager_spatials.append(lora_manager_spatial)
|
636 |
+
unet_lora_params_spatial, unet_negation_spatial = lora_manager_spatial.add_lora_to_model(
|
637 |
+
use_unet_lora, unet, lora_manager_spatial.unet_replace_modules, lora_unet_dropout,
|
638 |
+
lora_path + '/spatial/lora/', r=lora_rank)
|
639 |
+
|
640 |
+
unet_lora_params_spatial_list.append(unet_lora_params_spatial)
|
641 |
+
|
642 |
+
optimizer_spatial = optimizer_cls(
|
643 |
+
create_optimizer_params([param_optim(unet_lora_params_spatial, use_unet_lora, is_lora=True,
|
644 |
+
extra_params={**{"lr": learning_rate}, **extra_text_encoder_params}
|
645 |
+
)], learning_rate),
|
646 |
+
lr=learning_rate,
|
647 |
+
betas=(adam_beta1, adam_beta2),
|
648 |
+
weight_decay=adam_weight_decay,
|
649 |
+
eps=adam_epsilon,
|
650 |
+
)
|
651 |
+
|
652 |
+
optimizer_spatial_list.append(optimizer_spatial)
|
653 |
+
|
654 |
+
# Scheduler
|
655 |
+
lr_scheduler_spatial = get_scheduler(
|
656 |
+
lr_scheduler,
|
657 |
+
optimizer=optimizer_spatial,
|
658 |
+
num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
|
659 |
+
num_training_steps=max_train_steps * gradient_accumulation_steps,
|
660 |
+
)
|
661 |
+
lr_scheduler_spatial_list.append(lr_scheduler_spatial)
|
662 |
+
|
663 |
+
unet_negation_all = unet_negation_spatial + unet_negation_temporal
|
664 |
+
|
665 |
+
# DataLoaders creation:
|
666 |
+
train_dataloader = torch.utils.data.DataLoader(
|
667 |
+
train_dataset,
|
668 |
+
batch_size=train_batch_size,
|
669 |
+
shuffle=True
|
670 |
+
)
|
671 |
+
|
672 |
+
# Latents caching
|
673 |
+
cached_data_loader = handle_cache_latents(
|
674 |
+
cache_latents,
|
675 |
+
output_dir,
|
676 |
+
train_dataloader,
|
677 |
+
train_batch_size,
|
678 |
+
vae,
|
679 |
+
unet,
|
680 |
+
pretrained_model_path,
|
681 |
+
validation_data.noise_prior,
|
682 |
+
cached_latent_dir,
|
683 |
+
)
|
684 |
+
|
685 |
+
if cached_data_loader is not None:
|
686 |
+
train_dataloader = cached_data_loader
|
687 |
+
|
688 |
+
# Prepare everything with our `accelerator`.
|
689 |
+
unet, optimizer_spatial_list, optimizer_temporal, train_dataloader, lr_scheduler_spatial_list, lr_scheduler_temporal, text_encoder = accelerator.prepare(
|
690 |
+
unet,
|
691 |
+
optimizer_spatial_list, optimizer_temporal,
|
692 |
+
train_dataloader,
|
693 |
+
lr_scheduler_spatial_list, lr_scheduler_temporal,
|
694 |
+
text_encoder
|
695 |
+
)
|
696 |
+
|
697 |
+
# Use Gradient Checkpointing if enabled.
|
698 |
+
unet_and_text_g_c(
|
699 |
+
unet,
|
700 |
+
text_encoder,
|
701 |
+
gradient_checkpointing,
|
702 |
+
text_encoder_gradient_checkpointing
|
703 |
+
)
|
704 |
+
|
705 |
+
# Enable VAE slicing to save memory.
|
706 |
+
vae.enable_slicing()
|
707 |
+
|
708 |
+
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
709 |
+
# as these models are only used for inference, keeping weights in full precision is not required.
|
710 |
+
weight_dtype = is_mixed_precision(accelerator)
|
711 |
+
|
712 |
+
# Move text encoders, and VAE to GPU
|
713 |
+
models_to_cast = [text_encoder, vae]
|
714 |
+
cast_to_gpu_and_type(models_to_cast, accelerator, weight_dtype)
|
715 |
+
|
716 |
+
# Fix noise schedules to predcit light and dark areas if available.
|
717 |
+
if not use_offset_noise and rescale_schedule:
|
718 |
+
noise_scheduler.betas = enforce_zero_terminal_snr(noise_scheduler.betas)
|
719 |
+
|
720 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
721 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
|
722 |
+
|
723 |
+
# Afterwards we recalculate our number of training epochs
|
724 |
+
num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
|
725 |
+
|
726 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
727 |
+
# The trackers initializes automatically on the main process.
|
728 |
+
if accelerator.is_main_process:
|
729 |
+
accelerator.init_trackers("text2video-fine-tune")
|
730 |
+
|
731 |
+
# Train!
|
732 |
+
total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps
|
733 |
+
|
734 |
+
logger.info("***** Running training *****")
|
735 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
736 |
+
logger.info(f" Num Epochs = {num_train_epochs}")
|
737 |
+
logger.info(f" Instantaneous batch size per device = {train_batch_size}")
|
738 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
739 |
+
logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
|
740 |
+
logger.info(f" Total optimization steps = {max_train_steps}")
|
741 |
+
global_step = 0
|
742 |
+
first_epoch = 0
|
743 |
+
|
744 |
+
# Only show the progress bar once on each machine.
|
745 |
+
progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process)
|
746 |
+
progress_bar.set_description("Steps")
|
747 |
+
|
748 |
+
def finetune_unet(batch, step, mask_spatial_lora=False, mask_temporal_lora=False):
|
749 |
+
nonlocal use_offset_noise
|
750 |
+
nonlocal rescale_schedule
|
751 |
+
|
752 |
+
# Unfreeze UNET Layers
|
753 |
+
if global_step == 0:
|
754 |
+
already_printed_trainables = False
|
755 |
+
unet.train()
|
756 |
+
handle_trainable_modules(
|
757 |
+
unet,
|
758 |
+
trainable_modules,
|
759 |
+
is_enabled=True,
|
760 |
+
negation=unet_negation_all
|
761 |
+
)
|
762 |
+
|
763 |
+
# Convert videos to latent space
|
764 |
+
if not cache_latents:
|
765 |
+
latents = tensor_to_vae_latent(batch["pixel_values"], vae)
|
766 |
+
else:
|
767 |
+
latents = batch["latents"]
|
768 |
+
|
769 |
+
# Sample noise that we'll add to the latents
|
770 |
+
use_offset_noise = use_offset_noise and not rescale_schedule
|
771 |
+
noise = sample_noise(latents, offset_noise_strength, use_offset_noise)
|
772 |
+
bsz = latents.shape[0]
|
773 |
+
|
774 |
+
# Sample a random timestep for each video
|
775 |
+
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
776 |
+
timesteps = timesteps.long()
|
777 |
+
|
778 |
+
# Add noise to the latents according to the noise magnitude at each timestep
|
779 |
+
# (this is the forward diffusion process)
|
780 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
781 |
+
|
782 |
+
# *Potentially* Fixes gradient checkpointing training.
|
783 |
+
# See: https://github.com/prigoyal/pytorch_memonger/blob/master/tutorial/Checkpointing_for_PyTorch_models.ipynb
|
784 |
+
if kwargs.get('eval_train', False):
|
785 |
+
unet.eval()
|
786 |
+
text_encoder.eval()
|
787 |
+
|
788 |
+
# Encode text embeddings
|
789 |
+
token_ids = batch['prompt_ids']
|
790 |
+
encoder_hidden_states = text_encoder(token_ids)[0]
|
791 |
+
detached_encoder_state = encoder_hidden_states.clone().detach()
|
792 |
+
|
793 |
+
# Get the target for loss depending on the prediction type
|
794 |
+
if noise_scheduler.config.prediction_type == "epsilon":
|
795 |
+
target = noise
|
796 |
+
|
797 |
+
elif noise_scheduler.config.prediction_type == "v_prediction":
|
798 |
+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
799 |
+
|
800 |
+
else:
|
801 |
+
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
802 |
+
|
803 |
+
encoder_hidden_states = detached_encoder_state
|
804 |
+
|
805 |
+
if mask_spatial_lora:
|
806 |
+
loras = extract_lora_child_module(unet, target_replace_module=["Transformer2DModel"])
|
807 |
+
for lora_i in loras:
|
808 |
+
lora_i.scale = 0.
|
809 |
+
loss_spatial = None
|
810 |
+
else:
|
811 |
+
loras = extract_lora_child_module(unet, target_replace_module=["Transformer2DModel"])
|
812 |
+
for lora_i in loras:
|
813 |
+
lora_i.scale = 1.
|
814 |
+
|
815 |
+
for lora_idx in range(0, len(loras), spatial_lora_num):
|
816 |
+
loras[lora_idx + step].scale = 1.
|
817 |
+
|
818 |
+
loras = extract_lora_child_module(unet, target_replace_module=["TransformerTemporalModel"])
|
819 |
+
for lora_i in loras:
|
820 |
+
lora_i.scale = 0.
|
821 |
+
|
822 |
+
ran_idx = torch.randint(0, noisy_latents.shape[2], (1,)).item()
|
823 |
+
|
824 |
+
if random.uniform(0, 1) < -0.5:
|
825 |
+
pixel_values_spatial = transforms.functional.hflip(
|
826 |
+
batch["pixel_values"][:, ran_idx, :, :, :]).unsqueeze(1)
|
827 |
+
latents_spatial = tensor_to_vae_latent(pixel_values_spatial, vae)
|
828 |
+
noise_spatial = sample_noise(latents_spatial, offset_noise_strength, use_offset_noise)
|
829 |
+
noisy_latents_input = noise_scheduler.add_noise(latents_spatial, noise_spatial, timesteps)
|
830 |
+
target_spatial = noise_spatial
|
831 |
+
model_pred_spatial = unet(noisy_latents_input, timesteps,
|
832 |
+
encoder_hidden_states=encoder_hidden_states).sample
|
833 |
+
loss_spatial = F.mse_loss(model_pred_spatial[:, :, 0, :, :].float(),
|
834 |
+
target_spatial[:, :, 0, :, :].float(), reduction="mean")
|
835 |
+
else:
|
836 |
+
noisy_latents_input = noisy_latents[:, :, ran_idx, :, :]
|
837 |
+
target_spatial = target[:, :, ran_idx, :, :]
|
838 |
+
model_pred_spatial = unet(noisy_latents_input.unsqueeze(2), timesteps,
|
839 |
+
encoder_hidden_states=encoder_hidden_states).sample
|
840 |
+
loss_spatial = F.mse_loss(model_pred_spatial[:, :, 0, :, :].float(),
|
841 |
+
target_spatial.float(), reduction="mean")
|
842 |
+
|
843 |
+
if mask_temporal_lora:
|
844 |
+
loras = extract_lora_child_module(unet, target_replace_module=["TransformerTemporalModel"])
|
845 |
+
for lora_i in loras:
|
846 |
+
lora_i.scale = 0.
|
847 |
+
loss_temporal = None
|
848 |
+
else:
|
849 |
+
loras = extract_lora_child_module(unet, target_replace_module=["TransformerTemporalModel"])
|
850 |
+
for lora_i in loras:
|
851 |
+
lora_i.scale = 1.
|
852 |
+
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
|
853 |
+
loss_temporal = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
854 |
+
|
855 |
+
beta = 1
|
856 |
+
alpha = (beta ** 2 + 1) ** 0.5
|
857 |
+
ran_idx = torch.randint(0, model_pred.shape[2], (1,)).item()
|
858 |
+
model_pred_decent = alpha * model_pred - beta * model_pred[:, :, ran_idx, :, :].unsqueeze(2)
|
859 |
+
target_decent = alpha * target - beta * target[:, :, ran_idx, :, :].unsqueeze(2)
|
860 |
+
loss_ad_temporal = F.mse_loss(model_pred_decent.float(), target_decent.float(), reduction="mean")
|
861 |
+
loss_temporal = loss_temporal + loss_ad_temporal
|
862 |
+
|
863 |
+
return loss_spatial, loss_temporal, latents, noise
|
864 |
+
|
865 |
+
for epoch in range(first_epoch, num_train_epochs):
|
866 |
+
train_loss_spatial = 0.0
|
867 |
+
train_loss_temporal = 0.0
|
868 |
+
|
869 |
+
for step, batch in enumerate(train_dataloader):
|
870 |
+
# Skip steps until we reach the resumed step
|
871 |
+
if resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
872 |
+
if step % gradient_accumulation_steps == 0:
|
873 |
+
progress_bar.update(1)
|
874 |
+
continue
|
875 |
+
|
876 |
+
with accelerator.accumulate(unet), accelerator.accumulate(text_encoder):
|
877 |
+
|
878 |
+
text_prompt = batch['text_prompt'][0]
|
879 |
+
|
880 |
+
for optimizer_spatial in optimizer_spatial_list:
|
881 |
+
optimizer_spatial.zero_grad(set_to_none=True)
|
882 |
+
|
883 |
+
optimizer_temporal.zero_grad(set_to_none=True)
|
884 |
+
|
885 |
+
mask_temporal_lora = False
|
886 |
+
# mask_spatial_lora = False
|
887 |
+
mask_spatial_lora = random.uniform(0, 1) < 0.1 and not mask_temporal_lora
|
888 |
+
|
889 |
+
with accelerator.autocast():
|
890 |
+
loss_spatial, loss_temporal, latents, init_noise = finetune_unet(batch, step, mask_spatial_lora=mask_spatial_lora, mask_temporal_lora=mask_temporal_lora)
|
891 |
+
|
892 |
+
# Gather the losses across all processes for logging (if we use distributed training).
|
893 |
+
if not mask_spatial_lora:
|
894 |
+
avg_loss_spatial = accelerator.gather(loss_spatial.repeat(train_batch_size)).mean()
|
895 |
+
train_loss_spatial += avg_loss_spatial.item() / gradient_accumulation_steps
|
896 |
+
|
897 |
+
if not mask_temporal_lora:
|
898 |
+
avg_loss_temporal = accelerator.gather(loss_temporal.repeat(train_batch_size)).mean()
|
899 |
+
train_loss_temporal += avg_loss_temporal.item() / gradient_accumulation_steps
|
900 |
+
|
901 |
+
# Backpropagate
|
902 |
+
if not mask_spatial_lora:
|
903 |
+
accelerator.backward(loss_spatial, retain_graph = True)
|
904 |
+
optimizer_spatial_list[step].step()
|
905 |
+
|
906 |
+
if not mask_temporal_lora:
|
907 |
+
accelerator.backward(loss_temporal)
|
908 |
+
optimizer_temporal.step()
|
909 |
+
|
910 |
+
lr_scheduler_spatial_list[step].step()
|
911 |
+
lr_scheduler_temporal.step()
|
912 |
+
|
913 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
914 |
+
if accelerator.sync_gradients:
|
915 |
+
progress_bar.update(1)
|
916 |
+
global_step += 1
|
917 |
+
accelerator.log({"train_loss": train_loss_temporal}, step=global_step)
|
918 |
+
train_loss_temporal = 0.0
|
919 |
+
if global_step % checkpointing_steps == 0 and global_step > 0:
|
920 |
+
save_pipe(
|
921 |
+
pretrained_model_path,
|
922 |
+
global_step,
|
923 |
+
accelerator,
|
924 |
+
unet,
|
925 |
+
text_encoder,
|
926 |
+
vae,
|
927 |
+
output_dir,
|
928 |
+
lora_manager_spatial,
|
929 |
+
lora_manager_temporal,
|
930 |
+
unet_lora_modules,
|
931 |
+
text_encoder_lora_modules,
|
932 |
+
is_checkpoint=True,
|
933 |
+
save_pretrained_model=save_pretrained_model
|
934 |
+
)
|
935 |
+
|
936 |
+
if should_sample(global_step, validation_steps, validation_data):
|
937 |
+
if accelerator.is_main_process:
|
938 |
+
with accelerator.autocast():
|
939 |
+
unet.eval()
|
940 |
+
text_encoder.eval()
|
941 |
+
unet_and_text_g_c(unet, text_encoder, False, False)
|
942 |
+
loras = extract_lora_child_module(unet, target_replace_module=["Transformer2DModel"])
|
943 |
+
for lora_i in loras:
|
944 |
+
lora_i.scale = validation_data.spatial_scale
|
945 |
+
|
946 |
+
if validation_data.noise_prior > 0:
|
947 |
+
preset_noise = (validation_data.noise_prior) ** 0.5 * batch['inversion_noise'] + (
|
948 |
+
1-validation_data.noise_prior) ** 0.5 * torch.randn_like(batch['inversion_noise'])
|
949 |
+
else:
|
950 |
+
preset_noise = None
|
951 |
+
|
952 |
+
pipeline = TextToVideoSDPipeline.from_pretrained(
|
953 |
+
pretrained_model_path,
|
954 |
+
text_encoder=text_encoder,
|
955 |
+
vae=vae,
|
956 |
+
unet=unet
|
957 |
+
)
|
958 |
+
|
959 |
+
diffusion_scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
|
960 |
+
pipeline.scheduler = diffusion_scheduler
|
961 |
+
|
962 |
+
prompt_list = text_prompt if len(validation_data.prompt) <= 0 else validation_data.prompt
|
963 |
+
for prompt in prompt_list:
|
964 |
+
save_filename = f"{global_step}_{prompt.replace('.', '')}"
|
965 |
+
|
966 |
+
out_file = f"{output_dir}/samples/{save_filename}.mp4"
|
967 |
+
|
968 |
+
with torch.no_grad():
|
969 |
+
video_frames = pipeline(
|
970 |
+
prompt,
|
971 |
+
width=validation_data.width,
|
972 |
+
height=validation_data.height,
|
973 |
+
num_frames=validation_data.num_frames,
|
974 |
+
num_inference_steps=validation_data.num_inference_steps,
|
975 |
+
guidance_scale=validation_data.guidance_scale,
|
976 |
+
latents=preset_noise
|
977 |
+
).frames
|
978 |
+
export_to_video(video_frames, out_file, train_data.get('fps', 8))
|
979 |
+
logger.info(f"Saved a new sample to {out_file}")
|
980 |
+
del pipeline
|
981 |
+
torch.cuda.empty_cache()
|
982 |
+
|
983 |
+
unet_and_text_g_c(
|
984 |
+
unet,
|
985 |
+
text_encoder,
|
986 |
+
gradient_checkpointing,
|
987 |
+
text_encoder_gradient_checkpointing
|
988 |
+
)
|
989 |
+
|
990 |
+
accelerator.log({"loss_temporal": loss_temporal.detach().item()}, step=step)
|
991 |
+
|
992 |
+
if global_step >= max_train_steps:
|
993 |
+
break
|
994 |
+
|
995 |
+
# Create the pipeline using the trained modules and save it.
|
996 |
+
accelerator.wait_for_everyone()
|
997 |
+
if accelerator.is_main_process:
|
998 |
+
save_pipe(
|
999 |
+
pretrained_model_path,
|
1000 |
+
global_step,
|
1001 |
+
accelerator,
|
1002 |
+
unet,
|
1003 |
+
text_encoder,
|
1004 |
+
vae,
|
1005 |
+
output_dir,
|
1006 |
+
lora_manager_spatial,
|
1007 |
+
lora_manager_temporal,
|
1008 |
+
unet_lora_modules,
|
1009 |
+
text_encoder_lora_modules,
|
1010 |
+
is_checkpoint=False,
|
1011 |
+
save_pretrained_model=save_pretrained_model
|
1012 |
+
)
|
1013 |
+
accelerator.end_training()
|
1014 |
+
|
1015 |
+
|
1016 |
+
if __name__ == "__main__":
|
1017 |
+
parser = argparse.ArgumentParser()
|
1018 |
+
parser.add_argument("--config", type=str, default='./configs/config_multi_videos.yaml')
|
1019 |
+
args = parser.parse_args()
|
1020 |
+
main(**OmegaConf.load(args.config))
|
1021 |
+
|
README.md
CHANGED
@@ -1,13 +1,364 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MotionDirector
|
2 |
+
|
3 |
+
This is the official repository of [MotionDirector](https://showlab.github.io/MotionDirector).
|
4 |
+
|
5 |
+
**MotionDirector: Motion Customization of Text-to-Video Diffusion Models.**
|
6 |
+
<br/>
|
7 |
+
[Rui Zhao](https://ruizhaocv.github.io/),
|
8 |
+
[Yuchao Gu](https://ycgu.site/),
|
9 |
+
[Jay Zhangjie Wu](https://zhangjiewu.github.io/),
|
10 |
+
[David Junhao Zhang](https://junhaozhang98.github.io/),
|
11 |
+
[Jiawei Liu](https://jia-wei-liu.github.io/),
|
12 |
+
[Weijia Wu](https://weijiawu.github.io/),
|
13 |
+
[Jussi Keppo](https://www.jussikeppo.com/),
|
14 |
+
[Mike Zheng Shou](https://sites.google.com/view/showlab)
|
15 |
+
<br/>
|
16 |
+
|
17 |
+
[![Project Page](https://img.shields.io/badge/Project-Website-orange)](https://showlab.github.io/MotionDirector)
|
18 |
+
[![arXiv](https://img.shields.io/badge/arXiv-MotionDirector-b31b1b.svg)](https://arxiv.org/abs/2310.08465)
|
19 |
+
|
20 |
+
<p align="center">
|
21 |
+
<img src="https://github.com/showlab/MotionDirector/blob/page/assets/teaser.gif" width="1080px"/>
|
22 |
+
<br>
|
23 |
+
<em>MotionDirector can customize text-to-video diffusion models to generate videos with desired motions.</em>
|
24 |
+
</p>
|
25 |
+
|
26 |
+
<table class="center">
|
27 |
+
<tr>
|
28 |
+
<td style="text-align:center;" colspan="4"><b>Astronaut's daily life on Mars (Motion concepts learned by MotionDirector)</b></td>
|
29 |
+
</tr>
|
30 |
+
<tr>
|
31 |
+
<td style="text-align:center;"><b>Lifting Weights</b></td>
|
32 |
+
<td style="text-align:center;"><b>Playing Golf</b></td>
|
33 |
+
<td style="text-align:center;"><b>Riding Horse</b></td>
|
34 |
+
<td style="text-align:center;"><b>Riding Bicycle</b></td>
|
35 |
+
</tr>
|
36 |
+
<tr>
|
37 |
+
<td><img src=assets/astronaut_mars/An_astronaut_is_lifting_weights_on_Mars_4K_high_quailty_highly_detailed_4008521.gif></td>
|
38 |
+
<td><img src=assets/astronaut_mars/Astronaut_playing_golf_on_Mars_659514.gif></td>
|
39 |
+
<td><img src=assets/astronaut_mars/An_astronaut_is_riding_a_horse_on_Mars_4K_high_quailty_highly_detailed_1913261.gif></td>
|
40 |
+
<td><img src=assets/astronaut_mars/An_astronaut_is_riding_a_bicycle_past_the_pyramids_Mars_4K_high_quailty_highly_detailed_5532778.gif></td>
|
41 |
+
</tr>
|
42 |
+
<tr>
|
43 |
+
<td width=25% style="text-align:center;">"An astronaut is lifting weights on Mars, 4K, high quailty, highly detailed.” </br> seed: 4008521</td>
|
44 |
+
<td width=25% style="text-align:center;">"Astronaut playing golf on Mars” </br> seed: 659514</td>
|
45 |
+
<td width=25% style="text-align:center;">"An astronaut is riding a horse on Mars, 4K, high quailty, highly detailed." </br> seed: 1913261</td>
|
46 |
+
<td width=25% style="text-align:center;">"An astronaut is riding a bicycle past the pyramids Mars, 4K, high quailty, highly detailed." </br> seed: 5532778</td>
|
47 |
+
<tr>
|
48 |
+
</table>
|
49 |
+
|
50 |
+
## News
|
51 |
+
- [2023.12.06] [MotionDirector for Sports](#MotionDirector_for_Sports) released! Lifting weights, riding horse, palying golf, etc.
|
52 |
+
- [2023.12.05] [Colab demo](https://github.com/camenduru/MotionDirector-colab) is available. Thanks to [Camenduru](https://twitter.com/camenduru).
|
53 |
+
- [2023.12.04] [MotionDirector for Cinematic Shots](#MotionDirector_for_Cinematic_Shots) released. Now, you can make AI films with professional cinematic shots!
|
54 |
+
- [2023.12.02] Code and model weights released!
|
55 |
+
|
56 |
+
## ToDo
|
57 |
+
- [ ] Gradio Demo
|
58 |
+
- [ ] More trained weights of MotionDirector
|
59 |
+
|
60 |
+
## Setup
|
61 |
+
### Requirements
|
62 |
+
|
63 |
+
```shell
|
64 |
+
# create virtual environment
|
65 |
+
conda create -n motiondirector python=3.8
|
66 |
+
conda activate motiondirector
|
67 |
+
# install packages
|
68 |
+
pip install -r requirements.txt
|
69 |
+
```
|
70 |
+
|
71 |
+
### Weights of Foundation Models
|
72 |
+
```shell
|
73 |
+
git lfs install
|
74 |
+
## You can choose the ModelScopeT2V or ZeroScope, etc., as the foundation model.
|
75 |
+
## ZeroScope
|
76 |
+
git clone https://huggingface.co/cerspense/zeroscope_v2_576w ./models/zeroscope_v2_576w/
|
77 |
+
## ModelScopeT2V
|
78 |
+
git clone https://huggingface.co/damo-vilab/text-to-video-ms-1.7b ./models/model_scope/
|
79 |
+
```
|
80 |
+
### Weights of trained MotionDirector <a name="download_weights"></a>
|
81 |
+
```shell
|
82 |
+
# Make sure you have git-lfs installed (https://git-lfs.com)
|
83 |
+
git lfs install
|
84 |
+
git clone https://huggingface.co/ruizhaocv/MotionDirector_weights ./outputs
|
85 |
+
```
|
86 |
+
|
87 |
+
## Usage
|
88 |
+
### Training
|
89 |
+
|
90 |
+
#### Train MotionDirector on multiple videos:
|
91 |
+
```bash
|
92 |
+
python MotionDirector_train.py --config ./configs/config_multi_videos.yaml
|
93 |
+
```
|
94 |
+
#### Train MotionDirector on a single video:
|
95 |
+
```bash
|
96 |
+
python MotionDirector_train.py --config ./configs/config_single_video.yaml
|
97 |
+
```
|
98 |
+
|
99 |
+
Note:
|
100 |
+
- Before running the above command,
|
101 |
+
make sure you replace the path to foundational model weights and training data with your own in the config files `config_multi_videos.yaml` or `config_single_video.yaml`.
|
102 |
+
- Generally, training on multiple 16-frame videos usually takes `300~500` steps, about `9~16` minutes using one A5000 GPU. Training on a single video takes `50~150` steps, about `1.5~4.5` minutes using one A5000 GPU. The required VRAM for training is around `14GB`.
|
103 |
+
- Reduce `n_sample_frames` if your GPU memory is limited.
|
104 |
+
- Reduce the learning rate and increase the training steps for better performance.
|
105 |
+
|
106 |
+
|
107 |
+
### Inference
|
108 |
+
```bash
|
109 |
+
python MotionDirector_inference.py --model /path/to/the/foundation/model --prompt "Your prompt" --checkpoint_folder /path/to/the/trained/MotionDirector --checkpoint_index 300 --noise_prior 0.
|
110 |
+
```
|
111 |
+
Note:
|
112 |
+
- Replace `/path/to/the/foundation/model` with your own path to the foundation model, like ZeroScope.
|
113 |
+
- The value of `checkpoint_index` means the checkpoint saved at which the training step is selected.
|
114 |
+
- The value of `noise_prior` indicates how much the inversion noise of the reference video affects the generation.
|
115 |
+
We recommend setting it to `0` for MotionDirector trained on multiple videos to achieve the highest diverse generation, while setting it to `0.1~0.5` for MotionDirector trained on a single video for faster convergence and better alignment with the reference video.
|
116 |
+
|
117 |
+
|
118 |
+
## Inference with pre-trained MotionDirector
|
119 |
+
All available weights are at official [Huggingface Repo](https://huggingface.co/ruizhaocv/MotionDirector_weights).
|
120 |
+
Run the [download command](#download_weights), the weights will be downloaded to the folder `outputs`, then run the following inference command to generate videos.
|
121 |
+
|
122 |
+
### MotionDirector trained on multiple videos:
|
123 |
+
```bash
|
124 |
+
python MotionDirector_inference.py --model /path/to/the/ZeroScope --prompt "A person is riding a bicycle past the Eiffel Tower." --checkpoint_folder ./outputs/train/riding_bicycle/ --checkpoint_index 300 --noise_prior 0. --seed 7192280
|
125 |
+
```
|
126 |
+
Note:
|
127 |
+
- Replace `/path/to/the/ZeroScope` with your own path to the foundation model, i.e. the ZeroScope.
|
128 |
+
- Change the `prompt` to generate different videos.
|
129 |
+
- The `seed` is set to a random value by default. Set it to a specific value will obtain certain results, as provided in the table below.
|
130 |
+
|
131 |
+
Results:
|
132 |
+
|
133 |
+
<table class="center">
|
134 |
+
<tr>
|
135 |
+
<td style="text-align:center;"><b>Reference Videos</b></td>
|
136 |
+
<td style="text-align:center;" colspan="3"><b>Videos Generated by MotionDirector</b></td>
|
137 |
+
</tr>
|
138 |
+
<tr>
|
139 |
+
<td><img src=assets/multi_videos_results/reference_videos.gif></td>
|
140 |
+
<td><img src=assets/multi_videos_results/A_person_is_riding_a_bicycle_past_the_Eiffel_Tower_7192280.gif></td>
|
141 |
+
<td><img src=assets/multi_videos_results/A_panda_is_riding_a_bicycle_in_a_garden_2178639.gif></td>
|
142 |
+
<td><img src=assets/multi_videos_results/An_alien_is_riding_a_bicycle_on_Mars_2390886.gif></td>
|
143 |
+
</tr>
|
144 |
+
<tr>
|
145 |
+
<td width=25% style="text-align:center;color:gray;">"A person is riding a bicycle."</td>
|
146 |
+
<td width=25% style="text-align:center;">"A person is riding a bicycle past the Eiffel Tower.” </br> seed: 7192280</td>
|
147 |
+
<td width=25% style="text-align:center;">"A panda is riding a bicycle in a garden." </br> seed: 2178639</td>
|
148 |
+
<td width=25% style="text-align:center;">"An alien is riding a bicycle on Mars." </br> seed: 2390886</td>
|
149 |
+
</table>
|
150 |
+
|
151 |
+
### MotionDirector trained on a single video:
|
152 |
+
16 frames:
|
153 |
+
```bash
|
154 |
+
python MotionDirector_inference.py --model /path/to/the/ZeroScope --prompt "A tank is running on the moon." --checkpoint_folder ./outputs/train/car_16/ --checkpoint_index 150 --noise_prior 0.5 --seed 8551187
|
155 |
+
```
|
156 |
+
<table class="center">
|
157 |
+
<tr>
|
158 |
+
<td style="text-align:center;"><b>Reference Video</b></td>
|
159 |
+
<td style="text-align:center;" colspan="3"><b>Videos Generated by MotionDirector</b></td>
|
160 |
+
</tr>
|
161 |
+
<tr>
|
162 |
+
<td><img src=assets/single_video_results/reference_video.gif></td>
|
163 |
+
<td><img src=assets/single_video_results/A_tank_is_running_on_the_moon_8551187.gif></td>
|
164 |
+
<td><img src=assets/single_video_results/A_lion_is_running_past_the_pyramids_431554.gif></td>
|
165 |
+
<td><img src=assets/single_video_results/A_spaceship_is_flying_past_Mars_8808231.gif></td>
|
166 |
+
</tr>
|
167 |
+
<tr>
|
168 |
+
<td width=25% style="text-align:center;color:gray;">"A car is running on the road."</td>
|
169 |
+
<td width=25% style="text-align:center;">"A tank is running on the moon.” </br> seed: 8551187</td>
|
170 |
+
<td width=25% style="text-align:center;">"A lion is running past the pyramids." </br> seed: 431554</td>
|
171 |
+
<td width=25% style="text-align:center;">"A spaceship is flying past Mars." </br> seed: 8808231</td>
|
172 |
+
</tr>
|
173 |
+
</table>
|
174 |
+
|
175 |
+
24 frames:
|
176 |
+
```bash
|
177 |
+
python MotionDirector_inference.py --model /path/to/the/ZeroScope --prompt "A truck is running past the Arc de Triomphe." --checkpoint_folder ./outputs/train/car_24/ --checkpoint_index 150 --noise_prior 0.5 --width 576 --height 320 --num-frames 24 --seed 34543
|
178 |
+
```
|
179 |
+
<table class="center">
|
180 |
+
<tr>
|
181 |
+
<td style="text-align:center;"><b>Reference Video</b></td>
|
182 |
+
<td style="text-align:center;" colspan="3"><b>Videos Generated by MotionDirector</b></td>
|
183 |
+
</tr>
|
184 |
+
<tr>
|
185 |
+
<td><img src=assets/single_video_results/24_frames/reference_video.gif></td>
|
186 |
+
<td><img src=assets/single_video_results/24_frames/A_truck_is_running_past_the_Arc_de_Triomphe_34543.gif></td>
|
187 |
+
<td><img src=assets/single_video_results/24_frames/An_elephant_is_running_in_a_forest_2171736.gif></td>
|
188 |
+
</tr>
|
189 |
+
<tr>
|
190 |
+
<td width=25% style="text-align:center;color:gray;">"A car is running on the road."</td>
|
191 |
+
<td width=25% style="text-align:center;">"A truck is running past the Arc de Triomphe.” </br> seed: 34543</td>
|
192 |
+
<td width=25% style="text-align:center;">"An elephant is running in a forest." </br> seed: 2171736</td>
|
193 |
+
</tr>
|
194 |
+
<tr>
|
195 |
+
<td><img src=assets/single_video_results/24_frames/reference_video.gif></td>
|
196 |
+
<td><img src=assets/single_video_results/24_frames/A_person_on_a_camel_is_running_past_the_pyramids_4904126.gif></td>
|
197 |
+
<td><img src=assets/single_video_results/24_frames/A_spacecraft_is_flying_past_the_Milky_Way_galaxy_3235677.gif></td>
|
198 |
+
</tr>
|
199 |
+
<tr>
|
200 |
+
<td width=25% style="text-align:center;color:gray;">"A car is running on the road."</td>
|
201 |
+
<td width=25% style="text-align:center;">"A person on a camel is running past the pyramids." </br> seed: 4904126</td>
|
202 |
+
<td width=25% style="text-align:center;">"A spacecraft is flying past the Milky Way galaxy." </br> seed: 3235677</td>
|
203 |
+
</tr>
|
204 |
+
</table>
|
205 |
+
|
206 |
+
## MotionDirector for Sports <a name="MotionDirector_for_Sports"></a>
|
207 |
+
|
208 |
+
```bash
|
209 |
+
python MotionDirector_inference.py --model /path/to/the/ZeroScope --prompt "A panda is lifting weights in a garden." --checkpoint_folder ./outputs/train/lifting_weights/ --checkpoint_index 300 --noise_prior 0. --seed 9365597
|
210 |
+
```
|
211 |
+
<table class="center">
|
212 |
+
<tr>
|
213 |
+
<td style="text-align:center;" colspan="4"><b>Videos Generated by MotionDirector</b></td>
|
214 |
+
</tr>
|
215 |
+
<tr>
|
216 |
+
<td style="text-align:center;" colspan="2"><b>Lifting Weights</b></td>
|
217 |
+
<td style="text-align:center;" colspan="2"><b>Riding Bicycle</b></td>
|
218 |
+
</tr>
|
219 |
+
<tr>
|
220 |
+
<td><img src=assets/sports_results/lifting_weights/A_panda_is_lifting_weights_in_a_garden_1699276.gif></td>
|
221 |
+
<td><img src=assets/sports_results/lifting_weights/A_police_officer_is_lifting_weights_in_front_of_the_police_station_6804745.gif></td>
|
222 |
+
<td><img src=assets/multi_videos_results/A_panda_is_riding_a_bicycle_in_a_garden_2178639.gif></td>
|
223 |
+
<td><img src=assets/multi_videos_results/An_alien_is_riding_a_bicycle_on_Mars_2390886.gif></td>
|
224 |
+
</tr>
|
225 |
+
<tr>
|
226 |
+
<td width=25% style="text-align:center;">"A panda is lifting weights in a garden.” </br> seed: 1699276</td>
|
227 |
+
<td width=25% style="text-align:center;">"A police officer is lifting weights in front of the police station.” </br> seed: 6804745</td>
|
228 |
+
<td width=25% style="text-align:center;">"A panda is riding a bicycle in a garden." </br> seed: 2178639</td>
|
229 |
+
<td width=25% style="text-align:center;">"An alien is riding a bicycle on Mars." </br> seed: 2390886</td>
|
230 |
+
<tr>
|
231 |
+
<td style="text-align:center;" colspan="2"><b>Riding Horse</b></td>
|
232 |
+
<td style="text-align:center;" colspan="2"><b>Playing Golf</b></td>
|
233 |
+
</tr>
|
234 |
+
<tr>
|
235 |
+
<td><img src=assets/sports_results/riding_horse/A_Royal_Guard_riding_a_horse_in_front_of_Buckingham_Palace_4490970.gif></td>
|
236 |
+
<td><img src=assets/sports_results/riding_horse/A_man_riding_an_elephant_through_the_jungle_6230765.gif></td>
|
237 |
+
<td><img src=assets/sports_results/playing_golf/A_man_is_playing_golf_in_front_of_the_White_House_8870450.gif></td>
|
238 |
+
<td><img src=assets/sports_results/playing_golf/A_monkey_is_playing_golf_on_a_field_full_of_flowers_2989633.gif></td>
|
239 |
+
</tr>
|
240 |
+
<tr>
|
241 |
+
<td width=25% style="text-align:center;">"A Royal Guard riding a horse in front of Buckingham Palace.” </br> seed: 4490970</td>
|
242 |
+
<td width=25% style="text-align:center;">"A man riding an elephant through the jungle.” </br> seed: 6230765</td>
|
243 |
+
<td width=25% style="text-align:center;">"A man is playing golf in front of the White House." </br> seed: 8870450</td>
|
244 |
+
<td width=25% style="text-align:center;">"A monkey is playing golf on a field full of flowers." </br> seed: 2989633</td>
|
245 |
+
<tr>
|
246 |
+
</table>
|
247 |
+
|
248 |
+
More sports, to be continued ...
|
249 |
+
|
250 |
+
## MotionDirector for Cinematic Shots <a name="MotionDirector_for_Cinematic_Shots"></a>
|
251 |
+
|
252 |
+
### 1. Zoom
|
253 |
+
#### 1.1 Dolly Zoom (Hitchcockian Zoom)
|
254 |
+
```bash
|
255 |
+
python MotionDirector_inference.py --model /path/to/the/ZeroScope --prompt "A firefighter standing in front of a burning forest captured with a dolly zoom." --checkpoint_folder ./outputs/train/dolly_zoom/ --checkpoint_index 150 --noise_prior 0.5 --seed 9365597
|
256 |
+
```
|
257 |
+
<table class="center">
|
258 |
+
<tr>
|
259 |
+
<td style="text-align:center;"><b>Reference Video</b></td>
|
260 |
+
<td style="text-align:center;" colspan="3"><b>Videos Generated by MotionDirector</b></td>
|
261 |
+
</tr>
|
262 |
+
<tr>
|
263 |
+
<td><img src=assets/cinematic_shots_results/dolly_zoom_16.gif></td>
|
264 |
+
<td><img src=assets/cinematic_shots_results/A_firefighter_standing_in_front_of_a_burning_forest_captured_with_a_dolly_zoom_9365597.gif></td>
|
265 |
+
<td><img src=assets/cinematic_shots_results/A_lion_sitting_on_top_of_a_cliff_captured_with_a_dolly_zoom_1675932.gif></td>
|
266 |
+
<td><img src=assets/cinematic_shots_results/A_Roman_soldier_standing_in_front_of_the_Colosseum_captured_with_a_dolly_zoom_2310805.gif></td>
|
267 |
+
</tr>
|
268 |
+
<tr>
|
269 |
+
<td width=25% style="text-align:center;color:gray;">"A man standing in room captured with a dolly zoom."</td>
|
270 |
+
<td width=25% style="text-align:center;">"A firefighter standing in front of a burning forest captured with a dolly zoom." </br> seed: 9365597 </br> noise_prior: 0.5</td>
|
271 |
+
<td width=25% style="text-align:center;">"A lion sitting on top of a cliff captured with a dolly zoom." </br> seed: 1675932 </br> noise_prior: 0.5</td>
|
272 |
+
<td width=25% style="text-align:center;">"A Roman soldier standing in front of the Colosseum captured with a dolly zoom." </br> seed: 2310805 </br> noise_prior: 0.5 </td>
|
273 |
+
</tr>
|
274 |
+
<tr>
|
275 |
+
<td><img src=assets/cinematic_shots_results/dolly_zoom_16.gif></td>
|
276 |
+
<td><img src=assets/cinematic_shots_results/A_firefighter_standing_in_front_of_a_burning_forest_captured_with_a_dolly_zoom_4615820.gif></td>
|
277 |
+
<td><img src=assets/cinematic_shots_results/A_lion_sitting_on_top_of_a_cliff_captured_with_a_dolly_zoom_4114896.gif></td>
|
278 |
+
<td><img src=assets/cinematic_shots_results/A_Roman_soldier_standing_in_front_of_the_Colosseum_captured_with_a_dolly_zoom_7492004.gif></td>
|
279 |
+
</tr>
|
280 |
+
<tr>
|
281 |
+
<td width=25% style="text-align:center;color:gray;">"A man standing in room captured with a dolly zoom."</td>
|
282 |
+
<td width=25% style="text-align:center;">"A firefighter standing in front of a burning forest captured with a dolly zoom." </br> seed: 4615820 </br> noise_prior: 0.3</td>
|
283 |
+
<td width=25% style="text-align:center;">"A lion sitting on top of a cliff captured with a dolly zoom." </br> seed: 4114896 </br> noise_prior: 0.3</td>
|
284 |
+
<td width=25% style="text-align:center;">"A Roman soldier standing in front of the Colosseum captured with a dolly zoom." </br> seed: 7492004</td>
|
285 |
+
</tr>
|
286 |
+
</table>
|
287 |
+
|
288 |
+
#### 1.2 Zoom In
|
289 |
+
The reference video is shot with my own water cup. You can also pick up your cup or any other object to practice camera movements and turn it into imaginative videos. Create your AI films with customized camera movements!
|
290 |
+
|
291 |
+
```bash
|
292 |
+
python MotionDirector_inference.py --model /path/to/the/ZeroScope --prompt "A firefighter standing in front of a burning forest captured with a zoom in." --checkpoint_folder ./outputs/train/zoom_in/ --checkpoint_index 150 --noise_prior 0.3 --seed 1429227
|
293 |
+
```
|
294 |
+
<table class="center">
|
295 |
+
<tr>
|
296 |
+
<td style="text-align:center;"><b>Reference Video</b></td>
|
297 |
+
<td style="text-align:center;" colspan="3"><b>Videos Generated by MotionDirector</b></td>
|
298 |
+
</tr>
|
299 |
+
<tr>
|
300 |
+
<td><img src=assets/cinematic_shots_results/zoom_in_16.gif></td>
|
301 |
+
<td><img src=assets/cinematic_shots_results/A_firefighter_standing_in_front_of_a_burning_forest_captured_with_a_zoom_in_1429227.gif></td>
|
302 |
+
<td><img src=assets/cinematic_shots_results/A_lion_sitting_on_top_of_a_cliff_captured_with_a_zoom_in_487239.gif></td>
|
303 |
+
<td><img src=assets/cinematic_shots_results/A_Roman_soldier_standing_in_front_of_the_Colosseum_captured_with_a_zoom_in_1393184.gif></td>
|
304 |
+
</tr>
|
305 |
+
<tr>
|
306 |
+
<td width=25% style="text-align:center;color:gray;">"A cup in a lab captured with a zoom in."</td>
|
307 |
+
<td width=25% style="text-align:center;">"A firefighter standing in front of a burning forest captured with a zoom in." </br> seed: 1429227</td>
|
308 |
+
<td width=25% style="text-align:center;">"A lion sitting on top of a cliff captured with a zoom in." </br> seed: 487239 </td>
|
309 |
+
<td width=25% style="text-align:center;">"A Roman soldier standing in front of the Colosseum captured with a zoom in." </br> seed: 1393184</td>
|
310 |
+
</tr>
|
311 |
+
</table>
|
312 |
+
|
313 |
+
#### 1.3 Zoom Out
|
314 |
+
```bash
|
315 |
+
python MotionDirector_inference.py --model /path/to/the/ZeroScope --prompt "A firefighter standing in front of a burning forest captured with a zoom out." --checkpoint_folder ./outputs/train/zoom_out/ --checkpoint_index 150 --noise_prior 0.3 --seed 4971910
|
316 |
+
```
|
317 |
+
<table class="center">
|
318 |
+
<tr>
|
319 |
+
<td style="text-align:center;"><b>Reference Video</b></td>
|
320 |
+
<td style="text-align:center;" colspan="3"><b>Videos Generated by MotionDirector</b></td>
|
321 |
+
</tr>
|
322 |
+
<tr>
|
323 |
+
<td><img src=assets/cinematic_shots_results/zoom_out_16.gif></td>
|
324 |
+
<td><img src=assets/cinematic_shots_results/A_firefighter_standing_in_front_of_a_burning_forest_captured_with_a_zoom_out_4971910.gif></td>
|
325 |
+
<td><img src=assets/cinematic_shots_results/A_lion_sitting_on_top_of_a_cliff_captured_with_a_zoom_out_1767994.gif></td>
|
326 |
+
<td><img src=assets/cinematic_shots_results/A_Roman_soldier_standing_in_front_of_the_Colosseum_captured_with_a_zoom_out_8203639.gif></td>
|
327 |
+
</tr>
|
328 |
+
<tr>
|
329 |
+
<td width=25% style="text-align:center;color:gray;">"A cup in a lab captured with a zoom out."</td>
|
330 |
+
<td width=25% style="text-align:center;">"A firefighter standing in front of a burning forest captured with a zoom out." </br> seed: 4971910</td>
|
331 |
+
<td width=25% style="text-align:center;">"A lion sitting on top of a cliff captured with a zoom out." </br> seed: 1767994 </td>
|
332 |
+
<td width=25% style="text-align:center;">"A Roman soldier standing in front of the Colosseum captured with a zoom out." </br> seed: 8203639</td>
|
333 |
+
</tr>
|
334 |
+
</table>
|
335 |
+
|
336 |
+
More Cinematic Shots, to be continued ....
|
337 |
+
|
338 |
+
## More results
|
339 |
+
|
340 |
+
If you have a more impressive MotionDirector or generated videos, please feel free to open an issue and share them with us. We would greatly appreciate it.
|
341 |
+
Improvements to the code are also highly welcome.
|
342 |
+
|
343 |
+
Please refer to [Project Page](https://showlab.github.io/MotionDirector) for more results.
|
344 |
+
|
345 |
+
|
346 |
+
## Citation
|
347 |
+
|
348 |
+
|
349 |
+
```bibtex
|
350 |
+
|
351 |
+
@article{zhao2023motiondirector,
|
352 |
+
title={MotionDirector: Motion Customization of Text-to-Video Diffusion Models},
|
353 |
+
author={Zhao, Rui and Gu, Yuchao and Wu, Jay Zhangjie and Zhang, David Junhao and Liu, Jiawei and Wu, Weijia and Keppo, Jussi and Shou, Mike Zheng},
|
354 |
+
journal={arXiv preprint arXiv:2310.08465},
|
355 |
+
year={2023}
|
356 |
+
}
|
357 |
+
|
358 |
+
```
|
359 |
+
|
360 |
+
## Shoutouts
|
361 |
+
|
362 |
+
- This code builds on [diffusers](https://github.com/huggingface/diffusers) and [Text-To-Video-Finetuning](https://github.com/ExponentialML/Text-To-Video-Finetuning). Thanks for open-sourcing!
|
363 |
+
- Thanks to [camenduru](https://twitter.com/camenduru) for the [colab demo](https://github.com/camenduru/MotionDirector-colab).
|
364 |
+
- Thanks to [yhyu13](https://github.com/yhyu13) for the [Huggingface Repo](https://huggingface.co/Yhyu13/MotionDirector_LoRA).
|
app.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
from demo.motiondirector import MotionDirector
|
4 |
+
|
5 |
+
from huggingface_hub import snapshot_download
|
6 |
+
|
7 |
+
snapshot_download(repo_id="cerspense/zeroscope_v2_576w", local_dir="./zeroscope_v2_576w/")
|
8 |
+
snapshot_download(repo_id="ruizhaocv/MotionDirector", local_dir="./MotionDirector_pretrained")
|
9 |
+
|
10 |
+
is_spaces = True if "SPACE_ID" in os.environ else False
|
11 |
+
true_for_shared_ui = False # This will be true only if you are in a shared UI
|
12 |
+
if (is_spaces):
|
13 |
+
true_for_shared_ui = True if "ruizhaocv/MotionDirector" in os.environ['SPACE_ID'] else False
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
runner = MotionDirector()
|
18 |
+
|
19 |
+
|
20 |
+
def motiondirector(model_select, text_pormpt, neg_text_pormpt, random_seed=1, steps=25, guidance_scale=7.5, baseline_select=False):
|
21 |
+
return runner(model_select, text_pormpt, neg_text_pormpt, int(random_seed) if random_seed != "" else 1, int(steps), float(guidance_scale), baseline_select)
|
22 |
+
|
23 |
+
|
24 |
+
with gr.Blocks() as demo:
|
25 |
+
gr.HTML(
|
26 |
+
"""
|
27 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
28 |
+
<a href="https://github.com/showlab/MotionDirector" style="margin-right: 20px; text-decoration: none; display: flex; align-items: center;">
|
29 |
+
</a>
|
30 |
+
<div>
|
31 |
+
<h1 >MotionDirector: Motion Customization of Text-to-Video Diffusion Models</h1>
|
32 |
+
<h5 style="margin: 0;">More MotionDirectors are on the way. Stay tuned 🔥! Give us a star ✨ on Github for the latest update.</h5>
|
33 |
+
</br>
|
34 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;>
|
35 |
+
<a href="https://arxiv.org/abs/2310.08465"><img src="https://img.shields.io/badge/arXiv-MotionDirector-b31b1b.svg"></a>
|
36 |
+
<a href='https://showlab.github.io/MotionDirector'><img src='https://img.shields.io/badge/Project_Page-MotionDirector-green'></a>
|
37 |
+
<a href='https://github.com/showlab/MotionDirector'><img src='https://img.shields.io/badge/Github-MotionDirector-blue'></a>
|
38 |
+
</div>
|
39 |
+
</div>
|
40 |
+
</div>
|
41 |
+
""")
|
42 |
+
with gr.Row():
|
43 |
+
generated_video_baseline = gr.Video(format="mp4", label="Video Generated by base model (ZeroScope with same seed)")
|
44 |
+
generated_video = gr.Video(format="mp4", label="Video Generated by MotionDirector")
|
45 |
+
|
46 |
+
with gr.Column():
|
47 |
+
baseline_select = gr.Checkbox(label="Compare with baseline (ZeroScope with same seed)", info="Run baseline? Note: Inference time will be doubled.")
|
48 |
+
random_seed = gr.Textbox(label="Random seed", value=1, info="default: 1")
|
49 |
+
sampling_steps = gr.Textbox(label="Sampling steps", value=30, info="default: 30")
|
50 |
+
guidance_scale = gr.Textbox(label="Guidance scale", value=12, info="default: 12")
|
51 |
+
|
52 |
+
with gr.Row():
|
53 |
+
model_select = gr.Dropdown(
|
54 |
+
["1-1: [Cinematic Shots] -- Dolly Zoom (Hitchcockian Zoom)",
|
55 |
+
"1-2: [Cinematic Shots] -- Zoom In",
|
56 |
+
"1-3: [Cinematic Shots] -- Zoom Out",
|
57 |
+
"2-1: [Object Trajectory] -- Right to Left",
|
58 |
+
"2-2: [Object Trajectory] -- Left to Right",
|
59 |
+
"3-1: [Sports Concepts] -- Riding Bicycle",
|
60 |
+
"3-2: [Sports Concepts] -- Riding Horse",
|
61 |
+
"3-3: [Sports Concepts] -- Lifting Weights",
|
62 |
+
"3-4: [Sports Concepts] -- Playing Golf"
|
63 |
+
],
|
64 |
+
label="MotionDirector",
|
65 |
+
info="Which MotionDirector would you like to use!"
|
66 |
+
)
|
67 |
+
|
68 |
+
text_pormpt = gr.Textbox(label="Text Prompt", value='', placeholder="Input your text prompt here!")
|
69 |
+
neg_text_pormpt = gr.Textbox(label="Negative Text Prompt", value='', placeholder="default: None")
|
70 |
+
|
71 |
+
submit = gr.Button("Generate")
|
72 |
+
|
73 |
+
# when the `submit` button is clicked
|
74 |
+
submit.click(
|
75 |
+
motiondirector,
|
76 |
+
[model_select, text_pormpt, neg_text_pormpt, random_seed, sampling_steps, guidance_scale, baseline_select],
|
77 |
+
[generated_video, generated_video_baseline]
|
78 |
+
)
|
79 |
+
|
80 |
+
# Examples
|
81 |
+
gr.Markdown("## Examples")
|
82 |
+
gr.Examples(
|
83 |
+
fn=motiondirector,
|
84 |
+
examples=[
|
85 |
+
["1-1: [Cinematic Shots] -- Dolly Zoom (Hitchcockian Zoom)", "A lion sitting on top of a cliff captured with a dolly zoom.", 1675932],
|
86 |
+
["1-2: [Cinematic Shots] -- Zoom In", "A firefighter standing in front of a burning forest captured with a zoom in.", 1429227],
|
87 |
+
["1-3: [Cinematic Shots] -- Zoom Out", "A lion sitting on top of a cliff captured with a zoom out.", 1767994],
|
88 |
+
["2-1: [Object Trajectory] -- Right to Left", "A tank is running on the moon.", 8551187],
|
89 |
+
["2-2: [Object Trajectory] -- Left to Right", "A tiger is running in the forest.", 3463673],
|
90 |
+
["3-1: [Sports Concepts] -- Riding Bicycle", "An astronaut is riding a bicycle past the pyramids Mars 4K high quailty highly detailed.", 4422954],
|
91 |
+
["3-2: [Sports Concepts] -- Riding Horse", "A man riding an elephant through the jungle.", 6230765],
|
92 |
+
["3-3: [Sports Concepts] -- Lifting Weights", "A panda is lifting weights in a garden.", 1699276],
|
93 |
+
["3-4: [Sports Concepts] -- Playing Golf", "A man is playing golf in front of the White House.", 8870450],
|
94 |
+
],
|
95 |
+
inputs=[model_select, text_pormpt, random_seed],
|
96 |
+
outputs=generated_video,
|
97 |
+
)
|
98 |
+
|
99 |
+
demo.queue(max_size=15)
|
100 |
+
demo.launch(share=True)
|
demo/MotionDirector_gradio.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import imageio
|
3 |
+
import numpy as np
|
4 |
+
import gradio as gr
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
from demo.motiondirector import MotionDirector
|
8 |
+
|
9 |
+
runner = MotionDirector()
|
10 |
+
|
11 |
+
|
12 |
+
def motiondirector(model_select, text_pormpt, neg_text_pormpt, random_seed=1, steps=25, guidance_scale=7.5, baseline_select=False):
|
13 |
+
return runner(model_select, text_pormpt, neg_text_pormpt, int(random_seed) if random_seed != "" else 1, int(steps), float(guidance_scale), baseline_select)
|
14 |
+
|
15 |
+
|
16 |
+
with gr.Blocks() as demo:
|
17 |
+
gr.HTML(
|
18 |
+
"""
|
19 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
20 |
+
<a href="https://github.com/showlab/MotionDirector" style="margin-right: 20px; text-decoration: none; display: flex; align-items: center;">
|
21 |
+
</a>
|
22 |
+
<div>
|
23 |
+
<h1 >MotionDirector: Motion Customization of Text-to-Video Diffusion Models</h1>
|
24 |
+
<h5 style="margin: 0;">More MotionDirectors are on the way. Stay tuned 🔥! Give us a star ✨ on Github for the latest update.</h5>
|
25 |
+
</br>
|
26 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;>
|
27 |
+
<a href="https://arxiv.org/abs/2310.08465"><img src="https://img.shields.io/badge/arXiv-MotionDirector-b31b1b.svg"></a>
|
28 |
+
<a href='https://showlab.github.io/MotionDirector'><img src='https://img.shields.io/badge/Project_Page-MotionDirector-green'></a>
|
29 |
+
<a href='https://github.com/showlab/MotionDirector'><img src='https://img.shields.io/badge/Github-MotionDirector-blue'></a>
|
30 |
+
</div>
|
31 |
+
</div>
|
32 |
+
</div>
|
33 |
+
""")
|
34 |
+
with gr.Row():
|
35 |
+
generated_video_baseline = gr.Video(format="mp4", label="Video Generated by base model (ZeroScope with same seed)")
|
36 |
+
generated_video = gr.Video(format="mp4", label="Video Generated by MotionDirector")
|
37 |
+
|
38 |
+
with gr.Column():
|
39 |
+
baseline_select = gr.Checkbox(label="Compare with baseline (ZeroScope with same seed)", info="Run baseline? Note: Inference time will be doubled.")
|
40 |
+
random_seed = gr.Textbox(label="Random seed", value=1, info="default: 1")
|
41 |
+
sampling_steps = gr.Textbox(label="Sampling steps", value=30, info="default: 30")
|
42 |
+
guidance_scale = gr.Textbox(label="Guidance scale", value=12, info="default: 12")
|
43 |
+
|
44 |
+
with gr.Row():
|
45 |
+
model_select = gr.Dropdown(
|
46 |
+
["1-1: [Cinematic Shots] -- Dolly Zoom (Hitchcockian Zoom)",
|
47 |
+
"1-2: [Cinematic Shots] -- Zoom In",
|
48 |
+
"1-3: [Cinematic Shots] -- Zoom Out",
|
49 |
+
"2-1: [Object Trajectory] -- Right to Left",
|
50 |
+
"2-2: [Object Trajectory] -- Left to Right",
|
51 |
+
"3-1: [Sports Concepts] -- Riding Bicycle",
|
52 |
+
"3-2: [Sports Concepts] -- Riding Horse",
|
53 |
+
"3-3: [Sports Concepts] -- Lifting Weights",
|
54 |
+
"3-4: [Sports Concepts] -- Playing Golf"
|
55 |
+
],
|
56 |
+
label="MotionDirector",
|
57 |
+
info="Which MotionDirector would you like to use!"
|
58 |
+
)
|
59 |
+
|
60 |
+
text_pormpt = gr.Textbox(label="Text Prompt", value='', placeholder="Input your text prompt here!")
|
61 |
+
neg_text_pormpt = gr.Textbox(label="Negative Text Prompt", value='', placeholder="default: None")
|
62 |
+
|
63 |
+
submit = gr.Button("Generate")
|
64 |
+
|
65 |
+
# when the `submit` button is clicked
|
66 |
+
submit.click(
|
67 |
+
motiondirector,
|
68 |
+
[model_select, text_pormpt, neg_text_pormpt, random_seed, sampling_steps, guidance_scale, baseline_select],
|
69 |
+
[generated_video, generated_video_baseline]
|
70 |
+
)
|
71 |
+
|
72 |
+
# Examples
|
73 |
+
gr.Markdown("## Examples")
|
74 |
+
gr.Examples(
|
75 |
+
fn=motiondirector,
|
76 |
+
examples=[
|
77 |
+
["1-1: [Cinematic Shots] -- Dolly Zoom (Hitchcockian Zoom)", "A lion sitting on top of a cliff captured with a dolly zoom.", 1675932],
|
78 |
+
["1-2: [Cinematic Shots] -- Zoom In", "A firefighter standing in front of a burning forest captured with a zoom in.", 1429227],
|
79 |
+
["1-3: [Cinematic Shots] -- Zoom Out", "A lion sitting on top of a cliff captured with a zoom out.", 1767994],
|
80 |
+
["2-1: [Object Trajectory] -- Right to Left", "A tank is running on the moon.", 8551187],
|
81 |
+
["2-2: [Object Trajectory] -- Left to Right", "A tiger is running in the forest.", 3463673],
|
82 |
+
["3-1: [Sports Concepts] -- Riding Bicycle", "An astronaut is riding a bicycle past the pyramids Mars 4K high quailty highly detailed.", 4422954],
|
83 |
+
["3-2: [Sports Concepts] -- Riding Horse", "A man riding an elephant through the jungle.", 6230765],
|
84 |
+
["3-3: [Sports Concepts] -- Lifting Weights", "A panda is lifting weights in a garden.", 1699276],
|
85 |
+
["3-4: [Sports Concepts] -- Playing Golf", "A man is playing golf in front of the White House.", 8870450],
|
86 |
+
],
|
87 |
+
inputs=[model_select, text_pormpt, random_seed],
|
88 |
+
outputs=generated_video,
|
89 |
+
)
|
90 |
+
|
91 |
+
demo.queue(max_size=15)
|
92 |
+
demo.launch(share=True)
|
demo/motiondirector.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import warnings
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from diffusers import DDIMScheduler, TextToVideoSDPipeline
|
7 |
+
from einops import rearrange
|
8 |
+
from torch import Tensor
|
9 |
+
from torch.nn.functional import interpolate
|
10 |
+
from tqdm import trange
|
11 |
+
import random
|
12 |
+
|
13 |
+
from MotionDirector_train import export_to_video, handle_memory_attention, load_primary_models, unet_and_text_g_c, freeze_models
|
14 |
+
from utils.lora_handler import LoraHandler
|
15 |
+
from utils.ddim_utils import ddim_inversion
|
16 |
+
from utils.lora import extract_lora_child_module
|
17 |
+
import imageio
|
18 |
+
|
19 |
+
|
20 |
+
def initialize_pipeline(
|
21 |
+
model: str,
|
22 |
+
device: str = "cuda",
|
23 |
+
xformers: bool = True,
|
24 |
+
sdp: bool = True,
|
25 |
+
lora_path: str = "",
|
26 |
+
lora_rank: int = 32,
|
27 |
+
lora_scale: float = 1.0,
|
28 |
+
):
|
29 |
+
with warnings.catch_warnings():
|
30 |
+
warnings.simplefilter("ignore")
|
31 |
+
|
32 |
+
scheduler, tokenizer, text_encoder, vae, unet = load_primary_models(model)
|
33 |
+
|
34 |
+
# Freeze any necessary models
|
35 |
+
freeze_models([vae, text_encoder, unet])
|
36 |
+
|
37 |
+
# Enable xformers if available
|
38 |
+
handle_memory_attention(xformers, sdp, unet)
|
39 |
+
|
40 |
+
lora_manager_temporal = LoraHandler(
|
41 |
+
version="cloneofsimo",
|
42 |
+
use_unet_lora=True,
|
43 |
+
use_text_lora=False,
|
44 |
+
save_for_webui=False,
|
45 |
+
only_for_webui=False,
|
46 |
+
unet_replace_modules=["TransformerTemporalModel"],
|
47 |
+
text_encoder_replace_modules=None,
|
48 |
+
lora_bias=None
|
49 |
+
)
|
50 |
+
|
51 |
+
unet_lora_params, unet_negation = lora_manager_temporal.add_lora_to_model(
|
52 |
+
True, unet, lora_manager_temporal.unet_replace_modules, 0, lora_path, r=lora_rank, scale=lora_scale)
|
53 |
+
|
54 |
+
unet.eval()
|
55 |
+
text_encoder.eval()
|
56 |
+
unet_and_text_g_c(unet, text_encoder, False, False)
|
57 |
+
|
58 |
+
pipe = TextToVideoSDPipeline.from_pretrained(
|
59 |
+
pretrained_model_name_or_path=model,
|
60 |
+
scheduler=scheduler,
|
61 |
+
tokenizer=tokenizer,
|
62 |
+
text_encoder=text_encoder.to(device=device, dtype=torch.half),
|
63 |
+
vae=vae.to(device=device, dtype=torch.half),
|
64 |
+
unet=unet.to(device=device, dtype=torch.half),
|
65 |
+
)
|
66 |
+
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
67 |
+
|
68 |
+
return pipe
|
69 |
+
|
70 |
+
|
71 |
+
def inverse_video(pipe, latents, num_steps):
|
72 |
+
ddim_inv_scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
73 |
+
ddim_inv_scheduler.set_timesteps(num_steps)
|
74 |
+
|
75 |
+
ddim_inv_latent = ddim_inversion(
|
76 |
+
pipe, ddim_inv_scheduler, video_latent=latents.to(pipe.device),
|
77 |
+
num_inv_steps=num_steps, prompt="")[-1]
|
78 |
+
return ddim_inv_latent
|
79 |
+
|
80 |
+
|
81 |
+
def prepare_input_latents(
|
82 |
+
pipe: TextToVideoSDPipeline,
|
83 |
+
batch_size: int,
|
84 |
+
num_frames: int,
|
85 |
+
height: int,
|
86 |
+
width: int,
|
87 |
+
latents_path:str,
|
88 |
+
noise_prior: float
|
89 |
+
):
|
90 |
+
# initialize with random gaussian noise
|
91 |
+
scale = pipe.vae_scale_factor
|
92 |
+
shape = (batch_size, pipe.unet.config.in_channels, num_frames, height // scale, width // scale)
|
93 |
+
if noise_prior > 0.:
|
94 |
+
cached_latents = torch.load(latents_path)
|
95 |
+
if 'inversion_noise' not in cached_latents:
|
96 |
+
latents = inverse_video(pipe, cached_latents['latents'].unsqueeze(0), 50).squeeze(0)
|
97 |
+
else:
|
98 |
+
latents = torch.load(latents_path)['inversion_noise'].unsqueeze(0)
|
99 |
+
if latents.shape[0] != batch_size:
|
100 |
+
latents = latents.repeat(batch_size, 1, 1, 1, 1)
|
101 |
+
if latents.shape != shape:
|
102 |
+
latents = interpolate(rearrange(latents, "b c f h w -> (b f) c h w", b=batch_size), (height // scale, width // scale), mode='bilinear')
|
103 |
+
latents = rearrange(latents, "(b f) c h w -> b c f h w", b=batch_size)
|
104 |
+
noise = torch.randn_like(latents, dtype=torch.half)
|
105 |
+
latents_base = noise
|
106 |
+
latents = (noise_prior) ** 0.5 * latents + (1 - noise_prior) ** 0.5 * noise
|
107 |
+
else:
|
108 |
+
latents = torch.randn(shape, dtype=torch.half)
|
109 |
+
latents_base = latents
|
110 |
+
|
111 |
+
return latents, latents_base
|
112 |
+
|
113 |
+
|
114 |
+
class MotionDirector():
|
115 |
+
def __init__(self):
|
116 |
+
self.version = "0.0.0"
|
117 |
+
self.foundation_model_path = "./zeroscope_v2_576w/"
|
118 |
+
self.lora_path = "./MotionDirector_pretrained/dolly_zoom_(hitchcockian_zoom)/checkpoint-default/temporal/lora"
|
119 |
+
with torch.autocast("cuda", dtype=torch.half):
|
120 |
+
self.pipe = initialize_pipeline(model=self.foundation_model_path, lora_path=self.lora_path, lora_scale=1)
|
121 |
+
|
122 |
+
def reload_lora(self, lora_path):
|
123 |
+
if lora_path != self.lora_path:
|
124 |
+
self.lora_path = lora_path
|
125 |
+
with torch.autocast("cuda", dtype=torch.half):
|
126 |
+
self.pipe = initialize_pipeline(model=self.foundation_model_path, lora_path=self.lora_path)
|
127 |
+
|
128 |
+
def __call__(self, model_select, text_pormpt, neg_text_pormpt, random_seed, steps, guidance_scale, baseline_select):
|
129 |
+
model_select = str(model_select)
|
130 |
+
out_name = f"./outputs/inference"
|
131 |
+
out_name += f"{text_pormpt}".replace(' ', '_').replace(',', '').replace('.', '')
|
132 |
+
|
133 |
+
model_select_type = model_select.split('--')[1].strip()
|
134 |
+
model_select_type = model_select_type.lower().replace(' ', '_')
|
135 |
+
|
136 |
+
lora_path = f"./MotionDirector_pretrained/{model_select_type}/checkpoint-default/temporal/lora"
|
137 |
+
self.reload_lora(lora_path)
|
138 |
+
latents_folder = f"./MotionDirector_pretrained/{model_select_type}/cached_latents"
|
139 |
+
latents_path = f"{latents_folder}/{random.choice(os.listdir(latents_folder))}"
|
140 |
+
assert os.path.exists(lora_path)
|
141 |
+
|
142 |
+
if '3-' in model_select:
|
143 |
+
noise_prior = 0.
|
144 |
+
elif '2-' in model_select:
|
145 |
+
noise_prior = 0.5
|
146 |
+
else:
|
147 |
+
noise_prior = 0.3
|
148 |
+
|
149 |
+
if random_seed > 1000:
|
150 |
+
torch.manual_seed(random_seed)
|
151 |
+
else:
|
152 |
+
random_seed = random.randint(100, 10000000)
|
153 |
+
torch.manual_seed(random_seed)
|
154 |
+
device = "cuda"
|
155 |
+
with torch.autocast(device, dtype=torch.half):
|
156 |
+
# prepare input latents
|
157 |
+
with torch.no_grad():
|
158 |
+
init_latents,init_latents_base = prepare_input_latents(
|
159 |
+
pipe=self.pipe,
|
160 |
+
batch_size=1,
|
161 |
+
num_frames=16,
|
162 |
+
height=384,
|
163 |
+
width=384,
|
164 |
+
latents_path=latents_path,
|
165 |
+
noise_prior=noise_prior
|
166 |
+
)
|
167 |
+
video_frames = self.pipe(
|
168 |
+
prompt=text_pormpt,
|
169 |
+
negative_prompt=neg_text_pormpt,
|
170 |
+
width=384,
|
171 |
+
height=384,
|
172 |
+
num_frames=16,
|
173 |
+
num_inference_steps=steps,
|
174 |
+
guidance_scale=guidance_scale,
|
175 |
+
latents=init_latents
|
176 |
+
).frames
|
177 |
+
|
178 |
+
|
179 |
+
out_file = f"{out_name}_{random_seed}.mp4"
|
180 |
+
os.makedirs(os.path.dirname(out_file), exist_ok=True)
|
181 |
+
export_to_video(video_frames, out_file, 8)
|
182 |
+
|
183 |
+
if baseline_select:
|
184 |
+
with torch.autocast("cuda", dtype=torch.half):
|
185 |
+
|
186 |
+
loras = extract_lora_child_module(self.pipe.unet, target_replace_module=["TransformerTemporalModel"])
|
187 |
+
for lora_i in loras:
|
188 |
+
lora_i.scale = 0.
|
189 |
+
|
190 |
+
# self.pipe = initialize_pipeline(model=self.foundation_model_path, lora_path=self.lora_path,
|
191 |
+
# lora_scale=0.)
|
192 |
+
with torch.no_grad():
|
193 |
+
video_frames = self.pipe(
|
194 |
+
prompt=text_pormpt,
|
195 |
+
negative_prompt=neg_text_pormpt,
|
196 |
+
width=384,
|
197 |
+
height=384,
|
198 |
+
num_frames=16,
|
199 |
+
num_inference_steps=steps,
|
200 |
+
guidance_scale=guidance_scale,
|
201 |
+
latents=init_latents_base,
|
202 |
+
).frames
|
203 |
+
|
204 |
+
out_file_baseline = f"{out_name}_{random_seed}_baseline.mp4"
|
205 |
+
os.makedirs(os.path.dirname(out_file_baseline), exist_ok=True)
|
206 |
+
export_to_video(video_frames, out_file_baseline, 8)
|
207 |
+
# with torch.autocast("cuda", dtype=torch.half):
|
208 |
+
# self.pipe = initialize_pipeline(model=self.foundation_model_path, lora_path=self.lora_path,
|
209 |
+
# lora_scale=1.)
|
210 |
+
loras = extract_lora_child_module(self.pipe.unet,
|
211 |
+
target_replace_module=["TransformerTemporalModel"])
|
212 |
+
for lora_i in loras:
|
213 |
+
lora_i.scale = 1.
|
214 |
+
|
215 |
+
else:
|
216 |
+
out_file_baseline = None
|
217 |
+
|
218 |
+
return [out_file, out_file_baseline]
|
models/unet_3d_blocks.py
ADDED
@@ -0,0 +1,842 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.utils.checkpoint as checkpoint
|
17 |
+
from torch import nn
|
18 |
+
from diffusers.models.resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D
|
19 |
+
from diffusers.models.transformer_2d import Transformer2DModel
|
20 |
+
from diffusers.models.transformer_temporal import TransformerTemporalModel
|
21 |
+
|
22 |
+
# Assign gradient checkpoint function to simple variable for readability.
|
23 |
+
g_c = checkpoint.checkpoint
|
24 |
+
|
25 |
+
def use_temporal(module, num_frames, x):
|
26 |
+
if num_frames == 1:
|
27 |
+
if isinstance(module, TransformerTemporalModel):
|
28 |
+
return {"sample": x}
|
29 |
+
else:
|
30 |
+
return x
|
31 |
+
|
32 |
+
def custom_checkpoint(module, mode=None):
|
33 |
+
if mode == None: raise ValueError('Mode for gradient checkpointing cannot be none.')
|
34 |
+
custom_forward = None
|
35 |
+
|
36 |
+
if mode == 'resnet':
|
37 |
+
def custom_forward(hidden_states, temb):
|
38 |
+
inputs = module(hidden_states, temb)
|
39 |
+
return inputs
|
40 |
+
|
41 |
+
if mode == 'attn':
|
42 |
+
def custom_forward(
|
43 |
+
hidden_states,
|
44 |
+
encoder_hidden_states=None,
|
45 |
+
cross_attention_kwargs=None
|
46 |
+
):
|
47 |
+
inputs = module(
|
48 |
+
hidden_states,
|
49 |
+
encoder_hidden_states,
|
50 |
+
cross_attention_kwargs
|
51 |
+
)
|
52 |
+
return inputs
|
53 |
+
|
54 |
+
if mode == 'temp':
|
55 |
+
def custom_forward(hidden_states, num_frames=None):
|
56 |
+
inputs = use_temporal(module, num_frames, hidden_states)
|
57 |
+
if inputs is None: inputs = module(
|
58 |
+
hidden_states,
|
59 |
+
num_frames=num_frames
|
60 |
+
)
|
61 |
+
return inputs
|
62 |
+
|
63 |
+
return custom_forward
|
64 |
+
|
65 |
+
def transformer_g_c(transformer, sample, num_frames):
|
66 |
+
sample = g_c(custom_checkpoint(transformer, mode='temp'),
|
67 |
+
sample, num_frames, use_reentrant=False
|
68 |
+
)['sample']
|
69 |
+
|
70 |
+
return sample
|
71 |
+
|
72 |
+
def cross_attn_g_c(
|
73 |
+
attn,
|
74 |
+
temp_attn,
|
75 |
+
resnet,
|
76 |
+
temp_conv,
|
77 |
+
hidden_states,
|
78 |
+
encoder_hidden_states,
|
79 |
+
cross_attention_kwargs,
|
80 |
+
temb,
|
81 |
+
num_frames,
|
82 |
+
inverse_temp=False
|
83 |
+
):
|
84 |
+
|
85 |
+
def ordered_g_c(idx):
|
86 |
+
|
87 |
+
# Self and CrossAttention
|
88 |
+
if idx == 0: return g_c(custom_checkpoint(attn, mode='attn'),
|
89 |
+
hidden_states, encoder_hidden_states,cross_attention_kwargs, use_reentrant=False
|
90 |
+
)['sample']
|
91 |
+
|
92 |
+
# Temporal Self and CrossAttention
|
93 |
+
if idx == 1: return g_c(custom_checkpoint(temp_attn, mode='temp'),
|
94 |
+
hidden_states, num_frames, use_reentrant=False)['sample']
|
95 |
+
|
96 |
+
# Resnets
|
97 |
+
if idx == 2: return g_c(custom_checkpoint(resnet, mode='resnet'),
|
98 |
+
hidden_states, temb, use_reentrant=False)
|
99 |
+
|
100 |
+
# Temporal Convolutions
|
101 |
+
if idx == 3: return g_c(custom_checkpoint(temp_conv, mode='temp'),
|
102 |
+
hidden_states, num_frames, use_reentrant=False
|
103 |
+
)
|
104 |
+
|
105 |
+
# Here we call the function depending on the order in which they are called.
|
106 |
+
# For some layers, the orders are different, so we access the appropriate one by index.
|
107 |
+
|
108 |
+
if not inverse_temp:
|
109 |
+
for idx in [0,1,2,3]: hidden_states = ordered_g_c(idx)
|
110 |
+
else:
|
111 |
+
for idx in [2,3,0,1]: hidden_states = ordered_g_c(idx)
|
112 |
+
|
113 |
+
return hidden_states
|
114 |
+
|
115 |
+
def up_down_g_c(resnet, temp_conv, hidden_states, temb, num_frames):
|
116 |
+
hidden_states = g_c(custom_checkpoint(resnet, mode='resnet'), hidden_states, temb, use_reentrant=False)
|
117 |
+
hidden_states = g_c(custom_checkpoint(temp_conv, mode='temp'),
|
118 |
+
hidden_states, num_frames, use_reentrant=False
|
119 |
+
)
|
120 |
+
return hidden_states
|
121 |
+
|
122 |
+
def get_down_block(
|
123 |
+
down_block_type,
|
124 |
+
num_layers,
|
125 |
+
in_channels,
|
126 |
+
out_channels,
|
127 |
+
temb_channels,
|
128 |
+
add_downsample,
|
129 |
+
resnet_eps,
|
130 |
+
resnet_act_fn,
|
131 |
+
attn_num_head_channels,
|
132 |
+
resnet_groups=None,
|
133 |
+
cross_attention_dim=None,
|
134 |
+
downsample_padding=None,
|
135 |
+
dual_cross_attention=False,
|
136 |
+
use_linear_projection=True,
|
137 |
+
only_cross_attention=False,
|
138 |
+
upcast_attention=False,
|
139 |
+
resnet_time_scale_shift="default",
|
140 |
+
):
|
141 |
+
if down_block_type == "DownBlock3D":
|
142 |
+
return DownBlock3D(
|
143 |
+
num_layers=num_layers,
|
144 |
+
in_channels=in_channels,
|
145 |
+
out_channels=out_channels,
|
146 |
+
temb_channels=temb_channels,
|
147 |
+
add_downsample=add_downsample,
|
148 |
+
resnet_eps=resnet_eps,
|
149 |
+
resnet_act_fn=resnet_act_fn,
|
150 |
+
resnet_groups=resnet_groups,
|
151 |
+
downsample_padding=downsample_padding,
|
152 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
153 |
+
)
|
154 |
+
elif down_block_type == "CrossAttnDownBlock3D":
|
155 |
+
if cross_attention_dim is None:
|
156 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
|
157 |
+
return CrossAttnDownBlock3D(
|
158 |
+
num_layers=num_layers,
|
159 |
+
in_channels=in_channels,
|
160 |
+
out_channels=out_channels,
|
161 |
+
temb_channels=temb_channels,
|
162 |
+
add_downsample=add_downsample,
|
163 |
+
resnet_eps=resnet_eps,
|
164 |
+
resnet_act_fn=resnet_act_fn,
|
165 |
+
resnet_groups=resnet_groups,
|
166 |
+
downsample_padding=downsample_padding,
|
167 |
+
cross_attention_dim=cross_attention_dim,
|
168 |
+
attn_num_head_channels=attn_num_head_channels,
|
169 |
+
dual_cross_attention=dual_cross_attention,
|
170 |
+
use_linear_projection=use_linear_projection,
|
171 |
+
only_cross_attention=only_cross_attention,
|
172 |
+
upcast_attention=upcast_attention,
|
173 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
174 |
+
)
|
175 |
+
raise ValueError(f"{down_block_type} does not exist.")
|
176 |
+
|
177 |
+
|
178 |
+
def get_up_block(
|
179 |
+
up_block_type,
|
180 |
+
num_layers,
|
181 |
+
in_channels,
|
182 |
+
out_channels,
|
183 |
+
prev_output_channel,
|
184 |
+
temb_channels,
|
185 |
+
add_upsample,
|
186 |
+
resnet_eps,
|
187 |
+
resnet_act_fn,
|
188 |
+
attn_num_head_channels,
|
189 |
+
resnet_groups=None,
|
190 |
+
cross_attention_dim=None,
|
191 |
+
dual_cross_attention=False,
|
192 |
+
use_linear_projection=True,
|
193 |
+
only_cross_attention=False,
|
194 |
+
upcast_attention=False,
|
195 |
+
resnet_time_scale_shift="default",
|
196 |
+
):
|
197 |
+
if up_block_type == "UpBlock3D":
|
198 |
+
return UpBlock3D(
|
199 |
+
num_layers=num_layers,
|
200 |
+
in_channels=in_channels,
|
201 |
+
out_channels=out_channels,
|
202 |
+
prev_output_channel=prev_output_channel,
|
203 |
+
temb_channels=temb_channels,
|
204 |
+
add_upsample=add_upsample,
|
205 |
+
resnet_eps=resnet_eps,
|
206 |
+
resnet_act_fn=resnet_act_fn,
|
207 |
+
resnet_groups=resnet_groups,
|
208 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
209 |
+
)
|
210 |
+
elif up_block_type == "CrossAttnUpBlock3D":
|
211 |
+
if cross_attention_dim is None:
|
212 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
|
213 |
+
return CrossAttnUpBlock3D(
|
214 |
+
num_layers=num_layers,
|
215 |
+
in_channels=in_channels,
|
216 |
+
out_channels=out_channels,
|
217 |
+
prev_output_channel=prev_output_channel,
|
218 |
+
temb_channels=temb_channels,
|
219 |
+
add_upsample=add_upsample,
|
220 |
+
resnet_eps=resnet_eps,
|
221 |
+
resnet_act_fn=resnet_act_fn,
|
222 |
+
resnet_groups=resnet_groups,
|
223 |
+
cross_attention_dim=cross_attention_dim,
|
224 |
+
attn_num_head_channels=attn_num_head_channels,
|
225 |
+
dual_cross_attention=dual_cross_attention,
|
226 |
+
use_linear_projection=use_linear_projection,
|
227 |
+
only_cross_attention=only_cross_attention,
|
228 |
+
upcast_attention=upcast_attention,
|
229 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
230 |
+
)
|
231 |
+
raise ValueError(f"{up_block_type} does not exist.")
|
232 |
+
|
233 |
+
|
234 |
+
class UNetMidBlock3DCrossAttn(nn.Module):
|
235 |
+
def __init__(
|
236 |
+
self,
|
237 |
+
in_channels: int,
|
238 |
+
temb_channels: int,
|
239 |
+
dropout: float = 0.0,
|
240 |
+
num_layers: int = 1,
|
241 |
+
resnet_eps: float = 1e-6,
|
242 |
+
resnet_time_scale_shift: str = "default",
|
243 |
+
resnet_act_fn: str = "swish",
|
244 |
+
resnet_groups: int = 32,
|
245 |
+
resnet_pre_norm: bool = True,
|
246 |
+
attn_num_head_channels=1,
|
247 |
+
output_scale_factor=1.0,
|
248 |
+
cross_attention_dim=1280,
|
249 |
+
dual_cross_attention=False,
|
250 |
+
use_linear_projection=True,
|
251 |
+
upcast_attention=False,
|
252 |
+
):
|
253 |
+
super().__init__()
|
254 |
+
|
255 |
+
self.gradient_checkpointing = False
|
256 |
+
self.has_cross_attention = True
|
257 |
+
self.attn_num_head_channels = attn_num_head_channels
|
258 |
+
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
259 |
+
|
260 |
+
# there is always at least one resnet
|
261 |
+
resnets = [
|
262 |
+
ResnetBlock2D(
|
263 |
+
in_channels=in_channels,
|
264 |
+
out_channels=in_channels,
|
265 |
+
temb_channels=temb_channels,
|
266 |
+
eps=resnet_eps,
|
267 |
+
groups=resnet_groups,
|
268 |
+
dropout=dropout,
|
269 |
+
time_embedding_norm=resnet_time_scale_shift,
|
270 |
+
non_linearity=resnet_act_fn,
|
271 |
+
output_scale_factor=output_scale_factor,
|
272 |
+
pre_norm=resnet_pre_norm,
|
273 |
+
)
|
274 |
+
]
|
275 |
+
temp_convs = [
|
276 |
+
TemporalConvLayer(
|
277 |
+
in_channels,
|
278 |
+
in_channels,
|
279 |
+
dropout=0.1
|
280 |
+
)
|
281 |
+
]
|
282 |
+
attentions = []
|
283 |
+
temp_attentions = []
|
284 |
+
|
285 |
+
for _ in range(num_layers):
|
286 |
+
attentions.append(
|
287 |
+
Transformer2DModel(
|
288 |
+
in_channels // attn_num_head_channels,
|
289 |
+
attn_num_head_channels,
|
290 |
+
in_channels=in_channels,
|
291 |
+
num_layers=1,
|
292 |
+
cross_attention_dim=cross_attention_dim,
|
293 |
+
norm_num_groups=resnet_groups,
|
294 |
+
use_linear_projection=use_linear_projection,
|
295 |
+
upcast_attention=upcast_attention,
|
296 |
+
)
|
297 |
+
)
|
298 |
+
temp_attentions.append(
|
299 |
+
TransformerTemporalModel(
|
300 |
+
in_channels // attn_num_head_channels,
|
301 |
+
attn_num_head_channels,
|
302 |
+
in_channels=in_channels,
|
303 |
+
num_layers=1,
|
304 |
+
cross_attention_dim=cross_attention_dim,
|
305 |
+
norm_num_groups=resnet_groups,
|
306 |
+
)
|
307 |
+
)
|
308 |
+
resnets.append(
|
309 |
+
ResnetBlock2D(
|
310 |
+
in_channels=in_channels,
|
311 |
+
out_channels=in_channels,
|
312 |
+
temb_channels=temb_channels,
|
313 |
+
eps=resnet_eps,
|
314 |
+
groups=resnet_groups,
|
315 |
+
dropout=dropout,
|
316 |
+
time_embedding_norm=resnet_time_scale_shift,
|
317 |
+
non_linearity=resnet_act_fn,
|
318 |
+
output_scale_factor=output_scale_factor,
|
319 |
+
pre_norm=resnet_pre_norm,
|
320 |
+
)
|
321 |
+
)
|
322 |
+
temp_convs.append(
|
323 |
+
TemporalConvLayer(
|
324 |
+
in_channels,
|
325 |
+
in_channels,
|
326 |
+
dropout=0.1
|
327 |
+
)
|
328 |
+
)
|
329 |
+
|
330 |
+
self.resnets = nn.ModuleList(resnets)
|
331 |
+
self.temp_convs = nn.ModuleList(temp_convs)
|
332 |
+
self.attentions = nn.ModuleList(attentions)
|
333 |
+
self.temp_attentions = nn.ModuleList(temp_attentions)
|
334 |
+
|
335 |
+
def forward(
|
336 |
+
self,
|
337 |
+
hidden_states,
|
338 |
+
temb=None,
|
339 |
+
encoder_hidden_states=None,
|
340 |
+
attention_mask=None,
|
341 |
+
num_frames=1,
|
342 |
+
cross_attention_kwargs=None,
|
343 |
+
):
|
344 |
+
if self.gradient_checkpointing:
|
345 |
+
hidden_states = up_down_g_c(
|
346 |
+
self.resnets[0],
|
347 |
+
self.temp_convs[0],
|
348 |
+
hidden_states,
|
349 |
+
temb,
|
350 |
+
num_frames
|
351 |
+
)
|
352 |
+
else:
|
353 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
354 |
+
hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames)
|
355 |
+
|
356 |
+
for attn, temp_attn, resnet, temp_conv in zip(
|
357 |
+
self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:]
|
358 |
+
):
|
359 |
+
if self.gradient_checkpointing:
|
360 |
+
hidden_states = cross_attn_g_c(
|
361 |
+
attn,
|
362 |
+
temp_attn,
|
363 |
+
resnet,
|
364 |
+
temp_conv,
|
365 |
+
hidden_states,
|
366 |
+
encoder_hidden_states,
|
367 |
+
cross_attention_kwargs,
|
368 |
+
temb,
|
369 |
+
num_frames
|
370 |
+
)
|
371 |
+
else:
|
372 |
+
hidden_states = attn(
|
373 |
+
hidden_states,
|
374 |
+
encoder_hidden_states=encoder_hidden_states,
|
375 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
376 |
+
).sample
|
377 |
+
|
378 |
+
if num_frames > 1:
|
379 |
+
hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample
|
380 |
+
|
381 |
+
hidden_states = resnet(hidden_states, temb)
|
382 |
+
|
383 |
+
if num_frames > 1:
|
384 |
+
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
|
385 |
+
|
386 |
+
return hidden_states
|
387 |
+
|
388 |
+
|
389 |
+
class CrossAttnDownBlock3D(nn.Module):
|
390 |
+
def __init__(
|
391 |
+
self,
|
392 |
+
in_channels: int,
|
393 |
+
out_channels: int,
|
394 |
+
temb_channels: int,
|
395 |
+
dropout: float = 0.0,
|
396 |
+
num_layers: int = 1,
|
397 |
+
resnet_eps: float = 1e-6,
|
398 |
+
resnet_time_scale_shift: str = "default",
|
399 |
+
resnet_act_fn: str = "swish",
|
400 |
+
resnet_groups: int = 32,
|
401 |
+
resnet_pre_norm: bool = True,
|
402 |
+
attn_num_head_channels=1,
|
403 |
+
cross_attention_dim=1280,
|
404 |
+
output_scale_factor=1.0,
|
405 |
+
downsample_padding=1,
|
406 |
+
add_downsample=True,
|
407 |
+
dual_cross_attention=False,
|
408 |
+
use_linear_projection=False,
|
409 |
+
only_cross_attention=False,
|
410 |
+
upcast_attention=False,
|
411 |
+
):
|
412 |
+
super().__init__()
|
413 |
+
resnets = []
|
414 |
+
attentions = []
|
415 |
+
temp_attentions = []
|
416 |
+
temp_convs = []
|
417 |
+
|
418 |
+
self.gradient_checkpointing = False
|
419 |
+
self.has_cross_attention = True
|
420 |
+
self.attn_num_head_channels = attn_num_head_channels
|
421 |
+
|
422 |
+
for i in range(num_layers):
|
423 |
+
in_channels = in_channels if i == 0 else out_channels
|
424 |
+
resnets.append(
|
425 |
+
ResnetBlock2D(
|
426 |
+
in_channels=in_channels,
|
427 |
+
out_channels=out_channels,
|
428 |
+
temb_channels=temb_channels,
|
429 |
+
eps=resnet_eps,
|
430 |
+
groups=resnet_groups,
|
431 |
+
dropout=dropout,
|
432 |
+
time_embedding_norm=resnet_time_scale_shift,
|
433 |
+
non_linearity=resnet_act_fn,
|
434 |
+
output_scale_factor=output_scale_factor,
|
435 |
+
pre_norm=resnet_pre_norm,
|
436 |
+
)
|
437 |
+
)
|
438 |
+
temp_convs.append(
|
439 |
+
TemporalConvLayer(
|
440 |
+
out_channels,
|
441 |
+
out_channels,
|
442 |
+
dropout=0.1
|
443 |
+
)
|
444 |
+
)
|
445 |
+
attentions.append(
|
446 |
+
Transformer2DModel(
|
447 |
+
out_channels // attn_num_head_channels,
|
448 |
+
attn_num_head_channels,
|
449 |
+
in_channels=out_channels,
|
450 |
+
num_layers=1,
|
451 |
+
cross_attention_dim=cross_attention_dim,
|
452 |
+
norm_num_groups=resnet_groups,
|
453 |
+
use_linear_projection=use_linear_projection,
|
454 |
+
only_cross_attention=only_cross_attention,
|
455 |
+
upcast_attention=upcast_attention,
|
456 |
+
)
|
457 |
+
)
|
458 |
+
temp_attentions.append(
|
459 |
+
TransformerTemporalModel(
|
460 |
+
out_channels // attn_num_head_channels,
|
461 |
+
attn_num_head_channels,
|
462 |
+
in_channels=out_channels,
|
463 |
+
num_layers=1,
|
464 |
+
cross_attention_dim=cross_attention_dim,
|
465 |
+
norm_num_groups=resnet_groups,
|
466 |
+
)
|
467 |
+
)
|
468 |
+
self.resnets = nn.ModuleList(resnets)
|
469 |
+
self.temp_convs = nn.ModuleList(temp_convs)
|
470 |
+
self.attentions = nn.ModuleList(attentions)
|
471 |
+
self.temp_attentions = nn.ModuleList(temp_attentions)
|
472 |
+
|
473 |
+
if add_downsample:
|
474 |
+
self.downsamplers = nn.ModuleList(
|
475 |
+
[
|
476 |
+
Downsample2D(
|
477 |
+
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
478 |
+
)
|
479 |
+
]
|
480 |
+
)
|
481 |
+
else:
|
482 |
+
self.downsamplers = None
|
483 |
+
|
484 |
+
def forward(
|
485 |
+
self,
|
486 |
+
hidden_states,
|
487 |
+
temb=None,
|
488 |
+
encoder_hidden_states=None,
|
489 |
+
attention_mask=None,
|
490 |
+
num_frames=1,
|
491 |
+
cross_attention_kwargs=None,
|
492 |
+
):
|
493 |
+
# TODO(Patrick, William) - attention mask is not used
|
494 |
+
output_states = ()
|
495 |
+
|
496 |
+
for resnet, temp_conv, attn, temp_attn in zip(
|
497 |
+
self.resnets, self.temp_convs, self.attentions, self.temp_attentions
|
498 |
+
):
|
499 |
+
|
500 |
+
if self.gradient_checkpointing:
|
501 |
+
hidden_states = cross_attn_g_c(
|
502 |
+
attn,
|
503 |
+
temp_attn,
|
504 |
+
resnet,
|
505 |
+
temp_conv,
|
506 |
+
hidden_states,
|
507 |
+
encoder_hidden_states,
|
508 |
+
cross_attention_kwargs,
|
509 |
+
temb,
|
510 |
+
num_frames,
|
511 |
+
inverse_temp=True
|
512 |
+
)
|
513 |
+
else:
|
514 |
+
hidden_states = resnet(hidden_states, temb)
|
515 |
+
|
516 |
+
if num_frames > 1:
|
517 |
+
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
|
518 |
+
|
519 |
+
hidden_states = attn(
|
520 |
+
hidden_states,
|
521 |
+
encoder_hidden_states=encoder_hidden_states,
|
522 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
523 |
+
).sample
|
524 |
+
|
525 |
+
if num_frames > 1:
|
526 |
+
hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample
|
527 |
+
|
528 |
+
output_states += (hidden_states,)
|
529 |
+
|
530 |
+
if self.downsamplers is not None:
|
531 |
+
for downsampler in self.downsamplers:
|
532 |
+
hidden_states = downsampler(hidden_states)
|
533 |
+
|
534 |
+
output_states += (hidden_states,)
|
535 |
+
|
536 |
+
return hidden_states, output_states
|
537 |
+
|
538 |
+
|
539 |
+
class DownBlock3D(nn.Module):
|
540 |
+
def __init__(
|
541 |
+
self,
|
542 |
+
in_channels: int,
|
543 |
+
out_channels: int,
|
544 |
+
temb_channels: int,
|
545 |
+
dropout: float = 0.0,
|
546 |
+
num_layers: int = 1,
|
547 |
+
resnet_eps: float = 1e-6,
|
548 |
+
resnet_time_scale_shift: str = "default",
|
549 |
+
resnet_act_fn: str = "swish",
|
550 |
+
resnet_groups: int = 32,
|
551 |
+
resnet_pre_norm: bool = True,
|
552 |
+
output_scale_factor=1.0,
|
553 |
+
add_downsample=True,
|
554 |
+
downsample_padding=1,
|
555 |
+
):
|
556 |
+
super().__init__()
|
557 |
+
resnets = []
|
558 |
+
temp_convs = []
|
559 |
+
|
560 |
+
self.gradient_checkpointing = False
|
561 |
+
for i in range(num_layers):
|
562 |
+
in_channels = in_channels if i == 0 else out_channels
|
563 |
+
resnets.append(
|
564 |
+
ResnetBlock2D(
|
565 |
+
in_channels=in_channels,
|
566 |
+
out_channels=out_channels,
|
567 |
+
temb_channels=temb_channels,
|
568 |
+
eps=resnet_eps,
|
569 |
+
groups=resnet_groups,
|
570 |
+
dropout=dropout,
|
571 |
+
time_embedding_norm=resnet_time_scale_shift,
|
572 |
+
non_linearity=resnet_act_fn,
|
573 |
+
output_scale_factor=output_scale_factor,
|
574 |
+
pre_norm=resnet_pre_norm,
|
575 |
+
)
|
576 |
+
)
|
577 |
+
temp_convs.append(
|
578 |
+
TemporalConvLayer(
|
579 |
+
out_channels,
|
580 |
+
out_channels,
|
581 |
+
dropout=0.1
|
582 |
+
)
|
583 |
+
)
|
584 |
+
|
585 |
+
self.resnets = nn.ModuleList(resnets)
|
586 |
+
self.temp_convs = nn.ModuleList(temp_convs)
|
587 |
+
|
588 |
+
if add_downsample:
|
589 |
+
self.downsamplers = nn.ModuleList(
|
590 |
+
[
|
591 |
+
Downsample2D(
|
592 |
+
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
593 |
+
)
|
594 |
+
]
|
595 |
+
)
|
596 |
+
else:
|
597 |
+
self.downsamplers = None
|
598 |
+
|
599 |
+
def forward(self, hidden_states, temb=None, num_frames=1):
|
600 |
+
output_states = ()
|
601 |
+
|
602 |
+
for resnet, temp_conv in zip(self.resnets, self.temp_convs):
|
603 |
+
if self.gradient_checkpointing:
|
604 |
+
hidden_states = up_down_g_c(resnet, temp_conv, hidden_states, temb, num_frames)
|
605 |
+
else:
|
606 |
+
hidden_states = resnet(hidden_states, temb)
|
607 |
+
|
608 |
+
if num_frames > 1:
|
609 |
+
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
|
610 |
+
|
611 |
+
output_states += (hidden_states,)
|
612 |
+
|
613 |
+
if self.downsamplers is not None:
|
614 |
+
for downsampler in self.downsamplers:
|
615 |
+
hidden_states = downsampler(hidden_states)
|
616 |
+
|
617 |
+
output_states += (hidden_states,)
|
618 |
+
|
619 |
+
return hidden_states, output_states
|
620 |
+
|
621 |
+
|
622 |
+
class CrossAttnUpBlock3D(nn.Module):
|
623 |
+
def __init__(
|
624 |
+
self,
|
625 |
+
in_channels: int,
|
626 |
+
out_channels: int,
|
627 |
+
prev_output_channel: int,
|
628 |
+
temb_channels: int,
|
629 |
+
dropout: float = 0.0,
|
630 |
+
num_layers: int = 1,
|
631 |
+
resnet_eps: float = 1e-6,
|
632 |
+
resnet_time_scale_shift: str = "default",
|
633 |
+
resnet_act_fn: str = "swish",
|
634 |
+
resnet_groups: int = 32,
|
635 |
+
resnet_pre_norm: bool = True,
|
636 |
+
attn_num_head_channels=1,
|
637 |
+
cross_attention_dim=1280,
|
638 |
+
output_scale_factor=1.0,
|
639 |
+
add_upsample=True,
|
640 |
+
dual_cross_attention=False,
|
641 |
+
use_linear_projection=False,
|
642 |
+
only_cross_attention=False,
|
643 |
+
upcast_attention=False,
|
644 |
+
):
|
645 |
+
super().__init__()
|
646 |
+
resnets = []
|
647 |
+
temp_convs = []
|
648 |
+
attentions = []
|
649 |
+
temp_attentions = []
|
650 |
+
|
651 |
+
self.gradient_checkpointing = False
|
652 |
+
self.has_cross_attention = True
|
653 |
+
self.attn_num_head_channels = attn_num_head_channels
|
654 |
+
|
655 |
+
for i in range(num_layers):
|
656 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
657 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
658 |
+
|
659 |
+
resnets.append(
|
660 |
+
ResnetBlock2D(
|
661 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
662 |
+
out_channels=out_channels,
|
663 |
+
temb_channels=temb_channels,
|
664 |
+
eps=resnet_eps,
|
665 |
+
groups=resnet_groups,
|
666 |
+
dropout=dropout,
|
667 |
+
time_embedding_norm=resnet_time_scale_shift,
|
668 |
+
non_linearity=resnet_act_fn,
|
669 |
+
output_scale_factor=output_scale_factor,
|
670 |
+
pre_norm=resnet_pre_norm,
|
671 |
+
)
|
672 |
+
)
|
673 |
+
temp_convs.append(
|
674 |
+
TemporalConvLayer(
|
675 |
+
out_channels,
|
676 |
+
out_channels,
|
677 |
+
dropout=0.1
|
678 |
+
)
|
679 |
+
)
|
680 |
+
attentions.append(
|
681 |
+
Transformer2DModel(
|
682 |
+
out_channels // attn_num_head_channels,
|
683 |
+
attn_num_head_channels,
|
684 |
+
in_channels=out_channels,
|
685 |
+
num_layers=1,
|
686 |
+
cross_attention_dim=cross_attention_dim,
|
687 |
+
norm_num_groups=resnet_groups,
|
688 |
+
use_linear_projection=use_linear_projection,
|
689 |
+
only_cross_attention=only_cross_attention,
|
690 |
+
upcast_attention=upcast_attention,
|
691 |
+
)
|
692 |
+
)
|
693 |
+
temp_attentions.append(
|
694 |
+
TransformerTemporalModel(
|
695 |
+
out_channels // attn_num_head_channels,
|
696 |
+
attn_num_head_channels,
|
697 |
+
in_channels=out_channels,
|
698 |
+
num_layers=1,
|
699 |
+
cross_attention_dim=cross_attention_dim,
|
700 |
+
norm_num_groups=resnet_groups,
|
701 |
+
)
|
702 |
+
)
|
703 |
+
self.resnets = nn.ModuleList(resnets)
|
704 |
+
self.temp_convs = nn.ModuleList(temp_convs)
|
705 |
+
self.attentions = nn.ModuleList(attentions)
|
706 |
+
self.temp_attentions = nn.ModuleList(temp_attentions)
|
707 |
+
|
708 |
+
if add_upsample:
|
709 |
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
710 |
+
else:
|
711 |
+
self.upsamplers = None
|
712 |
+
|
713 |
+
def forward(
|
714 |
+
self,
|
715 |
+
hidden_states,
|
716 |
+
res_hidden_states_tuple,
|
717 |
+
temb=None,
|
718 |
+
encoder_hidden_states=None,
|
719 |
+
upsample_size=None,
|
720 |
+
attention_mask=None,
|
721 |
+
num_frames=1,
|
722 |
+
cross_attention_kwargs=None,
|
723 |
+
):
|
724 |
+
# TODO(Patrick, William) - attention mask is not used
|
725 |
+
for resnet, temp_conv, attn, temp_attn in zip(
|
726 |
+
self.resnets, self.temp_convs, self.attentions, self.temp_attentions
|
727 |
+
):
|
728 |
+
# pop res hidden states
|
729 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
730 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
731 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
732 |
+
|
733 |
+
if self.gradient_checkpointing:
|
734 |
+
hidden_states = cross_attn_g_c(
|
735 |
+
attn,
|
736 |
+
temp_attn,
|
737 |
+
resnet,
|
738 |
+
temp_conv,
|
739 |
+
hidden_states,
|
740 |
+
encoder_hidden_states,
|
741 |
+
cross_attention_kwargs,
|
742 |
+
temb,
|
743 |
+
num_frames,
|
744 |
+
inverse_temp=True
|
745 |
+
)
|
746 |
+
else:
|
747 |
+
hidden_states = resnet(hidden_states, temb)
|
748 |
+
|
749 |
+
if num_frames > 1:
|
750 |
+
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
|
751 |
+
|
752 |
+
hidden_states = attn(
|
753 |
+
hidden_states,
|
754 |
+
encoder_hidden_states=encoder_hidden_states,
|
755 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
756 |
+
).sample
|
757 |
+
|
758 |
+
if num_frames > 1:
|
759 |
+
hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample
|
760 |
+
|
761 |
+
if self.upsamplers is not None:
|
762 |
+
for upsampler in self.upsamplers:
|
763 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
764 |
+
|
765 |
+
return hidden_states
|
766 |
+
|
767 |
+
|
768 |
+
class UpBlock3D(nn.Module):
|
769 |
+
def __init__(
|
770 |
+
self,
|
771 |
+
in_channels: int,
|
772 |
+
prev_output_channel: int,
|
773 |
+
out_channels: int,
|
774 |
+
temb_channels: int,
|
775 |
+
dropout: float = 0.0,
|
776 |
+
num_layers: int = 1,
|
777 |
+
resnet_eps: float = 1e-6,
|
778 |
+
resnet_time_scale_shift: str = "default",
|
779 |
+
resnet_act_fn: str = "swish",
|
780 |
+
resnet_groups: int = 32,
|
781 |
+
resnet_pre_norm: bool = True,
|
782 |
+
output_scale_factor=1.0,
|
783 |
+
add_upsample=True,
|
784 |
+
):
|
785 |
+
super().__init__()
|
786 |
+
resnets = []
|
787 |
+
temp_convs = []
|
788 |
+
self.gradient_checkpointing = False
|
789 |
+
for i in range(num_layers):
|
790 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
791 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
792 |
+
|
793 |
+
resnets.append(
|
794 |
+
ResnetBlock2D(
|
795 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
796 |
+
out_channels=out_channels,
|
797 |
+
temb_channels=temb_channels,
|
798 |
+
eps=resnet_eps,
|
799 |
+
groups=resnet_groups,
|
800 |
+
dropout=dropout,
|
801 |
+
time_embedding_norm=resnet_time_scale_shift,
|
802 |
+
non_linearity=resnet_act_fn,
|
803 |
+
output_scale_factor=output_scale_factor,
|
804 |
+
pre_norm=resnet_pre_norm,
|
805 |
+
)
|
806 |
+
)
|
807 |
+
temp_convs.append(
|
808 |
+
TemporalConvLayer(
|
809 |
+
out_channels,
|
810 |
+
out_channels,
|
811 |
+
dropout=0.1
|
812 |
+
)
|
813 |
+
)
|
814 |
+
|
815 |
+
self.resnets = nn.ModuleList(resnets)
|
816 |
+
self.temp_convs = nn.ModuleList(temp_convs)
|
817 |
+
|
818 |
+
if add_upsample:
|
819 |
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
820 |
+
else:
|
821 |
+
self.upsamplers = None
|
822 |
+
|
823 |
+
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1):
|
824 |
+
for resnet, temp_conv in zip(self.resnets, self.temp_convs):
|
825 |
+
# pop res hidden states
|
826 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
827 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
828 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
829 |
+
|
830 |
+
if self.gradient_checkpointing:
|
831 |
+
hidden_states = up_down_g_c(resnet, temp_conv, hidden_states, temb, num_frames)
|
832 |
+
else:
|
833 |
+
hidden_states = resnet(hidden_states, temb)
|
834 |
+
|
835 |
+
if num_frames > 1:
|
836 |
+
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
|
837 |
+
|
838 |
+
if self.upsamplers is not None:
|
839 |
+
for upsampler in self.upsamplers:
|
840 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
841 |
+
|
842 |
+
return hidden_states
|
models/unet_3d_condition.py
ADDED
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
|
2 |
+
# Copyright 2023 The ModelScope Team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
from dataclasses import dataclass
|
16 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.utils.checkpoint
|
21 |
+
|
22 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
23 |
+
from diffusers.utils import BaseOutput, logging
|
24 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
25 |
+
from diffusers.models.modeling_utils import ModelMixin
|
26 |
+
from diffusers.models.transformer_temporal import TransformerTemporalModel
|
27 |
+
from .unet_3d_blocks import (
|
28 |
+
CrossAttnDownBlock3D,
|
29 |
+
CrossAttnUpBlock3D,
|
30 |
+
DownBlock3D,
|
31 |
+
UNetMidBlock3DCrossAttn,
|
32 |
+
UpBlock3D,
|
33 |
+
get_down_block,
|
34 |
+
get_up_block,
|
35 |
+
transformer_g_c
|
36 |
+
)
|
37 |
+
|
38 |
+
|
39 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
40 |
+
|
41 |
+
|
42 |
+
@dataclass
|
43 |
+
class UNet3DConditionOutput(BaseOutput):
|
44 |
+
"""
|
45 |
+
Args:
|
46 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
|
47 |
+
Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
|
48 |
+
"""
|
49 |
+
|
50 |
+
sample: torch.FloatTensor
|
51 |
+
|
52 |
+
|
53 |
+
class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
54 |
+
r"""
|
55 |
+
UNet3DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
|
56 |
+
and returns sample shaped output.
|
57 |
+
|
58 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
|
59 |
+
implements for all the models (such as downloading or saving, etc.)
|
60 |
+
|
61 |
+
Parameters:
|
62 |
+
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
63 |
+
Height and width of input/output sample.
|
64 |
+
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
|
65 |
+
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
|
66 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
67 |
+
The tuple of downsample blocks to use.
|
68 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
|
69 |
+
The tuple of upsample blocks to use.
|
70 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
71 |
+
The tuple of output channels for each block.
|
72 |
+
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
|
73 |
+
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
|
74 |
+
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
|
75 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
76 |
+
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
|
77 |
+
If `None`, it will skip the normalization and activation layers in post-processing
|
78 |
+
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
79 |
+
cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
|
80 |
+
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
81 |
+
"""
|
82 |
+
|
83 |
+
_supports_gradient_checkpointing = True
|
84 |
+
|
85 |
+
@register_to_config
|
86 |
+
def __init__(
|
87 |
+
self,
|
88 |
+
sample_size: Optional[int] = None,
|
89 |
+
in_channels: int = 4,
|
90 |
+
out_channels: int = 4,
|
91 |
+
down_block_types: Tuple[str] = (
|
92 |
+
"CrossAttnDownBlock3D",
|
93 |
+
"CrossAttnDownBlock3D",
|
94 |
+
"CrossAttnDownBlock3D",
|
95 |
+
"DownBlock3D",
|
96 |
+
),
|
97 |
+
up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"),
|
98 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
99 |
+
layers_per_block: int = 2,
|
100 |
+
downsample_padding: int = 1,
|
101 |
+
mid_block_scale_factor: float = 1,
|
102 |
+
act_fn: str = "silu",
|
103 |
+
norm_num_groups: Optional[int] = 32,
|
104 |
+
norm_eps: float = 1e-5,
|
105 |
+
cross_attention_dim: int = 1024,
|
106 |
+
attention_head_dim: Union[int, Tuple[int]] = 64,
|
107 |
+
):
|
108 |
+
super().__init__()
|
109 |
+
|
110 |
+
self.sample_size = sample_size
|
111 |
+
self.gradient_checkpointing = False
|
112 |
+
# Check inputs
|
113 |
+
if len(down_block_types) != len(up_block_types):
|
114 |
+
raise ValueError(
|
115 |
+
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
116 |
+
)
|
117 |
+
|
118 |
+
if len(block_out_channels) != len(down_block_types):
|
119 |
+
raise ValueError(
|
120 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
121 |
+
)
|
122 |
+
|
123 |
+
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
|
124 |
+
raise ValueError(
|
125 |
+
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
|
126 |
+
)
|
127 |
+
|
128 |
+
# input
|
129 |
+
conv_in_kernel = 3
|
130 |
+
conv_out_kernel = 3
|
131 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
132 |
+
self.conv_in = nn.Conv2d(
|
133 |
+
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
134 |
+
)
|
135 |
+
|
136 |
+
# time
|
137 |
+
time_embed_dim = block_out_channels[0] * 4
|
138 |
+
self.time_proj = Timesteps(block_out_channels[0], True, 0)
|
139 |
+
timestep_input_dim = block_out_channels[0]
|
140 |
+
|
141 |
+
self.time_embedding = TimestepEmbedding(
|
142 |
+
timestep_input_dim,
|
143 |
+
time_embed_dim,
|
144 |
+
act_fn=act_fn,
|
145 |
+
)
|
146 |
+
|
147 |
+
self.transformer_in = TransformerTemporalModel(
|
148 |
+
num_attention_heads=8,
|
149 |
+
attention_head_dim=attention_head_dim,
|
150 |
+
in_channels=block_out_channels[0],
|
151 |
+
num_layers=1,
|
152 |
+
)
|
153 |
+
|
154 |
+
# class embedding
|
155 |
+
self.down_blocks = nn.ModuleList([])
|
156 |
+
self.up_blocks = nn.ModuleList([])
|
157 |
+
|
158 |
+
if isinstance(attention_head_dim, int):
|
159 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
160 |
+
|
161 |
+
# down
|
162 |
+
output_channel = block_out_channels[0]
|
163 |
+
for i, down_block_type in enumerate(down_block_types):
|
164 |
+
input_channel = output_channel
|
165 |
+
output_channel = block_out_channels[i]
|
166 |
+
is_final_block = i == len(block_out_channels) - 1
|
167 |
+
|
168 |
+
down_block = get_down_block(
|
169 |
+
down_block_type,
|
170 |
+
num_layers=layers_per_block,
|
171 |
+
in_channels=input_channel,
|
172 |
+
out_channels=output_channel,
|
173 |
+
temb_channels=time_embed_dim,
|
174 |
+
add_downsample=not is_final_block,
|
175 |
+
resnet_eps=norm_eps,
|
176 |
+
resnet_act_fn=act_fn,
|
177 |
+
resnet_groups=norm_num_groups,
|
178 |
+
cross_attention_dim=cross_attention_dim,
|
179 |
+
attn_num_head_channels=attention_head_dim[i],
|
180 |
+
downsample_padding=downsample_padding,
|
181 |
+
dual_cross_attention=False,
|
182 |
+
)
|
183 |
+
self.down_blocks.append(down_block)
|
184 |
+
|
185 |
+
# mid
|
186 |
+
self.mid_block = UNetMidBlock3DCrossAttn(
|
187 |
+
in_channels=block_out_channels[-1],
|
188 |
+
temb_channels=time_embed_dim,
|
189 |
+
resnet_eps=norm_eps,
|
190 |
+
resnet_act_fn=act_fn,
|
191 |
+
output_scale_factor=mid_block_scale_factor,
|
192 |
+
cross_attention_dim=cross_attention_dim,
|
193 |
+
attn_num_head_channels=attention_head_dim[-1],
|
194 |
+
resnet_groups=norm_num_groups,
|
195 |
+
dual_cross_attention=False,
|
196 |
+
)
|
197 |
+
|
198 |
+
# count how many layers upsample the images
|
199 |
+
self.num_upsamplers = 0
|
200 |
+
|
201 |
+
# up
|
202 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
203 |
+
reversed_attention_head_dim = list(reversed(attention_head_dim))
|
204 |
+
|
205 |
+
output_channel = reversed_block_out_channels[0]
|
206 |
+
for i, up_block_type in enumerate(up_block_types):
|
207 |
+
is_final_block = i == len(block_out_channels) - 1
|
208 |
+
|
209 |
+
prev_output_channel = output_channel
|
210 |
+
output_channel = reversed_block_out_channels[i]
|
211 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
212 |
+
|
213 |
+
# add upsample block for all BUT final layer
|
214 |
+
if not is_final_block:
|
215 |
+
add_upsample = True
|
216 |
+
self.num_upsamplers += 1
|
217 |
+
else:
|
218 |
+
add_upsample = False
|
219 |
+
|
220 |
+
up_block = get_up_block(
|
221 |
+
up_block_type,
|
222 |
+
num_layers=layers_per_block + 1,
|
223 |
+
in_channels=input_channel,
|
224 |
+
out_channels=output_channel,
|
225 |
+
prev_output_channel=prev_output_channel,
|
226 |
+
temb_channels=time_embed_dim,
|
227 |
+
add_upsample=add_upsample,
|
228 |
+
resnet_eps=norm_eps,
|
229 |
+
resnet_act_fn=act_fn,
|
230 |
+
resnet_groups=norm_num_groups,
|
231 |
+
cross_attention_dim=cross_attention_dim,
|
232 |
+
attn_num_head_channels=reversed_attention_head_dim[i],
|
233 |
+
dual_cross_attention=False,
|
234 |
+
)
|
235 |
+
self.up_blocks.append(up_block)
|
236 |
+
prev_output_channel = output_channel
|
237 |
+
|
238 |
+
# out
|
239 |
+
if norm_num_groups is not None:
|
240 |
+
self.conv_norm_out = nn.GroupNorm(
|
241 |
+
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
242 |
+
)
|
243 |
+
self.conv_act = nn.SiLU()
|
244 |
+
else:
|
245 |
+
self.conv_norm_out = None
|
246 |
+
self.conv_act = None
|
247 |
+
|
248 |
+
conv_out_padding = (conv_out_kernel - 1) // 2
|
249 |
+
self.conv_out = nn.Conv2d(
|
250 |
+
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
|
251 |
+
)
|
252 |
+
|
253 |
+
def set_attention_slice(self, slice_size):
|
254 |
+
r"""
|
255 |
+
Enable sliced attention computation.
|
256 |
+
|
257 |
+
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
258 |
+
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
259 |
+
|
260 |
+
Args:
|
261 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
262 |
+
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
263 |
+
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
|
264 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
265 |
+
must be a multiple of `slice_size`.
|
266 |
+
"""
|
267 |
+
sliceable_head_dims = []
|
268 |
+
|
269 |
+
def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
|
270 |
+
if hasattr(module, "set_attention_slice"):
|
271 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
272 |
+
|
273 |
+
for child in module.children():
|
274 |
+
fn_recursive_retrieve_slicable_dims(child)
|
275 |
+
|
276 |
+
# retrieve number of attention layers
|
277 |
+
for module in self.children():
|
278 |
+
fn_recursive_retrieve_slicable_dims(module)
|
279 |
+
|
280 |
+
num_slicable_layers = len(sliceable_head_dims)
|
281 |
+
|
282 |
+
if slice_size == "auto":
|
283 |
+
# half the attention head size is usually a good trade-off between
|
284 |
+
# speed and memory
|
285 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
286 |
+
elif slice_size == "max":
|
287 |
+
# make smallest slice possible
|
288 |
+
slice_size = num_slicable_layers * [1]
|
289 |
+
|
290 |
+
slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
291 |
+
|
292 |
+
if len(slice_size) != len(sliceable_head_dims):
|
293 |
+
raise ValueError(
|
294 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
295 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
296 |
+
)
|
297 |
+
|
298 |
+
for i in range(len(slice_size)):
|
299 |
+
size = slice_size[i]
|
300 |
+
dim = sliceable_head_dims[i]
|
301 |
+
if size is not None and size > dim:
|
302 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
303 |
+
|
304 |
+
# Recursively walk through all the children.
|
305 |
+
# Any children which exposes the set_attention_slice method
|
306 |
+
# gets the message
|
307 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
308 |
+
if hasattr(module, "set_attention_slice"):
|
309 |
+
module.set_attention_slice(slice_size.pop())
|
310 |
+
|
311 |
+
for child in module.children():
|
312 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
313 |
+
|
314 |
+
reversed_slice_size = list(reversed(slice_size))
|
315 |
+
for module in self.children():
|
316 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
317 |
+
|
318 |
+
def _set_gradient_checkpointing(self, value=False):
|
319 |
+
self.gradient_checkpointing = value
|
320 |
+
self.mid_block.gradient_checkpointing = value
|
321 |
+
for module in self.down_blocks + self.up_blocks:
|
322 |
+
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
|
323 |
+
module.gradient_checkpointing = value
|
324 |
+
|
325 |
+
def forward(
|
326 |
+
self,
|
327 |
+
sample: torch.FloatTensor,
|
328 |
+
timestep: Union[torch.Tensor, float, int],
|
329 |
+
encoder_hidden_states: torch.Tensor,
|
330 |
+
class_labels: Optional[torch.Tensor] = None,
|
331 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
332 |
+
attention_mask: Optional[torch.Tensor] = None,
|
333 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
334 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
335 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
336 |
+
return_dict: bool = True,
|
337 |
+
) -> Union[UNet3DConditionOutput, Tuple]:
|
338 |
+
r"""
|
339 |
+
Args:
|
340 |
+
sample (`torch.FloatTensor`): (batch, num_frames, channel, height, width) noisy inputs tensor
|
341 |
+
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
342 |
+
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
343 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
344 |
+
Whether or not to return a [`models.unet_2d_condition.UNet3DConditionOutput`] instead of a plain tuple.
|
345 |
+
cross_attention_kwargs (`dict`, *optional*):
|
346 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
347 |
+
`self.processor` in
|
348 |
+
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
349 |
+
|
350 |
+
Returns:
|
351 |
+
[`~models.unet_2d_condition.UNet3DConditionOutput`] or `tuple`:
|
352 |
+
[`~models.unet_2d_condition.UNet3DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
353 |
+
returning a tuple, the first element is the sample tensor.
|
354 |
+
"""
|
355 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
356 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
357 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
358 |
+
# on the fly if necessary.
|
359 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
360 |
+
|
361 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
362 |
+
forward_upsample_size = False
|
363 |
+
upsample_size = None
|
364 |
+
|
365 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
366 |
+
logger.info("Forward upsample size to force interpolation output size.")
|
367 |
+
forward_upsample_size = True
|
368 |
+
|
369 |
+
# prepare attention_mask
|
370 |
+
if attention_mask is not None:
|
371 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
372 |
+
attention_mask = attention_mask.unsqueeze(1)
|
373 |
+
|
374 |
+
# 1. time
|
375 |
+
timesteps = timestep
|
376 |
+
if not torch.is_tensor(timesteps):
|
377 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
378 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
379 |
+
is_mps = sample.device.type == "mps"
|
380 |
+
if isinstance(timestep, float):
|
381 |
+
dtype = torch.float32 if is_mps else torch.float64
|
382 |
+
else:
|
383 |
+
dtype = torch.int32 if is_mps else torch.int64
|
384 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
385 |
+
elif len(timesteps.shape) == 0:
|
386 |
+
timesteps = timesteps[None].to(sample.device)
|
387 |
+
|
388 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
389 |
+
num_frames = sample.shape[2]
|
390 |
+
timesteps = timesteps.expand(sample.shape[0])
|
391 |
+
|
392 |
+
t_emb = self.time_proj(timesteps)
|
393 |
+
|
394 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
395 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
396 |
+
# there might be better ways to encapsulate this.
|
397 |
+
t_emb = t_emb.to(dtype=self.dtype)
|
398 |
+
|
399 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
400 |
+
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
|
401 |
+
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
|
402 |
+
|
403 |
+
# 2. pre-process
|
404 |
+
sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
|
405 |
+
sample = self.conv_in(sample)
|
406 |
+
|
407 |
+
if num_frames > 1:
|
408 |
+
if self.gradient_checkpointing:
|
409 |
+
sample = transformer_g_c(self.transformer_in, sample, num_frames)
|
410 |
+
else:
|
411 |
+
sample = self.transformer_in(sample, num_frames=num_frames).sample
|
412 |
+
|
413 |
+
# 3. down
|
414 |
+
down_block_res_samples = (sample,)
|
415 |
+
for downsample_block in self.down_blocks:
|
416 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
417 |
+
sample, res_samples = downsample_block(
|
418 |
+
hidden_states=sample,
|
419 |
+
temb=emb,
|
420 |
+
encoder_hidden_states=encoder_hidden_states,
|
421 |
+
attention_mask=attention_mask,
|
422 |
+
num_frames=num_frames,
|
423 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
424 |
+
)
|
425 |
+
else:
|
426 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames)
|
427 |
+
|
428 |
+
down_block_res_samples += res_samples
|
429 |
+
|
430 |
+
if down_block_additional_residuals is not None:
|
431 |
+
new_down_block_res_samples = ()
|
432 |
+
|
433 |
+
for down_block_res_sample, down_block_additional_residual in zip(
|
434 |
+
down_block_res_samples, down_block_additional_residuals
|
435 |
+
):
|
436 |
+
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
437 |
+
new_down_block_res_samples += (down_block_res_sample,)
|
438 |
+
|
439 |
+
down_block_res_samples = new_down_block_res_samples
|
440 |
+
|
441 |
+
# 4. mid
|
442 |
+
if self.mid_block is not None:
|
443 |
+
sample = self.mid_block(
|
444 |
+
sample,
|
445 |
+
emb,
|
446 |
+
encoder_hidden_states=encoder_hidden_states,
|
447 |
+
attention_mask=attention_mask,
|
448 |
+
num_frames=num_frames,
|
449 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
450 |
+
)
|
451 |
+
|
452 |
+
if mid_block_additional_residual is not None:
|
453 |
+
sample = sample + mid_block_additional_residual
|
454 |
+
|
455 |
+
# 5. up
|
456 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
457 |
+
is_final_block = i == len(self.up_blocks) - 1
|
458 |
+
|
459 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
460 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
461 |
+
|
462 |
+
# if we have not reached the final block and need to forward the
|
463 |
+
# upsample size, we do it here
|
464 |
+
if not is_final_block and forward_upsample_size:
|
465 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
466 |
+
|
467 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
468 |
+
sample = upsample_block(
|
469 |
+
hidden_states=sample,
|
470 |
+
temb=emb,
|
471 |
+
res_hidden_states_tuple=res_samples,
|
472 |
+
encoder_hidden_states=encoder_hidden_states,
|
473 |
+
upsample_size=upsample_size,
|
474 |
+
attention_mask=attention_mask,
|
475 |
+
num_frames=num_frames,
|
476 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
477 |
+
)
|
478 |
+
else:
|
479 |
+
sample = upsample_block(
|
480 |
+
hidden_states=sample,
|
481 |
+
temb=emb,
|
482 |
+
res_hidden_states_tuple=res_samples,
|
483 |
+
upsample_size=upsample_size,
|
484 |
+
num_frames=num_frames,
|
485 |
+
)
|
486 |
+
|
487 |
+
# 6. post-process
|
488 |
+
if self.conv_norm_out:
|
489 |
+
sample = self.conv_norm_out(sample)
|
490 |
+
sample = self.conv_act(sample)
|
491 |
+
|
492 |
+
sample = self.conv_out(sample)
|
493 |
+
|
494 |
+
# reshape to (batch, channel, framerate, width, height)
|
495 |
+
sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4)
|
496 |
+
|
497 |
+
if not return_dict:
|
498 |
+
return (sample,)
|
499 |
+
|
500 |
+
return UNet3DConditionOutput(sample=sample)
|
requirements.txt
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==0.18.0
|
2 |
+
decord==0.6.0
|
3 |
+
deepspeed==0.10.0
|
4 |
+
diffusers==0.18.0
|
5 |
+
huggingface-hub==0.16.4
|
6 |
+
lora-diffusion @ git+https://github.com/cloneofsimo/lora.git@bdd51b04c49fa90a88919a19850ec3b4cf3c5ecd
|
7 |
+
loralib==0.1.0
|
8 |
+
numpy==1.24.0
|
9 |
+
omegaconf==2.3.0
|
10 |
+
opencv-python==4.8.0.74
|
11 |
+
torch==2.0.1
|
12 |
+
torchaudio==2.0.2
|
13 |
+
torchvision==0.15.2
|
14 |
+
tqdm==4.65.0
|
15 |
+
transformers==4.27.4
|
16 |
+
einops==0.7.0
|
17 |
+
imageio==2.33.0
|
18 |
+
imageio-ffmpeg==0.4.9
|
19 |
+
gradio==3.26.0
|
utils/bucketing.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
|
3 |
+
def min_res(size, min_size): return 192 if size < 192 else size
|
4 |
+
|
5 |
+
def up_down_bucket(m_size, in_size, direction):
|
6 |
+
if direction == 'down': return abs(int(m_size - in_size))
|
7 |
+
if direction == 'up': return abs(int(m_size + in_size))
|
8 |
+
|
9 |
+
def get_bucket_sizes(size, direction: 'down', min_size):
|
10 |
+
multipliers = [64, 128]
|
11 |
+
for i, m in enumerate(multipliers):
|
12 |
+
res = up_down_bucket(m, size, direction)
|
13 |
+
multipliers[i] = min_res(res, min_size=min_size)
|
14 |
+
return multipliers
|
15 |
+
|
16 |
+
def closest_bucket(m_size, size, direction, min_size):
|
17 |
+
lst = get_bucket_sizes(m_size, direction, min_size)
|
18 |
+
return lst[min(range(len(lst)), key=lambda i: abs(lst[i]-size))]
|
19 |
+
|
20 |
+
def resolve_bucket(i,h,w): return (i / (h / w))
|
21 |
+
|
22 |
+
def sensible_buckets(m_width, m_height, w, h, min_size=192):
|
23 |
+
if h > w:
|
24 |
+
w = resolve_bucket(m_width, h, w)
|
25 |
+
w = closest_bucket(m_width, w, 'down', min_size=min_size)
|
26 |
+
return w, m_height
|
27 |
+
if h < w:
|
28 |
+
h = resolve_bucket(m_height, w, h)
|
29 |
+
h = closest_bucket(m_height, h, 'down', min_size=min_size)
|
30 |
+
return m_width, h
|
31 |
+
|
32 |
+
return m_width, m_height
|
utils/convert_diffusers_to_original_ms_text_to_video.py
ADDED
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
|
2 |
+
# *Only* converts the UNet, and Text Encoder.
|
3 |
+
# Does not convert optimizer state or any other thing.
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
import os.path as osp
|
7 |
+
import re
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from safetensors.torch import load_file, save_file
|
11 |
+
|
12 |
+
# =================#
|
13 |
+
# UNet Conversion #
|
14 |
+
# =================#
|
15 |
+
|
16 |
+
print ('Initializing the conversion map')
|
17 |
+
|
18 |
+
unet_conversion_map = [
|
19 |
+
# (ModelScope, HF Diffusers)
|
20 |
+
|
21 |
+
# from Vanilla ModelScope/StableDiffusion
|
22 |
+
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
23 |
+
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
24 |
+
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
25 |
+
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
26 |
+
|
27 |
+
|
28 |
+
# from Vanilla ModelScope/StableDiffusion
|
29 |
+
("input_blocks.0.0.weight", "conv_in.weight"),
|
30 |
+
("input_blocks.0.0.bias", "conv_in.bias"),
|
31 |
+
|
32 |
+
|
33 |
+
# from Vanilla ModelScope/StableDiffusion
|
34 |
+
("out.0.weight", "conv_norm_out.weight"),
|
35 |
+
("out.0.bias", "conv_norm_out.bias"),
|
36 |
+
("out.2.weight", "conv_out.weight"),
|
37 |
+
("out.2.bias", "conv_out.bias"),
|
38 |
+
]
|
39 |
+
|
40 |
+
unet_conversion_map_resnet = [
|
41 |
+
# (ModelScope, HF Diffusers)
|
42 |
+
|
43 |
+
# SD
|
44 |
+
("in_layers.0", "norm1"),
|
45 |
+
("in_layers.2", "conv1"),
|
46 |
+
("out_layers.0", "norm2"),
|
47 |
+
("out_layers.3", "conv2"),
|
48 |
+
("emb_layers.1", "time_emb_proj"),
|
49 |
+
("skip_connection", "conv_shortcut"),
|
50 |
+
|
51 |
+
# MS
|
52 |
+
#("temopral_conv", "temp_convs"), # ROFL, they have a typo here --kabachuha
|
53 |
+
]
|
54 |
+
|
55 |
+
unet_conversion_map_layer = []
|
56 |
+
|
57 |
+
# Convert input TemporalTransformer
|
58 |
+
unet_conversion_map_layer.append(('input_blocks.0.1', 'transformer_in'))
|
59 |
+
|
60 |
+
# Reference for the default settings
|
61 |
+
|
62 |
+
# "model_cfg": {
|
63 |
+
# "unet_in_dim": 4,
|
64 |
+
# "unet_dim": 320,
|
65 |
+
# "unet_y_dim": 768,
|
66 |
+
# "unet_context_dim": 1024,
|
67 |
+
# "unet_out_dim": 4,
|
68 |
+
# "unet_dim_mult": [1, 2, 4, 4],
|
69 |
+
# "unet_num_heads": 8,
|
70 |
+
# "unet_head_dim": 64,
|
71 |
+
# "unet_res_blocks": 2,
|
72 |
+
# "unet_attn_scales": [1, 0.5, 0.25],
|
73 |
+
# "unet_dropout": 0.1,
|
74 |
+
# "temporal_attention": "True",
|
75 |
+
# "num_timesteps": 1000,
|
76 |
+
# "mean_type": "eps",
|
77 |
+
# "var_type": "fixed_small",
|
78 |
+
# "loss_type": "mse"
|
79 |
+
# }
|
80 |
+
|
81 |
+
# hardcoded number of downblocks and resnets/attentions...
|
82 |
+
# would need smarter logic for other networks.
|
83 |
+
for i in range(4):
|
84 |
+
# loop over downblocks/upblocks
|
85 |
+
|
86 |
+
for j in range(2):
|
87 |
+
# loop over resnets/attentions for downblocks
|
88 |
+
|
89 |
+
# Spacial SD stuff
|
90 |
+
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
91 |
+
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
92 |
+
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
93 |
+
|
94 |
+
if i < 3:
|
95 |
+
# no attention layers in down_blocks.3
|
96 |
+
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
97 |
+
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
98 |
+
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
99 |
+
|
100 |
+
# Temporal MS stuff
|
101 |
+
hf_down_res_prefix = f"down_blocks.{i}.temp_convs.{j}."
|
102 |
+
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0.temopral_conv."
|
103 |
+
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
104 |
+
|
105 |
+
if i < 3:
|
106 |
+
# no attention layers in down_blocks.3
|
107 |
+
hf_down_atn_prefix = f"down_blocks.{i}.temp_attentions.{j}."
|
108 |
+
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.2."
|
109 |
+
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
110 |
+
|
111 |
+
for j in range(3):
|
112 |
+
# loop over resnets/attentions for upblocks
|
113 |
+
|
114 |
+
# Spacial SD stuff
|
115 |
+
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
116 |
+
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
117 |
+
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
118 |
+
|
119 |
+
if i > 0:
|
120 |
+
# no attention layers in up_blocks.0
|
121 |
+
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
122 |
+
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
123 |
+
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
124 |
+
|
125 |
+
# loop over resnets/attentions for upblocks
|
126 |
+
hf_up_res_prefix = f"up_blocks.{i}.temp_convs.{j}."
|
127 |
+
sd_up_res_prefix = f"output_blocks.{3*i + j}.0.temopral_conv."
|
128 |
+
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
129 |
+
|
130 |
+
if i > 0:
|
131 |
+
# no attention layers in up_blocks.0
|
132 |
+
hf_up_atn_prefix = f"up_blocks.{i}.temp_attentions.{j}."
|
133 |
+
sd_up_atn_prefix = f"output_blocks.{3*i + j}.2."
|
134 |
+
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
135 |
+
|
136 |
+
# Up/Downsamplers are 2D, so don't need to touch them
|
137 |
+
if i < 3:
|
138 |
+
# no downsample in down_blocks.3
|
139 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
140 |
+
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.op."
|
141 |
+
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
142 |
+
|
143 |
+
# no upsample in up_blocks.3
|
144 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
145 |
+
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 3}."
|
146 |
+
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
147 |
+
|
148 |
+
|
149 |
+
# Handle the middle block
|
150 |
+
|
151 |
+
# Spacial
|
152 |
+
hf_mid_atn_prefix = "mid_block.attentions.0."
|
153 |
+
sd_mid_atn_prefix = "middle_block.1."
|
154 |
+
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
155 |
+
|
156 |
+
for j in range(2):
|
157 |
+
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
158 |
+
sd_mid_res_prefix = f"middle_block.{3*j}."
|
159 |
+
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
160 |
+
|
161 |
+
# Temporal
|
162 |
+
hf_mid_atn_prefix = "mid_block.temp_attentions.0."
|
163 |
+
sd_mid_atn_prefix = "middle_block.2."
|
164 |
+
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
165 |
+
|
166 |
+
for j in range(2):
|
167 |
+
hf_mid_res_prefix = f"mid_block.temp_convs.{j}."
|
168 |
+
sd_mid_res_prefix = f"middle_block.{3*j}.temopral_conv."
|
169 |
+
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
170 |
+
|
171 |
+
# The pipeline
|
172 |
+
def convert_unet_state_dict(unet_state_dict, strict_mapping=False):
|
173 |
+
print ('Converting the UNET')
|
174 |
+
# buyer beware: this is a *brittle* function,
|
175 |
+
# and correct output requires that all of these pieces interact in
|
176 |
+
# the exact order in which I have arranged them.
|
177 |
+
mapping = {k: k for k in unet_state_dict.keys()}
|
178 |
+
|
179 |
+
for sd_name, hf_name in unet_conversion_map:
|
180 |
+
if strict_mapping:
|
181 |
+
if hf_name in mapping:
|
182 |
+
mapping[hf_name] = sd_name
|
183 |
+
else:
|
184 |
+
mapping[hf_name] = sd_name
|
185 |
+
for k, v in mapping.items():
|
186 |
+
if "resnets" in k:
|
187 |
+
for sd_part, hf_part in unet_conversion_map_resnet:
|
188 |
+
v = v.replace(hf_part, sd_part)
|
189 |
+
mapping[k] = v
|
190 |
+
# elif "temp_convs" in k:
|
191 |
+
# for sd_part, hf_part in unet_conversion_map_resnet:
|
192 |
+
# v = v.replace(hf_part, sd_part)
|
193 |
+
# mapping[k] = v
|
194 |
+
for k, v in mapping.items():
|
195 |
+
for sd_part, hf_part in unet_conversion_map_layer:
|
196 |
+
v = v.replace(hf_part, sd_part)
|
197 |
+
mapping[k] = v
|
198 |
+
|
199 |
+
|
200 |
+
# there must be a pattern, but I don't want to bother atm
|
201 |
+
do_not_unsqueeze = [f'output_blocks.{i}.1.proj_out.weight' for i in range(3, 12)] + [f'output_blocks.{i}.1.proj_in.weight' for i in range(3, 12)] + ['middle_block.1.proj_in.weight', 'middle_block.1.proj_out.weight'] + [f'input_blocks.{i}.1.proj_out.weight' for i in [1, 2, 4, 5, 7, 8]] + [f'input_blocks.{i}.1.proj_in.weight' for i in [1, 2, 4, 5, 7, 8]]
|
202 |
+
print (do_not_unsqueeze)
|
203 |
+
|
204 |
+
new_state_dict = {v: (unet_state_dict[k].unsqueeze(-1) if ('proj_' in k and ('bias' not in k) and (k not in do_not_unsqueeze)) else unet_state_dict[k]) for k, v in mapping.items()}
|
205 |
+
# HACK: idk why the hell it does not work with list comprehension
|
206 |
+
for k, v in new_state_dict.items():
|
207 |
+
has_k = False
|
208 |
+
for n in do_not_unsqueeze:
|
209 |
+
if k == n:
|
210 |
+
has_k = True
|
211 |
+
|
212 |
+
if has_k:
|
213 |
+
v = v.squeeze(-1)
|
214 |
+
new_state_dict[k] = v
|
215 |
+
|
216 |
+
return new_state_dict
|
217 |
+
|
218 |
+
# TODO: VAE conversion. We doesn't train it in the most cases, but may be handy for the future --kabachuha
|
219 |
+
|
220 |
+
# =========================#
|
221 |
+
# Text Encoder Conversion #
|
222 |
+
# =========================#
|
223 |
+
|
224 |
+
# IT IS THE SAME CLIP ENCODER, SO JUST COPYPASTING IT --kabachuha
|
225 |
+
|
226 |
+
# =========================#
|
227 |
+
# Text Encoder Conversion #
|
228 |
+
# =========================#
|
229 |
+
|
230 |
+
|
231 |
+
textenc_conversion_lst = [
|
232 |
+
# (stable-diffusion, HF Diffusers)
|
233 |
+
("resblocks.", "text_model.encoder.layers."),
|
234 |
+
("ln_1", "layer_norm1"),
|
235 |
+
("ln_2", "layer_norm2"),
|
236 |
+
(".c_fc.", ".fc1."),
|
237 |
+
(".c_proj.", ".fc2."),
|
238 |
+
(".attn", ".self_attn"),
|
239 |
+
("ln_final.", "transformer.text_model.final_layer_norm."),
|
240 |
+
("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
|
241 |
+
("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
|
242 |
+
]
|
243 |
+
protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
|
244 |
+
textenc_pattern = re.compile("|".join(protected.keys()))
|
245 |
+
|
246 |
+
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
|
247 |
+
code2idx = {"q": 0, "k": 1, "v": 2}
|
248 |
+
|
249 |
+
|
250 |
+
def convert_text_enc_state_dict_v20(text_enc_dict):
|
251 |
+
#print ('Converting the text encoder')
|
252 |
+
new_state_dict = {}
|
253 |
+
capture_qkv_weight = {}
|
254 |
+
capture_qkv_bias = {}
|
255 |
+
for k, v in text_enc_dict.items():
|
256 |
+
if (
|
257 |
+
k.endswith(".self_attn.q_proj.weight")
|
258 |
+
or k.endswith(".self_attn.k_proj.weight")
|
259 |
+
or k.endswith(".self_attn.v_proj.weight")
|
260 |
+
):
|
261 |
+
k_pre = k[: -len(".q_proj.weight")]
|
262 |
+
k_code = k[-len("q_proj.weight")]
|
263 |
+
if k_pre not in capture_qkv_weight:
|
264 |
+
capture_qkv_weight[k_pre] = [None, None, None]
|
265 |
+
capture_qkv_weight[k_pre][code2idx[k_code]] = v
|
266 |
+
continue
|
267 |
+
|
268 |
+
if (
|
269 |
+
k.endswith(".self_attn.q_proj.bias")
|
270 |
+
or k.endswith(".self_attn.k_proj.bias")
|
271 |
+
or k.endswith(".self_attn.v_proj.bias")
|
272 |
+
):
|
273 |
+
k_pre = k[: -len(".q_proj.bias")]
|
274 |
+
k_code = k[-len("q_proj.bias")]
|
275 |
+
if k_pre not in capture_qkv_bias:
|
276 |
+
capture_qkv_bias[k_pre] = [None, None, None]
|
277 |
+
capture_qkv_bias[k_pre][code2idx[k_code]] = v
|
278 |
+
continue
|
279 |
+
|
280 |
+
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
|
281 |
+
new_state_dict[relabelled_key] = v
|
282 |
+
|
283 |
+
for k_pre, tensors in capture_qkv_weight.items():
|
284 |
+
if None in tensors:
|
285 |
+
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
|
286 |
+
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
|
287 |
+
new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
|
288 |
+
|
289 |
+
for k_pre, tensors in capture_qkv_bias.items():
|
290 |
+
if None in tensors:
|
291 |
+
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
|
292 |
+
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
|
293 |
+
new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
|
294 |
+
|
295 |
+
return new_state_dict
|
296 |
+
|
297 |
+
|
298 |
+
def convert_text_enc_state_dict(text_enc_dict):
|
299 |
+
return text_enc_dict
|
300 |
+
|
301 |
+
textenc_conversion_lst = [
|
302 |
+
# (stable-diffusion, HF Diffusers)
|
303 |
+
("resblocks.", "text_model.encoder.layers."),
|
304 |
+
("ln_1", "layer_norm1"),
|
305 |
+
("ln_2", "layer_norm2"),
|
306 |
+
(".c_fc.", ".fc1."),
|
307 |
+
(".c_proj.", ".fc2."),
|
308 |
+
(".attn", ".self_attn"),
|
309 |
+
("ln_final.", "transformer.text_model.final_layer_norm."),
|
310 |
+
("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
|
311 |
+
("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
|
312 |
+
]
|
313 |
+
protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
|
314 |
+
textenc_pattern = re.compile("|".join(protected.keys()))
|
315 |
+
|
316 |
+
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
|
317 |
+
code2idx = {"q": 0, "k": 1, "v": 2}
|
318 |
+
|
319 |
+
|
320 |
+
def convert_text_enc_state_dict_v20(text_enc_dict):
|
321 |
+
new_state_dict = {}
|
322 |
+
capture_qkv_weight = {}
|
323 |
+
capture_qkv_bias = {}
|
324 |
+
for k, v in text_enc_dict.items():
|
325 |
+
if (
|
326 |
+
k.endswith(".self_attn.q_proj.weight")
|
327 |
+
or k.endswith(".self_attn.k_proj.weight")
|
328 |
+
or k.endswith(".self_attn.v_proj.weight")
|
329 |
+
):
|
330 |
+
k_pre = k[: -len(".q_proj.weight")]
|
331 |
+
k_code = k[-len("q_proj.weight")]
|
332 |
+
if k_pre not in capture_qkv_weight:
|
333 |
+
capture_qkv_weight[k_pre] = [None, None, None]
|
334 |
+
capture_qkv_weight[k_pre][code2idx[k_code]] = v
|
335 |
+
continue
|
336 |
+
|
337 |
+
if (
|
338 |
+
k.endswith(".self_attn.q_proj.bias")
|
339 |
+
or k.endswith(".self_attn.k_proj.bias")
|
340 |
+
or k.endswith(".self_attn.v_proj.bias")
|
341 |
+
):
|
342 |
+
k_pre = k[: -len(".q_proj.bias")]
|
343 |
+
k_code = k[-len("q_proj.bias")]
|
344 |
+
if k_pre not in capture_qkv_bias:
|
345 |
+
capture_qkv_bias[k_pre] = [None, None, None]
|
346 |
+
capture_qkv_bias[k_pre][code2idx[k_code]] = v
|
347 |
+
continue
|
348 |
+
|
349 |
+
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
|
350 |
+
new_state_dict[relabelled_key] = v
|
351 |
+
|
352 |
+
for k_pre, tensors in capture_qkv_weight.items():
|
353 |
+
if None in tensors:
|
354 |
+
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
|
355 |
+
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
|
356 |
+
new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
|
357 |
+
|
358 |
+
for k_pre, tensors in capture_qkv_bias.items():
|
359 |
+
if None in tensors:
|
360 |
+
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
|
361 |
+
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
|
362 |
+
new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
|
363 |
+
|
364 |
+
return new_state_dict
|
365 |
+
|
366 |
+
|
367 |
+
def convert_text_enc_state_dict(text_enc_dict):
|
368 |
+
return text_enc_dict
|
369 |
+
|
370 |
+
if __name__ == "__main__":
|
371 |
+
parser = argparse.ArgumentParser()
|
372 |
+
|
373 |
+
parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
|
374 |
+
parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
|
375 |
+
parser.add_argument("--clip_checkpoint_path", default=None, type=str, help="Path to the output CLIP model.")
|
376 |
+
parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
|
377 |
+
parser.add_argument(
|
378 |
+
"--use_safetensors", action="store_true", help="Save weights use safetensors, default is ckpt."
|
379 |
+
)
|
380 |
+
|
381 |
+
args = parser.parse_args()
|
382 |
+
|
383 |
+
assert args.model_path is not None, "Must provide a model path!"
|
384 |
+
|
385 |
+
assert args.checkpoint_path is not None, "Must provide a checkpoint path!"
|
386 |
+
|
387 |
+
assert args.clip_checkpoint_path is not None, "Must provide a CLIP checkpoint path!"
|
388 |
+
|
389 |
+
# Path for safetensors
|
390 |
+
unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.safetensors")
|
391 |
+
#vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.safetensors")
|
392 |
+
text_enc_path = osp.join(args.model_path, "text_encoder", "model.safetensors")
|
393 |
+
|
394 |
+
# Load models from safetensors if it exists, if it doesn't pytorch
|
395 |
+
if osp.exists(unet_path):
|
396 |
+
unet_state_dict = load_file(unet_path, device="cpu")
|
397 |
+
else:
|
398 |
+
unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin")
|
399 |
+
unet_state_dict = torch.load(unet_path, map_location="cpu")
|
400 |
+
|
401 |
+
# if osp.exists(vae_path):
|
402 |
+
# vae_state_dict = load_file(vae_path, device="cpu")
|
403 |
+
# else:
|
404 |
+
# vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
|
405 |
+
# vae_state_dict = torch.load(vae_path, map_location="cpu")
|
406 |
+
|
407 |
+
if osp.exists(text_enc_path):
|
408 |
+
text_enc_dict = load_file(text_enc_path, device="cpu")
|
409 |
+
else:
|
410 |
+
text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin")
|
411 |
+
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
|
412 |
+
|
413 |
+
# Convert the UNet model
|
414 |
+
unet_state_dict = convert_unet_state_dict(unet_state_dict)
|
415 |
+
#unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
|
416 |
+
|
417 |
+
# Convert the VAE model
|
418 |
+
# vae_state_dict = convert_vae_state_dict(vae_state_dict)
|
419 |
+
# vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
|
420 |
+
|
421 |
+
# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
|
422 |
+
is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
|
423 |
+
|
424 |
+
if is_v20_model:
|
425 |
+
|
426 |
+
# MODELSCOPE always uses the 2.X encoder, btw --kabachuha
|
427 |
+
|
428 |
+
# Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
|
429 |
+
text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
|
430 |
+
text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict)
|
431 |
+
#text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
|
432 |
+
else:
|
433 |
+
text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
|
434 |
+
#text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
|
435 |
+
|
436 |
+
# DON'T PUT TOGETHER FOR THE NEW CHECKPOINT AS MODELSCOPE USES THEM IN THE SPLITTED FORM --kabachuha
|
437 |
+
# Save CLIP and the Diffusion model to their own files
|
438 |
+
|
439 |
+
#state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
|
440 |
+
print ('Saving UNET')
|
441 |
+
state_dict = {**unet_state_dict}
|
442 |
+
|
443 |
+
if args.half:
|
444 |
+
state_dict = {k: v.half() for k, v in state_dict.items()}
|
445 |
+
|
446 |
+
if args.use_safetensors:
|
447 |
+
save_file(state_dict, args.checkpoint_path)
|
448 |
+
else:
|
449 |
+
#state_dict = {"state_dict": state_dict}
|
450 |
+
torch.save(state_dict, args.checkpoint_path)
|
451 |
+
|
452 |
+
# TODO: CLIP conversion doesn't work atm
|
453 |
+
# print ('Saving CLIP')
|
454 |
+
# state_dict = {**text_enc_dict}
|
455 |
+
|
456 |
+
# if args.half:
|
457 |
+
# state_dict = {k: v.half() for k, v in state_dict.items()}
|
458 |
+
|
459 |
+
# if args.use_safetensors:
|
460 |
+
# save_file(state_dict, args.checkpoint_path)
|
461 |
+
# else:
|
462 |
+
# #state_dict = {"state_dict": state_dict}
|
463 |
+
# torch.save(state_dict, args.clip_checkpoint_path)
|
464 |
+
|
465 |
+
print('Operation successfull')
|
utils/dataset.py
ADDED
@@ -0,0 +1,578 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import decord
|
3 |
+
import numpy as np
|
4 |
+
import random
|
5 |
+
import json
|
6 |
+
import torchvision
|
7 |
+
import torchvision.transforms as T
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from glob import glob
|
11 |
+
from PIL import Image
|
12 |
+
from itertools import islice
|
13 |
+
from pathlib import Path
|
14 |
+
from .bucketing import sensible_buckets
|
15 |
+
|
16 |
+
decord.bridge.set_bridge('torch')
|
17 |
+
|
18 |
+
from torch.utils.data import Dataset
|
19 |
+
from einops import rearrange, repeat
|
20 |
+
|
21 |
+
|
22 |
+
def get_prompt_ids(prompt, tokenizer):
|
23 |
+
prompt_ids = tokenizer(
|
24 |
+
prompt,
|
25 |
+
truncation=True,
|
26 |
+
padding="max_length",
|
27 |
+
max_length=tokenizer.model_max_length,
|
28 |
+
return_tensors="pt",
|
29 |
+
).input_ids
|
30 |
+
|
31 |
+
return prompt_ids
|
32 |
+
|
33 |
+
|
34 |
+
def read_caption_file(caption_file):
|
35 |
+
with open(caption_file, 'r', encoding="utf8") as t:
|
36 |
+
return t.read()
|
37 |
+
|
38 |
+
|
39 |
+
def get_text_prompt(
|
40 |
+
text_prompt: str = '',
|
41 |
+
fallback_prompt: str= '',
|
42 |
+
file_path:str = '',
|
43 |
+
ext_types=['.mp4'],
|
44 |
+
use_caption=False
|
45 |
+
):
|
46 |
+
try:
|
47 |
+
if use_caption:
|
48 |
+
if len(text_prompt) > 1: return text_prompt
|
49 |
+
caption_file = ''
|
50 |
+
# Use caption on per-video basis (One caption PER video)
|
51 |
+
for ext in ext_types:
|
52 |
+
maybe_file = file_path.replace(ext, '.txt')
|
53 |
+
if maybe_file.endswith(ext_types): continue
|
54 |
+
if os.path.exists(maybe_file):
|
55 |
+
caption_file = maybe_file
|
56 |
+
break
|
57 |
+
|
58 |
+
if os.path.exists(caption_file):
|
59 |
+
return read_caption_file(caption_file)
|
60 |
+
|
61 |
+
# Return fallback prompt if no conditions are met.
|
62 |
+
return fallback_prompt
|
63 |
+
|
64 |
+
return text_prompt
|
65 |
+
except:
|
66 |
+
print(f"Couldn't read prompt caption for {file_path}. Using fallback.")
|
67 |
+
return fallback_prompt
|
68 |
+
|
69 |
+
|
70 |
+
def get_video_frames(vr, start_idx, sample_rate=1, max_frames=24):
|
71 |
+
max_range = len(vr)
|
72 |
+
frame_number = sorted((0, start_idx, max_range))[1]
|
73 |
+
|
74 |
+
frame_range = range(frame_number, max_range, sample_rate)
|
75 |
+
frame_range_indices = list(frame_range)[:max_frames]
|
76 |
+
|
77 |
+
return frame_range_indices
|
78 |
+
|
79 |
+
|
80 |
+
def process_video(vid_path, use_bucketing, w, h, get_frame_buckets, get_frame_batch):
|
81 |
+
if use_bucketing:
|
82 |
+
vr = decord.VideoReader(vid_path)
|
83 |
+
resize = get_frame_buckets(vr)
|
84 |
+
video = get_frame_batch(vr, resize=resize)
|
85 |
+
|
86 |
+
else:
|
87 |
+
vr = decord.VideoReader(vid_path, width=w, height=h)
|
88 |
+
video = get_frame_batch(vr)
|
89 |
+
|
90 |
+
return video, vr
|
91 |
+
|
92 |
+
|
93 |
+
# https://github.com/ExponentialML/Video-BLIP2-Preprocessor
|
94 |
+
class VideoJsonDataset(Dataset):
|
95 |
+
def __init__(
|
96 |
+
self,
|
97 |
+
tokenizer = None,
|
98 |
+
width: int = 256,
|
99 |
+
height: int = 256,
|
100 |
+
n_sample_frames: int = 4,
|
101 |
+
sample_start_idx: int = 1,
|
102 |
+
frame_step: int = 1,
|
103 |
+
json_path: str ="",
|
104 |
+
json_data = None,
|
105 |
+
vid_data_key: str = "video_path",
|
106 |
+
preprocessed: bool = False,
|
107 |
+
use_bucketing: bool = False,
|
108 |
+
**kwargs
|
109 |
+
):
|
110 |
+
self.vid_types = (".mp4", ".avi", ".mov", ".webm", ".flv", ".mjpeg")
|
111 |
+
self.use_bucketing = use_bucketing
|
112 |
+
self.tokenizer = tokenizer
|
113 |
+
self.preprocessed = preprocessed
|
114 |
+
|
115 |
+
self.vid_data_key = vid_data_key
|
116 |
+
self.train_data = self.load_from_json(json_path, json_data)
|
117 |
+
|
118 |
+
self.width = width
|
119 |
+
self.height = height
|
120 |
+
|
121 |
+
self.n_sample_frames = n_sample_frames
|
122 |
+
self.sample_start_idx = sample_start_idx
|
123 |
+
self.frame_step = frame_step
|
124 |
+
|
125 |
+
def build_json(self, json_data):
|
126 |
+
extended_data = []
|
127 |
+
for data in json_data['data']:
|
128 |
+
for nested_data in data['data']:
|
129 |
+
self.build_json_dict(
|
130 |
+
data,
|
131 |
+
nested_data,
|
132 |
+
extended_data
|
133 |
+
)
|
134 |
+
json_data = extended_data
|
135 |
+
return json_data
|
136 |
+
|
137 |
+
def build_json_dict(self, data, nested_data, extended_data):
|
138 |
+
clip_path = nested_data['clip_path'] if 'clip_path' in nested_data else None
|
139 |
+
|
140 |
+
extended_data.append({
|
141 |
+
self.vid_data_key: data[self.vid_data_key],
|
142 |
+
'frame_index': nested_data['frame_index'],
|
143 |
+
'prompt': nested_data['prompt'],
|
144 |
+
'clip_path': clip_path
|
145 |
+
})
|
146 |
+
|
147 |
+
def load_from_json(self, path, json_data):
|
148 |
+
try:
|
149 |
+
with open(path) as jpath:
|
150 |
+
print(f"Loading JSON from {path}")
|
151 |
+
json_data = json.load(jpath)
|
152 |
+
|
153 |
+
return self.build_json(json_data)
|
154 |
+
|
155 |
+
except:
|
156 |
+
self.train_data = []
|
157 |
+
print("Non-existant JSON path. Skipping.")
|
158 |
+
|
159 |
+
def validate_json(self, base_path, path):
|
160 |
+
return os.path.exists(f"{base_path}/{path}")
|
161 |
+
|
162 |
+
def get_frame_range(self, vr):
|
163 |
+
return get_video_frames(
|
164 |
+
vr,
|
165 |
+
self.sample_start_idx,
|
166 |
+
self.frame_step,
|
167 |
+
self.n_sample_frames
|
168 |
+
)
|
169 |
+
|
170 |
+
def get_vid_idx(self, vr, vid_data=None):
|
171 |
+
frames = self.n_sample_frames
|
172 |
+
|
173 |
+
if vid_data is not None:
|
174 |
+
idx = vid_data['frame_index']
|
175 |
+
else:
|
176 |
+
idx = self.sample_start_idx
|
177 |
+
|
178 |
+
return idx
|
179 |
+
|
180 |
+
def get_frame_buckets(self, vr):
|
181 |
+
_, h, w = vr[0].shape
|
182 |
+
width, height = sensible_buckets(self.width, self.height, h, w)
|
183 |
+
# width, height = self.width, self.height
|
184 |
+
resize = T.transforms.Resize((height, width), antialias=True)
|
185 |
+
|
186 |
+
return resize
|
187 |
+
|
188 |
+
def get_frame_batch(self, vr, resize=None):
|
189 |
+
frame_range = self.get_frame_range(vr)
|
190 |
+
frames = vr.get_batch(frame_range)
|
191 |
+
video = rearrange(frames, "f h w c -> f c h w")
|
192 |
+
|
193 |
+
if resize is not None: video = resize(video)
|
194 |
+
return video
|
195 |
+
|
196 |
+
def process_video_wrapper(self, vid_path):
|
197 |
+
video, vr = process_video(
|
198 |
+
vid_path,
|
199 |
+
self.use_bucketing,
|
200 |
+
self.width,
|
201 |
+
self.height,
|
202 |
+
self.get_frame_buckets,
|
203 |
+
self.get_frame_batch
|
204 |
+
)
|
205 |
+
|
206 |
+
return video, vr
|
207 |
+
|
208 |
+
def train_data_batch(self, index):
|
209 |
+
|
210 |
+
# If we are training on individual clips.
|
211 |
+
if 'clip_path' in self.train_data[index] and \
|
212 |
+
self.train_data[index]['clip_path'] is not None:
|
213 |
+
|
214 |
+
vid_data = self.train_data[index]
|
215 |
+
|
216 |
+
clip_path = vid_data['clip_path']
|
217 |
+
|
218 |
+
# Get video prompt
|
219 |
+
prompt = vid_data['prompt']
|
220 |
+
|
221 |
+
video, _ = self.process_video_wrapper(clip_path)
|
222 |
+
|
223 |
+
prompt_ids = get_prompt_ids(prompt, self.tokenizer)
|
224 |
+
|
225 |
+
return video, prompt, prompt_ids
|
226 |
+
|
227 |
+
# Assign train data
|
228 |
+
train_data = self.train_data[index]
|
229 |
+
|
230 |
+
# Get the frame of the current index.
|
231 |
+
self.sample_start_idx = train_data['frame_index']
|
232 |
+
|
233 |
+
# Initialize resize
|
234 |
+
resize = None
|
235 |
+
|
236 |
+
video, vr = self.process_video_wrapper(train_data[self.vid_data_key])
|
237 |
+
|
238 |
+
# Get video prompt
|
239 |
+
prompt = train_data['prompt']
|
240 |
+
vr.seek(0)
|
241 |
+
|
242 |
+
prompt_ids = get_prompt_ids(prompt, self.tokenizer)
|
243 |
+
|
244 |
+
return video, prompt, prompt_ids
|
245 |
+
|
246 |
+
@staticmethod
|
247 |
+
def __getname__(): return 'json'
|
248 |
+
|
249 |
+
def __len__(self):
|
250 |
+
if self.train_data is not None:
|
251 |
+
return len(self.train_data)
|
252 |
+
else:
|
253 |
+
return 0
|
254 |
+
|
255 |
+
def __getitem__(self, index):
|
256 |
+
|
257 |
+
# Initialize variables
|
258 |
+
video = None
|
259 |
+
prompt = None
|
260 |
+
prompt_ids = None
|
261 |
+
|
262 |
+
# Use default JSON training
|
263 |
+
if self.train_data is not None:
|
264 |
+
video, prompt, prompt_ids = self.train_data_batch(index)
|
265 |
+
|
266 |
+
example = {
|
267 |
+
"pixel_values": (video / 127.5 - 1.0),
|
268 |
+
"prompt_ids": prompt_ids[0],
|
269 |
+
"text_prompt": prompt,
|
270 |
+
'dataset': self.__getname__()
|
271 |
+
}
|
272 |
+
|
273 |
+
return example
|
274 |
+
|
275 |
+
|
276 |
+
class SingleVideoDataset(Dataset):
|
277 |
+
def __init__(
|
278 |
+
self,
|
279 |
+
tokenizer = None,
|
280 |
+
width: int = 256,
|
281 |
+
height: int = 256,
|
282 |
+
n_sample_frames: int = 4,
|
283 |
+
frame_step: int = 1,
|
284 |
+
single_video_path: str = "",
|
285 |
+
single_video_prompt: str = "",
|
286 |
+
use_caption: bool = False,
|
287 |
+
use_bucketing: bool = False,
|
288 |
+
**kwargs
|
289 |
+
):
|
290 |
+
self.tokenizer = tokenizer
|
291 |
+
self.use_bucketing = use_bucketing
|
292 |
+
self.frames = []
|
293 |
+
self.index = 1
|
294 |
+
|
295 |
+
self.vid_types = (".mp4", ".avi", ".mov", ".webm", ".flv", ".mjpeg")
|
296 |
+
self.n_sample_frames = n_sample_frames
|
297 |
+
self.frame_step = frame_step
|
298 |
+
|
299 |
+
self.single_video_path = single_video_path
|
300 |
+
self.single_video_prompt = single_video_prompt
|
301 |
+
|
302 |
+
self.width = width
|
303 |
+
self.height = height
|
304 |
+
def create_video_chunks(self):
|
305 |
+
vr = decord.VideoReader(self.single_video_path)
|
306 |
+
vr_range = range(0, len(vr), self.frame_step)
|
307 |
+
|
308 |
+
self.frames = list(self.chunk(vr_range, self.n_sample_frames))
|
309 |
+
return self.frames
|
310 |
+
|
311 |
+
def chunk(self, it, size):
|
312 |
+
it = iter(it)
|
313 |
+
return iter(lambda: tuple(islice(it, size)), ())
|
314 |
+
|
315 |
+
def get_frame_batch(self, vr, resize=None):
|
316 |
+
index = self.index
|
317 |
+
frames = vr.get_batch(self.frames[self.index])
|
318 |
+
video = rearrange(frames, "f h w c -> f c h w")
|
319 |
+
|
320 |
+
if resize is not None: video = resize(video)
|
321 |
+
return video
|
322 |
+
|
323 |
+
def get_frame_buckets(self, vr):
|
324 |
+
_, h, w = vr[0].shape
|
325 |
+
# width, height = sensible_buckets(self.width, self.height, h, w)
|
326 |
+
width, height = self.width, self.height
|
327 |
+
resize = T.transforms.Resize((height, width), antialias=True)
|
328 |
+
|
329 |
+
return resize
|
330 |
+
|
331 |
+
def process_video_wrapper(self, vid_path):
|
332 |
+
video, vr = process_video(
|
333 |
+
vid_path,
|
334 |
+
self.use_bucketing,
|
335 |
+
self.width,
|
336 |
+
self.height,
|
337 |
+
self.get_frame_buckets,
|
338 |
+
self.get_frame_batch
|
339 |
+
)
|
340 |
+
|
341 |
+
return video, vr
|
342 |
+
|
343 |
+
def single_video_batch(self, index):
|
344 |
+
train_data = self.single_video_path
|
345 |
+
self.index = index
|
346 |
+
|
347 |
+
if train_data.endswith(self.vid_types):
|
348 |
+
video, _ = self.process_video_wrapper(train_data)
|
349 |
+
|
350 |
+
prompt = self.single_video_prompt
|
351 |
+
prompt_ids = get_prompt_ids(prompt, self.tokenizer)
|
352 |
+
|
353 |
+
return video, prompt, prompt_ids
|
354 |
+
else:
|
355 |
+
raise ValueError(f"Single video is not a video type. Types: {self.vid_types}")
|
356 |
+
|
357 |
+
@staticmethod
|
358 |
+
def __getname__(): return 'single_video'
|
359 |
+
|
360 |
+
def __len__(self):
|
361 |
+
|
362 |
+
return len(self.create_video_chunks())
|
363 |
+
|
364 |
+
def __getitem__(self, index):
|
365 |
+
|
366 |
+
video, prompt, prompt_ids = self.single_video_batch(index)
|
367 |
+
|
368 |
+
example = {
|
369 |
+
"pixel_values": (video / 127.5 - 1.0),
|
370 |
+
"prompt_ids": prompt_ids[0],
|
371 |
+
"text_prompt": prompt,
|
372 |
+
'dataset': self.__getname__()
|
373 |
+
}
|
374 |
+
|
375 |
+
return example
|
376 |
+
|
377 |
+
|
378 |
+
class ImageDataset(Dataset):
|
379 |
+
|
380 |
+
def __init__(
|
381 |
+
self,
|
382 |
+
tokenizer = None,
|
383 |
+
width: int = 256,
|
384 |
+
height: int = 256,
|
385 |
+
base_width: int = 256,
|
386 |
+
base_height: int = 256,
|
387 |
+
use_caption: bool = False,
|
388 |
+
image_dir: str = '',
|
389 |
+
single_img_prompt: str = '',
|
390 |
+
use_bucketing: bool = False,
|
391 |
+
fallback_prompt: str = '',
|
392 |
+
**kwargs
|
393 |
+
):
|
394 |
+
self.tokenizer = tokenizer
|
395 |
+
self.img_types = (".png", ".jpg", ".jpeg", '.bmp')
|
396 |
+
self.use_bucketing = use_bucketing
|
397 |
+
|
398 |
+
self.image_dir = self.get_images_list(image_dir)
|
399 |
+
self.fallback_prompt = fallback_prompt
|
400 |
+
|
401 |
+
self.use_caption = use_caption
|
402 |
+
self.single_img_prompt = single_img_prompt
|
403 |
+
|
404 |
+
self.width = width
|
405 |
+
self.height = height
|
406 |
+
|
407 |
+
def get_images_list(self, image_dir):
|
408 |
+
if os.path.exists(image_dir):
|
409 |
+
imgs = [x for x in os.listdir(image_dir) if x.endswith(self.img_types)]
|
410 |
+
full_img_dir = []
|
411 |
+
|
412 |
+
for img in imgs:
|
413 |
+
full_img_dir.append(f"{image_dir}/{img}")
|
414 |
+
|
415 |
+
return sorted(full_img_dir)
|
416 |
+
|
417 |
+
return ['']
|
418 |
+
|
419 |
+
def image_batch(self, index):
|
420 |
+
train_data = self.image_dir[index]
|
421 |
+
img = train_data
|
422 |
+
|
423 |
+
try:
|
424 |
+
img = torchvision.io.read_image(img, mode=torchvision.io.ImageReadMode.RGB)
|
425 |
+
except:
|
426 |
+
img = T.transforms.PILToTensor()(Image.open(img).convert("RGB"))
|
427 |
+
|
428 |
+
width = self.width
|
429 |
+
height = self.height
|
430 |
+
|
431 |
+
if self.use_bucketing:
|
432 |
+
_, h, w = img.shape
|
433 |
+
width, height = sensible_buckets(width, height, w, h)
|
434 |
+
|
435 |
+
resize = T.transforms.Resize((height, width), antialias=True)
|
436 |
+
|
437 |
+
img = resize(img)
|
438 |
+
img = repeat(img, 'c h w -> f c h w', f=1)
|
439 |
+
|
440 |
+
prompt = get_text_prompt(
|
441 |
+
file_path=train_data,
|
442 |
+
text_prompt=self.single_img_prompt,
|
443 |
+
fallback_prompt=self.fallback_prompt,
|
444 |
+
ext_types=self.img_types,
|
445 |
+
use_caption=True
|
446 |
+
)
|
447 |
+
prompt_ids = get_prompt_ids(prompt, self.tokenizer)
|
448 |
+
|
449 |
+
return img, prompt, prompt_ids
|
450 |
+
|
451 |
+
@staticmethod
|
452 |
+
def __getname__(): return 'image'
|
453 |
+
|
454 |
+
def __len__(self):
|
455 |
+
# Image directory
|
456 |
+
if os.path.exists(self.image_dir[0]):
|
457 |
+
return len(self.image_dir)
|
458 |
+
else:
|
459 |
+
return 0
|
460 |
+
|
461 |
+
def __getitem__(self, index):
|
462 |
+
img, prompt, prompt_ids = self.image_batch(index)
|
463 |
+
example = {
|
464 |
+
"pixel_values": (img / 127.5 - 1.0),
|
465 |
+
"prompt_ids": prompt_ids[0],
|
466 |
+
"text_prompt": prompt,
|
467 |
+
'dataset': self.__getname__()
|
468 |
+
}
|
469 |
+
|
470 |
+
return example
|
471 |
+
|
472 |
+
|
473 |
+
class VideoFolderDataset(Dataset):
|
474 |
+
def __init__(
|
475 |
+
self,
|
476 |
+
tokenizer=None,
|
477 |
+
width: int = 256,
|
478 |
+
height: int = 256,
|
479 |
+
n_sample_frames: int = 16,
|
480 |
+
fps: int = 8,
|
481 |
+
path: str = "./data",
|
482 |
+
fallback_prompt: str = "",
|
483 |
+
use_bucketing: bool = False,
|
484 |
+
**kwargs
|
485 |
+
):
|
486 |
+
self.tokenizer = tokenizer
|
487 |
+
self.use_bucketing = use_bucketing
|
488 |
+
|
489 |
+
self.fallback_prompt = fallback_prompt
|
490 |
+
|
491 |
+
self.video_files = glob(f"{path}/*.mp4")
|
492 |
+
|
493 |
+
self.width = width
|
494 |
+
self.height = height
|
495 |
+
|
496 |
+
self.n_sample_frames = n_sample_frames
|
497 |
+
self.fps = fps
|
498 |
+
|
499 |
+
def get_frame_buckets(self, vr):
|
500 |
+
_, h, w = vr[0].shape
|
501 |
+
width, height = sensible_buckets(self.width, self.height, h, w)
|
502 |
+
# width, height = self.width, self.height
|
503 |
+
resize = T.transforms.Resize((height, width), antialias=True)
|
504 |
+
|
505 |
+
return resize
|
506 |
+
|
507 |
+
def get_frame_batch(self, vr, resize=None):
|
508 |
+
n_sample_frames = self.n_sample_frames
|
509 |
+
native_fps = vr.get_avg_fps()
|
510 |
+
|
511 |
+
every_nth_frame = max(1, round(native_fps / self.fps))
|
512 |
+
every_nth_frame = min(len(vr), every_nth_frame)
|
513 |
+
|
514 |
+
effective_length = len(vr) // every_nth_frame
|
515 |
+
if effective_length < n_sample_frames:
|
516 |
+
n_sample_frames = effective_length
|
517 |
+
|
518 |
+
effective_idx = random.randint(0, (effective_length - n_sample_frames))
|
519 |
+
idxs = every_nth_frame * np.arange(effective_idx, effective_idx + n_sample_frames)
|
520 |
+
|
521 |
+
video = vr.get_batch(idxs)
|
522 |
+
video = rearrange(video, "f h w c -> f c h w")
|
523 |
+
|
524 |
+
if resize is not None: video = resize(video)
|
525 |
+
return video, vr
|
526 |
+
|
527 |
+
def process_video_wrapper(self, vid_path):
|
528 |
+
video, vr = process_video(
|
529 |
+
vid_path,
|
530 |
+
self.use_bucketing,
|
531 |
+
self.width,
|
532 |
+
self.height,
|
533 |
+
self.get_frame_buckets,
|
534 |
+
self.get_frame_batch
|
535 |
+
)
|
536 |
+
return video, vr
|
537 |
+
|
538 |
+
def get_prompt_ids(self, prompt):
|
539 |
+
return self.tokenizer(
|
540 |
+
prompt,
|
541 |
+
truncation=True,
|
542 |
+
padding="max_length",
|
543 |
+
max_length=self.tokenizer.model_max_length,
|
544 |
+
return_tensors="pt",
|
545 |
+
).input_ids
|
546 |
+
|
547 |
+
@staticmethod
|
548 |
+
def __getname__(): return 'folder'
|
549 |
+
|
550 |
+
def __len__(self):
|
551 |
+
return len(self.video_files)
|
552 |
+
|
553 |
+
def __getitem__(self, index):
|
554 |
+
|
555 |
+
video, _ = self.process_video_wrapper(self.video_files[index])
|
556 |
+
|
557 |
+
prompt = self.fallback_prompt
|
558 |
+
|
559 |
+
prompt_ids = self.get_prompt_ids(prompt)
|
560 |
+
|
561 |
+
return {"pixel_values": (video[0] / 127.5 - 1.0), "prompt_ids": prompt_ids[0], "text_prompt": prompt, 'dataset': self.__getname__()}
|
562 |
+
|
563 |
+
|
564 |
+
class CachedDataset(Dataset):
|
565 |
+
def __init__(self,cache_dir: str = ''):
|
566 |
+
self.cache_dir = cache_dir
|
567 |
+
self.cached_data_list = self.get_files_list()
|
568 |
+
|
569 |
+
def get_files_list(self):
|
570 |
+
tensors_list = [f"{self.cache_dir}/{x}" for x in os.listdir(self.cache_dir) if x.endswith('.pt')]
|
571 |
+
return sorted(tensors_list)
|
572 |
+
|
573 |
+
def __len__(self):
|
574 |
+
return len(self.cached_data_list)
|
575 |
+
|
576 |
+
def __getitem__(self, index):
|
577 |
+
cached_latent = torch.load(self.cached_data_list[index], map_location='cuda:0')
|
578 |
+
return cached_latent
|
utils/ddim_utils.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from typing import Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
|
9 |
+
# DDIM Inversion
|
10 |
+
@torch.no_grad()
|
11 |
+
def init_prompt(prompt, pipeline):
|
12 |
+
uncond_input = pipeline.tokenizer(
|
13 |
+
[""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
|
14 |
+
return_tensors="pt"
|
15 |
+
)
|
16 |
+
uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
|
17 |
+
text_input = pipeline.tokenizer(
|
18 |
+
[prompt],
|
19 |
+
padding="max_length",
|
20 |
+
max_length=pipeline.tokenizer.model_max_length,
|
21 |
+
truncation=True,
|
22 |
+
return_tensors="pt",
|
23 |
+
)
|
24 |
+
text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
|
25 |
+
context = torch.cat([uncond_embeddings, text_embeddings])
|
26 |
+
|
27 |
+
return context
|
28 |
+
|
29 |
+
|
30 |
+
def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
|
31 |
+
sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
|
32 |
+
timestep, next_timestep = min(
|
33 |
+
timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
|
34 |
+
alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
|
35 |
+
alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
|
36 |
+
beta_prod_t = 1 - alpha_prod_t
|
37 |
+
next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
|
38 |
+
next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
|
39 |
+
next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
|
40 |
+
return next_sample
|
41 |
+
|
42 |
+
|
43 |
+
def get_noise_pred_single(latents, t, context, unet):
|
44 |
+
noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
|
45 |
+
return noise_pred
|
46 |
+
|
47 |
+
|
48 |
+
@torch.no_grad()
|
49 |
+
def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
|
50 |
+
context = init_prompt(prompt, pipeline)
|
51 |
+
uncond_embeddings, cond_embeddings = context.chunk(2)
|
52 |
+
all_latent = [latent]
|
53 |
+
latent = latent.clone().detach()
|
54 |
+
for i in tqdm(range(num_inv_steps)):
|
55 |
+
t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
|
56 |
+
noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
|
57 |
+
latent = next_step(noise_pred, t, latent, ddim_scheduler)
|
58 |
+
all_latent.append(latent)
|
59 |
+
return all_latent
|
60 |
+
|
61 |
+
|
62 |
+
@torch.no_grad()
|
63 |
+
def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
|
64 |
+
ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
|
65 |
+
return ddim_latents
|
utils/lora.py
ADDED
@@ -0,0 +1,1483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import math
|
3 |
+
from itertools import groupby
|
4 |
+
import os
|
5 |
+
from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import PIL
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
try:
|
14 |
+
from safetensors.torch import safe_open
|
15 |
+
from safetensors.torch import save_file as safe_save
|
16 |
+
|
17 |
+
safetensors_available = True
|
18 |
+
except ImportError:
|
19 |
+
from .safe_open import safe_open
|
20 |
+
|
21 |
+
def safe_save(
|
22 |
+
tensors: Dict[str, torch.Tensor],
|
23 |
+
filename: str,
|
24 |
+
metadata: Optional[Dict[str, str]] = None,
|
25 |
+
) -> None:
|
26 |
+
raise EnvironmentError(
|
27 |
+
"Saving safetensors requires the safetensors library. Please install with pip or similar."
|
28 |
+
)
|
29 |
+
|
30 |
+
safetensors_available = False
|
31 |
+
|
32 |
+
|
33 |
+
class LoraInjectedLinear(nn.Module):
|
34 |
+
def __init__(
|
35 |
+
self, in_features, out_features, bias=False, r=4, dropout_p=0.1, scale=1.0
|
36 |
+
):
|
37 |
+
super().__init__()
|
38 |
+
|
39 |
+
if r > min(in_features, out_features):
|
40 |
+
#raise ValueError(
|
41 |
+
# f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}"
|
42 |
+
#)
|
43 |
+
print(f"LoRA rank {r} is too large. setting to: {min(in_features, out_features)}")
|
44 |
+
r = min(in_features, out_features)
|
45 |
+
|
46 |
+
self.r = r
|
47 |
+
self.linear = nn.Linear(in_features, out_features, bias)
|
48 |
+
self.lora_down = nn.Linear(in_features, r, bias=False)
|
49 |
+
self.dropout = nn.Dropout(dropout_p)
|
50 |
+
self.lora_up = nn.Linear(r, out_features, bias=False)
|
51 |
+
self.scale = scale
|
52 |
+
self.selector = nn.Identity()
|
53 |
+
|
54 |
+
nn.init.normal_(self.lora_down.weight, std=1 / r)
|
55 |
+
nn.init.zeros_(self.lora_up.weight)
|
56 |
+
|
57 |
+
def forward(self, input):
|
58 |
+
return (
|
59 |
+
self.linear(input)
|
60 |
+
+ self.dropout(self.lora_up(self.selector(self.lora_down(input))))
|
61 |
+
* self.scale
|
62 |
+
)
|
63 |
+
|
64 |
+
def realize_as_lora(self):
|
65 |
+
return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
|
66 |
+
|
67 |
+
def set_selector_from_diag(self, diag: torch.Tensor):
|
68 |
+
# diag is a 1D tensor of size (r,)
|
69 |
+
assert diag.shape == (self.r,)
|
70 |
+
self.selector = nn.Linear(self.r, self.r, bias=False)
|
71 |
+
self.selector.weight.data = torch.diag(diag)
|
72 |
+
self.selector.weight.data = self.selector.weight.data.to(
|
73 |
+
self.lora_up.weight.device
|
74 |
+
).to(self.lora_up.weight.dtype)
|
75 |
+
|
76 |
+
|
77 |
+
class MultiLoraInjectedLinear(nn.Module):
|
78 |
+
def __init__(
|
79 |
+
self, in_features, out_features, bias=False, r=4, dropout_p=0.1, lora_num=1, scales=[1.0]
|
80 |
+
):
|
81 |
+
super().__init__()
|
82 |
+
|
83 |
+
if r > min(in_features, out_features):
|
84 |
+
#raise ValueError(
|
85 |
+
# f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}"
|
86 |
+
#)
|
87 |
+
print(f"LoRA rank {r} is too large. setting to: {min(in_features, out_features)}")
|
88 |
+
r = min(in_features, out_features)
|
89 |
+
|
90 |
+
self.r = r
|
91 |
+
self.linear = nn.Linear(in_features, out_features, bias)
|
92 |
+
|
93 |
+
for i in range(lora_num):
|
94 |
+
if i==0:
|
95 |
+
self.lora_down =[nn.Linear(in_features, r, bias=False)]
|
96 |
+
self.dropout = [nn.Dropout(dropout_p)]
|
97 |
+
self.lora_up = [nn.Linear(r, out_features, bias=False)]
|
98 |
+
self.scale = scales[i]
|
99 |
+
self.selector = [nn.Identity()]
|
100 |
+
else:
|
101 |
+
self.lora_down.append(nn.Linear(in_features, r, bias=False))
|
102 |
+
self.dropout.append( nn.Dropout(dropout_p))
|
103 |
+
self.lora_up.append( nn.Linear(r, out_features, bias=False))
|
104 |
+
self.scale.append(scales[i])
|
105 |
+
|
106 |
+
nn.init.normal_(self.lora_down.weight, std=1 / r)
|
107 |
+
nn.init.zeros_(self.lora_up.weight)
|
108 |
+
|
109 |
+
def forward(self, input):
|
110 |
+
return (
|
111 |
+
self.linear(input)
|
112 |
+
+ self.dropout(self.lora_up(self.selector(self.lora_down(input))))
|
113 |
+
* self.scale
|
114 |
+
)
|
115 |
+
|
116 |
+
def realize_as_lora(self):
|
117 |
+
return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
|
118 |
+
|
119 |
+
def set_selector_from_diag(self, diag: torch.Tensor):
|
120 |
+
# diag is a 1D tensor of size (r,)
|
121 |
+
assert diag.shape == (self.r,)
|
122 |
+
self.selector = nn.Linear(self.r, self.r, bias=False)
|
123 |
+
self.selector.weight.data = torch.diag(diag)
|
124 |
+
self.selector.weight.data = self.selector.weight.data.to(
|
125 |
+
self.lora_up.weight.device
|
126 |
+
).to(self.lora_up.weight.dtype)
|
127 |
+
|
128 |
+
|
129 |
+
class LoraInjectedConv2d(nn.Module):
|
130 |
+
def __init__(
|
131 |
+
self,
|
132 |
+
in_channels: int,
|
133 |
+
out_channels: int,
|
134 |
+
kernel_size,
|
135 |
+
stride=1,
|
136 |
+
padding=0,
|
137 |
+
dilation=1,
|
138 |
+
groups: int = 1,
|
139 |
+
bias: bool = True,
|
140 |
+
r: int = 4,
|
141 |
+
dropout_p: float = 0.1,
|
142 |
+
scale: float = 1.0,
|
143 |
+
):
|
144 |
+
super().__init__()
|
145 |
+
if r > min(in_channels, out_channels):
|
146 |
+
print(f"LoRA rank {r} is too large. setting to: {min(in_channels, out_channels)}")
|
147 |
+
r = min(in_channels, out_channels)
|
148 |
+
|
149 |
+
self.r = r
|
150 |
+
self.conv = nn.Conv2d(
|
151 |
+
in_channels=in_channels,
|
152 |
+
out_channels=out_channels,
|
153 |
+
kernel_size=kernel_size,
|
154 |
+
stride=stride,
|
155 |
+
padding=padding,
|
156 |
+
dilation=dilation,
|
157 |
+
groups=groups,
|
158 |
+
bias=bias,
|
159 |
+
)
|
160 |
+
|
161 |
+
self.lora_down = nn.Conv2d(
|
162 |
+
in_channels=in_channels,
|
163 |
+
out_channels=r,
|
164 |
+
kernel_size=kernel_size,
|
165 |
+
stride=stride,
|
166 |
+
padding=padding,
|
167 |
+
dilation=dilation,
|
168 |
+
groups=groups,
|
169 |
+
bias=False,
|
170 |
+
)
|
171 |
+
self.dropout = nn.Dropout(dropout_p)
|
172 |
+
self.lora_up = nn.Conv2d(
|
173 |
+
in_channels=r,
|
174 |
+
out_channels=out_channels,
|
175 |
+
kernel_size=1,
|
176 |
+
stride=1,
|
177 |
+
padding=0,
|
178 |
+
bias=False,
|
179 |
+
)
|
180 |
+
self.selector = nn.Identity()
|
181 |
+
self.scale = scale
|
182 |
+
|
183 |
+
nn.init.normal_(self.lora_down.weight, std=1 / r)
|
184 |
+
nn.init.zeros_(self.lora_up.weight)
|
185 |
+
|
186 |
+
def forward(self, input):
|
187 |
+
return (
|
188 |
+
self.conv(input)
|
189 |
+
+ self.dropout(self.lora_up(self.selector(self.lora_down(input))))
|
190 |
+
* self.scale
|
191 |
+
)
|
192 |
+
|
193 |
+
def realize_as_lora(self):
|
194 |
+
return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
|
195 |
+
|
196 |
+
def set_selector_from_diag(self, diag: torch.Tensor):
|
197 |
+
# diag is a 1D tensor of size (r,)
|
198 |
+
assert diag.shape == (self.r,)
|
199 |
+
self.selector = nn.Conv2d(
|
200 |
+
in_channels=self.r,
|
201 |
+
out_channels=self.r,
|
202 |
+
kernel_size=1,
|
203 |
+
stride=1,
|
204 |
+
padding=0,
|
205 |
+
bias=False,
|
206 |
+
)
|
207 |
+
self.selector.weight.data = torch.diag(diag)
|
208 |
+
|
209 |
+
# same device + dtype as lora_up
|
210 |
+
self.selector.weight.data = self.selector.weight.data.to(
|
211 |
+
self.lora_up.weight.device
|
212 |
+
).to(self.lora_up.weight.dtype)
|
213 |
+
|
214 |
+
class LoraInjectedConv3d(nn.Module):
|
215 |
+
def __init__(
|
216 |
+
self,
|
217 |
+
in_channels: int,
|
218 |
+
out_channels: int,
|
219 |
+
kernel_size: (3, 1, 1),
|
220 |
+
padding: (1, 0, 0),
|
221 |
+
bias: bool = False,
|
222 |
+
r: int = 4,
|
223 |
+
dropout_p: float = 0,
|
224 |
+
scale: float = 1.0,
|
225 |
+
):
|
226 |
+
super().__init__()
|
227 |
+
if r > min(in_channels, out_channels):
|
228 |
+
print(f"LoRA rank {r} is too large. setting to: {min(in_channels, out_channels)}")
|
229 |
+
r = min(in_channels, out_channels)
|
230 |
+
|
231 |
+
self.r = r
|
232 |
+
self.kernel_size = kernel_size
|
233 |
+
self.padding = padding
|
234 |
+
self.conv = nn.Conv3d(
|
235 |
+
in_channels=in_channels,
|
236 |
+
out_channels=out_channels,
|
237 |
+
kernel_size=kernel_size,
|
238 |
+
padding=padding,
|
239 |
+
)
|
240 |
+
|
241 |
+
self.lora_down = nn.Conv3d(
|
242 |
+
in_channels=in_channels,
|
243 |
+
out_channels=r,
|
244 |
+
kernel_size=kernel_size,
|
245 |
+
bias=False,
|
246 |
+
padding=padding
|
247 |
+
)
|
248 |
+
self.dropout = nn.Dropout(dropout_p)
|
249 |
+
self.lora_up = nn.Conv3d(
|
250 |
+
in_channels=r,
|
251 |
+
out_channels=out_channels,
|
252 |
+
kernel_size=1,
|
253 |
+
stride=1,
|
254 |
+
padding=0,
|
255 |
+
bias=False,
|
256 |
+
)
|
257 |
+
self.selector = nn.Identity()
|
258 |
+
self.scale = scale
|
259 |
+
|
260 |
+
nn.init.normal_(self.lora_down.weight, std=1 / r)
|
261 |
+
nn.init.zeros_(self.lora_up.weight)
|
262 |
+
|
263 |
+
def forward(self, input):
|
264 |
+
return (
|
265 |
+
self.conv(input)
|
266 |
+
+ self.dropout(self.lora_up(self.selector(self.lora_down(input))))
|
267 |
+
* self.scale
|
268 |
+
)
|
269 |
+
|
270 |
+
def realize_as_lora(self):
|
271 |
+
return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
|
272 |
+
|
273 |
+
def set_selector_from_diag(self, diag: torch.Tensor):
|
274 |
+
# diag is a 1D tensor of size (r,)
|
275 |
+
assert diag.shape == (self.r,)
|
276 |
+
self.selector = nn.Conv3d(
|
277 |
+
in_channels=self.r,
|
278 |
+
out_channels=self.r,
|
279 |
+
kernel_size=1,
|
280 |
+
stride=1,
|
281 |
+
padding=0,
|
282 |
+
bias=False,
|
283 |
+
)
|
284 |
+
self.selector.weight.data = torch.diag(diag)
|
285 |
+
|
286 |
+
# same device + dtype as lora_up
|
287 |
+
self.selector.weight.data = self.selector.weight.data.to(
|
288 |
+
self.lora_up.weight.device
|
289 |
+
).to(self.lora_up.weight.dtype)
|
290 |
+
|
291 |
+
UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"}
|
292 |
+
|
293 |
+
UNET_EXTENDED_TARGET_REPLACE = {"ResnetBlock2D", "CrossAttention", "Attention", "GEGLU"}
|
294 |
+
|
295 |
+
TEXT_ENCODER_DEFAULT_TARGET_REPLACE = {"CLIPAttention"}
|
296 |
+
|
297 |
+
TEXT_ENCODER_EXTENDED_TARGET_REPLACE = {"CLIPAttention"}
|
298 |
+
|
299 |
+
DEFAULT_TARGET_REPLACE = UNET_DEFAULT_TARGET_REPLACE
|
300 |
+
|
301 |
+
EMBED_FLAG = "<embed>"
|
302 |
+
|
303 |
+
|
304 |
+
def _find_children(
|
305 |
+
model,
|
306 |
+
search_class: List[Type[nn.Module]] = [nn.Linear],
|
307 |
+
):
|
308 |
+
"""
|
309 |
+
Find all modules of a certain class (or union of classes).
|
310 |
+
|
311 |
+
Returns all matching modules, along with the parent of those moduless and the
|
312 |
+
names they are referenced by.
|
313 |
+
"""
|
314 |
+
# For each target find every linear_class module that isn't a child of a LoraInjectedLinear
|
315 |
+
for parent in model.modules():
|
316 |
+
for name, module in parent.named_children():
|
317 |
+
if any([isinstance(module, _class) for _class in search_class]):
|
318 |
+
yield parent, name, module
|
319 |
+
|
320 |
+
|
321 |
+
def _find_modules_v2(
|
322 |
+
model,
|
323 |
+
ancestor_class: Optional[Set[str]] = None,
|
324 |
+
search_class: List[Type[nn.Module]] = [nn.Linear],
|
325 |
+
exclude_children_of: Optional[List[Type[nn.Module]]] = None,
|
326 |
+
# [
|
327 |
+
# LoraInjectedLinear,
|
328 |
+
# LoraInjectedConv2d,
|
329 |
+
# LoraInjectedConv3d
|
330 |
+
# ],
|
331 |
+
):
|
332 |
+
"""
|
333 |
+
Find all modules of a certain class (or union of classes) that are direct or
|
334 |
+
indirect descendants of other modules of a certain class (or union of classes).
|
335 |
+
|
336 |
+
Returns all matching modules, along with the parent of those moduless and the
|
337 |
+
names they are referenced by.
|
338 |
+
"""
|
339 |
+
|
340 |
+
# Get the targets we should replace all linears under
|
341 |
+
if ancestor_class is not None:
|
342 |
+
ancestors = (
|
343 |
+
module
|
344 |
+
for name, module in model.named_modules()
|
345 |
+
if module.__class__.__name__ in ancestor_class # and ('transformer_in' not in name)
|
346 |
+
)
|
347 |
+
else:
|
348 |
+
# this, incase you want to naively iterate over all modules.
|
349 |
+
ancestors = [module for module in model.modules()]
|
350 |
+
|
351 |
+
# For each target find every linear_class module that isn't a child of a LoraInjectedLinear
|
352 |
+
for ancestor in ancestors:
|
353 |
+
for fullname, module in ancestor.named_modules():
|
354 |
+
if any([isinstance(module, _class) for _class in search_class]):
|
355 |
+
continue_flag = True
|
356 |
+
if 'Transformer2DModel' in ancestor_class and ('attn1' in fullname or 'ff' in fullname):
|
357 |
+
continue_flag = False
|
358 |
+
if 'TransformerTemporalModel' in ancestor_class and ('attn1' in fullname or 'attn2' in fullname or 'ff' in fullname):
|
359 |
+
continue_flag = False
|
360 |
+
if continue_flag:
|
361 |
+
continue
|
362 |
+
# Find the direct parent if this is a descendant, not a child, of target
|
363 |
+
*path, name = fullname.split(".")
|
364 |
+
parent = ancestor
|
365 |
+
while path:
|
366 |
+
parent = parent.get_submodule(path.pop(0))
|
367 |
+
# Skip this linear if it's a child of a LoraInjectedLinear
|
368 |
+
if exclude_children_of and any(
|
369 |
+
[isinstance(parent, _class) for _class in exclude_children_of]
|
370 |
+
):
|
371 |
+
continue
|
372 |
+
if name in ['lora_up', 'dropout', 'lora_down']:
|
373 |
+
continue
|
374 |
+
# Otherwise, yield it
|
375 |
+
yield parent, name, module
|
376 |
+
|
377 |
+
|
378 |
+
def _find_modules_old(
|
379 |
+
model,
|
380 |
+
ancestor_class: Set[str] = DEFAULT_TARGET_REPLACE,
|
381 |
+
search_class: List[Type[nn.Module]] = [nn.Linear],
|
382 |
+
exclude_children_of: Optional[List[Type[nn.Module]]] = [LoraInjectedLinear],
|
383 |
+
):
|
384 |
+
ret = []
|
385 |
+
for _module in model.modules():
|
386 |
+
if _module.__class__.__name__ in ancestor_class:
|
387 |
+
|
388 |
+
for name, _child_module in _module.named_modules():
|
389 |
+
if _child_module.__class__ in search_class:
|
390 |
+
ret.append((_module, name, _child_module))
|
391 |
+
print(ret)
|
392 |
+
return ret
|
393 |
+
|
394 |
+
|
395 |
+
_find_modules = _find_modules_v2
|
396 |
+
|
397 |
+
|
398 |
+
def inject_trainable_lora(
|
399 |
+
model: nn.Module,
|
400 |
+
target_replace_module: Set[str] = DEFAULT_TARGET_REPLACE,
|
401 |
+
r: int = 4,
|
402 |
+
loras=None, # path to lora .pt
|
403 |
+
verbose: bool = False,
|
404 |
+
dropout_p: float = 0.0,
|
405 |
+
scale: float = 1.0,
|
406 |
+
):
|
407 |
+
"""
|
408 |
+
inject lora into model, and returns lora parameter groups.
|
409 |
+
"""
|
410 |
+
|
411 |
+
require_grad_params = []
|
412 |
+
names = []
|
413 |
+
|
414 |
+
if loras != None:
|
415 |
+
loras = torch.load(loras)
|
416 |
+
|
417 |
+
for _module, name, _child_module in _find_modules(
|
418 |
+
model, target_replace_module, search_class=[nn.Linear]
|
419 |
+
):
|
420 |
+
weight = _child_module.weight
|
421 |
+
bias = _child_module.bias
|
422 |
+
if verbose:
|
423 |
+
print("LoRA Injection : injecting lora into ", name)
|
424 |
+
print("LoRA Injection : weight shape", weight.shape)
|
425 |
+
_tmp = LoraInjectedLinear(
|
426 |
+
_child_module.in_features,
|
427 |
+
_child_module.out_features,
|
428 |
+
_child_module.bias is not None,
|
429 |
+
r=r,
|
430 |
+
dropout_p=dropout_p,
|
431 |
+
scale=scale,
|
432 |
+
)
|
433 |
+
_tmp.linear.weight = weight
|
434 |
+
if bias is not None:
|
435 |
+
_tmp.linear.bias = bias
|
436 |
+
|
437 |
+
# switch the module
|
438 |
+
_tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
|
439 |
+
_module._modules[name] = _tmp
|
440 |
+
|
441 |
+
require_grad_params.append(_module._modules[name].lora_up.parameters())
|
442 |
+
require_grad_params.append(_module._modules[name].lora_down.parameters())
|
443 |
+
|
444 |
+
if loras != None:
|
445 |
+
_module._modules[name].lora_up.weight = loras.pop(0)
|
446 |
+
_module._modules[name].lora_down.weight = loras.pop(0)
|
447 |
+
|
448 |
+
_module._modules[name].lora_up.weight.requires_grad = True
|
449 |
+
_module._modules[name].lora_down.weight.requires_grad = True
|
450 |
+
names.append(name)
|
451 |
+
|
452 |
+
return require_grad_params, names
|
453 |
+
|
454 |
+
|
455 |
+
def inject_trainable_lora_extended(
|
456 |
+
model: nn.Module,
|
457 |
+
target_replace_module: Set[str] = UNET_EXTENDED_TARGET_REPLACE,
|
458 |
+
r: int = 4,
|
459 |
+
loras=None, # path to lora .pt
|
460 |
+
dropout_p: float = 0.0,
|
461 |
+
scale: float = 1.0,
|
462 |
+
):
|
463 |
+
"""
|
464 |
+
inject lora into model, and returns lora parameter groups.
|
465 |
+
"""
|
466 |
+
|
467 |
+
require_grad_params = []
|
468 |
+
names = []
|
469 |
+
|
470 |
+
if loras != None:
|
471 |
+
loras = torch.load(loras)
|
472 |
+
if True:
|
473 |
+
for target_replace_module_i in target_replace_module:
|
474 |
+
for _module, name, _child_module in _find_modules(
|
475 |
+
model, [target_replace_module_i], search_class=[nn.Linear, nn.Conv2d, nn.Conv3d]
|
476 |
+
):
|
477 |
+
# if name == 'to_q':
|
478 |
+
# continue
|
479 |
+
if _child_module.__class__ == nn.Linear:
|
480 |
+
weight = _child_module.weight
|
481 |
+
bias = _child_module.bias
|
482 |
+
_tmp = LoraInjectedLinear(
|
483 |
+
_child_module.in_features,
|
484 |
+
_child_module.out_features,
|
485 |
+
_child_module.bias is not None,
|
486 |
+
r=r,
|
487 |
+
dropout_p=dropout_p,
|
488 |
+
scale=scale,
|
489 |
+
)
|
490 |
+
_tmp.linear.weight = weight
|
491 |
+
if bias is not None:
|
492 |
+
_tmp.linear.bias = bias
|
493 |
+
elif _child_module.__class__ == nn.Conv2d:
|
494 |
+
weight = _child_module.weight
|
495 |
+
bias = _child_module.bias
|
496 |
+
_tmp = LoraInjectedConv2d(
|
497 |
+
_child_module.in_channels,
|
498 |
+
_child_module.out_channels,
|
499 |
+
_child_module.kernel_size,
|
500 |
+
_child_module.stride,
|
501 |
+
_child_module.padding,
|
502 |
+
_child_module.dilation,
|
503 |
+
_child_module.groups,
|
504 |
+
_child_module.bias is not None,
|
505 |
+
r=r,
|
506 |
+
dropout_p=dropout_p,
|
507 |
+
scale=scale,
|
508 |
+
)
|
509 |
+
|
510 |
+
_tmp.conv.weight = weight
|
511 |
+
if bias is not None:
|
512 |
+
_tmp.conv.bias = bias
|
513 |
+
|
514 |
+
elif _child_module.__class__ == nn.Conv3d:
|
515 |
+
weight = _child_module.weight
|
516 |
+
bias = _child_module.bias
|
517 |
+
_tmp = LoraInjectedConv3d(
|
518 |
+
_child_module.in_channels,
|
519 |
+
_child_module.out_channels,
|
520 |
+
bias=_child_module.bias is not None,
|
521 |
+
kernel_size=_child_module.kernel_size,
|
522 |
+
padding=_child_module.padding,
|
523 |
+
r=r,
|
524 |
+
dropout_p=dropout_p,
|
525 |
+
scale=scale,
|
526 |
+
)
|
527 |
+
|
528 |
+
_tmp.conv.weight = weight
|
529 |
+
if bias is not None:
|
530 |
+
_tmp.conv.bias = bias
|
531 |
+
# switch the module
|
532 |
+
_tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
|
533 |
+
if bias is not None:
|
534 |
+
_tmp.to(_child_module.bias.device).to(_child_module.bias.dtype)
|
535 |
+
|
536 |
+
_module._modules[name] = _tmp
|
537 |
+
require_grad_params.append(_module._modules[name].lora_up.parameters())
|
538 |
+
require_grad_params.append(_module._modules[name].lora_down.parameters())
|
539 |
+
|
540 |
+
if loras != None:
|
541 |
+
_module._modules[name].lora_up.weight = loras.pop(0)
|
542 |
+
_module._modules[name].lora_down.weight = loras.pop(0)
|
543 |
+
|
544 |
+
_module._modules[name].lora_up.weight.requires_grad = True
|
545 |
+
_module._modules[name].lora_down.weight.requires_grad = True
|
546 |
+
names.append(name)
|
547 |
+
else:
|
548 |
+
for _module, name, _child_module in _find_modules(
|
549 |
+
model, target_replace_module, search_class=[nn.Linear, nn.Conv2d, nn.Conv3d]
|
550 |
+
):
|
551 |
+
if _child_module.__class__ == nn.Linear:
|
552 |
+
weight = _child_module.weight
|
553 |
+
bias = _child_module.bias
|
554 |
+
_tmp = LoraInjectedLinear(
|
555 |
+
_child_module.in_features,
|
556 |
+
_child_module.out_features,
|
557 |
+
_child_module.bias is not None,
|
558 |
+
r=r,
|
559 |
+
dropout_p=dropout_p,
|
560 |
+
scale=scale,
|
561 |
+
)
|
562 |
+
_tmp.linear.weight = weight
|
563 |
+
if bias is not None:
|
564 |
+
_tmp.linear.bias = bias
|
565 |
+
elif _child_module.__class__ == nn.Conv2d:
|
566 |
+
weight = _child_module.weight
|
567 |
+
bias = _child_module.bias
|
568 |
+
_tmp = LoraInjectedConv2d(
|
569 |
+
_child_module.in_channels,
|
570 |
+
_child_module.out_channels,
|
571 |
+
_child_module.kernel_size,
|
572 |
+
_child_module.stride,
|
573 |
+
_child_module.padding,
|
574 |
+
_child_module.dilation,
|
575 |
+
_child_module.groups,
|
576 |
+
_child_module.bias is not None,
|
577 |
+
r=r,
|
578 |
+
dropout_p=dropout_p,
|
579 |
+
scale=scale,
|
580 |
+
)
|
581 |
+
|
582 |
+
_tmp.conv.weight = weight
|
583 |
+
if bias is not None:
|
584 |
+
_tmp.conv.bias = bias
|
585 |
+
|
586 |
+
elif _child_module.__class__ == nn.Conv3d:
|
587 |
+
weight = _child_module.weight
|
588 |
+
bias = _child_module.bias
|
589 |
+
_tmp = LoraInjectedConv3d(
|
590 |
+
_child_module.in_channels,
|
591 |
+
_child_module.out_channels,
|
592 |
+
bias=_child_module.bias is not None,
|
593 |
+
kernel_size=_child_module.kernel_size,
|
594 |
+
padding=_child_module.padding,
|
595 |
+
r=r,
|
596 |
+
dropout_p=dropout_p,
|
597 |
+
scale=scale,
|
598 |
+
)
|
599 |
+
|
600 |
+
_tmp.conv.weight = weight
|
601 |
+
if bias is not None:
|
602 |
+
_tmp.conv.bias = bias
|
603 |
+
# switch the module
|
604 |
+
_tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
|
605 |
+
if bias is not None:
|
606 |
+
_tmp.to(_child_module.bias.device).to(_child_module.bias.dtype)
|
607 |
+
|
608 |
+
_module._modules[name] = _tmp
|
609 |
+
require_grad_params.append(_module._modules[name].lora_up.parameters())
|
610 |
+
require_grad_params.append(_module._modules[name].lora_down.parameters())
|
611 |
+
|
612 |
+
if loras != None:
|
613 |
+
_module._modules[name].lora_up.weight = loras.pop(0)
|
614 |
+
_module._modules[name].lora_down.weight = loras.pop(0)
|
615 |
+
|
616 |
+
_module._modules[name].lora_up.weight.requires_grad = True
|
617 |
+
_module._modules[name].lora_down.weight.requires_grad = True
|
618 |
+
names.append(name)
|
619 |
+
|
620 |
+
return require_grad_params, names
|
621 |
+
|
622 |
+
|
623 |
+
def inject_inferable_lora(
|
624 |
+
model,
|
625 |
+
lora_path='',
|
626 |
+
unet_replace_modules=["UNet3DConditionModel"],
|
627 |
+
text_encoder_replace_modules=["CLIPEncoderLayer"],
|
628 |
+
is_extended=False,
|
629 |
+
r=16
|
630 |
+
):
|
631 |
+
from transformers.models.clip import CLIPTextModel
|
632 |
+
from diffusers import UNet3DConditionModel
|
633 |
+
|
634 |
+
def is_text_model(f): return 'text_encoder' in f and isinstance(model.text_encoder, CLIPTextModel)
|
635 |
+
def is_unet(f): return 'unet' in f and model.unet.__class__.__name__ == "UNet3DConditionModel"
|
636 |
+
|
637 |
+
if os.path.exists(lora_path):
|
638 |
+
try:
|
639 |
+
for f in os.listdir(lora_path):
|
640 |
+
if f.endswith('.pt'):
|
641 |
+
lora_file = os.path.join(lora_path, f)
|
642 |
+
|
643 |
+
if is_text_model(f):
|
644 |
+
monkeypatch_or_replace_lora(
|
645 |
+
model.text_encoder,
|
646 |
+
torch.load(lora_file),
|
647 |
+
target_replace_module=text_encoder_replace_modules,
|
648 |
+
r=r
|
649 |
+
)
|
650 |
+
print("Successfully loaded Text Encoder LoRa.")
|
651 |
+
continue
|
652 |
+
|
653 |
+
if is_unet(f):
|
654 |
+
monkeypatch_or_replace_lora_extended(
|
655 |
+
model.unet,
|
656 |
+
torch.load(lora_file),
|
657 |
+
target_replace_module=unet_replace_modules,
|
658 |
+
r=r
|
659 |
+
)
|
660 |
+
print("Successfully loaded UNET LoRa.")
|
661 |
+
continue
|
662 |
+
|
663 |
+
print("Found a .pt file, but doesn't have the correct name format. (unet.pt, text_encoder.pt)")
|
664 |
+
|
665 |
+
except Exception as e:
|
666 |
+
print(e)
|
667 |
+
print("Couldn't inject LoRA's due to an error.")
|
668 |
+
|
669 |
+
def extract_lora_ups_down(model, target_replace_module=DEFAULT_TARGET_REPLACE):
|
670 |
+
|
671 |
+
loras = []
|
672 |
+
|
673 |
+
for target_replace_module_i in target_replace_module:
|
674 |
+
|
675 |
+
for _m, _n, _child_module in _find_modules(
|
676 |
+
model,
|
677 |
+
[target_replace_module_i],
|
678 |
+
search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d],
|
679 |
+
):
|
680 |
+
loras.append((_child_module.lora_up, _child_module.lora_down))
|
681 |
+
|
682 |
+
if len(loras) == 0:
|
683 |
+
raise ValueError("No lora injected.")
|
684 |
+
|
685 |
+
return loras
|
686 |
+
|
687 |
+
|
688 |
+
def extract_lora_child_module(model, target_replace_module=DEFAULT_TARGET_REPLACE):
|
689 |
+
|
690 |
+
loras = []
|
691 |
+
|
692 |
+
for target_replace_module_i in target_replace_module:
|
693 |
+
|
694 |
+
for _m, _n, _child_module in _find_modules(
|
695 |
+
model,
|
696 |
+
[target_replace_module_i],
|
697 |
+
search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d],
|
698 |
+
):
|
699 |
+
loras.append(_child_module)
|
700 |
+
|
701 |
+
if len(loras) == 0:
|
702 |
+
raise ValueError("No lora injected.")
|
703 |
+
|
704 |
+
return loras
|
705 |
+
|
706 |
+
def extract_lora_as_tensor(
|
707 |
+
model, target_replace_module=DEFAULT_TARGET_REPLACE, as_fp16=True
|
708 |
+
):
|
709 |
+
|
710 |
+
loras = []
|
711 |
+
|
712 |
+
for _m, _n, _child_module in _find_modules(
|
713 |
+
model,
|
714 |
+
target_replace_module,
|
715 |
+
search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d],
|
716 |
+
):
|
717 |
+
up, down = _child_module.realize_as_lora()
|
718 |
+
if as_fp16:
|
719 |
+
up = up.to(torch.float16)
|
720 |
+
down = down.to(torch.float16)
|
721 |
+
|
722 |
+
loras.append((up, down))
|
723 |
+
|
724 |
+
if len(loras) == 0:
|
725 |
+
raise ValueError("No lora injected.")
|
726 |
+
|
727 |
+
return loras
|
728 |
+
|
729 |
+
|
730 |
+
def save_lora_weight(
|
731 |
+
model,
|
732 |
+
path="./lora.pt",
|
733 |
+
target_replace_module=DEFAULT_TARGET_REPLACE,
|
734 |
+
flag=None
|
735 |
+
):
|
736 |
+
weights = []
|
737 |
+
for _up, _down in extract_lora_ups_down(
|
738 |
+
model, target_replace_module=target_replace_module
|
739 |
+
):
|
740 |
+
weights.append(_up.weight.to("cpu").to(torch.float32))
|
741 |
+
weights.append(_down.weight.to("cpu").to(torch.float32))
|
742 |
+
if not flag:
|
743 |
+
torch.save(weights, path)
|
744 |
+
else:
|
745 |
+
weights_new=[]
|
746 |
+
for i in range(0, len(weights), 4):
|
747 |
+
subset = weights[i+(flag-1)*2:i+(flag-1)*2+2]
|
748 |
+
weights_new.extend(subset)
|
749 |
+
torch.save(weights_new, path)
|
750 |
+
|
751 |
+
def save_lora_as_json(model, path="./lora.json"):
|
752 |
+
weights = []
|
753 |
+
for _up, _down in extract_lora_ups_down(model):
|
754 |
+
weights.append(_up.weight.detach().cpu().numpy().tolist())
|
755 |
+
weights.append(_down.weight.detach().cpu().numpy().tolist())
|
756 |
+
|
757 |
+
import json
|
758 |
+
|
759 |
+
with open(path, "w") as f:
|
760 |
+
json.dump(weights, f)
|
761 |
+
|
762 |
+
|
763 |
+
def save_safeloras_with_embeds(
|
764 |
+
modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {},
|
765 |
+
embeds: Dict[str, torch.Tensor] = {},
|
766 |
+
outpath="./lora.safetensors",
|
767 |
+
):
|
768 |
+
"""
|
769 |
+
Saves the Lora from multiple modules in a single safetensor file.
|
770 |
+
|
771 |
+
modelmap is a dictionary of {
|
772 |
+
"module name": (module, target_replace_module)
|
773 |
+
}
|
774 |
+
"""
|
775 |
+
weights = {}
|
776 |
+
metadata = {}
|
777 |
+
|
778 |
+
for name, (model, target_replace_module) in modelmap.items():
|
779 |
+
metadata[name] = json.dumps(list(target_replace_module))
|
780 |
+
|
781 |
+
for i, (_up, _down) in enumerate(
|
782 |
+
extract_lora_as_tensor(model, target_replace_module)
|
783 |
+
):
|
784 |
+
rank = _down.shape[0]
|
785 |
+
|
786 |
+
metadata[f"{name}:{i}:rank"] = str(rank)
|
787 |
+
weights[f"{name}:{i}:up"] = _up
|
788 |
+
weights[f"{name}:{i}:down"] = _down
|
789 |
+
|
790 |
+
for token, tensor in embeds.items():
|
791 |
+
metadata[token] = EMBED_FLAG
|
792 |
+
weights[token] = tensor
|
793 |
+
|
794 |
+
print(f"Saving weights to {outpath}")
|
795 |
+
safe_save(weights, outpath, metadata)
|
796 |
+
|
797 |
+
|
798 |
+
def save_safeloras(
|
799 |
+
modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {},
|
800 |
+
outpath="./lora.safetensors",
|
801 |
+
):
|
802 |
+
return save_safeloras_with_embeds(modelmap=modelmap, outpath=outpath)
|
803 |
+
|
804 |
+
|
805 |
+
def convert_loras_to_safeloras_with_embeds(
|
806 |
+
modelmap: Dict[str, Tuple[str, Set[str], int]] = {},
|
807 |
+
embeds: Dict[str, torch.Tensor] = {},
|
808 |
+
outpath="./lora.safetensors",
|
809 |
+
):
|
810 |
+
"""
|
811 |
+
Converts the Lora from multiple pytorch .pt files into a single safetensor file.
|
812 |
+
|
813 |
+
modelmap is a dictionary of {
|
814 |
+
"module name": (pytorch_model_path, target_replace_module, rank)
|
815 |
+
}
|
816 |
+
"""
|
817 |
+
|
818 |
+
weights = {}
|
819 |
+
metadata = {}
|
820 |
+
|
821 |
+
for name, (path, target_replace_module, r) in modelmap.items():
|
822 |
+
metadata[name] = json.dumps(list(target_replace_module))
|
823 |
+
|
824 |
+
lora = torch.load(path)
|
825 |
+
for i, weight in enumerate(lora):
|
826 |
+
is_up = i % 2 == 0
|
827 |
+
i = i // 2
|
828 |
+
|
829 |
+
if is_up:
|
830 |
+
metadata[f"{name}:{i}:rank"] = str(r)
|
831 |
+
weights[f"{name}:{i}:up"] = weight
|
832 |
+
else:
|
833 |
+
weights[f"{name}:{i}:down"] = weight
|
834 |
+
|
835 |
+
for token, tensor in embeds.items():
|
836 |
+
metadata[token] = EMBED_FLAG
|
837 |
+
weights[token] = tensor
|
838 |
+
|
839 |
+
print(f"Saving weights to {outpath}")
|
840 |
+
safe_save(weights, outpath, metadata)
|
841 |
+
|
842 |
+
|
843 |
+
def convert_loras_to_safeloras(
|
844 |
+
modelmap: Dict[str, Tuple[str, Set[str], int]] = {},
|
845 |
+
outpath="./lora.safetensors",
|
846 |
+
):
|
847 |
+
convert_loras_to_safeloras_with_embeds(modelmap=modelmap, outpath=outpath)
|
848 |
+
|
849 |
+
|
850 |
+
def parse_safeloras(
|
851 |
+
safeloras,
|
852 |
+
) -> Dict[str, Tuple[List[nn.parameter.Parameter], List[int], List[str]]]:
|
853 |
+
"""
|
854 |
+
Converts a loaded safetensor file that contains a set of module Loras
|
855 |
+
into Parameters and other information
|
856 |
+
|
857 |
+
Output is a dictionary of {
|
858 |
+
"module name": (
|
859 |
+
[list of weights],
|
860 |
+
[list of ranks],
|
861 |
+
target_replacement_modules
|
862 |
+
)
|
863 |
+
}
|
864 |
+
"""
|
865 |
+
loras = {}
|
866 |
+
metadata = safeloras.metadata()
|
867 |
+
|
868 |
+
get_name = lambda k: k.split(":")[0]
|
869 |
+
|
870 |
+
keys = list(safeloras.keys())
|
871 |
+
keys.sort(key=get_name)
|
872 |
+
|
873 |
+
for name, module_keys in groupby(keys, get_name):
|
874 |
+
info = metadata.get(name)
|
875 |
+
|
876 |
+
if not info:
|
877 |
+
raise ValueError(
|
878 |
+
f"Tensor {name} has no metadata - is this a Lora safetensor?"
|
879 |
+
)
|
880 |
+
|
881 |
+
# Skip Textual Inversion embeds
|
882 |
+
if info == EMBED_FLAG:
|
883 |
+
continue
|
884 |
+
|
885 |
+
# Handle Loras
|
886 |
+
# Extract the targets
|
887 |
+
target = json.loads(info)
|
888 |
+
|
889 |
+
# Build the result lists - Python needs us to preallocate lists to insert into them
|
890 |
+
module_keys = list(module_keys)
|
891 |
+
ranks = [4] * (len(module_keys) // 2)
|
892 |
+
weights = [None] * len(module_keys)
|
893 |
+
|
894 |
+
for key in module_keys:
|
895 |
+
# Split the model name and index out of the key
|
896 |
+
_, idx, direction = key.split(":")
|
897 |
+
idx = int(idx)
|
898 |
+
|
899 |
+
# Add the rank
|
900 |
+
ranks[idx] = int(metadata[f"{name}:{idx}:rank"])
|
901 |
+
|
902 |
+
# Insert the weight into the list
|
903 |
+
idx = idx * 2 + (1 if direction == "down" else 0)
|
904 |
+
weights[idx] = nn.parameter.Parameter(safeloras.get_tensor(key))
|
905 |
+
|
906 |
+
loras[name] = (weights, ranks, target)
|
907 |
+
|
908 |
+
return loras
|
909 |
+
|
910 |
+
|
911 |
+
def parse_safeloras_embeds(
|
912 |
+
safeloras,
|
913 |
+
) -> Dict[str, torch.Tensor]:
|
914 |
+
"""
|
915 |
+
Converts a loaded safetensor file that contains Textual Inversion embeds into
|
916 |
+
a dictionary of embed_token: Tensor
|
917 |
+
"""
|
918 |
+
embeds = {}
|
919 |
+
metadata = safeloras.metadata()
|
920 |
+
|
921 |
+
for key in safeloras.keys():
|
922 |
+
# Only handle Textual Inversion embeds
|
923 |
+
meta = metadata.get(key)
|
924 |
+
if not meta or meta != EMBED_FLAG:
|
925 |
+
continue
|
926 |
+
|
927 |
+
embeds[key] = safeloras.get_tensor(key)
|
928 |
+
|
929 |
+
return embeds
|
930 |
+
|
931 |
+
|
932 |
+
def load_safeloras(path, device="cpu"):
|
933 |
+
safeloras = safe_open(path, framework="pt", device=device)
|
934 |
+
return parse_safeloras(safeloras)
|
935 |
+
|
936 |
+
|
937 |
+
def load_safeloras_embeds(path, device="cpu"):
|
938 |
+
safeloras = safe_open(path, framework="pt", device=device)
|
939 |
+
return parse_safeloras_embeds(safeloras)
|
940 |
+
|
941 |
+
|
942 |
+
def load_safeloras_both(path, device="cpu"):
|
943 |
+
safeloras = safe_open(path, framework="pt", device=device)
|
944 |
+
return parse_safeloras(safeloras), parse_safeloras_embeds(safeloras)
|
945 |
+
|
946 |
+
|
947 |
+
def collapse_lora(model, alpha=1.0):
|
948 |
+
|
949 |
+
for _module, name, _child_module in _find_modules(
|
950 |
+
model,
|
951 |
+
UNET_EXTENDED_TARGET_REPLACE | TEXT_ENCODER_EXTENDED_TARGET_REPLACE,
|
952 |
+
search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d],
|
953 |
+
):
|
954 |
+
|
955 |
+
if isinstance(_child_module, LoraInjectedLinear):
|
956 |
+
print("Collapsing Lin Lora in", name)
|
957 |
+
|
958 |
+
_child_module.linear.weight = nn.Parameter(
|
959 |
+
_child_module.linear.weight.data
|
960 |
+
+ alpha
|
961 |
+
* (
|
962 |
+
_child_module.lora_up.weight.data
|
963 |
+
@ _child_module.lora_down.weight.data
|
964 |
+
)
|
965 |
+
.type(_child_module.linear.weight.dtype)
|
966 |
+
.to(_child_module.linear.weight.device)
|
967 |
+
)
|
968 |
+
|
969 |
+
else:
|
970 |
+
print("Collapsing Conv Lora in", name)
|
971 |
+
_child_module.conv.weight = nn.Parameter(
|
972 |
+
_child_module.conv.weight.data
|
973 |
+
+ alpha
|
974 |
+
* (
|
975 |
+
_child_module.lora_up.weight.data.flatten(start_dim=1)
|
976 |
+
@ _child_module.lora_down.weight.data.flatten(start_dim=1)
|
977 |
+
)
|
978 |
+
.reshape(_child_module.conv.weight.data.shape)
|
979 |
+
.type(_child_module.conv.weight.dtype)
|
980 |
+
.to(_child_module.conv.weight.device)
|
981 |
+
)
|
982 |
+
|
983 |
+
|
984 |
+
def monkeypatch_or_replace_lora(
|
985 |
+
model,
|
986 |
+
loras,
|
987 |
+
target_replace_module=DEFAULT_TARGET_REPLACE,
|
988 |
+
r: Union[int, List[int]] = 4,
|
989 |
+
):
|
990 |
+
for _module, name, _child_module in _find_modules(
|
991 |
+
model, target_replace_module, search_class=[nn.Linear, LoraInjectedLinear]
|
992 |
+
):
|
993 |
+
_source = (
|
994 |
+
_child_module.linear
|
995 |
+
if isinstance(_child_module, LoraInjectedLinear)
|
996 |
+
else _child_module
|
997 |
+
)
|
998 |
+
|
999 |
+
weight = _source.weight
|
1000 |
+
bias = _source.bias
|
1001 |
+
_tmp = LoraInjectedLinear(
|
1002 |
+
_source.in_features,
|
1003 |
+
_source.out_features,
|
1004 |
+
_source.bias is not None,
|
1005 |
+
r=r.pop(0) if isinstance(r, list) else r,
|
1006 |
+
)
|
1007 |
+
_tmp.linear.weight = weight
|
1008 |
+
|
1009 |
+
if bias is not None:
|
1010 |
+
_tmp.linear.bias = bias
|
1011 |
+
|
1012 |
+
# switch the module
|
1013 |
+
_module._modules[name] = _tmp
|
1014 |
+
|
1015 |
+
up_weight = loras.pop(0)
|
1016 |
+
down_weight = loras.pop(0)
|
1017 |
+
|
1018 |
+
_module._modules[name].lora_up.weight = nn.Parameter(
|
1019 |
+
up_weight.type(weight.dtype)
|
1020 |
+
)
|
1021 |
+
_module._modules[name].lora_down.weight = nn.Parameter(
|
1022 |
+
down_weight.type(weight.dtype)
|
1023 |
+
)
|
1024 |
+
|
1025 |
+
_module._modules[name].to(weight.device)
|
1026 |
+
|
1027 |
+
|
1028 |
+
def monkeypatch_or_replace_lora_extended(
|
1029 |
+
model,
|
1030 |
+
loras,
|
1031 |
+
target_replace_module=DEFAULT_TARGET_REPLACE,
|
1032 |
+
r: Union[int, List[int]] = 4,
|
1033 |
+
):
|
1034 |
+
for _module, name, _child_module in _find_modules(
|
1035 |
+
model,
|
1036 |
+
target_replace_module,
|
1037 |
+
search_class=[
|
1038 |
+
nn.Linear,
|
1039 |
+
nn.Conv2d,
|
1040 |
+
nn.Conv3d,
|
1041 |
+
LoraInjectedLinear,
|
1042 |
+
LoraInjectedConv2d,
|
1043 |
+
LoraInjectedConv3d,
|
1044 |
+
],
|
1045 |
+
):
|
1046 |
+
|
1047 |
+
if (_child_module.__class__ == nn.Linear) or (
|
1048 |
+
_child_module.__class__ == LoraInjectedLinear
|
1049 |
+
):
|
1050 |
+
if len(loras[0].shape) != 2:
|
1051 |
+
continue
|
1052 |
+
|
1053 |
+
_source = (
|
1054 |
+
_child_module.linear
|
1055 |
+
if isinstance(_child_module, LoraInjectedLinear)
|
1056 |
+
else _child_module
|
1057 |
+
)
|
1058 |
+
|
1059 |
+
weight = _source.weight
|
1060 |
+
bias = _source.bias
|
1061 |
+
_tmp = LoraInjectedLinear(
|
1062 |
+
_source.in_features,
|
1063 |
+
_source.out_features,
|
1064 |
+
_source.bias is not None,
|
1065 |
+
r=r.pop(0) if isinstance(r, list) else r,
|
1066 |
+
)
|
1067 |
+
_tmp.linear.weight = weight
|
1068 |
+
|
1069 |
+
if bias is not None:
|
1070 |
+
_tmp.linear.bias = bias
|
1071 |
+
|
1072 |
+
elif (_child_module.__class__ == nn.Conv2d) or (
|
1073 |
+
_child_module.__class__ == LoraInjectedConv2d
|
1074 |
+
):
|
1075 |
+
if len(loras[0].shape) != 4:
|
1076 |
+
continue
|
1077 |
+
_source = (
|
1078 |
+
_child_module.conv
|
1079 |
+
if isinstance(_child_module, LoraInjectedConv2d)
|
1080 |
+
else _child_module
|
1081 |
+
)
|
1082 |
+
|
1083 |
+
weight = _source.weight
|
1084 |
+
bias = _source.bias
|
1085 |
+
_tmp = LoraInjectedConv2d(
|
1086 |
+
_source.in_channels,
|
1087 |
+
_source.out_channels,
|
1088 |
+
_source.kernel_size,
|
1089 |
+
_source.stride,
|
1090 |
+
_source.padding,
|
1091 |
+
_source.dilation,
|
1092 |
+
_source.groups,
|
1093 |
+
_source.bias is not None,
|
1094 |
+
r=r.pop(0) if isinstance(r, list) else r,
|
1095 |
+
)
|
1096 |
+
|
1097 |
+
_tmp.conv.weight = weight
|
1098 |
+
|
1099 |
+
if bias is not None:
|
1100 |
+
_tmp.conv.bias = bias
|
1101 |
+
|
1102 |
+
elif _child_module.__class__ == nn.Conv3d or(
|
1103 |
+
_child_module.__class__ == LoraInjectedConv3d
|
1104 |
+
):
|
1105 |
+
|
1106 |
+
if len(loras[0].shape) != 5:
|
1107 |
+
continue
|
1108 |
+
|
1109 |
+
_source = (
|
1110 |
+
_child_module.conv
|
1111 |
+
if isinstance(_child_module, LoraInjectedConv3d)
|
1112 |
+
else _child_module
|
1113 |
+
)
|
1114 |
+
|
1115 |
+
weight = _source.weight
|
1116 |
+
bias = _source.bias
|
1117 |
+
_tmp = LoraInjectedConv3d(
|
1118 |
+
_source.in_channels,
|
1119 |
+
_source.out_channels,
|
1120 |
+
bias=_source.bias is not None,
|
1121 |
+
kernel_size=_source.kernel_size,
|
1122 |
+
padding=_source.padding,
|
1123 |
+
r=r.pop(0) if isinstance(r, list) else r,
|
1124 |
+
)
|
1125 |
+
|
1126 |
+
_tmp.conv.weight = weight
|
1127 |
+
|
1128 |
+
if bias is not None:
|
1129 |
+
_tmp.conv.bias = bias
|
1130 |
+
|
1131 |
+
# switch the module
|
1132 |
+
_module._modules[name] = _tmp
|
1133 |
+
|
1134 |
+
up_weight = loras.pop(0)
|
1135 |
+
down_weight = loras.pop(0)
|
1136 |
+
|
1137 |
+
_module._modules[name].lora_up.weight = nn.Parameter(
|
1138 |
+
up_weight.type(weight.dtype)
|
1139 |
+
)
|
1140 |
+
_module._modules[name].lora_down.weight = nn.Parameter(
|
1141 |
+
down_weight.type(weight.dtype)
|
1142 |
+
)
|
1143 |
+
|
1144 |
+
_module._modules[name].to(weight.device)
|
1145 |
+
|
1146 |
+
|
1147 |
+
def monkeypatch_or_replace_safeloras(models, safeloras):
|
1148 |
+
loras = parse_safeloras(safeloras)
|
1149 |
+
|
1150 |
+
for name, (lora, ranks, target) in loras.items():
|
1151 |
+
model = getattr(models, name, None)
|
1152 |
+
|
1153 |
+
if not model:
|
1154 |
+
print(f"No model provided for {name}, contained in Lora")
|
1155 |
+
continue
|
1156 |
+
|
1157 |
+
monkeypatch_or_replace_lora_extended(model, lora, target, ranks)
|
1158 |
+
|
1159 |
+
|
1160 |
+
def monkeypatch_remove_lora(model):
|
1161 |
+
for _module, name, _child_module in _find_modules(
|
1162 |
+
model, search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d]
|
1163 |
+
):
|
1164 |
+
if isinstance(_child_module, LoraInjectedLinear):
|
1165 |
+
_source = _child_module.linear
|
1166 |
+
weight, bias = _source.weight, _source.bias
|
1167 |
+
|
1168 |
+
_tmp = nn.Linear(
|
1169 |
+
_source.in_features, _source.out_features, bias is not None
|
1170 |
+
)
|
1171 |
+
|
1172 |
+
_tmp.weight = weight
|
1173 |
+
if bias is not None:
|
1174 |
+
_tmp.bias = bias
|
1175 |
+
|
1176 |
+
else:
|
1177 |
+
_source = _child_module.conv
|
1178 |
+
weight, bias = _source.weight, _source.bias
|
1179 |
+
|
1180 |
+
if isinstance(_source, nn.Conv2d):
|
1181 |
+
_tmp = nn.Conv2d(
|
1182 |
+
in_channels=_source.in_channels,
|
1183 |
+
out_channels=_source.out_channels,
|
1184 |
+
kernel_size=_source.kernel_size,
|
1185 |
+
stride=_source.stride,
|
1186 |
+
padding=_source.padding,
|
1187 |
+
dilation=_source.dilation,
|
1188 |
+
groups=_source.groups,
|
1189 |
+
bias=bias is not None,
|
1190 |
+
)
|
1191 |
+
|
1192 |
+
_tmp.weight = weight
|
1193 |
+
if bias is not None:
|
1194 |
+
_tmp.bias = bias
|
1195 |
+
|
1196 |
+
if isinstance(_source, nn.Conv3d):
|
1197 |
+
_tmp = nn.Conv3d(
|
1198 |
+
_source.in_channels,
|
1199 |
+
_source.out_channels,
|
1200 |
+
bias=_source.bias is not None,
|
1201 |
+
kernel_size=_source.kernel_size,
|
1202 |
+
padding=_source.padding,
|
1203 |
+
)
|
1204 |
+
|
1205 |
+
_tmp.weight = weight
|
1206 |
+
if bias is not None:
|
1207 |
+
_tmp.bias = bias
|
1208 |
+
|
1209 |
+
_module._modules[name] = _tmp
|
1210 |
+
|
1211 |
+
|
1212 |
+
def monkeypatch_add_lora(
|
1213 |
+
model,
|
1214 |
+
loras,
|
1215 |
+
target_replace_module=DEFAULT_TARGET_REPLACE,
|
1216 |
+
alpha: float = 1.0,
|
1217 |
+
beta: float = 1.0,
|
1218 |
+
):
|
1219 |
+
for _module, name, _child_module in _find_modules(
|
1220 |
+
model, target_replace_module, search_class=[LoraInjectedLinear]
|
1221 |
+
):
|
1222 |
+
weight = _child_module.linear.weight
|
1223 |
+
|
1224 |
+
up_weight = loras.pop(0)
|
1225 |
+
down_weight = loras.pop(0)
|
1226 |
+
|
1227 |
+
_module._modules[name].lora_up.weight = nn.Parameter(
|
1228 |
+
up_weight.type(weight.dtype).to(weight.device) * alpha
|
1229 |
+
+ _module._modules[name].lora_up.weight.to(weight.device) * beta
|
1230 |
+
)
|
1231 |
+
_module._modules[name].lora_down.weight = nn.Parameter(
|
1232 |
+
down_weight.type(weight.dtype).to(weight.device) * alpha
|
1233 |
+
+ _module._modules[name].lora_down.weight.to(weight.device) * beta
|
1234 |
+
)
|
1235 |
+
|
1236 |
+
_module._modules[name].to(weight.device)
|
1237 |
+
|
1238 |
+
|
1239 |
+
def tune_lora_scale(model, alpha: float = 1.0):
|
1240 |
+
for _module in model.modules():
|
1241 |
+
if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d", "LoraInjectedConv3d"]:
|
1242 |
+
_module.scale = alpha
|
1243 |
+
|
1244 |
+
|
1245 |
+
def set_lora_diag(model, diag: torch.Tensor):
|
1246 |
+
for _module in model.modules():
|
1247 |
+
if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d", "LoraInjectedConv3d"]:
|
1248 |
+
_module.set_selector_from_diag(diag)
|
1249 |
+
|
1250 |
+
|
1251 |
+
def _text_lora_path(path: str) -> str:
|
1252 |
+
assert path.endswith(".pt"), "Only .pt files are supported"
|
1253 |
+
return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"])
|
1254 |
+
|
1255 |
+
|
1256 |
+
def _ti_lora_path(path: str) -> str:
|
1257 |
+
assert path.endswith(".pt"), "Only .pt files are supported"
|
1258 |
+
return ".".join(path.split(".")[:-1] + ["ti", "pt"])
|
1259 |
+
|
1260 |
+
|
1261 |
+
def apply_learned_embed_in_clip(
|
1262 |
+
learned_embeds,
|
1263 |
+
text_encoder,
|
1264 |
+
tokenizer,
|
1265 |
+
token: Optional[Union[str, List[str]]] = None,
|
1266 |
+
idempotent=False,
|
1267 |
+
):
|
1268 |
+
if isinstance(token, str):
|
1269 |
+
trained_tokens = [token]
|
1270 |
+
elif isinstance(token, list):
|
1271 |
+
assert len(learned_embeds.keys()) == len(
|
1272 |
+
token
|
1273 |
+
), "The number of tokens and the number of embeds should be the same"
|
1274 |
+
trained_tokens = token
|
1275 |
+
else:
|
1276 |
+
trained_tokens = list(learned_embeds.keys())
|
1277 |
+
|
1278 |
+
for token in trained_tokens:
|
1279 |
+
print(token)
|
1280 |
+
embeds = learned_embeds[token]
|
1281 |
+
|
1282 |
+
# cast to dtype of text_encoder
|
1283 |
+
dtype = text_encoder.get_input_embeddings().weight.dtype
|
1284 |
+
num_added_tokens = tokenizer.add_tokens(token)
|
1285 |
+
|
1286 |
+
i = 1
|
1287 |
+
if not idempotent:
|
1288 |
+
while num_added_tokens == 0:
|
1289 |
+
print(f"The tokenizer already contains the token {token}.")
|
1290 |
+
token = f"{token[:-1]}-{i}>"
|
1291 |
+
print(f"Attempting to add the token {token}.")
|
1292 |
+
num_added_tokens = tokenizer.add_tokens(token)
|
1293 |
+
i += 1
|
1294 |
+
elif num_added_tokens == 0 and idempotent:
|
1295 |
+
print(f"The tokenizer already contains the token {token}.")
|
1296 |
+
print(f"Replacing {token} embedding.")
|
1297 |
+
|
1298 |
+
# resize the token embeddings
|
1299 |
+
text_encoder.resize_token_embeddings(len(tokenizer))
|
1300 |
+
|
1301 |
+
# get the id for the token and assign the embeds
|
1302 |
+
token_id = tokenizer.convert_tokens_to_ids(token)
|
1303 |
+
text_encoder.get_input_embeddings().weight.data[token_id] = embeds
|
1304 |
+
return token
|
1305 |
+
|
1306 |
+
|
1307 |
+
def load_learned_embed_in_clip(
|
1308 |
+
learned_embeds_path,
|
1309 |
+
text_encoder,
|
1310 |
+
tokenizer,
|
1311 |
+
token: Optional[Union[str, List[str]]] = None,
|
1312 |
+
idempotent=False,
|
1313 |
+
):
|
1314 |
+
learned_embeds = torch.load(learned_embeds_path)
|
1315 |
+
apply_learned_embed_in_clip(
|
1316 |
+
learned_embeds, text_encoder, tokenizer, token, idempotent
|
1317 |
+
)
|
1318 |
+
|
1319 |
+
|
1320 |
+
def patch_pipe(
|
1321 |
+
pipe,
|
1322 |
+
maybe_unet_path,
|
1323 |
+
token: Optional[str] = None,
|
1324 |
+
r: int = 4,
|
1325 |
+
patch_unet=True,
|
1326 |
+
patch_text=True,
|
1327 |
+
patch_ti=True,
|
1328 |
+
idempotent_token=True,
|
1329 |
+
unet_target_replace_module=DEFAULT_TARGET_REPLACE,
|
1330 |
+
text_target_replace_module=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
|
1331 |
+
):
|
1332 |
+
if maybe_unet_path.endswith(".pt"):
|
1333 |
+
# torch format
|
1334 |
+
|
1335 |
+
if maybe_unet_path.endswith(".ti.pt"):
|
1336 |
+
unet_path = maybe_unet_path[:-6] + ".pt"
|
1337 |
+
elif maybe_unet_path.endswith(".text_encoder.pt"):
|
1338 |
+
unet_path = maybe_unet_path[:-16] + ".pt"
|
1339 |
+
else:
|
1340 |
+
unet_path = maybe_unet_path
|
1341 |
+
|
1342 |
+
ti_path = _ti_lora_path(unet_path)
|
1343 |
+
text_path = _text_lora_path(unet_path)
|
1344 |
+
|
1345 |
+
if patch_unet:
|
1346 |
+
print("LoRA : Patching Unet")
|
1347 |
+
monkeypatch_or_replace_lora(
|
1348 |
+
pipe.unet,
|
1349 |
+
torch.load(unet_path),
|
1350 |
+
r=r,
|
1351 |
+
target_replace_module=unet_target_replace_module,
|
1352 |
+
)
|
1353 |
+
|
1354 |
+
if patch_text:
|
1355 |
+
print("LoRA : Patching text encoder")
|
1356 |
+
monkeypatch_or_replace_lora(
|
1357 |
+
pipe.text_encoder,
|
1358 |
+
torch.load(text_path),
|
1359 |
+
target_replace_module=text_target_replace_module,
|
1360 |
+
r=r,
|
1361 |
+
)
|
1362 |
+
if patch_ti:
|
1363 |
+
print("LoRA : Patching token input")
|
1364 |
+
token = load_learned_embed_in_clip(
|
1365 |
+
ti_path,
|
1366 |
+
pipe.text_encoder,
|
1367 |
+
pipe.tokenizer,
|
1368 |
+
token=token,
|
1369 |
+
idempotent=idempotent_token,
|
1370 |
+
)
|
1371 |
+
|
1372 |
+
elif maybe_unet_path.endswith(".safetensors"):
|
1373 |
+
safeloras = safe_open(maybe_unet_path, framework="pt", device="cpu")
|
1374 |
+
monkeypatch_or_replace_safeloras(pipe, safeloras)
|
1375 |
+
tok_dict = parse_safeloras_embeds(safeloras)
|
1376 |
+
if patch_ti:
|
1377 |
+
apply_learned_embed_in_clip(
|
1378 |
+
tok_dict,
|
1379 |
+
pipe.text_encoder,
|
1380 |
+
pipe.tokenizer,
|
1381 |
+
token=token,
|
1382 |
+
idempotent=idempotent_token,
|
1383 |
+
)
|
1384 |
+
return tok_dict
|
1385 |
+
|
1386 |
+
|
1387 |
+
def train_patch_pipe(pipe, patch_unet, patch_text):
|
1388 |
+
if patch_unet:
|
1389 |
+
print("LoRA : Patching Unet")
|
1390 |
+
collapse_lora(pipe.unet)
|
1391 |
+
monkeypatch_remove_lora(pipe.unet)
|
1392 |
+
|
1393 |
+
if patch_text:
|
1394 |
+
print("LoRA : Patching text encoder")
|
1395 |
+
|
1396 |
+
collapse_lora(pipe.text_encoder)
|
1397 |
+
monkeypatch_remove_lora(pipe.text_encoder)
|
1398 |
+
|
1399 |
+
@torch.no_grad()
|
1400 |
+
def inspect_lora(model):
|
1401 |
+
moved = {}
|
1402 |
+
|
1403 |
+
for name, _module in model.named_modules():
|
1404 |
+
if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d", "LoraInjectedConv3d"]:
|
1405 |
+
ups = _module.lora_up.weight.data.clone()
|
1406 |
+
downs = _module.lora_down.weight.data.clone()
|
1407 |
+
|
1408 |
+
wght: torch.Tensor = ups.flatten(1) @ downs.flatten(1)
|
1409 |
+
|
1410 |
+
dist = wght.flatten().abs().mean().item()
|
1411 |
+
if name in moved:
|
1412 |
+
moved[name].append(dist)
|
1413 |
+
else:
|
1414 |
+
moved[name] = [dist]
|
1415 |
+
|
1416 |
+
return moved
|
1417 |
+
|
1418 |
+
|
1419 |
+
def save_all(
|
1420 |
+
unet,
|
1421 |
+
text_encoder,
|
1422 |
+
save_path,
|
1423 |
+
placeholder_token_ids=None,
|
1424 |
+
placeholder_tokens=None,
|
1425 |
+
save_lora=True,
|
1426 |
+
save_ti=True,
|
1427 |
+
target_replace_module_text=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
|
1428 |
+
target_replace_module_unet=DEFAULT_TARGET_REPLACE,
|
1429 |
+
safe_form=True,
|
1430 |
+
):
|
1431 |
+
if not safe_form:
|
1432 |
+
# save ti
|
1433 |
+
if save_ti:
|
1434 |
+
ti_path = _ti_lora_path(save_path)
|
1435 |
+
learned_embeds_dict = {}
|
1436 |
+
for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids):
|
1437 |
+
learned_embeds = text_encoder.get_input_embeddings().weight[tok_id]
|
1438 |
+
print(
|
1439 |
+
f"Current Learned Embeddings for {tok}:, id {tok_id} ",
|
1440 |
+
learned_embeds[:4],
|
1441 |
+
)
|
1442 |
+
learned_embeds_dict[tok] = learned_embeds.detach().cpu()
|
1443 |
+
|
1444 |
+
torch.save(learned_embeds_dict, ti_path)
|
1445 |
+
print("Ti saved to ", ti_path)
|
1446 |
+
|
1447 |
+
# save text encoder
|
1448 |
+
if save_lora:
|
1449 |
+
save_lora_weight(
|
1450 |
+
unet, save_path, target_replace_module=target_replace_module_unet
|
1451 |
+
)
|
1452 |
+
print("Unet saved to ", save_path)
|
1453 |
+
|
1454 |
+
save_lora_weight(
|
1455 |
+
text_encoder,
|
1456 |
+
_text_lora_path(save_path),
|
1457 |
+
target_replace_module=target_replace_module_text,
|
1458 |
+
)
|
1459 |
+
print("Text Encoder saved to ", _text_lora_path(save_path))
|
1460 |
+
|
1461 |
+
else:
|
1462 |
+
assert save_path.endswith(
|
1463 |
+
".safetensors"
|
1464 |
+
), f"Save path : {save_path} should end with .safetensors"
|
1465 |
+
|
1466 |
+
loras = {}
|
1467 |
+
embeds = {}
|
1468 |
+
|
1469 |
+
if save_lora:
|
1470 |
+
|
1471 |
+
loras["unet"] = (unet, target_replace_module_unet)
|
1472 |
+
loras["text_encoder"] = (text_encoder, target_replace_module_text)
|
1473 |
+
|
1474 |
+
if save_ti:
|
1475 |
+
for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids):
|
1476 |
+
learned_embeds = text_encoder.get_input_embeddings().weight[tok_id]
|
1477 |
+
print(
|
1478 |
+
f"Current Learned Embeddings for {tok}:, id {tok_id} ",
|
1479 |
+
learned_embeds[:4],
|
1480 |
+
)
|
1481 |
+
embeds[tok] = learned_embeds.detach().cpu()
|
1482 |
+
|
1483 |
+
save_safeloras_with_embeds(loras, embeds, save_path)
|
utils/lora_handler.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from logging import warnings
|
3 |
+
import torch
|
4 |
+
from typing import Union
|
5 |
+
from types import SimpleNamespace
|
6 |
+
from models.unet_3d_condition import UNet3DConditionModel
|
7 |
+
from transformers import CLIPTextModel
|
8 |
+
from utils.convert_diffusers_to_original_ms_text_to_video import convert_unet_state_dict, convert_text_enc_state_dict_v20
|
9 |
+
|
10 |
+
from .lora import (
|
11 |
+
extract_lora_ups_down,
|
12 |
+
inject_trainable_lora_extended,
|
13 |
+
save_lora_weight,
|
14 |
+
train_patch_pipe,
|
15 |
+
monkeypatch_or_replace_lora,
|
16 |
+
monkeypatch_or_replace_lora_extended
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
FILE_BASENAMES = ['unet', 'text_encoder']
|
21 |
+
LORA_FILE_TYPES = ['.pt', '.safetensors']
|
22 |
+
CLONE_OF_SIMO_KEYS = ['model', 'loras', 'target_replace_module', 'r']
|
23 |
+
STABLE_LORA_KEYS = ['model', 'target_module', 'search_class', 'r', 'dropout', 'lora_bias']
|
24 |
+
|
25 |
+
lora_versions = dict(
|
26 |
+
stable_lora = "stable_lora",
|
27 |
+
cloneofsimo = "cloneofsimo"
|
28 |
+
)
|
29 |
+
|
30 |
+
lora_func_types = dict(
|
31 |
+
loader = "loader",
|
32 |
+
injector = "injector"
|
33 |
+
)
|
34 |
+
|
35 |
+
lora_args = dict(
|
36 |
+
model = None,
|
37 |
+
loras = None,
|
38 |
+
target_replace_module = [],
|
39 |
+
target_module = [],
|
40 |
+
r = 4,
|
41 |
+
search_class = [torch.nn.Linear],
|
42 |
+
dropout = 0,
|
43 |
+
lora_bias = 'none'
|
44 |
+
)
|
45 |
+
|
46 |
+
LoraVersions = SimpleNamespace(**lora_versions)
|
47 |
+
LoraFuncTypes = SimpleNamespace(**lora_func_types)
|
48 |
+
|
49 |
+
LORA_VERSIONS = [LoraVersions.stable_lora, LoraVersions.cloneofsimo]
|
50 |
+
LORA_FUNC_TYPES = [LoraFuncTypes.loader, LoraFuncTypes.injector]
|
51 |
+
|
52 |
+
def filter_dict(_dict, keys=[]):
|
53 |
+
if len(keys) == 0:
|
54 |
+
assert "Keys cannot empty for filtering return dict."
|
55 |
+
|
56 |
+
for k in keys:
|
57 |
+
if k not in lora_args.keys():
|
58 |
+
assert f"{k} does not exist in available LoRA arguments"
|
59 |
+
|
60 |
+
return {k: v for k, v in _dict.items() if k in keys}
|
61 |
+
|
62 |
+
class LoraHandler(object):
|
63 |
+
def __init__(
|
64 |
+
self,
|
65 |
+
version: LORA_VERSIONS = LoraVersions.cloneofsimo,
|
66 |
+
use_unet_lora: bool = False,
|
67 |
+
use_text_lora: bool = False,
|
68 |
+
save_for_webui: bool = False,
|
69 |
+
only_for_webui: bool = False,
|
70 |
+
lora_bias: str = 'none',
|
71 |
+
unet_replace_modules: list = None,
|
72 |
+
text_encoder_replace_modules: list = None
|
73 |
+
):
|
74 |
+
self.version = version
|
75 |
+
self.lora_loader = self.get_lora_func(func_type=LoraFuncTypes.loader)
|
76 |
+
self.lora_injector = self.get_lora_func(func_type=LoraFuncTypes.injector)
|
77 |
+
self.lora_bias = lora_bias
|
78 |
+
self.use_unet_lora = use_unet_lora
|
79 |
+
self.use_text_lora = use_text_lora
|
80 |
+
self.save_for_webui = save_for_webui
|
81 |
+
self.only_for_webui = only_for_webui
|
82 |
+
self.unet_replace_modules = unet_replace_modules
|
83 |
+
self.text_encoder_replace_modules = text_encoder_replace_modules
|
84 |
+
self.use_lora = any([use_text_lora, use_unet_lora])
|
85 |
+
|
86 |
+
def is_cloneofsimo_lora(self):
|
87 |
+
return self.version == LoraVersions.cloneofsimo
|
88 |
+
|
89 |
+
|
90 |
+
def get_lora_func(self, func_type: LORA_FUNC_TYPES = LoraFuncTypes.loader):
|
91 |
+
|
92 |
+
if self.is_cloneofsimo_lora():
|
93 |
+
|
94 |
+
if func_type == LoraFuncTypes.loader:
|
95 |
+
return monkeypatch_or_replace_lora_extended
|
96 |
+
|
97 |
+
if func_type == LoraFuncTypes.injector:
|
98 |
+
return inject_trainable_lora_extended
|
99 |
+
|
100 |
+
assert "LoRA Version does not exist."
|
101 |
+
|
102 |
+
def check_lora_ext(self, lora_file: str):
|
103 |
+
return lora_file.endswith(tuple(LORA_FILE_TYPES))
|
104 |
+
|
105 |
+
def get_lora_file_path(
|
106 |
+
self,
|
107 |
+
lora_path: str,
|
108 |
+
model: Union[UNet3DConditionModel, CLIPTextModel]
|
109 |
+
):
|
110 |
+
if os.path.exists(lora_path):
|
111 |
+
lora_filenames = [fns for fns in os.listdir(lora_path)]
|
112 |
+
is_lora = self.check_lora_ext(lora_path)
|
113 |
+
|
114 |
+
is_unet = isinstance(model, UNet3DConditionModel)
|
115 |
+
is_text = isinstance(model, CLIPTextModel)
|
116 |
+
idx = 0 if is_unet else 1
|
117 |
+
|
118 |
+
base_name = FILE_BASENAMES[idx]
|
119 |
+
|
120 |
+
for lora_filename in lora_filenames:
|
121 |
+
is_lora = self.check_lora_ext(lora_filename)
|
122 |
+
if not is_lora:
|
123 |
+
continue
|
124 |
+
|
125 |
+
if base_name in lora_filename:
|
126 |
+
return os.path.join(lora_path, lora_filename)
|
127 |
+
|
128 |
+
return None
|
129 |
+
|
130 |
+
def handle_lora_load(self, file_name:str, lora_loader_args: dict = None):
|
131 |
+
self.lora_loader(**lora_loader_args)
|
132 |
+
print(f"Successfully loaded LoRA from: {file_name}")
|
133 |
+
|
134 |
+
def load_lora(self, model, lora_path: str = '', lora_loader_args: dict = None,):
|
135 |
+
try:
|
136 |
+
lora_file = self.get_lora_file_path(lora_path, model)
|
137 |
+
|
138 |
+
if lora_file is not None:
|
139 |
+
lora_loader_args.update({"lora_path": lora_file})
|
140 |
+
self.handle_lora_load(lora_file, lora_loader_args)
|
141 |
+
|
142 |
+
else:
|
143 |
+
print(f"Could not load LoRAs for {model.__class__.__name__}. Injecting new ones instead...")
|
144 |
+
|
145 |
+
except Exception as e:
|
146 |
+
print(f"An error occured while loading a LoRA file: {e}")
|
147 |
+
|
148 |
+
def get_lora_func_args(self, lora_path, use_lora, model, replace_modules, r, dropout, lora_bias, scale):
|
149 |
+
return_dict = lora_args.copy()
|
150 |
+
|
151 |
+
if self.is_cloneofsimo_lora():
|
152 |
+
return_dict = filter_dict(return_dict, keys=CLONE_OF_SIMO_KEYS)
|
153 |
+
return_dict.update({
|
154 |
+
"model": model,
|
155 |
+
"loras": self.get_lora_file_path(lora_path, model),
|
156 |
+
"target_replace_module": replace_modules,
|
157 |
+
"r": r,
|
158 |
+
"scale": scale,
|
159 |
+
"dropout_p": dropout,
|
160 |
+
})
|
161 |
+
|
162 |
+
return return_dict
|
163 |
+
|
164 |
+
def do_lora_injection(
|
165 |
+
self,
|
166 |
+
model,
|
167 |
+
replace_modules,
|
168 |
+
bias='none',
|
169 |
+
dropout=0,
|
170 |
+
r=4,
|
171 |
+
lora_loader_args=None,
|
172 |
+
):
|
173 |
+
REPLACE_MODULES = replace_modules
|
174 |
+
|
175 |
+
params = None
|
176 |
+
negation = None
|
177 |
+
is_injection_hybrid = False
|
178 |
+
|
179 |
+
if self.is_cloneofsimo_lora():
|
180 |
+
is_injection_hybrid = True
|
181 |
+
injector_args = lora_loader_args
|
182 |
+
|
183 |
+
params, negation = self.lora_injector(**injector_args) # inject_trainable_lora_extended
|
184 |
+
for _up, _down in extract_lora_ups_down(
|
185 |
+
model,
|
186 |
+
target_replace_module=REPLACE_MODULES):
|
187 |
+
|
188 |
+
if all(x is not None for x in [_up, _down]):
|
189 |
+
print(f"Lora successfully injected into {model.__class__.__name__}.")
|
190 |
+
|
191 |
+
break
|
192 |
+
|
193 |
+
return params, negation, is_injection_hybrid
|
194 |
+
|
195 |
+
return params, negation, is_injection_hybrid
|
196 |
+
|
197 |
+
def add_lora_to_model(self, use_lora, model, replace_modules, dropout=0.0, lora_path='', r=16, scale=1.0):
|
198 |
+
|
199 |
+
params = None
|
200 |
+
negation = None
|
201 |
+
|
202 |
+
lora_loader_args = self.get_lora_func_args(
|
203 |
+
lora_path,
|
204 |
+
use_lora,
|
205 |
+
model,
|
206 |
+
replace_modules,
|
207 |
+
r,
|
208 |
+
dropout,
|
209 |
+
self.lora_bias,
|
210 |
+
scale
|
211 |
+
)
|
212 |
+
|
213 |
+
if use_lora:
|
214 |
+
params, negation, is_injection_hybrid = self.do_lora_injection(
|
215 |
+
model,
|
216 |
+
replace_modules,
|
217 |
+
bias=self.lora_bias,
|
218 |
+
lora_loader_args=lora_loader_args,
|
219 |
+
dropout=dropout,
|
220 |
+
r=r
|
221 |
+
)
|
222 |
+
|
223 |
+
if not is_injection_hybrid:
|
224 |
+
self.load_lora(model, lora_path=lora_path, lora_loader_args=lora_loader_args)
|
225 |
+
|
226 |
+
params = model if params is None else params
|
227 |
+
return params, negation
|
228 |
+
|
229 |
+
def save_cloneofsimo_lora(self, model, save_path, step, flag):
|
230 |
+
|
231 |
+
def save_lora(model, name, condition, replace_modules, step, save_path, flag=None):
|
232 |
+
if condition and replace_modules is not None:
|
233 |
+
save_path = f"{save_path}/{step}_{name}.pt"
|
234 |
+
save_lora_weight(model, save_path, replace_modules, flag)
|
235 |
+
|
236 |
+
save_lora(
|
237 |
+
model.unet,
|
238 |
+
FILE_BASENAMES[0],
|
239 |
+
self.use_unet_lora,
|
240 |
+
self.unet_replace_modules,
|
241 |
+
step,
|
242 |
+
save_path,
|
243 |
+
flag
|
244 |
+
)
|
245 |
+
save_lora(
|
246 |
+
model.text_encoder,
|
247 |
+
FILE_BASENAMES[1],
|
248 |
+
self.use_text_lora,
|
249 |
+
self.text_encoder_replace_modules,
|
250 |
+
step,
|
251 |
+
save_path,
|
252 |
+
flag
|
253 |
+
)
|
254 |
+
|
255 |
+
# train_patch_pipe(model, self.use_unet_lora, self.use_text_lora)
|
256 |
+
|
257 |
+
def save_lora_weights(self, model: None, save_path: str ='',step: str = '', flag=None):
|
258 |
+
save_path = f"{save_path}/lora"
|
259 |
+
os.makedirs(save_path, exist_ok=True)
|
260 |
+
|
261 |
+
if self.is_cloneofsimo_lora():
|
262 |
+
if any([self.save_for_webui, self.only_for_webui]):
|
263 |
+
warnings.warn(
|
264 |
+
"""
|
265 |
+
You have 'save_for_webui' enabled, but are using cloneofsimo's LoRA implemention.
|
266 |
+
Only 'stable_lora' is supported for saving to a compatible webui file.
|
267 |
+
"""
|
268 |
+
)
|
269 |
+
self.save_cloneofsimo_lora(model, save_path, step, flag)
|