jadechoghari commited on
Commit
f5bb4af
1 Parent(s): ecbaf30

Create generate.py

Browse files
Files changed (1) hide show
  1. generate.py +375 -0
generate.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ import os
6
+ from transformers import logging
7
+
8
+ from utils import CONTROLNET_DICT
9
+ from utils import load_config, save_config
10
+ from utils import get_controlnet_kwargs, get_frame_ids, get_latents_dir, init_model, seed_everything
11
+ from utils import prepare_control, load_latent, load_video, prepare_depth, save_video
12
+ from utils import register_time, register_attention_control, register_conv_control
13
+
14
+ import vidtome
15
+
16
+ # suppress partial model loading warning
17
+ logging.set_verbosity_error()
18
+
19
+
20
+ class Generator(nn.Module):
21
+ def __init__(self, pipe, scheduler, config):
22
+ super().__init__()
23
+
24
+ self.device = config.device
25
+ self.seed = config.seed
26
+
27
+
28
+
29
+
30
+ self.model_key = config.model_key
31
+
32
+ self.config = config
33
+ gene_config = config.generation
34
+ float_precision = gene_config.float_precision if "float_precision" in gene_config else config.float_precision
35
+ if float_precision == "fp16":
36
+ self.dtype = torch.float16
37
+ print("[INFO] float precision fp16. Use torch.float16.")
38
+ else:
39
+ self.dtype = torch.float32
40
+ print("[INFO] float precision fp32. Use torch.float32.")
41
+
42
+ self.pipe = pipe
43
+ self.vae = pipe.vae
44
+ self.tokenizer = pipe.tokenizer
45
+ self.unet = pipe.unet
46
+ self.text_encoder = pipe.text_encoder
47
+ if config.enable_xformers_memory_efficient_attention:
48
+ try:
49
+ pipe.enable_xformers_memory_efficient_attention()
50
+ except ModuleNotFoundError:
51
+ print("[WARNING] xformers not found. Disable xformers attention.")
52
+ self.n_timesteps = gene_config.n_timesteps
53
+ scheduler.set_timesteps(gene_config.n_timesteps, device=self.device)
54
+ self.scheduler = scheduler
55
+
56
+ self.batch_size = 2
57
+ self.control = gene_config.control
58
+ self.use_depth = config.sd_version == "depth"
59
+ self.use_controlnet = self.control in CONTROLNET_DICT.keys()
60
+ self.use_pnp = self.control == "pnp"
61
+ if self.use_controlnet:
62
+ self.controlnet = pipe.controlnet
63
+ self.controlnet_scale = gene_config.control_scale
64
+ elif self.use_pnp:
65
+ pnp_f_t = int(gene_config.n_timesteps * gene_config.pnp_f_t)
66
+ pnp_attn_t = int(gene_config.n_timesteps * gene_config.pnp_attn_t)
67
+ self.batch_size += 1
68
+ self.init_pnp(conv_injection_t=pnp_f_t, qk_injection_t=pnp_attn_t)
69
+
70
+ self.chunk_size = gene_config.chunk_size
71
+ self.chunk_ord = gene_config.chunk_ord
72
+ self.merge_global = gene_config.merge_global
73
+ self.local_merge_ratio = gene_config.local_merge_ratio
74
+ self.global_merge_ratio = gene_config.global_merge_ratio
75
+ self.global_rand = gene_config.global_rand
76
+ self.align_batch = gene_config.align_batch
77
+
78
+ self.prompt = gene_config.prompt
79
+ self.negative_prompt = gene_config.negative_prompt
80
+ self.guidance_scale = gene_config.guidance_scale
81
+ self.save_frame = gene_config.save_frame
82
+
83
+ self.frame_height, self.frame_width = config.height, config.width
84
+ self.work_dir = config.work_dir
85
+
86
+ self.chunk_ord = gene_config.chunk_ord
87
+ if "mix" in self.chunk_ord:
88
+ self.perm_div = float(self.chunk_ord.split("-")[-1]) if "-" in self.chunk_ord else 3.
89
+ self.chunk_ord = "mix"
90
+ # Patch VidToMe to model
91
+ self.activate_vidtome()
92
+
93
+ if gene_config.use_lora:
94
+ self.pipe.load_lora_weights(**gene_config.lora)
95
+
96
+ def activate_vidtome(self):
97
+ vidtome.apply_patch(self.pipe, self.local_merge_ratio, self.merge_global, self.global_merge_ratio,
98
+ seed = self.seed, batch_size = self.batch_size, align_batch = self.use_pnp or self.align_batch, global_rand = self.global_rand)
99
+
100
+ @torch.no_grad()
101
+ def get_text_embeds_input(self, prompt, negative_prompt):
102
+ text_embeds = self.get_text_embeds(
103
+ prompt, negative_prompt, self.device)
104
+ if self.use_pnp:
105
+ pnp_guidance_embeds = self.get_text_embeds("", device=self.device)
106
+ text_embeds = torch.cat(
107
+ [pnp_guidance_embeds, text_embeds], dim=0)
108
+ return text_embeds
109
+
110
+ @torch.no_grad()
111
+ def get_text_embeds(self, prompt, negative_prompt=None, device="cuda"):
112
+ text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
113
+ truncation=True, return_tensors='pt')
114
+ text_embeddings = self.text_encoder(text_input.input_ids.to(device))[0]
115
+ if negative_prompt is not None:
116
+ uncond_input = self.tokenizer(negative_prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
117
+ return_tensors='pt')
118
+ uncond_embeddings = self.text_encoder(
119
+ uncond_input.input_ids.to(device))[0]
120
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
121
+ return text_embeddings
122
+
123
+ @torch.no_grad()
124
+ def prepare_data(self, data_path, latent_path, frame_ids):
125
+ self.frames = load_video(data_path, self.frame_height,
126
+ self.frame_width, frame_ids=frame_ids, device=self.device)
127
+ self.init_noise = load_latent(
128
+ latent_path, t=self.scheduler.timesteps[0], frame_ids=frame_ids).to(self.dtype).to(self.device)
129
+
130
+ if self.use_depth:
131
+ self.depths = prepare_depth(
132
+ self.pipe, self.frames, frame_ids, self.work_dir).to(self.init_noise)
133
+
134
+ if self.use_controlnet:
135
+ self.controlnet_images = prepare_control(
136
+ self.control, self.frames, frame_ids, self.work_dir).to(self.init_noise)
137
+
138
+ @torch.no_grad()
139
+ def decode_latents(self, latents):
140
+ with torch.autocast(device_type=self.device, dtype=self.dtype):
141
+ latents = 1 / 0.18215 * latents
142
+ imgs = self.vae.decode(latents).sample
143
+ imgs = (imgs / 2 + 0.5).clamp(0, 1)
144
+ return imgs
145
+
146
+ @torch.no_grad()
147
+ def decode_latents_batch(self, latents):
148
+ imgs = []
149
+ batch_latents = latents.split(self.batch_size, dim=0)
150
+ for latent in batch_latents:
151
+ imgs += [self.decode_latents(latent)]
152
+ imgs = torch.cat(imgs)
153
+ return imgs
154
+
155
+ @torch.no_grad()
156
+ def encode_imgs(self, imgs):
157
+ with torch.autocast(device_type=self.device, dtype=self.dtype):
158
+ imgs = 2 * imgs - 1
159
+ posterior = self.vae.encode(imgs).latent_dist
160
+ latents = posterior.mean * 0.18215
161
+ return latents
162
+
163
+ @torch.no_grad()
164
+ def encode_imgs_batch(self, imgs):
165
+ latents = []
166
+ batch_imgs = imgs.split(self.batch_size, dim=0)
167
+ for img in batch_imgs:
168
+ latents += [self.encode_imgs(img)]
169
+ latents = torch.cat(latents)
170
+ return latents
171
+
172
+ def get_chunks(self, flen):
173
+ x_index = torch.arange(flen)
174
+
175
+ # The first chunk has a random length
176
+ rand_first = np.random.randint(0, self.chunk_size) + 1
177
+ chunks = x_index[rand_first:].split(self.chunk_size, dim=0)
178
+ chunks = [x_index[:rand_first]] + list(chunks) if len(chunks[0]) > 0 else [x_index[:rand_first]]
179
+ if np.random.rand() > 0.5:
180
+ chunks = chunks[::-1]
181
+
182
+ # Chunk order only matter when we do global token merging
183
+ if self.merge_global == False:
184
+ return chunks
185
+
186
+ # Chunk order. "seq": sequential order. "rand": full permutation. "mix": partial permutation.
187
+ if self.chunk_ord == "rand":
188
+ order = torch.randperm(len(chunks))
189
+ elif self.chunk_ord == "mix":
190
+ randord = torch.randperm(len(chunks)).tolist()
191
+ rand_len = int(len(randord) / self.perm_div)
192
+ seqord = sorted(randord[rand_len:])
193
+ if rand_len > 0:
194
+ randord = randord[:rand_len]
195
+ if abs(seqord[-1] - randord[-1]) < abs(seqord[0] - randord[-1]):
196
+ seqord = seqord[::-1]
197
+ order = randord + seqord
198
+ else:
199
+ order = seqord
200
+ else:
201
+ order = torch.arange(len(chunks))
202
+ chunks = [chunks[i] for i in order]
203
+ return chunks
204
+
205
+ @torch.no_grad()
206
+ def ddim_sample(self, x, conds):
207
+ print("[INFO] denoising frames...")
208
+ timesteps = self.scheduler.timesteps
209
+ noises = torch.zeros_like(x)
210
+
211
+ for i, t in enumerate(tqdm(timesteps, desc="Sampling")):
212
+ self.pre_iter(x, t)
213
+
214
+ # Split video into chunks and denoise
215
+ chunks = self.get_chunks(len(x))
216
+ for chunk in chunks:
217
+ torch.cuda.empty_cache()
218
+ noises[chunk] = self.pred_noise(
219
+ x[chunk], conds, t, batch_idx=chunk)
220
+
221
+ x = self.pred_next_x(x, noises, t, i, inversion=False)
222
+
223
+ self.post_iter(x, t)
224
+ return x
225
+
226
+ def pre_iter(self, x, t):
227
+ if self.use_pnp:
228
+ # Prepare PnP
229
+ register_time(self, t.item())
230
+ cur_latents = load_latent(self.latent_path, t=t, frame_ids = self.frame_ids)
231
+ self.cur_latents = cur_latents
232
+
233
+ def post_iter(self, x, t):
234
+ if self.merge_global:
235
+ # Reset global tokens
236
+ vidtome.update_patch(self.pipe, global_tokens = None)
237
+
238
+ @torch.no_grad()
239
+ def pred_noise(self, x, cond, t, batch_idx=None):
240
+
241
+ flen = len(x)
242
+ text_embed_input = cond.repeat_interleave(flen, dim=0)
243
+
244
+ # For classifier-free guidance
245
+ latent_model_input = torch.cat([x, x])
246
+ batch_size = 2
247
+
248
+ if self.use_pnp:
249
+ # Cat latents from inverted source frames for PnP operation
250
+ source_latents = self.cur_latents
251
+ if batch_idx is not None:
252
+ source_latents = source_latents[batch_idx]
253
+ latent_model_input = torch.cat([source_latents.to(x), latent_model_input])
254
+ batch_size += 1
255
+
256
+ # For sd-depth model
257
+ if self.use_depth:
258
+ depth = self.depths
259
+ if batch_idx is not None:
260
+ depth = depth[batch_idx]
261
+ depth = depth.repeat(batch_size, 1, 1, 1)
262
+ latent_model_input = torch.cat([latent_model_input, depth.to(x)], dim=1)
263
+
264
+ kwargs = dict()
265
+ # Compute controlnet outputs
266
+ if self.use_controlnet:
267
+ controlnet_cond = self.controlnet_images
268
+ if batch_idx is not None:
269
+ controlnet_cond = controlnet_cond[batch_idx]
270
+ controlnet_cond = controlnet_cond.repeat(batch_size, 1, 1, 1)
271
+ controlnet_kwargs = get_controlnet_kwargs(
272
+ self.controlnet, latent_model_input, text_embed_input, t, controlnet_cond, self.controlnet_scale)
273
+ kwargs.update(controlnet_kwargs)
274
+ # Pred noise!
275
+ eps = self.unet(latent_model_input, t, encoder_hidden_states=text_embed_input, **kwargs).sample
276
+ noise_pred_uncond, noise_pred_cond = eps.chunk(batch_size)[-2:]
277
+ # CFG
278
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
279
+ return noise_pred
280
+
281
+ @torch.no_grad()
282
+ def pred_next_x(self, x, eps, t, i, inversion=False):
283
+ if inversion:
284
+ timesteps = reversed(self.scheduler.timesteps)
285
+ else:
286
+ timesteps = self.scheduler.timesteps
287
+ alpha_prod_t = self.scheduler.alphas_cumprod[t]
288
+ if inversion:
289
+ alpha_prod_t_prev = (
290
+ self.scheduler.alphas_cumprod[timesteps[i - 1]]
291
+ if i > 0 else self.scheduler.final_alpha_cumprod
292
+ )
293
+ else:
294
+ alpha_prod_t_prev = (
295
+ self.scheduler.alphas_cumprod[timesteps[i + 1]]
296
+ if i < len(timesteps) - 1
297
+ else self.scheduler.final_alpha_cumprod
298
+ )
299
+ mu = alpha_prod_t ** 0.5
300
+ sigma = (1 - alpha_prod_t) ** 0.5
301
+ mu_prev = alpha_prod_t_prev ** 0.5
302
+ sigma_prev = (1 - alpha_prod_t_prev) ** 0.5
303
+
304
+ if inversion:
305
+ pred_x0 = (x - sigma_prev * eps) / mu_prev
306
+ x = mu * pred_x0 + sigma * eps
307
+ else:
308
+ pred_x0 = (x - sigma * eps) / mu
309
+ x = mu_prev * pred_x0 + sigma_prev * eps
310
+
311
+ return x
312
+
313
+ def init_pnp(self, conv_injection_t, qk_injection_t):
314
+ qk_injection_timesteps = self.scheduler.timesteps[:qk_injection_t] if qk_injection_t >= 0 else []
315
+ conv_injection_timesteps = self.scheduler.timesteps[:conv_injection_t] if conv_injection_t >= 0 else []
316
+ register_attention_control(
317
+ self, qk_injection_timesteps, num_inputs=self.batch_size)
318
+ register_conv_control(
319
+ self, conv_injection_timesteps, num_inputs=self.batch_size)
320
+
321
+ def check_latent_exists(self, latent_path):
322
+ if self.use_pnp:
323
+ timesteps = self.scheduler.timesteps
324
+ else:
325
+ timesteps = [self.scheduler.timesteps[0]]
326
+
327
+ for ts in timesteps:
328
+ cur_latent_path = os.path.join(
329
+ latent_path, f'noisy_latents_{ts}.pt')
330
+ if not os.path.exists(cur_latent_path):
331
+ return False
332
+ return True
333
+
334
+ @torch.no_grad()
335
+ def __call__(self, data_path, latent_path, output_path, frame_ids):
336
+ self.scheduler.set_timesteps(self.n_timesteps)
337
+ latent_path = get_latents_dir(latent_path, self.model_key)
338
+ assert self.check_latent_exists(
339
+ latent_path), f"Required latent not found at {latent_path}. \
340
+ Note: If using PnP as control, you need inversion latents saved \
341
+ at each generation timestep."
342
+
343
+ self.data_path = data_path
344
+ self.latent_path = latent_path
345
+ self.frame_ids = frame_ids
346
+ self.prepare_data(data_path, latent_path, frame_ids)
347
+
348
+ print(f"[INFO] initial noise latent shape: {self.init_noise.shape}")
349
+
350
+ for edit_name, edit_prompt in self.prompt.items():
351
+ print(f"[INFO] current prompt: {edit_prompt}")
352
+ conds = self.get_text_embeds_input(edit_prompt, self.negative_prompt)
353
+ # Comment this if you have enough GPU memory
354
+ clean_latent = self.ddim_sample(self.init_noise, conds)
355
+ torch.cuda.empty_cache()
356
+ clean_frames = self.decode_latents_batch(clean_latent)
357
+ cur_output_path = os.path.join(output_path, edit_name)
358
+ save_config(self.config, cur_output_path, gene = True)
359
+ save_video(clean_frames, cur_output_path, save_frame = self.save_frame)
360
+
361
+
362
+
363
+
364
+
365
+ if __name__ == "__main__":
366
+ config = load_config()
367
+ pipe, scheduler, model_key = init_model(
368
+ config.device, config.sd_version, config.model_key, config.generation.control, config.float_precision)
369
+ config.model_key = model_key
370
+ seed_everything(config.seed)
371
+ generator = Generator(pipe, scheduler, config)
372
+ frame_ids = get_frame_ids(
373
+ config.generation.frame_range, config.generation.frame_ids)
374
+ generator(config.input_path, config.generation.latents_path,
375
+ config.generation.output_path, frame_ids=frame_ids)