jadechoghari commited on
Commit
94ccc87
1 Parent(s): f5bb4af

Create invert.py

Browse files
Files changed (1) hide show
  1. invert.py +289 -0
invert.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ from tqdm import tqdm
4
+ import os
5
+ from transformers import logging
6
+
7
+ from utils import load_config, save_config
8
+ from utils import get_controlnet_kwargs, get_latents_dir, init_model, seed_everything
9
+ from utils import load_video, prepare_depth, save_frames, control_preprocess
10
+
11
+ # suppress partial model loading warning
12
+ logging.set_verbosity_error()
13
+
14
+
15
+ class Inverter(nn.Module):
16
+ def __init__(self, pipe, scheduler, config):
17
+ super().__init__()
18
+
19
+ self.device = config.device
20
+ self.use_depth = config.sd_version == "depth"
21
+ self.model_key = config.model_key
22
+
23
+ self.config = config
24
+ inv_config = config.inversion
25
+
26
+ float_precision = inv_config.float_precision if "float_precision" in inv_config else config.float_precision
27
+ if float_precision == "fp16":
28
+ self.dtype = torch.float16
29
+ print("[INFO] float precision fp16. Use torch.float16.")
30
+ else:
31
+ self.dtype = torch.float32
32
+ print("[INFO] float precision fp32. Use torch.float32.")
33
+
34
+ self.pipe = pipe
35
+ self.vae = pipe.vae
36
+ self.tokenizer = pipe.tokenizer
37
+ self.unet = pipe.unet
38
+ self.text_encoder = pipe.text_encoder
39
+ if config.enable_xformers_memory_efficient_attention:
40
+ try:
41
+ pipe.enable_xformers_memory_efficient_attention()
42
+ except ModuleNotFoundError:
43
+ print("[WARNING] xformers not found. Disable xformers attention.")
44
+
45
+ self.control = inv_config.control
46
+ if self.control != "none":
47
+ self.controlnet = pipe.controlnet
48
+
49
+ self.controlnet_scale = inv_config.control_scale
50
+
51
+ scheduler.set_timesteps(inv_config.save_steps)
52
+ self.timesteps_to_save = scheduler.timesteps
53
+ scheduler.set_timesteps(inv_config.steps)
54
+
55
+ self.scheduler = scheduler
56
+
57
+ self.prompt=inv_config.prompt
58
+ self.recon=inv_config.recon
59
+ self.save_latents=inv_config.save_intermediate
60
+ self.use_blip=inv_config.use_blip
61
+ self.steps=inv_config.steps
62
+ self.batch_size = inv_config.batch_size
63
+ self.force = inv_config.force
64
+
65
+ self.n_frames = inv_config.n_frames
66
+ self.frame_height, self.frame_width = config.height, config.width
67
+ self.work_dir = config.work_dir
68
+
69
+
70
+ @torch.no_grad()
71
+ def get_text_embeds(self, prompt, negative_prompt=None, device="cuda"):
72
+ text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
73
+ truncation=True, return_tensors='pt')
74
+ text_embeddings = self.text_encoder(text_input.input_ids.to(device))[0]
75
+ if negative_prompt is not None:
76
+ uncond_input = self.tokenizer(negative_prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
77
+ return_tensors='pt')
78
+ uncond_embeddings = self.text_encoder(
79
+ uncond_input.input_ids.to(device))[0]
80
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
81
+ return text_embeddings
82
+
83
+ @torch.no_grad()
84
+ def decode_latents(self, latents):
85
+ with torch.autocast(device_type=self.device, dtype=self.dtype):
86
+ latents = 1 / 0.18215 * latents
87
+ imgs = self.vae.decode(latents).sample
88
+ imgs = (imgs / 2 + 0.5).clamp(0, 1)
89
+ return imgs
90
+
91
+ @torch.no_grad()
92
+ def decode_latents_batch(self, latents):
93
+ imgs = []
94
+ batch_latents = latents.split(self.batch_size, dim = 0)
95
+ for latent in batch_latents:
96
+ imgs += [self.decode_latents(latent)]
97
+ imgs = torch.cat(imgs)
98
+ return imgs
99
+
100
+ @torch.no_grad()
101
+ def encode_imgs(self, imgs):
102
+ with torch.autocast(device_type=self.device, dtype=self.dtype):
103
+ imgs = 2 * imgs - 1
104
+ posterior = self.vae.encode(imgs).latent_dist
105
+ latents = posterior.mean * 0.18215
106
+ return latents
107
+
108
+ @torch.no_grad()
109
+ def encode_imgs_batch(self, imgs):
110
+ latents = []
111
+ batch_imgs = imgs.split(self.batch_size, dim = 0)
112
+ for img in batch_imgs:
113
+ latents += [self.encode_imgs(img)]
114
+ latents = torch.cat(latents)
115
+ return latents
116
+
117
+ @torch.no_grad()
118
+ def ddim_inversion(self, x, conds, save_path):
119
+ print("[INFO] start DDIM Inversion!")
120
+ timesteps = reversed(self.scheduler.timesteps)
121
+ with torch.autocast(device_type=self.device, dtype=self.dtype):
122
+ for i, t in enumerate(tqdm(timesteps)):
123
+ noises = []
124
+ x_index = torch.arange(len(x))
125
+ batches = x_index.split(self.batch_size, dim = 0)
126
+ for batch in batches:
127
+ noise = self.pred_noise(
128
+ x[batch], conds[batch], timesteps[i], batch_idx=batch)
129
+ noises += [noise]
130
+ noises = torch.cat(noises)
131
+ x = self.pred_next_x(x, noises, t, i, inversion=True)
132
+ if self.save_latents and t in self.timesteps_to_save:
133
+ torch.save(x, os.path.join(
134
+ save_path, f'noisy_latents_{t}.pt'))
135
+
136
+ # Save inverted noise latents
137
+ pth = os.path.join(save_path, f'noisy_latents_{t}.pt')
138
+ torch.save(x, pth)
139
+ print(f"[INFO] inverted latent saved to: {pth}")
140
+ return x
141
+
142
+ @torch.no_grad()
143
+ def ddim_sample(self, x, conds):
144
+ print("[INFO] reconstructing frames...")
145
+ timesteps = self.scheduler.timesteps
146
+ with torch.autocast(device_type=self.device, dtype=self.dtype):
147
+ for i, t in enumerate(tqdm(timesteps)):
148
+ noises = []
149
+ x_index = torch.arange(len(x))
150
+ batches = x_index.split(self.batch_size, dim = 0)
151
+ for batch in batches:
152
+ noise = self.pred_noise(
153
+ x[batch], conds[batch], t, batch_idx=batch)
154
+ noises += [noise]
155
+ noises = torch.cat(noises)
156
+ x = self.pred_next_x(x, noises, t, i, inversion=False)
157
+ return x
158
+
159
+ @torch.no_grad()
160
+ def pred_noise(self, x, cond, t, batch_idx=None):
161
+ # For sd-depth model
162
+ if self.use_depth:
163
+ depth = self.depths
164
+ if batch_idx is not None:
165
+ depth = depth[batch_idx]
166
+ x = torch.cat([x, depth.to(x)], dim=1)
167
+
168
+ kwargs = dict()
169
+ # Compute controlnet outputs
170
+ if self.control != "none":
171
+ if batch_idx is None:
172
+ controlnet_cond = self.controlnet_images
173
+ else:
174
+ controlnet_cond = self.controlnet_images[batch_idx]
175
+ controlnet_kwargs = get_controlnet_kwargs(self.controlnet, x, cond, t, controlnet_cond, self.controlnet_scale)
176
+ kwargs.update(controlnet_kwargs)
177
+
178
+ eps = self.unet(x, t, encoder_hidden_states=cond, **kwargs).sample
179
+ return eps
180
+
181
+ @torch.no_grad()
182
+ def pred_next_x(self, x, eps, t, i, inversion=False):
183
+ if inversion:
184
+ timesteps = reversed(self.scheduler.timesteps)
185
+ else:
186
+ timesteps = self.scheduler.timesteps
187
+ alpha_prod_t = self.scheduler.alphas_cumprod[t]
188
+ if inversion:
189
+ alpha_prod_t_prev = (
190
+ self.scheduler.alphas_cumprod[timesteps[i - 1]]
191
+ if i > 0 else self.scheduler.final_alpha_cumprod
192
+ )
193
+ else:
194
+ alpha_prod_t_prev = (
195
+ self.scheduler.alphas_cumprod[timesteps[i + 1]]
196
+ if i < len(timesteps) - 1
197
+ else self.scheduler.final_alpha_cumprod
198
+ )
199
+ mu = alpha_prod_t ** 0.5
200
+ sigma = (1 - alpha_prod_t) ** 0.5
201
+ mu_prev = alpha_prod_t_prev ** 0.5
202
+ sigma_prev = (1 - alpha_prod_t_prev) ** 0.5
203
+
204
+ if inversion:
205
+ pred_x0 = (x - sigma_prev * eps) / mu_prev
206
+ x = mu * pred_x0 + sigma * eps
207
+ else:
208
+ pred_x0 = (x - sigma * eps) / mu
209
+ x = mu_prev * pred_x0 + sigma_prev * eps
210
+
211
+ return x
212
+
213
+ @torch.no_grad()
214
+ def prepare_cond(self, prompts, n_frames):
215
+ if isinstance(prompts, str):
216
+ prompts = [prompts] * n_frames
217
+ cond = self.get_text_embeds(prompts[0])
218
+ conds = torch.cat([cond] * n_frames)
219
+ elif isinstance(prompts, list):
220
+ cond_ls = []
221
+ for prompt in prompts:
222
+ cond = self.get_text_embeds(prompt)
223
+ cond_ls += [cond]
224
+ conds = torch.cat(cond_ls)
225
+ return conds, prompts
226
+
227
+ def check_latent_exists(self, save_path):
228
+ save_timesteps = [self.scheduler.timesteps[0]]
229
+ if self.save_latents:
230
+ save_timesteps += self.timesteps_to_save
231
+ for ts in save_timesteps:
232
+ latent_path = os.path.join(
233
+ save_path, f'noisy_latents_{ts}.pt')
234
+ if not os.path.exists(latent_path):
235
+ return False
236
+ return True
237
+
238
+
239
+ @torch.no_grad()
240
+ def __call__(self, data_path, save_path):
241
+ self.scheduler.set_timesteps(self.steps)
242
+ save_path = get_latents_dir(save_path, self.model_key)
243
+ os.makedirs(save_path, exist_ok = True)
244
+ if self.check_latent_exists(save_path) and not self.force:
245
+ print(f"[INFO] inverted latents exist at: {save_path}. Skip inversion! Set 'inversion.force: True' to invert again.")
246
+ return
247
+
248
+ frames = load_video(data_path, self.frame_height, self.frame_width, device = self.device)
249
+
250
+ frame_ids = list(range(len(frames)))
251
+ if self.n_frames is not None:
252
+ frame_ids = frame_ids[:self.n_frames]
253
+ frames = frames[frame_ids]
254
+
255
+ if self.use_depth:
256
+ self.depths = prepare_depth(self.pipe, frames, frame_ids, self.work_dir)
257
+ conds, prompts = self.prepare_cond(self.prompt, len(frames))
258
+ with open(os.path.join(save_path, 'inversion_prompts.txt'), 'w') as f:
259
+ f.write('\n'.join(prompts))
260
+
261
+ if self.control != "none":
262
+ images = control_preprocess(
263
+ frames, self.control)
264
+ self.controlnet_images = images.to(self.device)
265
+
266
+ latents = self.encode_imgs_batch(frames)
267
+ torch.cuda.empty_cache()
268
+ print(f"[INFO] clean latents shape: {latents.shape}")
269
+
270
+ inverted_x = self.ddim_inversion(latents, conds, save_path)
271
+ save_config(self.config, save_path, inv = True)
272
+ if self.recon:
273
+ latent_reconstruction = self.ddim_sample(inverted_x, conds)
274
+
275
+ torch.cuda.empty_cache()
276
+ recon_frames = self.decode_latents_batch(
277
+ latent_reconstruction)
278
+
279
+ recon_save_path = os.path.join(save_path, 'recon_frames')
280
+ save_frames(recon_frames, recon_save_path, frame_ids = frame_ids)
281
+
282
+ if __name__ == "__main__":
283
+ config = load_config()
284
+ pipe, scheduler, model_key = init_model(
285
+ config.device, config.sd_version, config.model_key, config.inversion.control, config.float_precision)
286
+ config.model_key = model_key
287
+ seed_everything(config.seed)
288
+ inversion = Inverter(pipe, scheduler, config)
289
+ inversion(config.input_path, config.inversion.save_path)