Spaces:
Runtime error
Runtime error
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
import math | |
import argparse | |
import shutil | |
import datetime | |
import logging | |
from omegaconf import OmegaConf | |
from tqdm.auto import tqdm | |
from einops import rearrange | |
import torch | |
import torch.nn.functional as F | |
import torch.distributed as dist | |
from torch.utils.data.distributed import DistributedSampler | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
import diffusers | |
from diffusers import AutoencoderKL, DDIMScheduler | |
from diffusers.utils.logging import get_logger | |
from diffusers.optimization import get_scheduler | |
from diffusers.utils.import_utils import is_xformers_available | |
from accelerate.utils import set_seed | |
from latentsync.data.unet_dataset import UNetDataset | |
from latentsync.models.unet import UNet3DConditionModel | |
from latentsync.models.syncnet import SyncNet | |
from latentsync.pipelines.lipsync_pipeline import LipsyncPipeline | |
from latentsync.utils.util import ( | |
init_dist, | |
cosine_loss, | |
reversed_forward, | |
) | |
from latentsync.utils.util import plot_loss_chart, gather_loss | |
from latentsync.whisper.audio2feature import Audio2Feature | |
from latentsync.trepa import TREPALoss | |
from eval.syncnet import SyncNetEval | |
from eval.syncnet_detect import SyncNetDetector | |
from eval.eval_sync_conf import syncnet_eval | |
import lpips | |
logger = get_logger(__name__) | |
def main(config): | |
# Initialize distributed training | |
local_rank = init_dist() | |
global_rank = dist.get_rank() | |
num_processes = dist.get_world_size() | |
is_main_process = global_rank == 0 | |
seed = config.run.seed + global_rank | |
set_seed(seed) | |
# Logging folder | |
folder_name = "train" + datetime.datetime.now().strftime(f"-%Y_%m_%d-%H:%M:%S") | |
output_dir = os.path.join(config.data.train_output_dir, folder_name) | |
# Make one log on every process with the configuration for debugging. | |
logging.basicConfig( | |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
datefmt="%m/%d/%Y %H:%M:%S", | |
level=logging.INFO, | |
) | |
# Handle the output folder creation | |
if is_main_process: | |
diffusers.utils.logging.set_verbosity_info() | |
os.makedirs(output_dir, exist_ok=True) | |
os.makedirs(f"{output_dir}/checkpoints", exist_ok=True) | |
os.makedirs(f"{output_dir}/val_videos", exist_ok=True) | |
os.makedirs(f"{output_dir}/loss_charts", exist_ok=True) | |
shutil.copy(config.unet_config_path, output_dir) | |
shutil.copy(config.data.syncnet_config_path, output_dir) | |
device = torch.device(local_rank) | |
noise_scheduler = DDIMScheduler.from_pretrained("configs") | |
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16) | |
vae.config.scaling_factor = 0.18215 | |
vae.config.shift_factor = 0 | |
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) | |
vae.requires_grad_(False) | |
vae.to(device) | |
syncnet_eval_model = SyncNetEval(device=device) | |
syncnet_eval_model.loadParameters("checkpoints/auxiliary/syncnet_v2.model") | |
syncnet_detector = SyncNetDetector(device=device, detect_results_dir="detect_results") | |
if config.model.cross_attention_dim == 768: | |
whisper_model_path = "checkpoints/whisper/small.pt" | |
elif config.model.cross_attention_dim == 384: | |
whisper_model_path = "checkpoints/whisper/tiny.pt" | |
else: | |
raise NotImplementedError("cross_attention_dim must be 768 or 384") | |
audio_encoder = Audio2Feature( | |
model_path=whisper_model_path, | |
device=device, | |
audio_embeds_cache_dir=config.data.audio_embeds_cache_dir, | |
num_frames=config.data.num_frames, | |
) | |
unet, resume_global_step = UNet3DConditionModel.from_pretrained( | |
OmegaConf.to_container(config.model), | |
config.ckpt.resume_ckpt_path, # load checkpoint | |
device=device, | |
) | |
if config.model.add_audio_layer and config.run.use_syncnet: | |
syncnet_config = OmegaConf.load(config.data.syncnet_config_path) | |
if syncnet_config.ckpt.inference_ckpt_path == "": | |
raise ValueError("SyncNet path is not provided") | |
syncnet = SyncNet(OmegaConf.to_container(syncnet_config.model)).to(device=device, dtype=torch.float16) | |
syncnet_checkpoint = torch.load(syncnet_config.ckpt.inference_ckpt_path, map_location=device) | |
syncnet.load_state_dict(syncnet_checkpoint["state_dict"]) | |
syncnet.requires_grad_(False) | |
unet.requires_grad_(True) | |
trainable_params = list(unet.parameters()) | |
if config.optimizer.scale_lr: | |
config.optimizer.lr = config.optimizer.lr * num_processes | |
optimizer = torch.optim.AdamW(trainable_params, lr=config.optimizer.lr) | |
if is_main_process: | |
logger.info(f"trainable params number: {len(trainable_params)}") | |
logger.info(f"trainable params scale: {sum(p.numel() for p in trainable_params) / 1e6:.3f} M") | |
# Enable xformers | |
if config.run.enable_xformers_memory_efficient_attention: | |
if is_xformers_available(): | |
unet.enable_xformers_memory_efficient_attention() | |
else: | |
raise ValueError("xformers is not available. Make sure it is installed correctly") | |
# Enable gradient checkpointing | |
if config.run.enable_gradient_checkpointing: | |
unet.enable_gradient_checkpointing() | |
# Get the training dataset | |
train_dataset = UNetDataset(config.data.train_data_dir, config) | |
distributed_sampler = DistributedSampler( | |
train_dataset, | |
num_replicas=num_processes, | |
rank=global_rank, | |
shuffle=True, | |
seed=config.run.seed, | |
) | |
# DataLoaders creation: | |
train_dataloader = torch.utils.data.DataLoader( | |
train_dataset, | |
batch_size=config.data.batch_size, | |
shuffle=False, | |
sampler=distributed_sampler, | |
num_workers=config.data.num_workers, | |
pin_memory=False, | |
drop_last=True, | |
worker_init_fn=train_dataset.worker_init_fn, | |
) | |
# Get the training iteration | |
if config.run.max_train_steps == -1: | |
assert config.run.max_train_epochs != -1 | |
config.run.max_train_steps = config.run.max_train_epochs * len(train_dataloader) | |
# Scheduler | |
lr_scheduler = get_scheduler( | |
config.optimizer.lr_scheduler, | |
optimizer=optimizer, | |
num_warmup_steps=config.optimizer.lr_warmup_steps, | |
num_training_steps=config.run.max_train_steps, | |
) | |
if config.run.perceptual_loss_weight != 0 and config.run.pixel_space_supervise: | |
lpips_loss_func = lpips.LPIPS(net="vgg").to(device) | |
if config.run.trepa_loss_weight != 0 and config.run.pixel_space_supervise: | |
trepa_loss_func = TREPALoss(device=device) | |
# Validation pipeline | |
pipeline = LipsyncPipeline( | |
vae=vae, | |
audio_encoder=audio_encoder, | |
unet=unet, | |
scheduler=noise_scheduler, | |
).to(device) | |
pipeline.set_progress_bar_config(disable=True) | |
# DDP warpper | |
unet = DDP(unet, device_ids=[local_rank], output_device=local_rank) | |
# We need to recalculate our total training steps as the size of the training dataloader may have changed. | |
num_update_steps_per_epoch = math.ceil(len(train_dataloader)) | |
# Afterwards we recalculate our number of training epochs | |
num_train_epochs = math.ceil(config.run.max_train_steps / num_update_steps_per_epoch) | |
# Train! | |
total_batch_size = config.data.batch_size * num_processes | |
if is_main_process: | |
logger.info("***** Running training *****") | |
logger.info(f" Num examples = {len(train_dataset)}") | |
logger.info(f" Num Epochs = {num_train_epochs}") | |
logger.info(f" Instantaneous batch size per device = {config.data.batch_size}") | |
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") | |
logger.info(f" Total optimization steps = {config.run.max_train_steps}") | |
global_step = resume_global_step | |
first_epoch = resume_global_step // num_update_steps_per_epoch | |
# Only show the progress bar once on each machine. | |
progress_bar = tqdm( | |
range(0, config.run.max_train_steps), | |
initial=resume_global_step, | |
desc="Steps", | |
disable=not is_main_process, | |
) | |
train_step_list = [] | |
sync_loss_list = [] | |
recon_loss_list = [] | |
val_step_list = [] | |
sync_conf_list = [] | |
# Support mixed-precision training | |
scaler = torch.cuda.amp.GradScaler() if config.run.mixed_precision_training else None | |
for epoch in range(first_epoch, num_train_epochs): | |
train_dataloader.sampler.set_epoch(epoch) | |
unet.train() | |
for step, batch in enumerate(train_dataloader): | |
### >>>> Training >>>> ### | |
if config.model.add_audio_layer: | |
if batch["mel"] != []: | |
mel = batch["mel"].to(device, dtype=torch.float16) | |
audio_embeds_list = [] | |
try: | |
for idx in range(len(batch["video_path"])): | |
video_path = batch["video_path"][idx] | |
start_idx = batch["start_idx"][idx] | |
with torch.no_grad(): | |
audio_feat = audio_encoder.audio2feat(video_path) | |
audio_embeds = audio_encoder.crop_overlap_audio_window(audio_feat, start_idx) | |
audio_embeds_list.append(audio_embeds) | |
except Exception as e: | |
logger.info(f"{type(e).__name__} - {e} - {video_path}") | |
continue | |
audio_embeds = torch.stack(audio_embeds_list) # (B, 16, 50, 384) | |
audio_embeds = audio_embeds.to(device, dtype=torch.float16) | |
else: | |
audio_embeds = None | |
# Convert videos to latent space | |
gt_images = batch["gt"].to(device, dtype=torch.float16) | |
gt_masked_images = batch["masked_gt"].to(device, dtype=torch.float16) | |
mask = batch["mask"].to(device, dtype=torch.float16) | |
ref_images = batch["ref"].to(device, dtype=torch.float16) | |
gt_images = rearrange(gt_images, "b f c h w -> (b f) c h w") | |
gt_masked_images = rearrange(gt_masked_images, "b f c h w -> (b f) c h w") | |
mask = rearrange(mask, "b f c h w -> (b f) c h w") | |
ref_images = rearrange(ref_images, "b f c h w -> (b f) c h w") | |
with torch.no_grad(): | |
gt_latents = vae.encode(gt_images).latent_dist.sample() | |
gt_masked_images = vae.encode(gt_masked_images).latent_dist.sample() | |
ref_images = vae.encode(ref_images).latent_dist.sample() | |
mask = torch.nn.functional.interpolate(mask, size=config.data.resolution // vae_scale_factor) | |
gt_latents = ( | |
rearrange(gt_latents, "(b f) c h w -> b c f h w", f=config.data.num_frames) - vae.config.shift_factor | |
) * vae.config.scaling_factor | |
gt_masked_images = ( | |
rearrange(gt_masked_images, "(b f) c h w -> b c f h w", f=config.data.num_frames) | |
- vae.config.shift_factor | |
) * vae.config.scaling_factor | |
ref_images = ( | |
rearrange(ref_images, "(b f) c h w -> b c f h w", f=config.data.num_frames) - vae.config.shift_factor | |
) * vae.config.scaling_factor | |
mask = rearrange(mask, "(b f) c h w -> b c f h w", f=config.data.num_frames) | |
# Sample noise that we'll add to the latents | |
if config.run.use_mixed_noise: | |
# Refer to the paper: https://arxiv.org/abs/2305.10474 | |
noise_shared_std_dev = (config.run.mixed_noise_alpha**2 / (1 + config.run.mixed_noise_alpha**2)) ** 0.5 | |
noise_shared = torch.randn_like(gt_latents) * noise_shared_std_dev | |
noise_shared = noise_shared[:, :, 0:1].repeat(1, 1, config.data.num_frames, 1, 1) | |
noise_ind_std_dev = (1 / (1 + config.run.mixed_noise_alpha**2)) ** 0.5 | |
noise_ind = torch.randn_like(gt_latents) * noise_ind_std_dev | |
noise = noise_ind + noise_shared | |
else: | |
noise = torch.randn_like(gt_latents) | |
noise = noise[:, :, 0:1].repeat( | |
1, 1, config.data.num_frames, 1, 1 | |
) # Using the same noise for all frames, refer to the paper: https://arxiv.org/abs/2308.09716 | |
bsz = gt_latents.shape[0] | |
# Sample a random timestep for each video | |
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=gt_latents.device) | |
timesteps = timesteps.long() | |
# Add noise to the latents according to the noise magnitude at each timestep | |
# (this is the forward diffusion process) | |
noisy_tensor = noise_scheduler.add_noise(gt_latents, noise, timesteps) | |
# Get the target for loss depending on the prediction type | |
if noise_scheduler.config.prediction_type == "epsilon": | |
target = noise | |
elif noise_scheduler.config.prediction_type == "v_prediction": | |
raise NotImplementedError | |
else: | |
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | |
unet_input = torch.cat([noisy_tensor, mask, gt_masked_images, ref_images], dim=1) | |
# Predict the noise and compute loss | |
# Mixed-precision training | |
with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=config.run.mixed_precision_training): | |
pred_noise = unet(unet_input, timesteps, encoder_hidden_states=audio_embeds).sample | |
if config.run.recon_loss_weight != 0: | |
recon_loss = F.mse_loss(pred_noise.float(), target.float(), reduction="mean") | |
else: | |
recon_loss = 0 | |
pred_latents = reversed_forward(noise_scheduler, pred_noise, timesteps, noisy_tensor) | |
if config.run.pixel_space_supervise: | |
pred_images = vae.decode( | |
rearrange(pred_latents, "b c f h w -> (b f) c h w") / vae.config.scaling_factor | |
+ vae.config.shift_factor | |
).sample | |
if config.run.perceptual_loss_weight != 0 and config.run.pixel_space_supervise: | |
pred_images_perceptual = pred_images[:, :, pred_images.shape[2] // 2 :, :] | |
gt_images_perceptual = gt_images[:, :, gt_images.shape[2] // 2 :, :] | |
lpips_loss = lpips_loss_func(pred_images_perceptual.float(), gt_images_perceptual.float()).mean() | |
else: | |
lpips_loss = 0 | |
if config.run.trepa_loss_weight != 0 and config.run.pixel_space_supervise: | |
trepa_pred_images = rearrange(pred_images, "(b f) c h w -> b c f h w", f=config.data.num_frames) | |
trepa_gt_images = rearrange(gt_images, "(b f) c h w -> b c f h w", f=config.data.num_frames) | |
trepa_loss = trepa_loss_func(trepa_pred_images, trepa_gt_images) | |
else: | |
trepa_loss = 0 | |
if config.model.add_audio_layer and config.run.use_syncnet: | |
if config.run.pixel_space_supervise: | |
syncnet_input = rearrange(pred_images, "(b f) c h w -> b (f c) h w", f=config.data.num_frames) | |
else: | |
syncnet_input = rearrange(pred_latents, "b c f h w -> b (f c) h w") | |
if syncnet_config.data.lower_half: | |
height = syncnet_input.shape[2] | |
syncnet_input = syncnet_input[:, :, height // 2 :, :] | |
ones_tensor = torch.ones((config.data.batch_size, 1)).float().to(device=device) | |
vision_embeds, audio_embeds = syncnet(syncnet_input, mel) | |
sync_loss = cosine_loss(vision_embeds.float(), audio_embeds.float(), ones_tensor).mean() | |
sync_loss_list.append(gather_loss(sync_loss, device)) | |
else: | |
sync_loss = 0 | |
loss = ( | |
recon_loss * config.run.recon_loss_weight | |
+ sync_loss * config.run.sync_loss_weight | |
+ lpips_loss * config.run.perceptual_loss_weight | |
+ trepa_loss * config.run.trepa_loss_weight | |
) | |
train_step_list.append(global_step) | |
if config.run.recon_loss_weight != 0: | |
recon_loss_list.append(gather_loss(recon_loss, device)) | |
optimizer.zero_grad() | |
# Backpropagate | |
if config.run.mixed_precision_training: | |
scaler.scale(loss).backward() | |
""" >>> gradient clipping >>> """ | |
scaler.unscale_(optimizer) | |
torch.nn.utils.clip_grad_norm_(unet.parameters(), config.optimizer.max_grad_norm) | |
""" <<< gradient clipping <<< """ | |
scaler.step(optimizer) | |
scaler.update() | |
else: | |
loss.backward() | |
""" >>> gradient clipping >>> """ | |
torch.nn.utils.clip_grad_norm_(unet.parameters(), config.optimizer.max_grad_norm) | |
""" <<< gradient clipping <<< """ | |
optimizer.step() | |
# Check the grad of attn blocks for debugging | |
# print(unet.module.up_blocks[3].attentions[2].transformer_blocks[0].audio_cross_attn.attn.to_q.weight.grad) | |
lr_scheduler.step() | |
progress_bar.update(1) | |
global_step += 1 | |
### <<<< Training <<<< ### | |
# Save checkpoint and conduct validation | |
if is_main_process and (global_step % config.ckpt.save_ckpt_steps == 0): | |
if config.run.recon_loss_weight != 0: | |
plot_loss_chart( | |
os.path.join(output_dir, f"loss_charts/recon_loss_chart-{global_step}.png"), | |
("Reconstruction loss", train_step_list, recon_loss_list), | |
) | |
if config.model.add_audio_layer: | |
if sync_loss_list != []: | |
plot_loss_chart( | |
os.path.join(output_dir, f"loss_charts/sync_loss_chart-{global_step}.png"), | |
("Sync loss", train_step_list, sync_loss_list), | |
) | |
model_save_path = os.path.join(output_dir, f"checkpoints/checkpoint-{global_step}.pt") | |
state_dict = { | |
"global_step": global_step, | |
"state_dict": unet.module.state_dict(), # to unwrap DDP | |
} | |
try: | |
torch.save(state_dict, model_save_path) | |
logger.info(f"Saved checkpoint to {model_save_path}") | |
except Exception as e: | |
logger.error(f"Error saving model: {e}") | |
# Validation | |
logger.info("Running validation... ") | |
validation_video_out_path = os.path.join(output_dir, f"val_videos/val_video_{global_step}.mp4") | |
validation_video_mask_path = os.path.join(output_dir, f"val_videos/val_video_mask.mp4") | |
with torch.autocast(device_type="cuda", dtype=torch.float16): | |
pipeline( | |
config.data.val_video_path, | |
config.data.val_audio_path, | |
validation_video_out_path, | |
validation_video_mask_path, | |
num_frames=config.data.num_frames, | |
num_inference_steps=config.run.inference_steps, | |
guidance_scale=config.run.guidance_scale, | |
weight_dtype=torch.float16, | |
width=config.data.resolution, | |
height=config.data.resolution, | |
mask=config.data.mask, | |
) | |
logger.info(f"Saved validation video output to {validation_video_out_path}") | |
val_step_list.append(global_step) | |
if config.model.add_audio_layer: | |
try: | |
_, conf = syncnet_eval(syncnet_eval_model, syncnet_detector, validation_video_out_path, "temp") | |
except Exception as e: | |
logger.info(e) | |
conf = 0 | |
sync_conf_list.append(conf) | |
plot_loss_chart( | |
os.path.join(output_dir, f"loss_charts/sync_conf_chart-{global_step}.png"), | |
("Sync confidence", val_step_list, sync_conf_list), | |
) | |
logs = {"step_loss": loss.item(), "lr": lr_scheduler.get_last_lr()[0]} | |
progress_bar.set_postfix(**logs) | |
if global_step >= config.run.max_train_steps: | |
break | |
progress_bar.close() | |
dist.destroy_process_group() | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
# Config file path | |
parser.add_argument("--unet_config_path", type=str, default="configs/unet.yaml") | |
args = parser.parse_args() | |
config = OmegaConf.load(args.unet_config_path) | |
config.unet_config_path = args.unet_config_path | |
main(config) | |