Francke commited on
Commit
5d63776
·
1 Parent(s): aad5337
scripts/inference.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+ from omegaconf import OmegaConf
17
+ import torch
18
+ from diffusers import AutoencoderKL, DDIMScheduler
19
+ from latentsync.models.unet import UNet3DConditionModel
20
+ from latentsync.pipelines.lipsync_pipeline import LipsyncPipeline
21
+ from diffusers.utils.import_utils import is_xformers_available
22
+ from accelerate.utils import set_seed
23
+ from latentsync.whisper.audio2feature import Audio2Feature
24
+
25
+
26
+ def main(config, args):
27
+ print(f"Input video path: {args.video_path}")
28
+ print(f"Input audio path: {args.audio_path}")
29
+ print(f"Loaded checkpoint path: {args.inference_ckpt_path}")
30
+
31
+ scheduler = DDIMScheduler.from_pretrained("configs")
32
+
33
+ if config.model.cross_attention_dim == 768:
34
+ whisper_model_path = "checkpoints/whisper/small.pt"
35
+ elif config.model.cross_attention_dim == 384:
36
+ whisper_model_path = "checkpoints/whisper/tiny.pt"
37
+ else:
38
+ raise NotImplementedError("cross_attention_dim must be 768 or 384")
39
+
40
+ audio_encoder = Audio2Feature(model_path=whisper_model_path, device="cuda", num_frames=config.data.num_frames)
41
+
42
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
43
+ vae.config.scaling_factor = 0.18215
44
+ vae.config.shift_factor = 0
45
+
46
+ unet, _ = UNet3DConditionModel.from_pretrained(
47
+ OmegaConf.to_container(config.model),
48
+ args.inference_ckpt_path, # load checkpoint
49
+ device="cpu",
50
+ )
51
+
52
+ unet = unet.to(dtype=torch.float16)
53
+
54
+ # set xformers
55
+ if is_xformers_available():
56
+ unet.enable_xformers_memory_efficient_attention()
57
+
58
+ pipeline = LipsyncPipeline(
59
+ vae=vae,
60
+ audio_encoder=audio_encoder,
61
+ unet=unet,
62
+ scheduler=scheduler,
63
+ ).to("cuda")
64
+
65
+ if args.seed != -1:
66
+ set_seed(args.seed)
67
+ else:
68
+ torch.seed()
69
+
70
+ print(f"Initial seed: {torch.initial_seed()}")
71
+
72
+ pipeline(
73
+ video_path=args.video_path,
74
+ audio_path=args.audio_path,
75
+ video_out_path=args.video_out_path,
76
+ video_mask_path=args.video_out_path.replace(".mp4", "_mask.mp4"),
77
+ num_frames=config.data.num_frames,
78
+ num_inference_steps=config.run.inference_steps,
79
+ guidance_scale=args.guidance_scale,
80
+ weight_dtype=torch.float16,
81
+ width=config.data.resolution,
82
+ height=config.data.resolution,
83
+ )
84
+
85
+
86
+ if __name__ == "__main__":
87
+ parser = argparse.ArgumentParser()
88
+ parser.add_argument("--unet_config_path", type=str, default="configs/unet.yaml")
89
+ parser.add_argument("--inference_ckpt_path", type=str, required=True)
90
+ parser.add_argument("--video_path", type=str, required=True)
91
+ parser.add_argument("--audio_path", type=str, required=True)
92
+ parser.add_argument("--video_out_path", type=str, required=True)
93
+ parser.add_argument("--guidance_scale", type=float, default=1.0)
94
+ parser.add_argument("--seed", type=int, default=1247)
95
+ args = parser.parse_args()
96
+
97
+ config = OmegaConf.load(args.unet_config_path)
98
+
99
+ main(config, args)
scripts/train_syncnet.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from tqdm.auto import tqdm
16
+ import os, argparse, datetime, math
17
+ import logging
18
+ from omegaconf import OmegaConf
19
+ import shutil
20
+
21
+ from latentsync.data.syncnet_dataset import SyncNetDataset
22
+ from latentsync.models.syncnet import SyncNet
23
+ from latentsync.models.syncnet_wav2lip import SyncNetWav2Lip
24
+ from latentsync.utils.util import gather_loss, plot_loss_chart
25
+ from accelerate.utils import set_seed
26
+
27
+ import torch
28
+ from diffusers import AutoencoderKL
29
+ from diffusers.utils.logging import get_logger
30
+ from einops import rearrange
31
+ import torch.distributed as dist
32
+ from torch.nn.parallel import DistributedDataParallel as DDP
33
+ from torch.utils.data.distributed import DistributedSampler
34
+ from latentsync.utils.util import init_dist, cosine_loss
35
+
36
+ logger = get_logger(__name__)
37
+
38
+
39
+ def main(config):
40
+ # Initialize distributed training
41
+ local_rank = init_dist()
42
+ global_rank = dist.get_rank()
43
+ num_processes = dist.get_world_size()
44
+ is_main_process = global_rank == 0
45
+
46
+ seed = config.run.seed + global_rank
47
+ set_seed(seed)
48
+
49
+ # Logging folder
50
+ folder_name = "train" + datetime.datetime.now().strftime(f"-%Y_%m_%d-%H:%M:%S")
51
+ output_dir = os.path.join(config.data.train_output_dir, folder_name)
52
+
53
+ # Make one log on every process with the configuration for debugging.
54
+ logging.basicConfig(
55
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
56
+ datefmt="%m/%d/%Y %H:%M:%S",
57
+ level=logging.INFO,
58
+ )
59
+
60
+ # Handle the output folder creation
61
+ if is_main_process:
62
+ os.makedirs(output_dir, exist_ok=True)
63
+ os.makedirs(f"{output_dir}/checkpoints", exist_ok=True)
64
+ os.makedirs(f"{output_dir}/loss_charts", exist_ok=True)
65
+ shutil.copy(config.config_path, output_dir)
66
+
67
+ device = torch.device(local_rank)
68
+
69
+ if config.data.latent_space:
70
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
71
+ vae.requires_grad_(False)
72
+ vae.to(device)
73
+ else:
74
+ vae = None
75
+
76
+ # Dataset and Dataloader setup
77
+ train_dataset = SyncNetDataset(config.data.train_data_dir, config.data.train_fileslist, config)
78
+ val_dataset = SyncNetDataset(config.data.val_data_dir, config.data.val_fileslist, config)
79
+
80
+ train_distributed_sampler = DistributedSampler(
81
+ train_dataset,
82
+ num_replicas=num_processes,
83
+ rank=global_rank,
84
+ shuffle=True,
85
+ seed=config.run.seed,
86
+ )
87
+
88
+ # DataLoaders creation:
89
+ train_dataloader = torch.utils.data.DataLoader(
90
+ train_dataset,
91
+ batch_size=config.data.batch_size,
92
+ shuffle=False,
93
+ sampler=train_distributed_sampler,
94
+ num_workers=config.data.num_workers,
95
+ pin_memory=False,
96
+ drop_last=True,
97
+ worker_init_fn=train_dataset.worker_init_fn,
98
+ )
99
+
100
+ num_samples_limit = 640
101
+
102
+ val_batch_size = min(
103
+ num_samples_limit // config.data.num_frames, config.data.batch_size
104
+ ) # limit batch size to avoid CUDA OOM
105
+
106
+ val_dataloader = torch.utils.data.DataLoader(
107
+ val_dataset,
108
+ batch_size=val_batch_size,
109
+ shuffle=False,
110
+ num_workers=config.data.num_workers,
111
+ pin_memory=False,
112
+ drop_last=False,
113
+ worker_init_fn=val_dataset.worker_init_fn,
114
+ )
115
+
116
+ # Model
117
+ syncnet = SyncNet(OmegaConf.to_container(config.model)).to(device)
118
+ # syncnet = SyncNetWav2Lip().to(device)
119
+
120
+ optimizer = torch.optim.AdamW(
121
+ list(filter(lambda p: p.requires_grad, syncnet.parameters())), lr=config.optimizer.lr
122
+ )
123
+
124
+ if config.ckpt.resume_ckpt_path != "":
125
+ if is_main_process:
126
+ logger.info(f"Load checkpoint from: {config.ckpt.resume_ckpt_path}")
127
+ ckpt = torch.load(config.ckpt.resume_ckpt_path, map_location=device)
128
+
129
+ syncnet.load_state_dict(ckpt["state_dict"])
130
+ global_step = ckpt["global_step"]
131
+ train_step_list = ckpt["train_step_list"]
132
+ train_loss_list = ckpt["train_loss_list"]
133
+ val_step_list = ckpt["val_step_list"]
134
+ val_loss_list = ckpt["val_loss_list"]
135
+ else:
136
+ global_step = 0
137
+ train_step_list = []
138
+ train_loss_list = []
139
+ val_step_list = []
140
+ val_loss_list = []
141
+
142
+ # DDP wrapper
143
+ syncnet = DDP(syncnet, device_ids=[local_rank], output_device=local_rank)
144
+
145
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader))
146
+ num_train_epochs = math.ceil(config.run.max_train_steps / num_update_steps_per_epoch)
147
+ # validation_steps = int(config.ckpt.save_ckpt_steps // 5)
148
+ # validation_steps = 100
149
+
150
+ if is_main_process:
151
+ logger.info("***** Running training *****")
152
+ logger.info(f" Num examples = {len(train_dataset)}")
153
+ logger.info(f" Num Epochs = {num_train_epochs}")
154
+ logger.info(f" Instantaneous batch size per device = {config.data.batch_size}")
155
+ logger.info(f" Total train batch size (w. parallel & distributed) = {config.data.batch_size * num_processes}")
156
+ logger.info(f" Total optimization steps = {config.run.max_train_steps}")
157
+
158
+ first_epoch = global_step // num_update_steps_per_epoch
159
+ num_val_batches = config.data.num_val_samples // (num_processes * config.data.batch_size)
160
+
161
+ # Only show the progress bar once on each machine.
162
+ progress_bar = tqdm(
163
+ range(0, config.run.max_train_steps), initial=global_step, desc="Steps", disable=not is_main_process
164
+ )
165
+
166
+ # Support mixed-precision training
167
+ scaler = torch.cuda.amp.GradScaler() if config.run.mixed_precision_training else None
168
+
169
+ for epoch in range(first_epoch, num_train_epochs):
170
+ train_dataloader.sampler.set_epoch(epoch)
171
+ syncnet.train()
172
+
173
+ for step, batch in enumerate(train_dataloader):
174
+ ### >>>> Training >>>> ###
175
+
176
+ frames = batch["frames"].to(device, dtype=torch.float16)
177
+ audio_samples = batch["audio_samples"].to(device, dtype=torch.float16)
178
+ y = batch["y"].to(device, dtype=torch.float32)
179
+
180
+ if config.data.latent_space:
181
+ max_batch_size = (
182
+ num_samples_limit // config.data.num_frames
183
+ ) # due to the limited cuda memory, we split the input frames into parts
184
+ if frames.shape[0] > max_batch_size:
185
+ assert (
186
+ frames.shape[0] % max_batch_size == 0
187
+ ), f"max_batch_size {max_batch_size} should be divisible by batch_size {frames.shape[0]}"
188
+ frames_part_results = []
189
+ for i in range(0, frames.shape[0], max_batch_size):
190
+ frames_part = frames[i : i + max_batch_size]
191
+ frames_part = rearrange(frames_part, "b f c h w -> (b f) c h w")
192
+ with torch.no_grad():
193
+ frames_part = vae.encode(frames_part).latent_dist.sample() * 0.18215
194
+ frames_part_results.append(frames_part)
195
+ frames = torch.cat(frames_part_results, dim=0)
196
+ else:
197
+ frames = rearrange(frames, "b f c h w -> (b f) c h w")
198
+ with torch.no_grad():
199
+ frames = vae.encode(frames).latent_dist.sample() * 0.18215
200
+
201
+ frames = rearrange(frames, "(b f) c h w -> b (f c) h w", f=config.data.num_frames)
202
+ else:
203
+ frames = rearrange(frames, "b f c h w -> b (f c) h w")
204
+
205
+ if config.data.lower_half:
206
+ height = frames.shape[2]
207
+ frames = frames[:, :, height // 2 :, :]
208
+
209
+ # audio_embeds = wav2vec_encoder(audio_samples).last_hidden_state
210
+
211
+ # Mixed-precision training
212
+ with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=config.run.mixed_precision_training):
213
+ vision_embeds, audio_embeds = syncnet(frames, audio_samples)
214
+
215
+ loss = cosine_loss(vision_embeds.float(), audio_embeds.float(), y).mean()
216
+
217
+ optimizer.zero_grad()
218
+
219
+ # Backpropagate
220
+ if config.run.mixed_precision_training:
221
+ scaler.scale(loss).backward()
222
+ """ >>> gradient clipping >>> """
223
+ scaler.unscale_(optimizer)
224
+ torch.nn.utils.clip_grad_norm_(syncnet.parameters(), config.optimizer.max_grad_norm)
225
+ """ <<< gradient clipping <<< """
226
+ scaler.step(optimizer)
227
+ scaler.update()
228
+ else:
229
+ loss.backward()
230
+ """ >>> gradient clipping >>> """
231
+ torch.nn.utils.clip_grad_norm_(syncnet.parameters(), config.optimizer.max_grad_norm)
232
+ """ <<< gradient clipping <<< """
233
+ optimizer.step()
234
+
235
+ progress_bar.update(1)
236
+ global_step += 1
237
+
238
+ global_average_loss = gather_loss(loss, device)
239
+ train_step_list.append(global_step)
240
+ train_loss_list.append(global_average_loss)
241
+
242
+ if is_main_process and global_step % config.run.validation_steps == 0:
243
+ logger.info(f"Validation at step {global_step}")
244
+ val_loss = validation(
245
+ val_dataloader,
246
+ device,
247
+ syncnet,
248
+ cosine_loss,
249
+ config.data.latent_space,
250
+ config.data.lower_half,
251
+ vae,
252
+ num_val_batches,
253
+ )
254
+ val_step_list.append(global_step)
255
+ val_loss_list.append(val_loss)
256
+ logger.info(f"Validation loss at step {global_step} is {val_loss:0.3f}")
257
+
258
+ if is_main_process and global_step % config.ckpt.save_ckpt_steps == 0:
259
+ checkpoint_save_path = os.path.join(output_dir, f"checkpoints/checkpoint-{global_step}.pt")
260
+ torch.save(
261
+ {
262
+ "state_dict": syncnet.module.state_dict(), # to unwrap DDP
263
+ "global_step": global_step,
264
+ "train_step_list": train_step_list,
265
+ "train_loss_list": train_loss_list,
266
+ "val_step_list": val_step_list,
267
+ "val_loss_list": val_loss_list,
268
+ },
269
+ checkpoint_save_path,
270
+ )
271
+ logger.info(f"Saved checkpoint to {checkpoint_save_path}")
272
+ plot_loss_chart(
273
+ os.path.join(output_dir, f"loss_charts/loss_chart-{global_step}.png"),
274
+ ("Train loss", train_step_list, train_loss_list),
275
+ ("Val loss", val_step_list, val_loss_list),
276
+ )
277
+
278
+ progress_bar.set_postfix({"step_loss": global_average_loss})
279
+ if global_step >= config.run.max_train_steps:
280
+ break
281
+
282
+ progress_bar.close()
283
+ dist.destroy_process_group()
284
+
285
+
286
+ @torch.no_grad()
287
+ def validation(val_dataloader, device, syncnet, cosine_loss, latent_space, lower_half, vae, num_val_batches):
288
+ syncnet.eval()
289
+
290
+ losses = []
291
+ val_step = 0
292
+ while True:
293
+ for step, batch in enumerate(val_dataloader):
294
+ ### >>>> Validation >>>> ###
295
+
296
+ frames = batch["frames"].to(device, dtype=torch.float16)
297
+ audio_samples = batch["audio_samples"].to(device, dtype=torch.float16)
298
+ y = batch["y"].to(device, dtype=torch.float32)
299
+
300
+ if latent_space:
301
+ num_frames = frames.shape[1]
302
+ frames = rearrange(frames, "b f c h w -> (b f) c h w")
303
+ frames = vae.encode(frames).latent_dist.sample() * 0.18215
304
+ frames = rearrange(frames, "(b f) c h w -> b (f c) h w", f=num_frames)
305
+ else:
306
+ frames = rearrange(frames, "b f c h w -> b (f c) h w")
307
+
308
+ if lower_half:
309
+ height = frames.shape[2]
310
+ frames = frames[:, :, height // 2 :, :]
311
+
312
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
313
+ vision_embeds, audio_embeds = syncnet(frames, audio_samples)
314
+
315
+ loss = cosine_loss(vision_embeds.float(), audio_embeds.float(), y).mean()
316
+
317
+ losses.append(loss.item())
318
+
319
+ val_step += 1
320
+ if val_step > num_val_batches:
321
+ syncnet.train()
322
+ if len(losses) == 0:
323
+ raise RuntimeError("No validation data")
324
+ return sum(losses) / len(losses)
325
+
326
+
327
+ if __name__ == "__main__":
328
+ parser = argparse.ArgumentParser(description="Code to train the expert lip-sync discriminator")
329
+ parser.add_argument("--config_path", type=str, default="configs/syncnet/syncnet_16_vae.yaml")
330
+ args = parser.parse_args()
331
+
332
+ # Load a configuration file
333
+ config = OmegaConf.load(args.config_path)
334
+ config.config_path = args.config_path
335
+
336
+ main(config)
scripts/train_unet.py ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import math
17
+ import argparse
18
+ import shutil
19
+ import datetime
20
+ import logging
21
+ from omegaconf import OmegaConf
22
+
23
+ from tqdm.auto import tqdm
24
+ from einops import rearrange
25
+
26
+ import torch
27
+ import torch.nn.functional as F
28
+ import torch.distributed as dist
29
+ from torch.utils.data.distributed import DistributedSampler
30
+ from torch.nn.parallel import DistributedDataParallel as DDP
31
+
32
+ import diffusers
33
+ from diffusers import AutoencoderKL, DDIMScheduler
34
+ from diffusers.utils.logging import get_logger
35
+ from diffusers.optimization import get_scheduler
36
+ from diffusers.utils.import_utils import is_xformers_available
37
+ from accelerate.utils import set_seed
38
+
39
+ from latentsync.data.unet_dataset import UNetDataset
40
+ from latentsync.models.unet import UNet3DConditionModel
41
+ from latentsync.models.syncnet import SyncNet
42
+ from latentsync.pipelines.lipsync_pipeline import LipsyncPipeline
43
+ from latentsync.utils.util import (
44
+ init_dist,
45
+ cosine_loss,
46
+ reversed_forward,
47
+ )
48
+ from latentsync.utils.util import plot_loss_chart, gather_loss
49
+ from latentsync.whisper.audio2feature import Audio2Feature
50
+ from latentsync.trepa import TREPALoss
51
+ from eval.syncnet import SyncNetEval
52
+ from eval.syncnet_detect import SyncNetDetector
53
+ from eval.eval_sync_conf import syncnet_eval
54
+ import lpips
55
+
56
+
57
+ logger = get_logger(__name__)
58
+
59
+
60
+ def main(config):
61
+ # Initialize distributed training
62
+ local_rank = init_dist()
63
+ global_rank = dist.get_rank()
64
+ num_processes = dist.get_world_size()
65
+ is_main_process = global_rank == 0
66
+
67
+ seed = config.run.seed + global_rank
68
+ set_seed(seed)
69
+
70
+ # Logging folder
71
+ folder_name = "train" + datetime.datetime.now().strftime(f"-%Y_%m_%d-%H:%M:%S")
72
+ output_dir = os.path.join(config.data.train_output_dir, folder_name)
73
+
74
+ # Make one log on every process with the configuration for debugging.
75
+ logging.basicConfig(
76
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
77
+ datefmt="%m/%d/%Y %H:%M:%S",
78
+ level=logging.INFO,
79
+ )
80
+
81
+ # Handle the output folder creation
82
+ if is_main_process:
83
+ diffusers.utils.logging.set_verbosity_info()
84
+ os.makedirs(output_dir, exist_ok=True)
85
+ os.makedirs(f"{output_dir}/checkpoints", exist_ok=True)
86
+ os.makedirs(f"{output_dir}/val_videos", exist_ok=True)
87
+ os.makedirs(f"{output_dir}/loss_charts", exist_ok=True)
88
+ shutil.copy(config.unet_config_path, output_dir)
89
+ shutil.copy(config.data.syncnet_config_path, output_dir)
90
+
91
+ device = torch.device(local_rank)
92
+
93
+ noise_scheduler = DDIMScheduler.from_pretrained("configs")
94
+
95
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
96
+ vae.config.scaling_factor = 0.18215
97
+ vae.config.shift_factor = 0
98
+ vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
99
+ vae.requires_grad_(False)
100
+ vae.to(device)
101
+
102
+ syncnet_eval_model = SyncNetEval(device=device)
103
+ syncnet_eval_model.loadParameters("checkpoints/auxiliary/syncnet_v2.model")
104
+
105
+ syncnet_detector = SyncNetDetector(device=device, detect_results_dir="detect_results")
106
+
107
+ if config.model.cross_attention_dim == 768:
108
+ whisper_model_path = "checkpoints/whisper/small.pt"
109
+ elif config.model.cross_attention_dim == 384:
110
+ whisper_model_path = "checkpoints/whisper/tiny.pt"
111
+ else:
112
+ raise NotImplementedError("cross_attention_dim must be 768 or 384")
113
+
114
+ audio_encoder = Audio2Feature(
115
+ model_path=whisper_model_path,
116
+ device=device,
117
+ audio_embeds_cache_dir=config.data.audio_embeds_cache_dir,
118
+ num_frames=config.data.num_frames,
119
+ )
120
+
121
+ unet, resume_global_step = UNet3DConditionModel.from_pretrained(
122
+ OmegaConf.to_container(config.model),
123
+ config.ckpt.resume_ckpt_path, # load checkpoint
124
+ device=device,
125
+ )
126
+
127
+ if config.model.add_audio_layer and config.run.use_syncnet:
128
+ syncnet_config = OmegaConf.load(config.data.syncnet_config_path)
129
+ if syncnet_config.ckpt.inference_ckpt_path == "":
130
+ raise ValueError("SyncNet path is not provided")
131
+ syncnet = SyncNet(OmegaConf.to_container(syncnet_config.model)).to(device=device, dtype=torch.float16)
132
+ syncnet_checkpoint = torch.load(syncnet_config.ckpt.inference_ckpt_path, map_location=device)
133
+ syncnet.load_state_dict(syncnet_checkpoint["state_dict"])
134
+ syncnet.requires_grad_(False)
135
+
136
+ unet.requires_grad_(True)
137
+ trainable_params = list(unet.parameters())
138
+
139
+ if config.optimizer.scale_lr:
140
+ config.optimizer.lr = config.optimizer.lr * num_processes
141
+
142
+ optimizer = torch.optim.AdamW(trainable_params, lr=config.optimizer.lr)
143
+
144
+ if is_main_process:
145
+ logger.info(f"trainable params number: {len(trainable_params)}")
146
+ logger.info(f"trainable params scale: {sum(p.numel() for p in trainable_params) / 1e6:.3f} M")
147
+
148
+ # Enable xformers
149
+ if config.run.enable_xformers_memory_efficient_attention:
150
+ if is_xformers_available():
151
+ unet.enable_xformers_memory_efficient_attention()
152
+ else:
153
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
154
+
155
+ # Enable gradient checkpointing
156
+ if config.run.enable_gradient_checkpointing:
157
+ unet.enable_gradient_checkpointing()
158
+
159
+ # Get the training dataset
160
+ train_dataset = UNetDataset(config.data.train_data_dir, config)
161
+ distributed_sampler = DistributedSampler(
162
+ train_dataset,
163
+ num_replicas=num_processes,
164
+ rank=global_rank,
165
+ shuffle=True,
166
+ seed=config.run.seed,
167
+ )
168
+
169
+ # DataLoaders creation:
170
+ train_dataloader = torch.utils.data.DataLoader(
171
+ train_dataset,
172
+ batch_size=config.data.batch_size,
173
+ shuffle=False,
174
+ sampler=distributed_sampler,
175
+ num_workers=config.data.num_workers,
176
+ pin_memory=False,
177
+ drop_last=True,
178
+ worker_init_fn=train_dataset.worker_init_fn,
179
+ )
180
+
181
+ # Get the training iteration
182
+ if config.run.max_train_steps == -1:
183
+ assert config.run.max_train_epochs != -1
184
+ config.run.max_train_steps = config.run.max_train_epochs * len(train_dataloader)
185
+
186
+ # Scheduler
187
+ lr_scheduler = get_scheduler(
188
+ config.optimizer.lr_scheduler,
189
+ optimizer=optimizer,
190
+ num_warmup_steps=config.optimizer.lr_warmup_steps,
191
+ num_training_steps=config.run.max_train_steps,
192
+ )
193
+
194
+ if config.run.perceptual_loss_weight != 0 and config.run.pixel_space_supervise:
195
+ lpips_loss_func = lpips.LPIPS(net="vgg").to(device)
196
+
197
+ if config.run.trepa_loss_weight != 0 and config.run.pixel_space_supervise:
198
+ trepa_loss_func = TREPALoss(device=device)
199
+
200
+ # Validation pipeline
201
+ pipeline = LipsyncPipeline(
202
+ vae=vae,
203
+ audio_encoder=audio_encoder,
204
+ unet=unet,
205
+ scheduler=noise_scheduler,
206
+ ).to(device)
207
+ pipeline.set_progress_bar_config(disable=True)
208
+
209
+ # DDP warpper
210
+ unet = DDP(unet, device_ids=[local_rank], output_device=local_rank)
211
+
212
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
213
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader))
214
+ # Afterwards we recalculate our number of training epochs
215
+ num_train_epochs = math.ceil(config.run.max_train_steps / num_update_steps_per_epoch)
216
+
217
+ # Train!
218
+ total_batch_size = config.data.batch_size * num_processes
219
+
220
+ if is_main_process:
221
+ logger.info("***** Running training *****")
222
+ logger.info(f" Num examples = {len(train_dataset)}")
223
+ logger.info(f" Num Epochs = {num_train_epochs}")
224
+ logger.info(f" Instantaneous batch size per device = {config.data.batch_size}")
225
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
226
+ logger.info(f" Total optimization steps = {config.run.max_train_steps}")
227
+ global_step = resume_global_step
228
+ first_epoch = resume_global_step // num_update_steps_per_epoch
229
+
230
+ # Only show the progress bar once on each machine.
231
+ progress_bar = tqdm(
232
+ range(0, config.run.max_train_steps),
233
+ initial=resume_global_step,
234
+ desc="Steps",
235
+ disable=not is_main_process,
236
+ )
237
+
238
+ train_step_list = []
239
+ sync_loss_list = []
240
+ recon_loss_list = []
241
+
242
+ val_step_list = []
243
+ sync_conf_list = []
244
+
245
+ # Support mixed-precision training
246
+ scaler = torch.cuda.amp.GradScaler() if config.run.mixed_precision_training else None
247
+
248
+ for epoch in range(first_epoch, num_train_epochs):
249
+ train_dataloader.sampler.set_epoch(epoch)
250
+ unet.train()
251
+
252
+ for step, batch in enumerate(train_dataloader):
253
+ ### >>>> Training >>>> ###
254
+
255
+ if config.model.add_audio_layer:
256
+ if batch["mel"] != []:
257
+ mel = batch["mel"].to(device, dtype=torch.float16)
258
+
259
+ audio_embeds_list = []
260
+ try:
261
+ for idx in range(len(batch["video_path"])):
262
+ video_path = batch["video_path"][idx]
263
+ start_idx = batch["start_idx"][idx]
264
+
265
+ with torch.no_grad():
266
+ audio_feat = audio_encoder.audio2feat(video_path)
267
+ audio_embeds = audio_encoder.crop_overlap_audio_window(audio_feat, start_idx)
268
+ audio_embeds_list.append(audio_embeds)
269
+ except Exception as e:
270
+ logger.info(f"{type(e).__name__} - {e} - {video_path}")
271
+ continue
272
+ audio_embeds = torch.stack(audio_embeds_list) # (B, 16, 50, 384)
273
+ audio_embeds = audio_embeds.to(device, dtype=torch.float16)
274
+ else:
275
+ audio_embeds = None
276
+
277
+ # Convert videos to latent space
278
+ gt_images = batch["gt"].to(device, dtype=torch.float16)
279
+ gt_masked_images = batch["masked_gt"].to(device, dtype=torch.float16)
280
+ mask = batch["mask"].to(device, dtype=torch.float16)
281
+ ref_images = batch["ref"].to(device, dtype=torch.float16)
282
+
283
+ gt_images = rearrange(gt_images, "b f c h w -> (b f) c h w")
284
+ gt_masked_images = rearrange(gt_masked_images, "b f c h w -> (b f) c h w")
285
+ mask = rearrange(mask, "b f c h w -> (b f) c h w")
286
+ ref_images = rearrange(ref_images, "b f c h w -> (b f) c h w")
287
+
288
+ with torch.no_grad():
289
+ gt_latents = vae.encode(gt_images).latent_dist.sample()
290
+ gt_masked_images = vae.encode(gt_masked_images).latent_dist.sample()
291
+ ref_images = vae.encode(ref_images).latent_dist.sample()
292
+
293
+ mask = torch.nn.functional.interpolate(mask, size=config.data.resolution // vae_scale_factor)
294
+
295
+ gt_latents = (
296
+ rearrange(gt_latents, "(b f) c h w -> b c f h w", f=config.data.num_frames) - vae.config.shift_factor
297
+ ) * vae.config.scaling_factor
298
+ gt_masked_images = (
299
+ rearrange(gt_masked_images, "(b f) c h w -> b c f h w", f=config.data.num_frames)
300
+ - vae.config.shift_factor
301
+ ) * vae.config.scaling_factor
302
+ ref_images = (
303
+ rearrange(ref_images, "(b f) c h w -> b c f h w", f=config.data.num_frames) - vae.config.shift_factor
304
+ ) * vae.config.scaling_factor
305
+ mask = rearrange(mask, "(b f) c h w -> b c f h w", f=config.data.num_frames)
306
+
307
+ # Sample noise that we'll add to the latents
308
+ if config.run.use_mixed_noise:
309
+ # Refer to the paper: https://arxiv.org/abs/2305.10474
310
+ noise_shared_std_dev = (config.run.mixed_noise_alpha**2 / (1 + config.run.mixed_noise_alpha**2)) ** 0.5
311
+ noise_shared = torch.randn_like(gt_latents) * noise_shared_std_dev
312
+ noise_shared = noise_shared[:, :, 0:1].repeat(1, 1, config.data.num_frames, 1, 1)
313
+
314
+ noise_ind_std_dev = (1 / (1 + config.run.mixed_noise_alpha**2)) ** 0.5
315
+ noise_ind = torch.randn_like(gt_latents) * noise_ind_std_dev
316
+ noise = noise_ind + noise_shared
317
+ else:
318
+ noise = torch.randn_like(gt_latents)
319
+ noise = noise[:, :, 0:1].repeat(
320
+ 1, 1, config.data.num_frames, 1, 1
321
+ ) # Using the same noise for all frames, refer to the paper: https://arxiv.org/abs/2308.09716
322
+
323
+ bsz = gt_latents.shape[0]
324
+
325
+ # Sample a random timestep for each video
326
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=gt_latents.device)
327
+ timesteps = timesteps.long()
328
+
329
+ # Add noise to the latents according to the noise magnitude at each timestep
330
+ # (this is the forward diffusion process)
331
+ noisy_tensor = noise_scheduler.add_noise(gt_latents, noise, timesteps)
332
+
333
+ # Get the target for loss depending on the prediction type
334
+ if noise_scheduler.config.prediction_type == "epsilon":
335
+ target = noise
336
+ elif noise_scheduler.config.prediction_type == "v_prediction":
337
+ raise NotImplementedError
338
+ else:
339
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
340
+
341
+ unet_input = torch.cat([noisy_tensor, mask, gt_masked_images, ref_images], dim=1)
342
+
343
+ # Predict the noise and compute loss
344
+ # Mixed-precision training
345
+ with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=config.run.mixed_precision_training):
346
+ pred_noise = unet(unet_input, timesteps, encoder_hidden_states=audio_embeds).sample
347
+
348
+ if config.run.recon_loss_weight != 0:
349
+ recon_loss = F.mse_loss(pred_noise.float(), target.float(), reduction="mean")
350
+ else:
351
+ recon_loss = 0
352
+
353
+ pred_latents = reversed_forward(noise_scheduler, pred_noise, timesteps, noisy_tensor)
354
+
355
+ if config.run.pixel_space_supervise:
356
+ pred_images = vae.decode(
357
+ rearrange(pred_latents, "b c f h w -> (b f) c h w") / vae.config.scaling_factor
358
+ + vae.config.shift_factor
359
+ ).sample
360
+
361
+ if config.run.perceptual_loss_weight != 0 and config.run.pixel_space_supervise:
362
+ pred_images_perceptual = pred_images[:, :, pred_images.shape[2] // 2 :, :]
363
+ gt_images_perceptual = gt_images[:, :, gt_images.shape[2] // 2 :, :]
364
+ lpips_loss = lpips_loss_func(pred_images_perceptual.float(), gt_images_perceptual.float()).mean()
365
+ else:
366
+ lpips_loss = 0
367
+
368
+ if config.run.trepa_loss_weight != 0 and config.run.pixel_space_supervise:
369
+ trepa_pred_images = rearrange(pred_images, "(b f) c h w -> b c f h w", f=config.data.num_frames)
370
+ trepa_gt_images = rearrange(gt_images, "(b f) c h w -> b c f h w", f=config.data.num_frames)
371
+ trepa_loss = trepa_loss_func(trepa_pred_images, trepa_gt_images)
372
+ else:
373
+ trepa_loss = 0
374
+
375
+ if config.model.add_audio_layer and config.run.use_syncnet:
376
+ if config.run.pixel_space_supervise:
377
+ syncnet_input = rearrange(pred_images, "(b f) c h w -> b (f c) h w", f=config.data.num_frames)
378
+ else:
379
+ syncnet_input = rearrange(pred_latents, "b c f h w -> b (f c) h w")
380
+
381
+ if syncnet_config.data.lower_half:
382
+ height = syncnet_input.shape[2]
383
+ syncnet_input = syncnet_input[:, :, height // 2 :, :]
384
+ ones_tensor = torch.ones((config.data.batch_size, 1)).float().to(device=device)
385
+ vision_embeds, audio_embeds = syncnet(syncnet_input, mel)
386
+ sync_loss = cosine_loss(vision_embeds.float(), audio_embeds.float(), ones_tensor).mean()
387
+ sync_loss_list.append(gather_loss(sync_loss, device))
388
+ else:
389
+ sync_loss = 0
390
+
391
+ loss = (
392
+ recon_loss * config.run.recon_loss_weight
393
+ + sync_loss * config.run.sync_loss_weight
394
+ + lpips_loss * config.run.perceptual_loss_weight
395
+ + trepa_loss * config.run.trepa_loss_weight
396
+ )
397
+
398
+ train_step_list.append(global_step)
399
+ if config.run.recon_loss_weight != 0:
400
+ recon_loss_list.append(gather_loss(recon_loss, device))
401
+
402
+ optimizer.zero_grad()
403
+
404
+ # Backpropagate
405
+ if config.run.mixed_precision_training:
406
+ scaler.scale(loss).backward()
407
+ """ >>> gradient clipping >>> """
408
+ scaler.unscale_(optimizer)
409
+ torch.nn.utils.clip_grad_norm_(unet.parameters(), config.optimizer.max_grad_norm)
410
+ """ <<< gradient clipping <<< """
411
+ scaler.step(optimizer)
412
+ scaler.update()
413
+ else:
414
+ loss.backward()
415
+ """ >>> gradient clipping >>> """
416
+ torch.nn.utils.clip_grad_norm_(unet.parameters(), config.optimizer.max_grad_norm)
417
+ """ <<< gradient clipping <<< """
418
+ optimizer.step()
419
+
420
+ # Check the grad of attn blocks for debugging
421
+ # print(unet.module.up_blocks[3].attentions[2].transformer_blocks[0].audio_cross_attn.attn.to_q.weight.grad)
422
+
423
+ lr_scheduler.step()
424
+ progress_bar.update(1)
425
+ global_step += 1
426
+
427
+ ### <<<< Training <<<< ###
428
+
429
+ # Save checkpoint and conduct validation
430
+ if is_main_process and (global_step % config.ckpt.save_ckpt_steps == 0):
431
+ if config.run.recon_loss_weight != 0:
432
+ plot_loss_chart(
433
+ os.path.join(output_dir, f"loss_charts/recon_loss_chart-{global_step}.png"),
434
+ ("Reconstruction loss", train_step_list, recon_loss_list),
435
+ )
436
+ if config.model.add_audio_layer:
437
+ if sync_loss_list != []:
438
+ plot_loss_chart(
439
+ os.path.join(output_dir, f"loss_charts/sync_loss_chart-{global_step}.png"),
440
+ ("Sync loss", train_step_list, sync_loss_list),
441
+ )
442
+ model_save_path = os.path.join(output_dir, f"checkpoints/checkpoint-{global_step}.pt")
443
+ state_dict = {
444
+ "global_step": global_step,
445
+ "state_dict": unet.module.state_dict(), # to unwrap DDP
446
+ }
447
+ try:
448
+ torch.save(state_dict, model_save_path)
449
+ logger.info(f"Saved checkpoint to {model_save_path}")
450
+ except Exception as e:
451
+ logger.error(f"Error saving model: {e}")
452
+
453
+ # Validation
454
+ logger.info("Running validation... ")
455
+
456
+ validation_video_out_path = os.path.join(output_dir, f"val_videos/val_video_{global_step}.mp4")
457
+ validation_video_mask_path = os.path.join(output_dir, f"val_videos/val_video_mask.mp4")
458
+
459
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
460
+ pipeline(
461
+ config.data.val_video_path,
462
+ config.data.val_audio_path,
463
+ validation_video_out_path,
464
+ validation_video_mask_path,
465
+ num_frames=config.data.num_frames,
466
+ num_inference_steps=config.run.inference_steps,
467
+ guidance_scale=config.run.guidance_scale,
468
+ weight_dtype=torch.float16,
469
+ width=config.data.resolution,
470
+ height=config.data.resolution,
471
+ mask=config.data.mask,
472
+ )
473
+
474
+ logger.info(f"Saved validation video output to {validation_video_out_path}")
475
+
476
+ val_step_list.append(global_step)
477
+
478
+ if config.model.add_audio_layer:
479
+ try:
480
+ _, conf = syncnet_eval(syncnet_eval_model, syncnet_detector, validation_video_out_path, "temp")
481
+ except Exception as e:
482
+ logger.info(e)
483
+ conf = 0
484
+ sync_conf_list.append(conf)
485
+ plot_loss_chart(
486
+ os.path.join(output_dir, f"loss_charts/sync_conf_chart-{global_step}.png"),
487
+ ("Sync confidence", val_step_list, sync_conf_list),
488
+ )
489
+
490
+ logs = {"step_loss": loss.item(), "lr": lr_scheduler.get_last_lr()[0]}
491
+ progress_bar.set_postfix(**logs)
492
+
493
+ if global_step >= config.run.max_train_steps:
494
+ break
495
+
496
+ progress_bar.close()
497
+ dist.destroy_process_group()
498
+
499
+
500
+ if __name__ == "__main__":
501
+ parser = argparse.ArgumentParser()
502
+
503
+ # Config file path
504
+ parser.add_argument("--unet_config_path", type=str, default="configs/unet.yaml")
505
+
506
+ args = parser.parse_args()
507
+ config = OmegaConf.load(args.unet_config_path)
508
+ config.unet_config_path = args.unet_config_path
509
+
510
+ main(config)
tools/count_videos_time.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import matplotlib.pyplot as plt
16
+ from latentsync.utils.util import count_video_time, gather_video_paths_recursively
17
+ from tqdm import tqdm
18
+
19
+
20
+ def plot_histogram(data, fig_path):
21
+ # Create histogram
22
+ plt.hist(data, bins=30, edgecolor="black")
23
+
24
+ # Add titles and labels
25
+ plt.title("Histogram of Data Distribution")
26
+ plt.xlabel("Video time")
27
+ plt.ylabel("Frequency")
28
+
29
+ # Save plot as an image file
30
+ plt.savefig(fig_path) # Save as PNG file. You can also use 'histogram.jpg', 'histogram.pdf', etc.
31
+
32
+
33
+ def main(input_dir, fig_path):
34
+ video_paths = gather_video_paths_recursively(input_dir)
35
+ video_times = []
36
+ for video_path in tqdm(video_paths):
37
+ video_times.append(count_video_time(video_path))
38
+ plot_histogram(video_times, fig_path)
39
+
40
+
41
+ if __name__ == "__main__":
42
+ input_dir = "validation"
43
+ fig_path = "histogram.png"
44
+
45
+ main(input_dir, fig_path)
tools/download_youtube_videos.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import subprocess
17
+ from concurrent.futures import ThreadPoolExecutor
18
+ import pandas as pd
19
+ from tqdm import tqdm
20
+
21
+ """
22
+ To use this python file, first install yt-dlp by:
23
+
24
+ pip install yt-dlp==2024.5.27
25
+ """
26
+
27
+
28
+ def download_video(video_url, video_path):
29
+ get_video_channel_command = f"yt-dlp --print channel {video_url}"
30
+ result = subprocess.run(get_video_channel_command, shell=True, capture_output=True, text=True)
31
+ channel = result.stdout.strip()
32
+ if channel in unwanted_channels:
33
+ return
34
+ download_video_command = f"yt-dlp -f bestvideo+bestaudio --skip-unavailable-fragments --merge-output-format mp4 '{video_url}' --output '{video_path}' --external-downloader aria2c --external-downloader-args '-x 16 -k 1M'"
35
+ try:
36
+ subprocess.run(download_video_command, shell=True) # ignore_security_alert_wait_for_fix RCE
37
+ except KeyboardInterrupt:
38
+ print("Stopped")
39
+ exit()
40
+ except:
41
+ print(f"Error downloading video {video_url}")
42
+
43
+
44
+ def download_videos(num_workers, video_urls, video_paths):
45
+ with ThreadPoolExecutor(max_workers=num_workers) as executor:
46
+ executor.map(download_video, video_urls, video_paths)
47
+
48
+
49
+ def read_video_urls(csv_file_path: str, language_column, video_url_column):
50
+ video_urls = []
51
+ print("Reading video urls...")
52
+ df = pd.read_csv(csv_file_path, sep=",")
53
+ for row in tqdm(df.itertuples(), total=len(df)):
54
+ language = getattr(row, language_column)
55
+ video_url = getattr(row, video_url_column)
56
+ if "clip" in video_url:
57
+ continue
58
+ video_urls.append((language, video_url))
59
+ return video_urls
60
+
61
+
62
+ def extract_vid(video_url):
63
+ if "watch?v=" in video_url: # ignore_security_alert_wait_for_fix RCE
64
+ return video_url.split("watch?v=")[1][:11]
65
+ elif "shorts/" in video_url:
66
+ return video_url.split("shorts/")[1][:11]
67
+ elif "youtu.be/" in video_url:
68
+ return video_url.split("youtu.be/")[1][:11]
69
+ elif "&v=" in video_url:
70
+ return video_url.split("&v=")[1][:11]
71
+ else:
72
+ print(f"Invalid video url: {video_url}")
73
+ return None
74
+
75
+
76
+ def main(csv_file_path, language_column, video_url_column, output_dir, num_workers):
77
+ os.makedirs(output_dir, exist_ok=True)
78
+ all_video_urls = read_video_urls(csv_file_path, language_column, video_url_column)
79
+
80
+ video_paths = []
81
+ video_urls = []
82
+
83
+ print("Extracting vid...")
84
+ for language, video_url in tqdm(all_video_urls):
85
+ vid = extract_vid(video_url)
86
+ if vid is None:
87
+ continue
88
+ video_path = os.path.join(output_dir, language.lower(), f"vid_{vid}.mp4")
89
+ if os.path.isfile(video_path):
90
+ continue
91
+ os.makedirs(os.path.dirname(video_path), exist_ok=True)
92
+ video_paths.append(video_path)
93
+ video_urls.append(video_url)
94
+
95
+ if len(video_paths) == 0:
96
+ print("All videos have been downloaded")
97
+ exit()
98
+ else:
99
+ print(f"Downloading {len(video_paths)} videos")
100
+
101
+ download_videos(num_workers, video_urls, video_paths)
102
+
103
+
104
+ if __name__ == "__main__":
105
+ csv_file_path = "dcc.csv"
106
+ language_column = "video_language"
107
+ video_url_column = "video_link"
108
+ output_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/multilingual/raw"
109
+ num_workers = 50
110
+
111
+ unwanted_channels = ["TEDx Talks", "DaePyeong Mukbang", "Joeman"]
112
+
113
+ main(csv_file_path, language_column, video_url_column, output_dir, num_workers)
tools/move_files_recur.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import shutil
17
+ from tqdm import tqdm
18
+
19
+ paths = []
20
+
21
+
22
+ def gather_paths(input_dir, output_dir):
23
+ os.makedirs(output_dir, exist_ok=True)
24
+
25
+ for video in sorted(os.listdir(input_dir)):
26
+ if video.endswith(".mp4"):
27
+ video_input = os.path.join(input_dir, video)
28
+ video_output = os.path.join(output_dir, video)
29
+ if os.path.isfile(video_output):
30
+ continue
31
+ paths.append([video_input, output_dir])
32
+ elif os.path.isdir(os.path.join(input_dir, video)):
33
+ gather_paths(os.path.join(input_dir, video), os.path.join(output_dir, video))
34
+
35
+
36
+ def main(input_dir, output_dir):
37
+ print(f"Recursively gathering video paths of {input_dir} ...")
38
+ gather_paths(input_dir, output_dir)
39
+
40
+ for video_input, output_dir in tqdm(paths):
41
+ shutil.move(video_input, output_dir)
42
+
43
+
44
+ if __name__ == "__main__":
45
+ input_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/multilingual_dcc"
46
+ output_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/multilingual"
47
+
48
+ main(input_dir, output_dir)
tools/occupy_gpu.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import os
17
+ import torch.multiprocessing as mp
18
+ import time
19
+
20
+
21
+ def check_mem(cuda_device):
22
+ devices_info = (
23
+ os.popen('"/usr/bin/nvidia-smi" --query-gpu=memory.total,memory.used --format=csv,nounits,noheader')
24
+ .read()
25
+ .strip()
26
+ .split("\n")
27
+ )
28
+ total, used = devices_info[int(cuda_device)].split(",")
29
+ return total, used
30
+
31
+
32
+ def loop(cuda_device):
33
+ cuda_i = torch.device(f"cuda:{cuda_device}")
34
+ total, used = check_mem(cuda_device)
35
+ total = int(total)
36
+ used = int(used)
37
+ max_mem = int(total * 0.9)
38
+ block_mem = max_mem - used
39
+ while True:
40
+ x = torch.rand(20, 512, 512, dtype=torch.float, device=cuda_i)
41
+ y = torch.rand(20, 512, 512, dtype=torch.float, device=cuda_i)
42
+ time.sleep(0.001)
43
+ x = torch.matmul(x, y)
44
+
45
+
46
+ def main():
47
+ if torch.cuda.is_available():
48
+ num_processes = torch.cuda.device_count()
49
+ processes = list()
50
+ for i in range(num_processes):
51
+ p = mp.Process(target=loop, args=(i,))
52
+ p.start()
53
+ processes.append(p)
54
+ for p in processes:
55
+ p.join()
56
+
57
+
58
+ if __name__ == "__main__":
59
+ torch.multiprocessing.set_start_method("spawn")
60
+ main()
tools/remove_outdated_files.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import subprocess
17
+
18
+
19
+ def remove_outdated_files(input_dir, begin_date, end_date):
20
+ # Remove files from a specific time period
21
+ for subdir in os.listdir(input_dir):
22
+ if subdir >= begin_date and subdir <= end_date:
23
+ subdir_path = os.path.join(input_dir, subdir)
24
+ command = f"rm -rf {subdir_path}"
25
+ subprocess.run(command, shell=True)
26
+ print(f"Deleted: {subdir_path}")
27
+
28
+
29
+ if __name__ == "__main__":
30
+ input_dir = "/mnt/bn/video-datasets/output/syncnet"
31
+ begin_date = "train-2024_06_19-16:25:44"
32
+ end_date = "train-2024_08_03-07:39:58"
33
+
34
+ remove_outdated_files(input_dir, begin_date, end_date)
tools/write_fileslist.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from tqdm import tqdm
16
+ from latentsync.utils.util import gather_video_paths_recursively
17
+
18
+
19
+ def write_fileslist(fileslist_path):
20
+ with open(fileslist_path, "w") as _:
21
+ pass
22
+
23
+
24
+ def append_fileslist(fileslist_path, video_paths):
25
+ with open(fileslist_path, "a") as f:
26
+ for video_path in tqdm(video_paths):
27
+ f.write(f"{video_path}\n")
28
+
29
+
30
+ def process_input_dir(fileslist_path, input_dir):
31
+ print(f"Processing input dir: {input_dir}")
32
+ video_paths = gather_video_paths_recursively(input_dir)
33
+ append_fileslist(fileslist_path, video_paths)
34
+
35
+
36
+ if __name__ == "__main__":
37
+ fileslist_path = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/all_data_v6.txt"
38
+
39
+ write_fileslist(fileslist_path)
40
+ process_input_dir(fileslist_path, "/mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/high_visual_quality/train")
41
+ process_input_dir(fileslist_path, "/mnt/bn/maliva-gen-ai-v2/chunyu.li/HDTF/high_visual_quality/train")
42
+ process_input_dir(fileslist_path, "/mnt/bn/maliva-gen-ai-v2/chunyu.li/avatars/high_visual_quality/train")
43
+ process_input_dir(fileslist_path, "/mnt/bn/maliva-gen-ai-v2/chunyu.li/multilingual/high_visual_quality")
44
+ process_input_dir(fileslist_path, "/mnt/bn/maliva-gen-ai-v2/chunyu.li/celebv_text/high_visual_quality/train")
45
+ process_input_dir(fileslist_path, "/mnt/bn/maliva-gen-ai-v2/chunyu.li/youtube/high_visual_quality")