Joeythemonster Linoy Tsaban commited on
Commit
f05bb07
·
0 Parent(s):

Duplicate from weizmannscience/tokenflow

Browse files

Co-authored-by: Linoy Tsaban <[email protected]>

.gitattributes ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/woman-running.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ examples/running_dog.mp4 filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Tokenflow
3
+ emoji: 🐠
4
+ colorFrom: purple
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 3.41.2
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: weizmannscience/tokenflow
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from diffusers import StableDiffusionPipeline, DDIMScheduler
4
+ from utils import video_to_frames, add_dict_to_yaml_file, save_video, seed_everything
5
+ # from diffusers.utils import export_to_video
6
+ from tokenflow_pnp import TokenFlow
7
+ from preprocess_utils import *
8
+ from tokenflow_utils import *
9
+ # load sd model
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ model_id = "stabilityai/stable-diffusion-2-1-base"
12
+
13
+ # components for the Preprocessor
14
+ scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
15
+ vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", revision="fp16",
16
+ torch_dtype=torch.float16).to(device)
17
+ tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
18
+ text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", revision="fp16",
19
+ torch_dtype=torch.float16).to(device)
20
+ unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", revision="fp16",
21
+ torch_dtype=torch.float16).to(device)
22
+
23
+ # pipe for TokenFlow
24
+ tokenflow_pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
25
+ tokenflow_pipe.enable_xformers_memory_efficient_attention()
26
+
27
+ def randomize_seed_fn():
28
+ seed = random.randint(0, np.iinfo(np.int32).max)
29
+ return seed
30
+
31
+ def reset_do_inversion():
32
+ return True
33
+
34
+ def get_example():
35
+ case = [
36
+ [
37
+ 'examples/wolf.mp4',
38
+ ],
39
+ [
40
+ 'examples/woman-running.mp4',
41
+ ],
42
+ [
43
+ 'examples/cutting_bread.mp4',
44
+ ],
45
+ [
46
+ 'examples/running_dog.mp4',
47
+ ]
48
+ ]
49
+ return case
50
+
51
+
52
+ def prep(config):
53
+ # timesteps to save
54
+ if config["sd_version"] == '2.1':
55
+ model_key = "stabilityai/stable-diffusion-2-1-base"
56
+ elif config["sd_version"] == '2.0':
57
+ model_key = "stabilityai/stable-diffusion-2-base"
58
+ elif config["sd_version"] == '1.5' or config["sd_version"] == 'ControlNet':
59
+ model_key = "runwayml/stable-diffusion-v1-5"
60
+ elif config["sd_version"] == 'depth':
61
+ model_key = "stabilityai/stable-diffusion-2-depth"
62
+ toy_scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")
63
+ toy_scheduler.set_timesteps(config["save_steps"])
64
+ print("config[save_steps]", config["save_steps"])
65
+ timesteps_to_save, num_inference_steps = get_timesteps(toy_scheduler, num_inference_steps=config["save_steps"],
66
+ strength=1.0,
67
+ device=device)
68
+ print("YOOOO timesteps to save", timesteps_to_save)
69
+
70
+ # seed_everything(config["seed"])
71
+ if not config["frames"]: # original non demo setting
72
+ save_path = os.path.join(config["save_dir"],
73
+ f'sd_{config["sd_version"]}',
74
+ Path(config["data_path"]).stem,
75
+ f'steps_{config["steps"]}',
76
+ f'nframes_{config["n_frames"]}')
77
+ os.makedirs(os.path.join(save_path, f'latents'), exist_ok=True)
78
+ add_dict_to_yaml_file(os.path.join(config["save_dir"], 'inversion_prompts.yaml'), Path(config["data_path"]).stem, config["inversion_prompt"])
79
+ # save inversion prompt in a txt file
80
+ with open(os.path.join(save_path, 'inversion_prompt.txt'), 'w') as f:
81
+ f.write(config["inversion_prompt"])
82
+ else:
83
+ save_path = None
84
+
85
+ model = Preprocess(device, config,
86
+ vae=vae,
87
+ text_encoder=text_encoder,
88
+ scheduler=scheduler,
89
+ tokenizer=tokenizer,
90
+ unet=unet)
91
+ print(type(model.config["batch_size"]))
92
+ frames, latents, total_inverted_latents, rgb_reconstruction = model.extract_latents(
93
+ num_steps=model.config["steps"],
94
+ save_path=save_path,
95
+ batch_size=model.config["batch_size"],
96
+ timesteps_to_save=timesteps_to_save,
97
+ inversion_prompt=model.config["inversion_prompt"],
98
+ )
99
+
100
+
101
+ return frames, latents, total_inverted_latents, rgb_reconstruction
102
+
103
+ def preprocess_and_invert(input_video,
104
+ frames,
105
+ latents,
106
+ inverted_latents,
107
+ seed,
108
+ randomize_seed,
109
+ do_inversion,
110
+ # save_dir: str = "latents",
111
+ steps,
112
+ n_timesteps = 50,
113
+ batch_size: int = 8,
114
+ n_frames: int = 40,
115
+ inversion_prompt:str = '',
116
+
117
+ ):
118
+ sd_version = "2.1"
119
+ height = 512
120
+ weidth: int = 512
121
+ print("n timesteps", n_timesteps)
122
+ if do_inversion or randomize_seed:
123
+ preprocess_config = {}
124
+ preprocess_config['H'] = height
125
+ preprocess_config['W'] = weidth
126
+ preprocess_config['save_dir'] = 'latents'
127
+ preprocess_config['sd_version'] = sd_version
128
+ preprocess_config['steps'] = steps
129
+ preprocess_config['batch_size'] = batch_size
130
+ preprocess_config['save_steps'] = int(n_timesteps)
131
+ preprocess_config['n_frames'] = n_frames
132
+ preprocess_config['seed'] = seed
133
+ preprocess_config['inversion_prompt'] = inversion_prompt
134
+ preprocess_config['frames'] = video_to_frames(input_video)
135
+ preprocess_config['data_path'] = input_video.split(".")[0]
136
+
137
+
138
+ if randomize_seed:
139
+ seed = randomize_seed_fn()
140
+ seed_everything(seed)
141
+
142
+ frames, latents, total_inverted_latents, rgb_reconstruction = prep(preprocess_config)
143
+ print(total_inverted_latents.keys())
144
+ print(len(total_inverted_latents.keys()))
145
+ frames = gr.State(value=frames)
146
+ latents = gr.State(value=latents)
147
+ inverted_latents = gr.State(value=total_inverted_latents)
148
+ do_inversion = False
149
+
150
+ return frames, latents, inverted_latents, do_inversion
151
+
152
+
153
+ def edit_with_pnp(input_video,
154
+ frames,
155
+ latents,
156
+ inverted_latents,
157
+ seed,
158
+ randomize_seed,
159
+ do_inversion,
160
+ steps,
161
+ prompt: str = "a marble sculpture of a woman running, Venus de Milo",
162
+ # negative_prompt: str = "ugly, blurry, low res, unrealistic, unaesthetic",
163
+ pnp_attn_t: float = 0.5,
164
+ pnp_f_t: float = 0.8,
165
+ batch_size: int = 8, #needs to be the same as for preprocess
166
+ n_frames: int = 40,#needs to be the same as for preprocess
167
+ n_timesteps: int = 50,
168
+ gudiance_scale: float = 7.5,
169
+ inversion_prompt: str = "", #needs to be the same as for preprocess
170
+ n_fps: int = 10,
171
+ progress=gr.Progress(track_tqdm=True)
172
+ ):
173
+ config = {}
174
+
175
+ config["sd_version"] = "2.1"
176
+ config["device"] = device
177
+ config["n_timesteps"] = int(n_timesteps)
178
+ config["n_frames"] = n_frames
179
+ config["batch_size"] = batch_size
180
+ config["guidance_scale"] = gudiance_scale
181
+ config["prompt"] = prompt
182
+ config["negative_prompt"] = "ugly, blurry, low res, unrealistic, unaesthetic",
183
+ config["pnp_attn_t"] = pnp_attn_t
184
+ config["pnp_f_t"] = pnp_f_t
185
+ config["pnp_inversion_prompt"] = inversion_prompt
186
+
187
+
188
+ if do_inversion:
189
+ frames, latents, inverted_latents, do_inversion = preprocess_and_invert(
190
+ input_video,
191
+ frames,
192
+ latents,
193
+ inverted_latents,
194
+ seed,
195
+ randomize_seed,
196
+ do_inversion,
197
+ steps,
198
+ n_timesteps,
199
+ batch_size,
200
+ n_frames,
201
+ inversion_prompt)
202
+ do_inversion = False
203
+
204
+
205
+ if randomize_seed:
206
+ seed = randomize_seed_fn()
207
+ seed_everything(seed)
208
+
209
+
210
+ editor = TokenFlow(config=config,pipe=tokenflow_pipe, frames=frames.value, inverted_latents=inverted_latents.value)
211
+ edited_frames = editor.edit_video()
212
+
213
+ save_video(edited_frames, 'tokenflow_PnP_fps_30.mp4', fps=n_fps)
214
+ # path = export_to_video(edited_frames)
215
+ return 'tokenflow_PnP_fps_30.mp4', frames, latents, inverted_latents, do_inversion
216
+
217
+ ########
218
+ # demo #
219
+ ########
220
+
221
+
222
+ intro = """
223
+ <div style="text-align:center">
224
+ <h1 style="font-weight: 1400; text-align: center; margin-bottom: 7px;">
225
+ TokenFlow - <small>Temporally consistent video editing</small>
226
+ </h1>
227
+ <span>[<a target="_blank" href="https://diffusion-tokenflow.github.io">Project page</a>], [<a target="_blank" href="https://github.com/omerbt/TokenFlow">GitHub</a>], [<a target="_blank" href="https://huggingface.co/papers/2307.10373">Paper</a>]</span>
228
+ <div style="display:flex; justify-content: center;margin-top: 0.5em">Each edit takes ~5 min <a href="https://huggingface.co/weizmannscience/tokenflow?duplicate=true" target="_blank">
229
+ <img style="margin-top: 0em; margin-bottom: 0em; margin-left: 0.5em" src="https://bit.ly/3CWLGkA" alt="Duplicate Space"></a></div>
230
+ </div>
231
+ """
232
+
233
+
234
+
235
+ with gr.Blocks(css="style.css") as demo:
236
+
237
+ gr.HTML(intro)
238
+ frames = gr.State()
239
+ inverted_latents = gr.State()
240
+ latents = gr.State()
241
+ do_inversion = gr.State(value=True)
242
+
243
+ with gr.Row():
244
+ input_video = gr.Video(label="Input Video", interactive=True, elem_id="input_video")
245
+ output_video = gr.Video(label="Edited Video", interactive=False, elem_id="output_video")
246
+ input_video.style(height=365, width=365)
247
+ output_video.style(height=365, width=365)
248
+
249
+
250
+ with gr.Row():
251
+ prompt = gr.Textbox(
252
+ label="Describe your edited video",
253
+ max_lines=1, value=""
254
+ )
255
+ # with gr.Group(visible=False) as share_btn_container:
256
+ # with gr.Group(elem_id="share-btn-container"):
257
+ # community_icon = gr.HTML(community_icon_html, visible=True)
258
+ # loading_icon = gr.HTML(loading_icon_html, visible=False)
259
+ # share_button = gr.Button("Share to community", elem_id="share-btn", visible=True)
260
+
261
+
262
+ # with gr.Row():
263
+ # inversion_progress = gr.Textbox(visible=False, label="Inversion progress")
264
+
265
+ with gr.Row():
266
+ run_button = gr.Button("Edit your video!", visible=True)
267
+
268
+ with gr.Accordion("Advanced Options", open=False):
269
+ with gr.Tabs() as tabs:
270
+ with gr.TabItem('General options'):
271
+ with gr.Row():
272
+ with gr.Column(min_width=100):
273
+ seed = gr.Number(value=0, precision=0, label="Seed", interactive=True)
274
+ randomize_seed = gr.Checkbox(label='Randomize seed', value=False)
275
+ gudiance_scale = gr.Slider(label='Guidance Scale', minimum=1, maximum=30,
276
+ value=7.5, step=0.5, interactive=True)
277
+ steps = gr.Slider(label='Inversion steps', minimum=10, maximum=500,
278
+ value=500, step=1, interactive=True)
279
+
280
+ with gr.Column(min_width=100):
281
+ inversion_prompt = gr.Textbox(lines=1, label="Inversion prompt", interactive=True, placeholder="")
282
+ batch_size = gr.Slider(label='Batch size', minimum=1, maximum=10,
283
+ value=8, step=1, interactive=True)
284
+ n_frames = gr.Slider(label='Num frames', minimum=2, maximum=200,
285
+ value=24, step=1, interactive=True)
286
+ n_timesteps = gr.Slider(label='Diffusion steps', minimum=25, maximum=100,
287
+ value=50, step=25, interactive=True)
288
+ n_fps = gr.Slider(label='Frames per second', minimum=1, maximum=60,
289
+ value=10, step=1, interactive=True)
290
+
291
+ with gr.TabItem('Plug-and-Play Parameters'):
292
+ with gr.Column(min_width=100):
293
+ pnp_attn_t = gr.Slider(label='pnp attention threshold', minimum=0, maximum=1,
294
+ value=0.5, step=0.5, interactive=True)
295
+ pnp_f_t = gr.Slider(label='pnp feature threshold', minimum=0, maximum=1,
296
+ value=0.8, step=0.05, interactive=True)
297
+
298
+
299
+ input_video.change(
300
+ fn = reset_do_inversion,
301
+ outputs = [do_inversion],
302
+ queue = False)
303
+
304
+ inversion_prompt.change(
305
+ fn = reset_do_inversion,
306
+ outputs = [do_inversion],
307
+ queue = False)
308
+
309
+ randomize_seed.change(
310
+ fn = reset_do_inversion,
311
+ outputs = [do_inversion],
312
+ queue = False)
313
+
314
+ seed.change(
315
+ fn = reset_do_inversion,
316
+ outputs = [do_inversion],
317
+ queue = False)
318
+
319
+
320
+
321
+ input_video.upload(
322
+ fn = reset_do_inversion,
323
+ outputs = [do_inversion],
324
+ queue = False).then(fn = preprocess_and_invert,
325
+ inputs = [input_video,
326
+ frames,
327
+ latents,
328
+ inverted_latents,
329
+ seed,
330
+ randomize_seed,
331
+ do_inversion,
332
+ steps,
333
+ n_timesteps,
334
+ batch_size,
335
+ n_frames,
336
+ inversion_prompt
337
+ ],
338
+ outputs = [frames,
339
+ latents,
340
+ inverted_latents,
341
+ do_inversion
342
+
343
+ ])
344
+
345
+ run_button.click(fn = edit_with_pnp,
346
+ inputs = [input_video,
347
+ frames,
348
+ latents,
349
+ inverted_latents,
350
+ seed,
351
+ randomize_seed,
352
+ do_inversion,
353
+ steps,
354
+ prompt,
355
+ pnp_attn_t,
356
+ pnp_f_t,
357
+ batch_size,
358
+ n_frames,
359
+ n_timesteps,
360
+ gudiance_scale,
361
+ inversion_prompt,
362
+ n_fps ],
363
+ outputs = [output_video, frames, latents, inverted_latents, do_inversion]
364
+ )
365
+
366
+ gr.Examples(
367
+ examples=get_example(),
368
+ label='Examples',
369
+ inputs=[input_video],
370
+ outputs=[output_video]
371
+ )
372
+
373
+ demo.queue()
374
+ demo.launch()
examples/cutting_bread.mp4 ADDED
Binary file (848 kB). View file
 
examples/rocket_kittens.mp4 ADDED
Binary file (561 kB). View file
 
examples/running_dog.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:904a4f165d70dd46164e27f3b279869debcfb01797aa4c3f0f3ae6fabea8631d
3
+ size 1443849
examples/wolf.mp4 ADDED
Binary file (379 kB). View file
 
examples/woman-running.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c622e40ca8b01a6a678eae719be376cb7969a2b7460955706a214720104880b7
3
+ size 1285788
preprocess_utils.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import CLIPTextModel, CLIPTokenizer, logging
2
+ from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler
3
+ # suppress partial model loading warning
4
+ logging.set_verbosity_error()
5
+
6
+ import os
7
+ from tqdm import tqdm, trange
8
+ import torch
9
+ import torch.nn as nn
10
+ import argparse
11
+ from torchvision.io import write_video
12
+ from pathlib import Path
13
+ from utils import *
14
+ import torchvision.transforms as T
15
+
16
+
17
+ def get_timesteps(scheduler, num_inference_steps, strength, device):
18
+ # get the original timestep using init_timestep
19
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
20
+
21
+ t_start = max(num_inference_steps - init_timestep, 0)
22
+ timesteps = scheduler.timesteps[t_start:]
23
+
24
+ return timesteps, num_inference_steps - t_start
25
+
26
+
27
+ class Preprocess(nn.Module):
28
+ def __init__(self, device, opt, vae, tokenizer, text_encoder, unet,scheduler, hf_key=None):
29
+ super().__init__()
30
+
31
+ self.device = device
32
+ self.sd_version = opt["sd_version"]
33
+ self.use_depth = False
34
+ self.config = opt
35
+
36
+ print(f'[INFO] loading stable diffusion...')
37
+ if hf_key is not None:
38
+ print(f'[INFO] using hugging face custom model key: {hf_key}')
39
+ model_key = hf_key
40
+ elif self.sd_version == '2.1':
41
+ model_key = "stabilityai/stable-diffusion-2-1-base"
42
+ elif self.sd_version == '2.0':
43
+ model_key = "stabilityai/stable-diffusion-2-base"
44
+ elif self.sd_version == '1.5' or self.sd_version == 'ControlNet':
45
+ model_key = "runwayml/stable-diffusion-v1-5"
46
+ elif self.sd_version == 'depth':
47
+ model_key = "stabilityai/stable-diffusion-2-depth"
48
+ else:
49
+ raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.')
50
+
51
+ self.model_key = model_key
52
+
53
+ # Create model
54
+ # self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae", revision="fp16",
55
+ # torch_dtype=torch.float16).to(self.device)
56
+ # self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer")
57
+ # self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder="text_encoder", revision="fp16",
58
+ # torch_dtype=torch.float16).to(self.device)
59
+ # self.unet = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet", revision="fp16",
60
+ # torch_dtype=torch.float16).to(self.device)
61
+
62
+ self.vae = vae
63
+ self.tokenizer = tokenizer
64
+ self.text_encoder = text_encoder
65
+ self.unet = unet
66
+ self.scheduler=scheduler
67
+ self.total_inverted_latents = {}
68
+
69
+ self.paths, self.frames, self.latents = self.get_data(self.config["data_path"], self.config["n_frames"])
70
+ print("self.frames", self.frames.shape)
71
+ print("self.latents", self.latents.shape)
72
+
73
+
74
+ if self.sd_version == 'ControlNet':
75
+ from diffusers import ControlNetModel, StableDiffusionControlNetPipeline
76
+ controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16).to(self.device)
77
+ control_pipe = StableDiffusionControlNetPipeline.from_pretrained(
78
+ "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
79
+ ).to(self.device)
80
+ self.unet = control_pipe.unet
81
+ self.controlnet = control_pipe.controlnet
82
+ self.canny_cond = self.get_canny_cond()
83
+ elif self.sd_version == 'depth':
84
+ self.depth_maps = self.prepare_depth_maps()
85
+ self.scheduler = scheduler
86
+
87
+ self.unet.enable_xformers_memory_efficient_attention()
88
+ print(f'[INFO] loaded stable diffusion!')
89
+
90
+
91
+ @torch.no_grad()
92
+ def prepare_depth_maps(self, model_type='DPT_Large', device='cuda'):
93
+ depth_maps = []
94
+ midas = torch.hub.load("intel-isl/MiDaS", model_type)
95
+ midas.to(device)
96
+ midas.eval()
97
+
98
+ midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
99
+
100
+ if model_type == "DPT_Large" or model_type == "DPT_Hybrid":
101
+ transform = midas_transforms.dpt_transform
102
+ else:
103
+ transform = midas_transforms.small_transform
104
+
105
+ for i in range(len(self.paths)):
106
+ img = cv2.imread(self.paths[i])
107
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
108
+
109
+ latent_h = img.shape[0] // 8
110
+ latent_w = img.shape[1] // 8
111
+
112
+ input_batch = transform(img).to(device)
113
+ prediction = midas(input_batch)
114
+
115
+ depth_map = torch.nn.functional.interpolate(
116
+ prediction.unsqueeze(1),
117
+ size=(latent_h, latent_w),
118
+ mode="bicubic",
119
+ align_corners=False,
120
+ )
121
+ depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
122
+ depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
123
+ depth_map = 2.0 * (depth_map - depth_min) / (depth_max - depth_min) - 1.0
124
+ depth_maps.append(depth_map)
125
+
126
+ return torch.cat(depth_maps).to(self.device).to(torch.float16)
127
+
128
+ @torch.no_grad()
129
+ def get_canny_cond(self):
130
+ canny_cond = []
131
+ for image in self.frames.cpu().permute(0, 2, 3, 1):
132
+ image = np.uint8(np.array(255 * image))
133
+ low_threshold = 100
134
+ high_threshold = 200
135
+
136
+ image = cv2.Canny(image, low_threshold, high_threshold)
137
+ image = image[:, :, None]
138
+ image = np.concatenate([image, image, image], axis=2)
139
+ image = torch.from_numpy((image.astype(np.float32) / 255.0))
140
+ canny_cond.append(image)
141
+ canny_cond = torch.stack(canny_cond).permute(0, 3, 1, 2).to(self.device).to(torch.float16)
142
+ return canny_cond
143
+
144
+ def controlnet_pred(self, latent_model_input, t, text_embed_input, controlnet_cond):
145
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
146
+ latent_model_input,
147
+ t,
148
+ encoder_hidden_states=text_embed_input,
149
+ controlnet_cond=controlnet_cond,
150
+ conditioning_scale=1,
151
+ return_dict=False,
152
+ )
153
+
154
+ # apply the denoising network
155
+ noise_pred = self.unet(
156
+ latent_model_input,
157
+ t,
158
+ encoder_hidden_states=text_embed_input,
159
+ cross_attention_kwargs={},
160
+ down_block_additional_residuals=down_block_res_samples,
161
+ mid_block_additional_residual=mid_block_res_sample,
162
+ return_dict=False,
163
+ )[0]
164
+ return noise_pred
165
+
166
+ @torch.no_grad()
167
+ def get_text_embeds(self, prompt, negative_prompt, device="cuda"):
168
+ text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
169
+ truncation=True, return_tensors='pt')
170
+ text_embeddings = self.text_encoder(text_input.input_ids.to(device))[0]
171
+ uncond_input = self.tokenizer(negative_prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
172
+ return_tensors='pt')
173
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
174
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
175
+ return text_embeddings
176
+
177
+ @torch.no_grad()
178
+ def decode_latents(self, latents):
179
+ decoded = []
180
+ batch_size = 8
181
+ for b in range(0, latents.shape[0], batch_size):
182
+ latents_batch = 1 / 0.18215 * latents[b:b + batch_size]
183
+ imgs = self.vae.decode(latents_batch).sample
184
+ imgs = (imgs / 2 + 0.5).clamp(0, 1)
185
+ decoded.append(imgs)
186
+ return torch.cat(decoded)
187
+
188
+ @torch.no_grad()
189
+ def encode_imgs(self, imgs, batch_size=10, deterministic=True):
190
+ imgs = 2 * imgs - 1
191
+ latents = []
192
+ for i in range(0, len(imgs), batch_size):
193
+ posterior = self.vae.encode(imgs[i:i + batch_size]).latent_dist
194
+ latent = posterior.mean if deterministic else posterior.sample()
195
+ latents.append(latent * 0.18215)
196
+ latents = torch.cat(latents)
197
+ return latents
198
+
199
+ def get_data(self, frames_path, n_frames):
200
+
201
+ # load frames
202
+ if not self.config["frames"]:
203
+ paths = [f"{frames_path}/%05d.png" % i for i in range(n_frames)]
204
+ print(paths)
205
+ if not os.path.exists(paths[0]):
206
+ paths = [f"{frames_path}/%05d.jpg" % i for i in range(n_frames)]
207
+ self.paths = paths
208
+ frames = [Image.open(path).convert('RGB') for path in paths]
209
+ if frames[0].size[0] == frames[0].size[1]:
210
+ frames = [frame.resize((512, 512), resample=Image.Resampling.LANCZOS) for frame in frames]
211
+ else:
212
+ frames = self.config["frames"][:n_frames]
213
+ frames = torch.stack([T.ToTensor()(frame) for frame in frames]).to(torch.float16).to(self.device)
214
+ # encode to latents
215
+ latents = self.encode_imgs(frames, deterministic=True).to(torch.float16).to(self.device)
216
+ print("frames", frames.shape)
217
+ print("latents", latents.shape)
218
+
219
+ if not self.config["frames"]:
220
+ return paths, frames, latents
221
+ else:
222
+ return None, frames, latents
223
+
224
+ @torch.no_grad()
225
+ def ddim_inversion(self, cond, latent_frames, save_path, batch_size, save_latents=True, timesteps_to_save=None):
226
+ timesteps = reversed(self.scheduler.timesteps)
227
+ timesteps_to_save = timesteps_to_save if timesteps_to_save is not None else timesteps
228
+
229
+ return_inverted_latents = self.config["frames"] is not None
230
+ for i, t in enumerate(tqdm(timesteps)):
231
+ for b in range(0, latent_frames.shape[0], int(batch_size)):
232
+ x_batch = latent_frames[b:b + batch_size]
233
+ model_input = x_batch
234
+ cond_batch = cond.repeat(x_batch.shape[0], 1, 1)
235
+ if self.sd_version == 'depth':
236
+ depth_maps = torch.cat([self.depth_maps[b: b + batch_size]])
237
+ model_input = torch.cat([x_batch, depth_maps],dim=1)
238
+
239
+ alpha_prod_t = self.scheduler.alphas_cumprod[t]
240
+ alpha_prod_t_prev = (
241
+ self.scheduler.alphas_cumprod[timesteps[i - 1]]
242
+ if i > 0 else self.scheduler.final_alpha_cumprod
243
+ )
244
+
245
+ mu = alpha_prod_t ** 0.5
246
+ mu_prev = alpha_prod_t_prev ** 0.5
247
+ sigma = (1 - alpha_prod_t) ** 0.5
248
+ sigma_prev = (1 - alpha_prod_t_prev) ** 0.5
249
+
250
+ eps = self.unet(model_input, t, encoder_hidden_states=cond_batch).sample if self.sd_version != 'ControlNet' \
251
+ else self.controlnet_pred(x_batch, t, cond_batch, torch.cat([self.canny_cond[b: b + batch_size]]))
252
+ pred_x0 = (x_batch - sigma_prev * eps) / mu_prev
253
+ latent_frames[b:b + batch_size] = mu * pred_x0 + sigma * eps
254
+
255
+ if return_inverted_latents and t in timesteps_to_save:
256
+ self.total_inverted_latents[f'noisy_latents_{t}'] = latent_frames.clone()
257
+
258
+ if save_latents and t in timesteps_to_save:
259
+ torch.save(latent_frames, os.path.join(save_path, 'latents', f'noisy_latents_{t}.pt'))
260
+
261
+ if save_latents:
262
+ torch.save(latent_frames, os.path.join(save_path, 'latents', f'noisy_latents_{t}.pt'))
263
+ if return_inverted_latents:
264
+ self.total_inverted_latents[f'noisy_latents_{t}'] = latent_frames.clone()
265
+
266
+ return latent_frames
267
+
268
+ @torch.no_grad()
269
+ def ddim_sample(self, x, cond, batch_size):
270
+ timesteps = self.scheduler.timesteps
271
+ for i, t in enumerate(tqdm(timesteps)):
272
+ for b in range(0, x.shape[0], batch_size):
273
+ x_batch = x[b:b + batch_size]
274
+ model_input = x_batch
275
+ cond_batch = cond.repeat(x_batch.shape[0], 1, 1)
276
+
277
+ if self.sd_version == 'depth':
278
+ depth_maps = torch.cat([self.depth_maps[b: b + batch_size]])
279
+ model_input = torch.cat([x_batch, depth_maps],dim=1)
280
+
281
+ alpha_prod_t = self.scheduler.alphas_cumprod[t]
282
+ alpha_prod_t_prev = (
283
+ self.scheduler.alphas_cumprod[timesteps[i + 1]]
284
+ if i < len(timesteps) - 1
285
+ else self.scheduler.final_alpha_cumprod
286
+ )
287
+ mu = alpha_prod_t ** 0.5
288
+ sigma = (1 - alpha_prod_t) ** 0.5
289
+ mu_prev = alpha_prod_t_prev ** 0.5
290
+ sigma_prev = (1 - alpha_prod_t_prev) ** 0.5
291
+
292
+ eps = self.unet(model_input, t, encoder_hidden_states=cond_batch).sample if self.sd_version != 'ControlNet' \
293
+ else self.controlnet_pred(x_batch, t, cond_batch, torch.cat([self.canny_cond[b: b + batch_size]]))
294
+
295
+ pred_x0 = (x_batch - sigma * eps) / mu
296
+ x[b:b + batch_size] = mu_prev * pred_x0 + sigma_prev * eps
297
+ return x
298
+
299
+ @torch.no_grad()
300
+ def extract_latents(self,
301
+ num_steps,
302
+ save_path,
303
+ batch_size,
304
+ timesteps_to_save,
305
+ inversion_prompt='',
306
+ reconstruct=False):
307
+ self.scheduler.set_timesteps(num_steps)
308
+ cond = self.get_text_embeds(inversion_prompt, "")[1].unsqueeze(0)
309
+ latent_frames = self.latents
310
+ print("latent_frames", latent_frames.shape)
311
+
312
+ inverted_x= self.ddim_inversion(cond,
313
+ latent_frames,
314
+ save_path,
315
+ batch_size=batch_size,
316
+ save_latents=True if save_path else False,
317
+ timesteps_to_save=timesteps_to_save)
318
+
319
+
320
+
321
+ # print("total_inverted_latents", len(total_inverted_latents.keys()))
322
+
323
+ if reconstruct:
324
+ latent_reconstruction = self.ddim_sample(inverted_x, cond, batch_size=batch_size)
325
+
326
+ rgb_reconstruction = self.decode_latents(latent_reconstruction)
327
+ return self.frames, self.latents, self.total_inverted_latents, rgb_reconstruction
328
+
329
+ return self.frames, self.latents, self.total_inverted_latents, None
330
+
331
+
332
+ def prep(opt):
333
+ # timesteps to save
334
+ if opt["sd_version"] == '2.1':
335
+ model_key = "stabilityai/stable-diffusion-2-1-base"
336
+ elif opt["sd_version"] == '2.0':
337
+ model_key = "stabilityai/stable-diffusion-2-base"
338
+ elif opt["sd_version"] == '1.5' or opt["sd_version"] == 'ControlNet':
339
+ model_key = "runwayml/stable-diffusion-v1-5"
340
+ elif opt["sd_version"] == 'depth':
341
+ model_key = "stabilityai/stable-diffusion-2-depth"
342
+ toy_scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")
343
+ toy_scheduler.set_timesteps(opt["save_steps"])
344
+ timesteps_to_save, num_inference_steps = get_timesteps(toy_scheduler, num_inference_steps=opt["save_steps"],
345
+ strength=1.0,
346
+ device=device)
347
+
348
+ seed_everything(opt["seed"])
349
+ if not opt["frames"]: # original non demo setting
350
+ save_path = os.path.join(opt["save_dir"],
351
+ f'sd_{opt["sd_version"]}',
352
+ Path(opt["data_path"]).stem,
353
+ f'steps_{opt["steps"]}',
354
+ f'nframes_{opt["n_frames"]}')
355
+ os.makedirs(os.path.join(save_path, f'latents'), exist_ok=True)
356
+ add_dict_to_yaml_file(os.path.join(opt["save_dir"], 'inversion_prompts.yaml'), Path(opt["data_path"]).stem, opt["inversion_prompt"])
357
+ # save inversion prompt in a txt file
358
+ with open(os.path.join(save_path, 'inversion_prompt.txt'), 'w') as f:
359
+ f.write(opt["inversion_prompt"])
360
+ else:
361
+ save_path = None
362
+
363
+ model = Preprocess(device, opt)
364
+
365
+ frames, latents, total_inverted_latents, rgb_reconstruction = model.extract_latents(
366
+ num_steps=model.config["steps"],
367
+ save_path=save_path,
368
+ batch_size=model.config["batch_size"],
369
+ timesteps_to_save=timesteps_to_save,
370
+ inversion_prompt=model.config["inversion_prompt"],
371
+ )
372
+
373
+
374
+ return frames, latents, total_inverted_latents, rgb_reconstruction
375
+
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pillow
2
+ diffusers
3
+ ftfy
4
+ transformers
5
+ opencv-python
6
+ tqdm
7
+ numpy
8
+ pyyaml
9
+ accelerate
10
+ xformers
11
+ tensorboard
12
+ kornia
13
+ av
14
+ torchvision==0.15.2
style.css ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ This CSS file is modified from:
3
+ https://huggingface.co/spaces/DeepFloyd/IF/blob/main/style.css
4
+ */
5
+
6
+ h1 {
7
+ text-align: center;
8
+ }
9
+
10
+ .gradio-container {
11
+ font-family: 'IBM Plex Sans', sans-serif;
12
+ }
13
+
14
+ .gr-button {
15
+ color: white;
16
+ border-color: black;
17
+ background: black;
18
+ }
19
+
20
+ input[type='range'] {
21
+ accent-color: black;
22
+ }
23
+
24
+ .dark input[type='range'] {
25
+ accent-color: #dfdfdf;
26
+ }
27
+
28
+ .container {
29
+ max-width: 730px;
30
+ margin: auto;
31
+ }
32
+
33
+ .gr-button:focus {
34
+ border-color: rgb(147 197 253 / var(--tw-border-opacity));
35
+ outline: none;
36
+ box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000);
37
+ --tw-border-opacity: 1;
38
+ --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color);
39
+ --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color);
40
+ --tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity));
41
+ --tw-ring-opacity: .5;
42
+ }
43
+
44
+ .gr-form {
45
+ flex: 1 1 50%;
46
+ border-top-right-radius: 0;
47
+ border-bottom-right-radius: 0;
48
+ }
49
+
50
+ #prompt-container {
51
+ gap: 0;
52
+ }
53
+
54
+ #prompt-text-input,
55
+ #negative-prompt-text-input {
56
+ padding: .45rem 0.625rem
57
+ }
58
+
59
+ /* #component-16 {
60
+ border-top-width: 1px !important;
61
+ margin-top: 1em
62
+ } */
63
+
64
+ .image_duplication {
65
+ position: absolute;
66
+ width: 100px;
67
+ left: 50px
68
+ }
69
+
70
+ #component-0 {
71
+ max-width: 730px;
72
+ margin: auto;
73
+ padding-top: 1.5rem;
74
+ }
75
+
76
+ #share-btn-container {
77
+ display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem; margin-left: auto;
78
+ }
79
+ #share-btn {
80
+ all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;
81
+ }
82
+ #share-btn * {
83
+ all: unset;
84
+ }
85
+ #share-btn-container div:nth-child(-n+2){
86
+ width: auto !important;
87
+ min-height: 0px !important;
88
+ }
89
+ #share-btn-container .wrap {
90
+ display: none !important;
91
+ }
tokenflow_pnp.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import numpy as np
4
+ import cv2
5
+ from pathlib import Path
6
+ import torch
7
+ import torch.nn as nn
8
+ import torchvision.transforms as T
9
+ import argparse
10
+ from PIL import Image
11
+ import yaml
12
+ from tqdm import tqdm
13
+ from transformers import logging
14
+ from diffusers import DDIMScheduler, StableDiffusionPipeline
15
+
16
+ from tokenflow_utils import *
17
+ from utils import save_video, seed_everything
18
+
19
+ # suppress partial model loading warning
20
+ logging.set_verbosity_error()
21
+
22
+ VAE_BATCH_SIZE = 10
23
+
24
+
25
+ class TokenFlow(nn.Module):
26
+ def __init__(self, config,
27
+ pipe,
28
+ frames=None,
29
+ # latents = None,
30
+ inverted_latents = None):
31
+ super().__init__()
32
+ self.config = config
33
+ self.device = config["device"]
34
+
35
+ sd_version = config["sd_version"]
36
+ self.sd_version = sd_version
37
+ if sd_version == '2.1':
38
+ model_key = "stabilityai/stable-diffusion-2-1-base"
39
+ elif sd_version == '2.0':
40
+ model_key = "stabilityai/stable-diffusion-2-base"
41
+ elif sd_version == '1.5':
42
+ model_key = "runwayml/stable-diffusion-v1-5"
43
+ elif sd_version == 'depth':
44
+ model_key = "stabilityai/stable-diffusion-2-depth"
45
+ else:
46
+ raise ValueError(f'Stable-diffusion version {sd_version} not supported.')
47
+
48
+ # Create SD models
49
+ print('Loading SD model')
50
+
51
+ # pipe = StableDiffusionPipeline.from_pretrained(model_key, torch_dtype=torch.float16).to("cuda")
52
+ # pipe.enable_xformers_memory_efficient_attention()
53
+
54
+ self.vae = pipe.vae
55
+ self.tokenizer = pipe.tokenizer
56
+ self.text_encoder = pipe.text_encoder
57
+ self.unet = pipe.unet
58
+
59
+ self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")
60
+ self.scheduler.set_timesteps(config["n_timesteps"], device=self.device)
61
+ print('SD model loaded')
62
+
63
+ # data
64
+ self.frames, self.inverted_latents = frames, inverted_latents
65
+ self.latents_path = self.get_latents_path()
66
+
67
+ # load frames
68
+ self.paths, self.frames, self.latents, self.eps = self.get_data()
69
+
70
+ if self.sd_version == 'depth':
71
+ self.depth_maps = self.prepare_depth_maps()
72
+
73
+ self.text_embeds = self.get_text_embeds(config["prompt"], config["negative_prompt"])
74
+ # pnp_inversion_prompt = self.get_pnp_inversion_prompt()
75
+ self.pnp_guidance_embeds = self.get_text_embeds(config["pnp_inversion_prompt"], config["pnp_inversion_prompt"]).chunk(2)[0]
76
+
77
+ @torch.no_grad()
78
+ def prepare_depth_maps(self, model_type='DPT_Large', device='cuda'):
79
+ depth_maps = []
80
+ midas = torch.hub.load("intel-isl/MiDaS", model_type)
81
+ midas.to(device)
82
+ midas.eval()
83
+
84
+ midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
85
+
86
+ if model_type == "DPT_Large" or model_type == "DPT_Hybrid":
87
+ transform = midas_transforms.dpt_transform
88
+ else:
89
+ transform = midas_transforms.small_transform
90
+
91
+ for i in range(len(self.paths)):
92
+ img = cv2.imread(self.paths[i])
93
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
94
+
95
+ latent_h = img.shape[0] // 8
96
+ latent_w = img.shape[1] // 8
97
+
98
+ input_batch = transform(img).to(device)
99
+ prediction = midas(input_batch)
100
+
101
+ depth_map = torch.nn.functional.interpolate(
102
+ prediction.unsqueeze(1),
103
+ size=(latent_h, latent_w),
104
+ mode="bicubic",
105
+ align_corners=False,
106
+ )
107
+ depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
108
+ depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
109
+ depth_map = 2.0 * (depth_map - depth_min) / (depth_max - depth_min) - 1.0
110
+ depth_maps.append(depth_map)
111
+
112
+ return torch.cat(depth_maps).to(torch.float16).to(self.device)
113
+
114
+ def get_pnp_inversion_prompt(self):
115
+ inv_prompts_path = os.path.join(str(Path(self.latents_path).parent), 'inversion_prompt.txt')
116
+ # read inversion prompt
117
+ with open(inv_prompts_path, 'r') as f:
118
+ inv_prompt = f.read()
119
+ return inv_prompt
120
+
121
+ def get_latents_path(self):
122
+ read_from_files = self.frames is None
123
+ # read_from_files = True
124
+ if read_from_files:
125
+ latents_path = os.path.join(self.config["latents_path"], f'sd_{self.config["sd_version"]}',
126
+ Path(self.config["data_path"]).stem, f'steps_{self.config["n_inversion_steps"]}')
127
+ latents_path = [x for x in glob.glob(f'{latents_path}/*') if '.' not in Path(x).name]
128
+ n_frames = [int([x for x in latents_path[i].split('/') if 'nframes' in x][0].split('_')[1]) for i in range(len(latents_path))]
129
+ print("n_frames", n_frames)
130
+ latents_path = latents_path[np.argmax(n_frames)]
131
+ print("latents_path", latents_path)
132
+ self.config["n_frames"] = min(max(n_frames), self.config["n_frames"])
133
+
134
+ else:
135
+ n_frames = self.frames.shape[0]
136
+ self.config["n_frames"] = min(n_frames, self.config["n_frames"])
137
+
138
+ if self.config["n_frames"] % self.config["batch_size"] != 0:
139
+ # make n_frames divisible by batch_size
140
+ self.config["n_frames"] = self.config["n_frames"] - (self.config["n_frames"] % self.config["batch_size"])
141
+ print("Number of frames: ", self.config["n_frames"])
142
+ if read_from_files:
143
+ print("YOOOOOOO", os.path.join(latents_path, 'latents'))
144
+ return os.path.join(latents_path, 'latents')
145
+ else:
146
+ return None
147
+
148
+ @torch.no_grad()
149
+ def get_text_embeds(self, prompt, negative_prompt, batch_size=1):
150
+ # Tokenize text and get embeddings
151
+ text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
152
+ truncation=True, return_tensors='pt')
153
+ text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
154
+
155
+ # Do the same for unconditional embeddings
156
+ uncond_input = self.tokenizer(negative_prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
157
+ return_tensors='pt')
158
+
159
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
160
+
161
+ # Cat for final embeddings
162
+ text_embeddings = torch.cat([uncond_embeddings] * batch_size + [text_embeddings] * batch_size)
163
+ return text_embeddings
164
+
165
+ @torch.no_grad()
166
+ def encode_imgs(self, imgs, batch_size=VAE_BATCH_SIZE, deterministic=False):
167
+ imgs = 2 * imgs - 1
168
+ latents = []
169
+ for i in range(0, len(imgs), batch_size):
170
+ posterior = self.vae.encode(imgs[i:i + batch_size]).latent_dist
171
+ latent = posterior.mean if deterministic else posterior.sample()
172
+ latents.append(latent * 0.18215)
173
+ latents = torch.cat(latents)
174
+ return latents
175
+
176
+ @torch.no_grad()
177
+ def decode_latents(self, latents, batch_size=VAE_BATCH_SIZE):
178
+ latents = 1 / 0.18215 * latents
179
+ imgs = []
180
+ for i in range(0, len(latents), batch_size):
181
+ imgs.append(self.vae.decode(latents[i:i + batch_size]).sample)
182
+ imgs = torch.cat(imgs)
183
+ imgs = (imgs / 2 + 0.5).clamp(0, 1)
184
+ return imgs
185
+
186
+
187
+ def get_data(self):
188
+ read_from_files = self.frames is None
189
+ # read_from_files = True
190
+ if read_from_files:
191
+ # load frames
192
+ paths = [os.path.join(self.config["data_path"], "%05d.jpg" % idx) for idx in
193
+ range(self.config["n_frames"])]
194
+ if not os.path.exists(paths[0]):
195
+ paths = [os.path.join(self.config["data_path"], "%05d.png" % idx) for idx in
196
+ range(self.config["n_frames"])]
197
+ frames = [Image.open(paths[idx]).convert('RGB') for idx in range(self.config["n_frames"])]
198
+ if frames[0].size[0] == frames[0].size[1]:
199
+ frames = [frame.resize((512, 512), resample=Image.Resampling.LANCZOS) for frame in frames]
200
+ frames = torch.stack([T.ToTensor()(frame) for frame in frames]).to(torch.float16).to(self.device)
201
+ save_video(frames, f'{self.config["output_path"]}/input_fps10.mp4', fps=10)
202
+ save_video(frames, f'{self.config["output_path"]}/input_fps20.mp4', fps=20)
203
+ save_video(frames, f'{self.config["output_path"]}/input_fps30.mp4', fps=30)
204
+ else:
205
+ frames = self.frames
206
+ # encode to latents
207
+ latents = self.encode_imgs(frames, deterministic=True).to(torch.float16).to(self.device)
208
+ # get noise
209
+ eps = self.get_ddim_eps(latents, range(self.config["n_frames"])).to(torch.float16).to(self.device)
210
+ if not read_from_files:
211
+ return None, frames, latents, eps
212
+ return paths, frames, latents, eps
213
+
214
+ def get_ddim_eps(self, latent, indices):
215
+ read_from_files = self.inverted_latents is None
216
+ # read_from_files = True
217
+ if read_from_files:
218
+ noisest = max([int(x.split('_')[-1].split('.')[0]) for x in glob.glob(os.path.join(self.latents_path, f'noisy_latents_*.pt'))])
219
+ print("noisets:", noisest)
220
+ print("indecies:", indices)
221
+ latents_path = os.path.join(self.latents_path, f'noisy_latents_{noisest}.pt')
222
+ noisy_latent = torch.load(latents_path)[indices].to(self.device)
223
+
224
+ # path = os.path.join('test_latents', f'noisy_latents_{noisest}.pt')
225
+ # f_noisy_latent = torch.load(path)[indices].to(self.device)
226
+ # print(f_noisy_latent==noisy_latent)
227
+ else:
228
+ noisest = max([int(key.split("_")[-1]) for key in self.inverted_latents.keys()])
229
+ print("noisets:", noisest)
230
+ print("indecies:", indices)
231
+ noisy_latent = self.inverted_latents[f'noisy_latents_{noisest}'][indices]
232
+
233
+ alpha_prod_T = self.scheduler.alphas_cumprod[noisest]
234
+ mu_T, sigma_T = alpha_prod_T ** 0.5, (1 - alpha_prod_T) ** 0.5
235
+ eps = (noisy_latent - mu_T * latent) / sigma_T
236
+ return eps
237
+
238
+ @torch.no_grad()
239
+ def denoise_step(self, x, t, indices):
240
+ # register the time step and features in pnp injection modules
241
+ read_files = self.inverted_latents is None
242
+
243
+ if read_files:
244
+ source_latents = load_source_latents_t(t, self.latents_path)[indices]
245
+
246
+ else:
247
+ source_latents = self.inverted_latents[f'noisy_latents_{t}'][indices]
248
+
249
+ latent_model_input = torch.cat([source_latents] + ([x] * 2))
250
+ if self.sd_version == 'depth':
251
+ latent_model_input = torch.cat([latent_model_input, torch.cat([self.depth_maps[indices]] * 3)], dim=1)
252
+
253
+ register_time(self, t.item())
254
+
255
+ # compute text embeddings
256
+ text_embed_input = torch.cat([self.pnp_guidance_embeds.repeat(len(indices), 1, 1),
257
+ torch.repeat_interleave(self.text_embeds, len(indices), dim=0)])
258
+
259
+ # apply the denoising network
260
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embed_input)['sample']
261
+
262
+ # perform guidance
263
+ _, noise_pred_uncond, noise_pred_cond = noise_pred.chunk(3)
264
+ noise_pred = noise_pred_uncond + self.config["guidance_scale"] * (noise_pred_cond - noise_pred_uncond)
265
+
266
+ # compute the denoising step with the reference model
267
+ denoised_latent = self.scheduler.step(noise_pred, t, x)['prev_sample']
268
+ return denoised_latent
269
+
270
+ @torch.autocast(dtype=torch.float16, device_type='cuda')
271
+ def batched_denoise_step(self, x, t, indices):
272
+ batch_size = self.config["batch_size"]
273
+ denoised_latents = []
274
+ pivotal_idx = torch.randint(batch_size, (len(x)//batch_size,)) + torch.arange(0,len(x),batch_size)
275
+
276
+ register_pivotal(self, True)
277
+ self.denoise_step(x[pivotal_idx], t, indices[pivotal_idx])
278
+ register_pivotal(self, False)
279
+ for i, b in enumerate(range(0, len(x), batch_size)):
280
+ register_batch_idx(self, i)
281
+ denoised_latents.append(self.denoise_step(x[b:b + batch_size], t, indices[b:b + batch_size]))
282
+ denoised_latents = torch.cat(denoised_latents)
283
+ return denoised_latents
284
+
285
+ def init_method(self, conv_injection_t, qk_injection_t):
286
+ self.qk_injection_timesteps = self.scheduler.timesteps[:qk_injection_t] if qk_injection_t >= 0 else []
287
+ self.conv_injection_timesteps = self.scheduler.timesteps[:conv_injection_t] if conv_injection_t >= 0 else []
288
+ register_extended_attention_pnp(self, self.qk_injection_timesteps)
289
+ register_conv_injection(self, self.conv_injection_timesteps)
290
+ set_tokenflow(self.unet)
291
+
292
+ def save_vae_recon(self):
293
+ os.makedirs(f'{self.config["output_path"]}/vae_recon', exist_ok=True)
294
+ decoded = self.decode_latents(self.latents)
295
+ for i in range(len(decoded)):
296
+ T.ToPILImage()(decoded[i]).save(f'{self.config["output_path"]}/vae_recon/%05d.png' % i)
297
+ save_video(decoded, f'{self.config["output_path"]}/vae_recon_10.mp4', fps=10)
298
+ save_video(decoded, f'{self.config["output_path"]}/vae_recon_20.mp4', fps=20)
299
+ save_video(decoded, f'{self.config["output_path"]}/vae_recon_30.mp4', fps=30)
300
+
301
+ def edit_video(self):
302
+ save_files = self.inverted_latents is None # if we're in the original non-demo setting
303
+ if save_files:
304
+ os.makedirs(f'{self.config["output_path"]}/img_ode', exist_ok=True)
305
+ self.save_vae_recon()
306
+ # self.save_vae_recon()
307
+ pnp_f_t = int(self.config["n_timesteps"] * self.config["pnp_f_t"])
308
+ pnp_attn_t = int(self.config["n_timesteps"] * self.config["pnp_attn_t"])
309
+
310
+ self.init_method(conv_injection_t=pnp_f_t, qk_injection_t=pnp_attn_t)
311
+
312
+ noisy_latents = self.scheduler.add_noise(self.latents, self.eps, self.scheduler.timesteps[0])
313
+ edited_frames = self.sample_loop(noisy_latents, torch.arange(self.config["n_frames"]))
314
+
315
+ if save_files:
316
+ save_video(edited_frames, f'{self.config["output_path"]}/tokenflow_PnP_fps_10.mp4')
317
+ save_video(edited_frames, f'{self.config["output_path"]}/tokenflow_PnP_fps_20.mp4', fps=20)
318
+ save_video(edited_frames, f'{self.config["output_path"]}/tokenflow_PnP_fps_30.mp4', fps=30)
319
+ print('Done!')
320
+ else:
321
+ return edited_frames
322
+
323
+ def sample_loop(self, x, indices):
324
+ save_files = self.inverted_latents is None # if we're in the original non-demo setting
325
+ # save_files = True
326
+ if save_files:
327
+ os.makedirs(f'{self.config["output_path"]}/img_ode', exist_ok=True)
328
+ for i, t in enumerate(tqdm(self.scheduler.timesteps, desc="Sampling")):
329
+ x = self.batched_denoise_step(x, t, indices)
330
+
331
+ decoded_latents = self.decode_latents(x)
332
+ if save_files:
333
+ for i in range(len(decoded_latents)):
334
+ T.ToPILImage()(decoded_latents[i]).save(f'{self.config["output_path"]}/img_ode/%05d.png' % i)
335
+
336
+ return decoded_latents
337
+
338
+
339
+ # def run(config):
340
+ # seed_everything(config["seed"])
341
+ # print(config)
342
+ # editor = TokenFlow(config)
343
+ # editor.edit_video()
344
+
345
+
346
+ # if __name__ == '__main__':
347
+ # parser = argparse.ArgumentParser()
348
+ # parser.add_argument('--config_path', type=str, default='configs/config_pnp.yaml')
349
+ # opt = parser.parse_args()
350
+ # with open(opt.config_path, "r") as f:
351
+ # config = yaml.safe_load(f)
352
+ # config["output_path"] = os.path.join(config["output_path"] + f'_pnp_SD_{config["sd_version"]}',
353
+ # Path(config["data_path"]).stem,
354
+ # config["prompt"][:240],
355
+ # f'attn_{config["pnp_attn_t"]}_f_{config["pnp_f_t"]}',
356
+ # f'batch_size_{str(config["batch_size"])}',
357
+ # str(config["n_timesteps"]),
358
+ # )
359
+ # os.makedirs(config["output_path"], exist_ok=True)
360
+ # print(config["data_path"])
361
+ # assert os.path.exists(config["data_path"]), "Data path does not exist"
362
+ # with open(os.path.join(config["output_path"], "config.yaml"), "w") as f:
363
+ # yaml.dump(config, f)
364
+ # run(config)
tokenflow_utils.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Type
2
+ import torch
3
+ import os
4
+
5
+ from utils import isinstance_str, batch_cosine_sim
6
+
7
+ def register_pivotal(diffusion_model, is_pivotal):
8
+ for _, module in diffusion_model.named_modules():
9
+ # If for some reason this has a different name, create an issue and I'll fix it
10
+ if isinstance_str(module, "BasicTransformerBlock"):
11
+ setattr(module, "pivotal_pass", is_pivotal)
12
+
13
+ def register_batch_idx(diffusion_model, batch_idx):
14
+ for _, module in diffusion_model.named_modules():
15
+ # If for some reason this has a different name, create an issue and I'll fix it
16
+ if isinstance_str(module, "BasicTransformerBlock"):
17
+ setattr(module, "batch_idx", batch_idx)
18
+
19
+
20
+ def register_time(model, t):
21
+ conv_module = model.unet.up_blocks[1].resnets[1]
22
+ setattr(conv_module, 't', t)
23
+ down_res_dict = {0: [0, 1], 1: [0, 1], 2: [0, 1]}
24
+ up_res_dict = {1: [0, 1, 2], 2: [0, 1, 2], 3: [0, 1, 2]}
25
+ for res in up_res_dict:
26
+ for block in up_res_dict[res]:
27
+ module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
28
+ setattr(module, 't', t)
29
+ module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn2
30
+ setattr(module, 't', t)
31
+ for res in down_res_dict:
32
+ for block in down_res_dict[res]:
33
+ module = model.unet.down_blocks[res].attentions[block].transformer_blocks[0].attn1
34
+ setattr(module, 't', t)
35
+ module = model.unet.down_blocks[res].attentions[block].transformer_blocks[0].attn2
36
+ setattr(module, 't', t)
37
+ module = model.unet.mid_block.attentions[0].transformer_blocks[0].attn1
38
+ setattr(module, 't', t)
39
+ module = model.unet.mid_block.attentions[0].transformer_blocks[0].attn2
40
+ setattr(module, 't', t)
41
+
42
+
43
+ def load_source_latents_t(t, latents_path):
44
+ latents_t_path = os.path.join(latents_path, f'noisy_latents_{t}.pt')
45
+ assert os.path.exists(latents_t_path), f'Missing latents at t {t} path {latents_t_path}'
46
+ latents = torch.load(latents_t_path)
47
+ return latents
48
+
49
+ def register_conv_injection(model, injection_schedule):
50
+ def conv_forward(self):
51
+ def forward(input_tensor, temb):
52
+ hidden_states = input_tensor
53
+
54
+ hidden_states = self.norm1(hidden_states)
55
+ hidden_states = self.nonlinearity(hidden_states)
56
+
57
+ if self.upsample is not None:
58
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
59
+ if hidden_states.shape[0] >= 64:
60
+ input_tensor = input_tensor.contiguous()
61
+ hidden_states = hidden_states.contiguous()
62
+ input_tensor = self.upsample(input_tensor)
63
+ hidden_states = self.upsample(hidden_states)
64
+ elif self.downsample is not None:
65
+ input_tensor = self.downsample(input_tensor)
66
+ hidden_states = self.downsample(hidden_states)
67
+
68
+ hidden_states = self.conv1(hidden_states)
69
+
70
+ if temb is not None:
71
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
72
+
73
+ if temb is not None and self.time_embedding_norm == "default":
74
+ hidden_states = hidden_states + temb
75
+
76
+ hidden_states = self.norm2(hidden_states)
77
+
78
+ if temb is not None and self.time_embedding_norm == "scale_shift":
79
+ scale, shift = torch.chunk(temb, 2, dim=1)
80
+ hidden_states = hidden_states * (1 + scale) + shift
81
+
82
+ hidden_states = self.nonlinearity(hidden_states)
83
+
84
+ hidden_states = self.dropout(hidden_states)
85
+ hidden_states = self.conv2(hidden_states)
86
+ if self.injection_schedule is not None and (self.t in self.injection_schedule or self.t == 1000):
87
+ source_batch_size = int(hidden_states.shape[0] // 3)
88
+ # inject unconditional
89
+ hidden_states[source_batch_size:2 * source_batch_size] = hidden_states[:source_batch_size]
90
+ # inject conditional
91
+ hidden_states[2 * source_batch_size:] = hidden_states[:source_batch_size]
92
+
93
+ if self.conv_shortcut is not None:
94
+ input_tensor = self.conv_shortcut(input_tensor)
95
+
96
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
97
+
98
+ return output_tensor
99
+
100
+ return forward
101
+
102
+ conv_module = model.unet.up_blocks[1].resnets[1]
103
+ conv_module.forward = conv_forward(conv_module)
104
+ setattr(conv_module, 'injection_schedule', injection_schedule)
105
+
106
+ def register_extended_attention_pnp(model, injection_schedule):
107
+ def sa_forward(self):
108
+ to_out = self.to_out
109
+ if type(to_out) is torch.nn.modules.container.ModuleList:
110
+ to_out = self.to_out[0]
111
+ else:
112
+ to_out = self.to_out
113
+
114
+ def forward(x, encoder_hidden_states=None):
115
+ batch_size, sequence_length, dim = x.shape
116
+ h = self.heads
117
+ n_frames = batch_size // 3
118
+ is_cross = encoder_hidden_states is not None
119
+ encoder_hidden_states = encoder_hidden_states if is_cross else x
120
+ q = self.to_q(x)
121
+ k = self.to_k(encoder_hidden_states)
122
+ v = self.to_v(encoder_hidden_states)
123
+
124
+ if self.injection_schedule is not None and (self.t in self.injection_schedule or self.t == 1000):
125
+ # inject unconditional
126
+ q[n_frames:2 * n_frames] = q[:n_frames]
127
+ k[n_frames:2 * n_frames] = k[:n_frames]
128
+ # inject conditional
129
+ q[2 * n_frames:] = q[:n_frames]
130
+ k[2 * n_frames:] = k[:n_frames]
131
+
132
+ k_source = k[:n_frames]
133
+ k_uncond = k[n_frames:2 * n_frames].reshape(1, n_frames * sequence_length, -1).repeat(n_frames, 1, 1)
134
+ k_cond = k[2 * n_frames:].reshape(1, n_frames * sequence_length, -1).repeat(n_frames, 1, 1)
135
+
136
+ v_source = v[:n_frames]
137
+ v_uncond = v[n_frames:2 * n_frames].reshape(1, n_frames * sequence_length, -1).repeat(n_frames, 1, 1)
138
+ v_cond = v[2 * n_frames:].reshape(1, n_frames * sequence_length, -1).repeat(n_frames, 1, 1)
139
+
140
+ q_source = self.head_to_batch_dim(q[:n_frames])
141
+ q_uncond = self.head_to_batch_dim(q[n_frames:2 * n_frames])
142
+ q_cond = self.head_to_batch_dim(q[2 * n_frames:])
143
+ k_source = self.head_to_batch_dim(k_source)
144
+ k_uncond = self.head_to_batch_dim(k_uncond)
145
+ k_cond = self.head_to_batch_dim(k_cond)
146
+ v_source = self.head_to_batch_dim(v_source)
147
+ v_uncond = self.head_to_batch_dim(v_uncond)
148
+ v_cond = self.head_to_batch_dim(v_cond)
149
+
150
+
151
+ q_src = q_source.view(n_frames, h, sequence_length, dim // h)
152
+ k_src = k_source.view(n_frames, h, sequence_length, dim // h)
153
+ v_src = v_source.view(n_frames, h, sequence_length, dim // h)
154
+ q_uncond = q_uncond.view(n_frames, h, sequence_length, dim // h)
155
+ k_uncond = k_uncond.view(n_frames, h, sequence_length * n_frames, dim // h)
156
+ v_uncond = v_uncond.view(n_frames, h, sequence_length * n_frames, dim // h)
157
+ q_cond = q_cond.view(n_frames, h, sequence_length, dim // h)
158
+ k_cond = k_cond.view(n_frames, h, sequence_length * n_frames, dim // h)
159
+ v_cond = v_cond.view(n_frames, h, sequence_length * n_frames, dim // h)
160
+
161
+ out_source_all = []
162
+ out_uncond_all = []
163
+ out_cond_all = []
164
+
165
+ single_batch = n_frames<=12
166
+ b = n_frames if single_batch else 1
167
+
168
+ for frame in range(0, n_frames, b):
169
+ out_source = []
170
+ out_uncond = []
171
+ out_cond = []
172
+ for j in range(h):
173
+ sim_source_b = torch.bmm(q_src[frame: frame+ b, j], k_src[frame: frame+ b, j].transpose(-1, -2)) * self.scale
174
+ sim_uncond_b = torch.bmm(q_uncond[frame: frame+ b, j], k_uncond[frame: frame+ b, j].transpose(-1, -2)) * self.scale
175
+ sim_cond = torch.bmm(q_cond[frame: frame+ b, j], k_cond[frame: frame+ b, j].transpose(-1, -2)) * self.scale
176
+
177
+ out_source.append(torch.bmm(sim_source_b.softmax(dim=-1), v_src[frame: frame+ b, j]))
178
+ out_uncond.append(torch.bmm(sim_uncond_b.softmax(dim=-1), v_uncond[frame: frame+ b, j]))
179
+ out_cond.append(torch.bmm(sim_cond.softmax(dim=-1), v_cond[frame: frame+ b, j]))
180
+
181
+ out_source = torch.cat(out_source, dim=0)
182
+ out_uncond = torch.cat(out_uncond, dim=0)
183
+ out_cond = torch.cat(out_cond, dim=0)
184
+ if single_batch:
185
+ out_source = out_source.view(h, n_frames,sequence_length, dim // h).permute(1, 0, 2, 3).reshape(h * n_frames, sequence_length, -1)
186
+ out_uncond = out_uncond.view(h, n_frames,sequence_length, dim // h).permute(1, 0, 2, 3).reshape(h * n_frames, sequence_length, -1)
187
+ out_cond = out_cond.view(h, n_frames,sequence_length, dim // h).permute(1, 0, 2, 3).reshape(h * n_frames, sequence_length, -1)
188
+ out_source_all.append(out_source)
189
+ out_uncond_all.append(out_uncond)
190
+ out_cond_all.append(out_cond)
191
+
192
+ out_source = torch.cat(out_source_all, dim=0)
193
+ out_uncond = torch.cat(out_uncond_all, dim=0)
194
+ out_cond = torch.cat(out_cond_all, dim=0)
195
+
196
+ out = torch.cat([out_source, out_uncond, out_cond], dim=0)
197
+ out = self.batch_to_head_dim(out)
198
+
199
+ return to_out(out)
200
+
201
+ return forward
202
+
203
+ for _, module in model.unet.named_modules():
204
+ if isinstance_str(module, "BasicTransformerBlock"):
205
+ module.attn1.forward = sa_forward(module.attn1)
206
+ setattr(module.attn1, 'injection_schedule', [])
207
+
208
+ res_dict = {1: [1, 2], 2: [0, 1, 2], 3: [0, 1, 2]}
209
+ # we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution
210
+ for res in res_dict:
211
+ for block in res_dict[res]:
212
+ module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
213
+ module.forward = sa_forward(module)
214
+ setattr(module, 'injection_schedule', injection_schedule)
215
+
216
+ def register_extended_attention(model):
217
+ def sa_forward(self):
218
+ to_out = self.to_out
219
+ if type(to_out) is torch.nn.modules.container.ModuleList:
220
+ to_out = self.to_out[0]
221
+ else:
222
+ to_out = self.to_out
223
+
224
+ def forward(x, encoder_hidden_states=None):
225
+ batch_size, sequence_length, dim = x.shape
226
+ h = self.heads
227
+ n_frames = batch_size // 3
228
+ is_cross = encoder_hidden_states is not None
229
+ encoder_hidden_states = encoder_hidden_states if is_cross else x
230
+ q = self.to_q(x)
231
+ k = self.to_k(encoder_hidden_states)
232
+ v = self.to_v(encoder_hidden_states)
233
+
234
+ k_source = k[:n_frames]
235
+ k_uncond = k[n_frames: 2*n_frames].reshape(1, n_frames * sequence_length, -1).repeat(n_frames, 1, 1)
236
+ k_cond = k[2*n_frames:].reshape(1, n_frames * sequence_length, -1).repeat(n_frames, 1, 1)
237
+ v_source = v[:n_frames]
238
+ v_uncond = v[n_frames:2*n_frames].reshape(1, n_frames * sequence_length, -1).repeat(n_frames, 1, 1)
239
+ v_cond = v[2*n_frames:].reshape(1, n_frames * sequence_length, -1).repeat(n_frames, 1, 1)
240
+
241
+ q_source = self.head_to_batch_dim(q[:n_frames])
242
+ q_uncond = self.head_to_batch_dim(q[n_frames: 2*n_frames])
243
+ q_cond = self.head_to_batch_dim(q[2 * n_frames:])
244
+ k_source = self.head_to_batch_dim(k_source)
245
+ k_uncond = self.head_to_batch_dim(k_uncond)
246
+ k_cond = self.head_to_batch_dim(k_cond)
247
+ v_source = self.head_to_batch_dim(v_source)
248
+ v_uncond = self.head_to_batch_dim(v_uncond)
249
+ v_cond = self.head_to_batch_dim(v_cond)
250
+
251
+ out_source = []
252
+ out_uncond = []
253
+ out_cond = []
254
+
255
+ q_src = q_source.view(n_frames, h, sequence_length, dim // h)
256
+ k_src = k_source.view(n_frames, h, sequence_length, dim // h)
257
+ v_src = v_source.view(n_frames, h, sequence_length, dim // h)
258
+ q_uncond = q_uncond.view(n_frames, h, sequence_length, dim // h)
259
+ k_uncond = k_uncond.view(n_frames, h, sequence_length * n_frames, dim // h)
260
+ v_uncond = v_uncond.view(n_frames, h, sequence_length * n_frames, dim // h)
261
+ q_cond = q_cond.view(n_frames, h, sequence_length, dim // h)
262
+ k_cond = k_cond.view(n_frames, h, sequence_length * n_frames, dim // h)
263
+ v_cond = v_cond.view(n_frames, h, sequence_length * n_frames, dim // h)
264
+
265
+ for j in range(h):
266
+ sim_source_b = torch.bmm(q_src[:, j], k_src[:, j].transpose(-1, -2)) * self.scale
267
+ sim_uncond_b = torch.bmm(q_uncond[:, j], k_uncond[:, j].transpose(-1, -2)) * self.scale
268
+ sim_cond = torch.bmm(q_cond[:, j], k_cond[:, j].transpose(-1, -2)) * self.scale
269
+
270
+ out_source.append(torch.bmm(sim_source_b.softmax(dim=-1), v_src[:, j]))
271
+ out_uncond.append(torch.bmm(sim_uncond_b.softmax(dim=-1), v_uncond[:, j]))
272
+ out_cond.append(torch.bmm(sim_cond.softmax(dim=-1), v_cond[:, j]))
273
+
274
+ out_source = torch.cat(out_source, dim=0).view(h, n_frames,sequence_length, dim // h).permute(1, 0, 2, 3).reshape(h * n_frames, sequence_length, -1)
275
+ out_uncond = torch.cat(out_uncond, dim=0).view(h, n_frames,sequence_length, dim // h).permute(1, 0, 2, 3).reshape(h * n_frames, sequence_length, -1)
276
+ out_cond = torch.cat(out_cond, dim=0).view(h, n_frames,sequence_length, dim // h).permute(1, 0, 2, 3).reshape(h * n_frames, sequence_length, -1)
277
+
278
+ out = torch.cat([out_source, out_uncond, out_cond], dim=0)
279
+ out = self.batch_to_head_dim(out)
280
+
281
+ return to_out(out)
282
+
283
+ return forward
284
+
285
+ for _, module in model.unet.named_modules():
286
+ if isinstance_str(module, "BasicTransformerBlock"):
287
+ module.attn1.forward = sa_forward(module.attn1)
288
+
289
+ res_dict = {1: [1, 2], 2: [0, 1, 2], 3: [0, 1, 2]}
290
+ # we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution
291
+ for res in res_dict:
292
+ for block in res_dict[res]:
293
+ module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
294
+ module.forward = sa_forward(module)
295
+
296
+ def make_tokenflow_attention_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]:
297
+
298
+ class TokenFlowBlock(block_class):
299
+
300
+ def forward(
301
+ self,
302
+ hidden_states,
303
+ attention_mask=None,
304
+ encoder_hidden_states=None,
305
+ encoder_attention_mask=None,
306
+ timestep=None,
307
+ cross_attention_kwargs=None,
308
+ class_labels=None,
309
+ ) -> torch.Tensor:
310
+
311
+ batch_size, sequence_length, dim = hidden_states.shape
312
+ n_frames = batch_size // 3
313
+ mid_idx = n_frames // 2
314
+ hidden_states = hidden_states.view(3, n_frames, sequence_length, dim)
315
+
316
+ if self.use_ada_layer_norm:
317
+ norm_hidden_states = self.norm1(hidden_states, timestep)
318
+ elif self.use_ada_layer_norm_zero:
319
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
320
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
321
+ )
322
+ else:
323
+ norm_hidden_states = self.norm1(hidden_states)
324
+
325
+ norm_hidden_states = norm_hidden_states.view(3, n_frames, sequence_length, dim)
326
+ if self.pivotal_pass:
327
+ self.pivot_hidden_states = norm_hidden_states
328
+ else:
329
+ idx1 = []
330
+ idx2 = []
331
+ batch_idxs = [self.batch_idx]
332
+ if self.batch_idx > 0:
333
+ batch_idxs.append(self.batch_idx - 1)
334
+
335
+ sim = batch_cosine_sim(norm_hidden_states[0].reshape(-1, dim),
336
+ self.pivot_hidden_states[0][batch_idxs].reshape(-1, dim))
337
+ if len(batch_idxs) == 2:
338
+ sim1, sim2 = sim.chunk(2, dim=1)
339
+ # sim: n_frames * seq_len, len(batch_idxs) * seq_len
340
+ idx1.append(sim1.argmax(dim=-1)) # n_frames * seq_len
341
+ idx2.append(sim2.argmax(dim=-1)) # n_frames * seq_len
342
+ else:
343
+ idx1.append(sim.argmax(dim=-1))
344
+ idx1 = torch.stack(idx1 * 3, dim=0) # 3, n_frames * seq_len
345
+ idx1 = idx1.squeeze(1)
346
+ if len(batch_idxs) == 2:
347
+ idx2 = torch.stack(idx2 * 3, dim=0) # 3, n_frames * seq_len
348
+ idx2 = idx2.squeeze(1)
349
+
350
+ # 1. Self-Attention
351
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
352
+ if self.pivotal_pass:
353
+ # norm_hidden_states.shape = 3, n_frames * seq_len, dim
354
+ self.attn_output = self.attn1(
355
+ norm_hidden_states.view(batch_size, sequence_length, dim),
356
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
357
+ **cross_attention_kwargs,
358
+ )
359
+ # 3, n_frames * seq_len, dim - > 3 * n_frames, seq_len, dim
360
+ self.kf_attn_output = self.attn_output
361
+ else:
362
+ batch_kf_size, _, _ = self.kf_attn_output.shape
363
+ self.attn_output = self.kf_attn_output.view(3, batch_kf_size // 3, sequence_length, dim)[:,
364
+ batch_idxs] # 3, n_frames, seq_len, dim --> 3, len(batch_idxs), seq_len, dim
365
+ if self.use_ada_layer_norm_zero:
366
+ self.attn_output = gate_msa.unsqueeze(1) * self.attn_output
367
+
368
+ # gather values from attn_output, using idx as indices, and get a tensor of shape 3, n_frames, seq_len, dim
369
+ if not self.pivotal_pass:
370
+ if len(batch_idxs) == 2:
371
+ attn_1, attn_2 = self.attn_output[:, 0], self.attn_output[:, 1]
372
+ attn_output1 = attn_1.gather(dim=1, index=idx1.unsqueeze(-1).repeat(1, 1, dim))
373
+ attn_output2 = attn_2.gather(dim=1, index=idx2.unsqueeze(-1).repeat(1, 1, dim))
374
+
375
+ s = torch.arange(0, n_frames).to(idx1.device) + batch_idxs[0] * n_frames
376
+ # distance from the pivot
377
+ p1 = batch_idxs[0] * n_frames + n_frames // 2
378
+ p2 = batch_idxs[1] * n_frames + n_frames // 2
379
+ d1 = torch.abs(s - p1)
380
+ d2 = torch.abs(s - p2)
381
+ # weight
382
+ w1 = d2 / (d1 + d2)
383
+ w1 = torch.sigmoid(w1)
384
+
385
+ w1 = w1.unsqueeze(0).unsqueeze(-1).unsqueeze(-1).repeat(3, 1, sequence_length, dim)
386
+ attn_output1 = attn_output1.view(3, n_frames, sequence_length, dim)
387
+ attn_output2 = attn_output2.view(3, n_frames, sequence_length, dim)
388
+ attn_output = w1 * attn_output1 + (1 - w1) * attn_output2
389
+ else:
390
+ attn_output = self.attn_output[:,0].gather(dim=1, index=idx1.unsqueeze(-1).repeat(1, 1, dim))
391
+
392
+ attn_output = attn_output.reshape(
393
+ batch_size, sequence_length, dim) # 3 * n_frames, seq_len, dim
394
+ else:
395
+ attn_output = self.attn_output
396
+ hidden_states = hidden_states.reshape(batch_size, sequence_length, dim) # 3 * n_frames, seq_len, dim
397
+ hidden_states = attn_output + hidden_states
398
+
399
+ if self.attn2 is not None:
400
+ norm_hidden_states = (
401
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
402
+ )
403
+
404
+ # 2. Cross-Attention
405
+ attn_output = self.attn2(
406
+ norm_hidden_states,
407
+ encoder_hidden_states=encoder_hidden_states,
408
+ attention_mask=encoder_attention_mask,
409
+ **cross_attention_kwargs,
410
+ )
411
+ hidden_states = attn_output + hidden_states
412
+
413
+ # 3. Feed-forward
414
+ norm_hidden_states = self.norm3(hidden_states)
415
+
416
+ if self.use_ada_layer_norm_zero:
417
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
418
+
419
+
420
+ ff_output = self.ff(norm_hidden_states)
421
+
422
+ if self.use_ada_layer_norm_zero:
423
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
424
+
425
+ hidden_states = ff_output + hidden_states
426
+
427
+ return hidden_states
428
+
429
+ return TokenFlowBlock
430
+
431
+
432
+ def set_tokenflow(
433
+ model: torch.nn.Module):
434
+ """
435
+ Sets the tokenflow attention blocks in a model.
436
+ """
437
+
438
+ for _, module in model.named_modules():
439
+ if isinstance_str(module, "BasicTransformerBlock"):
440
+ make_tokenflow_block_fn = make_tokenflow_attention_block
441
+ module.__class__ = make_tokenflow_block_fn(module.__class__)
442
+
443
+ # Something needed for older versions of diffusers
444
+ if not hasattr(module, "use_ada_layer_norm_zero"):
445
+ module.use_ada_layer_norm = False
446
+ module.use_ada_layer_norm_zero = False
447
+
448
+ return model
utils.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from PIL import Image
3
+ import torch
4
+ import yaml
5
+ import math
6
+
7
+ import torchvision.transforms as T
8
+ from torchvision.io import read_video,write_video
9
+ import os
10
+ import random
11
+ import numpy as np
12
+ from torchvision.io import write_video
13
+ # from kornia.filters import joint_bilateral_blur
14
+ from kornia.geometry.transform import remap
15
+ from kornia.utils.grid import create_meshgrid
16
+ import cv2
17
+
18
+ def save_video_frames(video_path, img_size=(512,512)):
19
+ video, _, _ = read_video(video_path, output_format="TCHW")
20
+ # rotate video -90 degree if video is .mov format. this is a weird bug in torchvision
21
+ if video_path.endswith('.mov'):
22
+ video = T.functional.rotate(video, -90)
23
+ video_name = Path(video_path).stem
24
+ os.makedirs(f'data/{video_name}', exist_ok=True)
25
+ for i in range(len(video)):
26
+ ind = str(i).zfill(5)
27
+ image = T.ToPILImage()(video[i])
28
+ image_resized = image.resize((img_size), resample=Image.Resampling.LANCZOS)
29
+ image_resized.save(f'data/{video_name}/{ind}.png')
30
+
31
+ def video_to_frames(video_path, img_size=(512,512)):
32
+ video, _, _ = read_video(video_path, output_format="TCHW")
33
+ # rotate video -90 degree if video is .mov format. this is a weird bug in torchvision
34
+ if video_path.endswith('.mov'):
35
+ video = T.functional.rotate(video, -90)
36
+ video_name = Path(video_path).stem
37
+ # os.makedirs(f'data/{video_name}', exist_ok=True)
38
+ frames = []
39
+ for i in range(len(video)):
40
+ ind = str(i).zfill(5)
41
+ image = T.ToPILImage()(video[i])
42
+ image_resized = image.resize((img_size), resample=Image.Resampling.LANCZOS)
43
+ # image_resized.save(f'data/{video_name}/{ind}.png')
44
+ frames.append(image_resized)
45
+ return frames
46
+
47
+ def add_dict_to_yaml_file(file_path, key, value):
48
+ data = {}
49
+
50
+ # If the file already exists, load its contents into the data dictionary
51
+ if os.path.exists(file_path):
52
+ with open(file_path, 'r') as file:
53
+ data = yaml.safe_load(file)
54
+
55
+ # Add or update the key-value pair
56
+ data[key] = value
57
+
58
+ # Save the data back to the YAML file
59
+ with open(file_path, 'w') as file:
60
+ yaml.dump(data, file)
61
+
62
+ def isinstance_str(x: object, cls_name: str):
63
+ """
64
+ Checks whether x has any class *named* cls_name in its ancestry.
65
+ Doesn't require access to the class's implementation.
66
+
67
+ Useful for patching!
68
+ """
69
+
70
+ for _cls in x.__class__.__mro__:
71
+ if _cls.__name__ == cls_name:
72
+ return True
73
+
74
+ return False
75
+
76
+
77
+ def batch_cosine_sim(x, y):
78
+ if type(x) is list:
79
+ x = torch.cat(x, dim=0)
80
+ if type(y) is list:
81
+ y = torch.cat(y, dim=0)
82
+ x = x / x.norm(dim=-1, keepdim=True)
83
+ y = y / y.norm(dim=-1, keepdim=True)
84
+ similarity = x @ y.T
85
+ return similarity
86
+
87
+
88
+ def load_imgs(data_path, n_frames, device='cuda', pil=False):
89
+ imgs = []
90
+ pils = []
91
+ for i in range(n_frames):
92
+ img_path = os.path.join(data_path, "%05d.jpg" % i)
93
+ if not os.path.exists(img_path):
94
+ img_path = os.path.join(data_path, "%05d.png" % i)
95
+ img_pil = Image.open(img_path)
96
+ pils.append(img_pil)
97
+ img = T.ToTensor()(img_pil).unsqueeze(0)
98
+ imgs.append(img)
99
+ if pil:
100
+ return torch.cat(imgs).to(device), pils
101
+ return torch.cat(imgs).to(device)
102
+
103
+
104
+ def save_video(raw_frames, save_path, fps=10):
105
+ video_codec = "libx264"
106
+ video_options = {
107
+ "crf": "18", # Constant Rate Factor (lower value = higher quality, 18 is a good balance)
108
+ "preset": "slow", # Encoding preset (e.g., ultrafast, superfast, veryfast, faster, fast, medium, slow, slower, veryslow)
109
+ }
110
+
111
+ frames = (raw_frames * 255).to(torch.uint8).cpu().permute(0, 2, 3, 1)
112
+ write_video(save_path, frames, fps=fps, video_codec=video_codec, options=video_options)
113
+
114
+
115
+ def seed_everything(seed):
116
+ torch.manual_seed(seed)
117
+ torch.cuda.manual_seed(seed)
118
+ random.seed(seed)
119
+ np.random.seed(seed)
120
+
121
+