LatentSync / scripts /train_syncnet.py
Francke's picture
4
5d63776
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tqdm.auto import tqdm
import os, argparse, datetime, math
import logging
from omegaconf import OmegaConf
import shutil
from latentsync.data.syncnet_dataset import SyncNetDataset
from latentsync.models.syncnet import SyncNet
from latentsync.models.syncnet_wav2lip import SyncNetWav2Lip
from latentsync.utils.util import gather_loss, plot_loss_chart
from accelerate.utils import set_seed
import torch
from diffusers import AutoencoderKL
from diffusers.utils.logging import get_logger
from einops import rearrange
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from latentsync.utils.util import init_dist, cosine_loss
logger = get_logger(__name__)
def main(config):
# Initialize distributed training
local_rank = init_dist()
global_rank = dist.get_rank()
num_processes = dist.get_world_size()
is_main_process = global_rank == 0
seed = config.run.seed + global_rank
set_seed(seed)
# Logging folder
folder_name = "train" + datetime.datetime.now().strftime(f"-%Y_%m_%d-%H:%M:%S")
output_dir = os.path.join(config.data.train_output_dir, folder_name)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
# Handle the output folder creation
if is_main_process:
os.makedirs(output_dir, exist_ok=True)
os.makedirs(f"{output_dir}/checkpoints", exist_ok=True)
os.makedirs(f"{output_dir}/loss_charts", exist_ok=True)
shutil.copy(config.config_path, output_dir)
device = torch.device(local_rank)
if config.data.latent_space:
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
vae.requires_grad_(False)
vae.to(device)
else:
vae = None
# Dataset and Dataloader setup
train_dataset = SyncNetDataset(config.data.train_data_dir, config.data.train_fileslist, config)
val_dataset = SyncNetDataset(config.data.val_data_dir, config.data.val_fileslist, config)
train_distributed_sampler = DistributedSampler(
train_dataset,
num_replicas=num_processes,
rank=global_rank,
shuffle=True,
seed=config.run.seed,
)
# DataLoaders creation:
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=config.data.batch_size,
shuffle=False,
sampler=train_distributed_sampler,
num_workers=config.data.num_workers,
pin_memory=False,
drop_last=True,
worker_init_fn=train_dataset.worker_init_fn,
)
num_samples_limit = 640
val_batch_size = min(
num_samples_limit // config.data.num_frames, config.data.batch_size
) # limit batch size to avoid CUDA OOM
val_dataloader = torch.utils.data.DataLoader(
val_dataset,
batch_size=val_batch_size,
shuffle=False,
num_workers=config.data.num_workers,
pin_memory=False,
drop_last=False,
worker_init_fn=val_dataset.worker_init_fn,
)
# Model
syncnet = SyncNet(OmegaConf.to_container(config.model)).to(device)
# syncnet = SyncNetWav2Lip().to(device)
optimizer = torch.optim.AdamW(
list(filter(lambda p: p.requires_grad, syncnet.parameters())), lr=config.optimizer.lr
)
if config.ckpt.resume_ckpt_path != "":
if is_main_process:
logger.info(f"Load checkpoint from: {config.ckpt.resume_ckpt_path}")
ckpt = torch.load(config.ckpt.resume_ckpt_path, map_location=device)
syncnet.load_state_dict(ckpt["state_dict"])
global_step = ckpt["global_step"]
train_step_list = ckpt["train_step_list"]
train_loss_list = ckpt["train_loss_list"]
val_step_list = ckpt["val_step_list"]
val_loss_list = ckpt["val_loss_list"]
else:
global_step = 0
train_step_list = []
train_loss_list = []
val_step_list = []
val_loss_list = []
# DDP wrapper
syncnet = DDP(syncnet, device_ids=[local_rank], output_device=local_rank)
num_update_steps_per_epoch = math.ceil(len(train_dataloader))
num_train_epochs = math.ceil(config.run.max_train_steps / num_update_steps_per_epoch)
# validation_steps = int(config.ckpt.save_ckpt_steps // 5)
# validation_steps = 100
if is_main_process:
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {config.data.batch_size}")
logger.info(f" Total train batch size (w. parallel & distributed) = {config.data.batch_size * num_processes}")
logger.info(f" Total optimization steps = {config.run.max_train_steps}")
first_epoch = global_step // num_update_steps_per_epoch
num_val_batches = config.data.num_val_samples // (num_processes * config.data.batch_size)
# Only show the progress bar once on each machine.
progress_bar = tqdm(
range(0, config.run.max_train_steps), initial=global_step, desc="Steps", disable=not is_main_process
)
# Support mixed-precision training
scaler = torch.cuda.amp.GradScaler() if config.run.mixed_precision_training else None
for epoch in range(first_epoch, num_train_epochs):
train_dataloader.sampler.set_epoch(epoch)
syncnet.train()
for step, batch in enumerate(train_dataloader):
### >>>> Training >>>> ###
frames = batch["frames"].to(device, dtype=torch.float16)
audio_samples = batch["audio_samples"].to(device, dtype=torch.float16)
y = batch["y"].to(device, dtype=torch.float32)
if config.data.latent_space:
max_batch_size = (
num_samples_limit // config.data.num_frames
) # due to the limited cuda memory, we split the input frames into parts
if frames.shape[0] > max_batch_size:
assert (
frames.shape[0] % max_batch_size == 0
), f"max_batch_size {max_batch_size} should be divisible by batch_size {frames.shape[0]}"
frames_part_results = []
for i in range(0, frames.shape[0], max_batch_size):
frames_part = frames[i : i + max_batch_size]
frames_part = rearrange(frames_part, "b f c h w -> (b f) c h w")
with torch.no_grad():
frames_part = vae.encode(frames_part).latent_dist.sample() * 0.18215
frames_part_results.append(frames_part)
frames = torch.cat(frames_part_results, dim=0)
else:
frames = rearrange(frames, "b f c h w -> (b f) c h w")
with torch.no_grad():
frames = vae.encode(frames).latent_dist.sample() * 0.18215
frames = rearrange(frames, "(b f) c h w -> b (f c) h w", f=config.data.num_frames)
else:
frames = rearrange(frames, "b f c h w -> b (f c) h w")
if config.data.lower_half:
height = frames.shape[2]
frames = frames[:, :, height // 2 :, :]
# audio_embeds = wav2vec_encoder(audio_samples).last_hidden_state
# Mixed-precision training
with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=config.run.mixed_precision_training):
vision_embeds, audio_embeds = syncnet(frames, audio_samples)
loss = cosine_loss(vision_embeds.float(), audio_embeds.float(), y).mean()
optimizer.zero_grad()
# Backpropagate
if config.run.mixed_precision_training:
scaler.scale(loss).backward()
""" >>> gradient clipping >>> """
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(syncnet.parameters(), config.optimizer.max_grad_norm)
""" <<< gradient clipping <<< """
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
""" >>> gradient clipping >>> """
torch.nn.utils.clip_grad_norm_(syncnet.parameters(), config.optimizer.max_grad_norm)
""" <<< gradient clipping <<< """
optimizer.step()
progress_bar.update(1)
global_step += 1
global_average_loss = gather_loss(loss, device)
train_step_list.append(global_step)
train_loss_list.append(global_average_loss)
if is_main_process and global_step % config.run.validation_steps == 0:
logger.info(f"Validation at step {global_step}")
val_loss = validation(
val_dataloader,
device,
syncnet,
cosine_loss,
config.data.latent_space,
config.data.lower_half,
vae,
num_val_batches,
)
val_step_list.append(global_step)
val_loss_list.append(val_loss)
logger.info(f"Validation loss at step {global_step} is {val_loss:0.3f}")
if is_main_process and global_step % config.ckpt.save_ckpt_steps == 0:
checkpoint_save_path = os.path.join(output_dir, f"checkpoints/checkpoint-{global_step}.pt")
torch.save(
{
"state_dict": syncnet.module.state_dict(), # to unwrap DDP
"global_step": global_step,
"train_step_list": train_step_list,
"train_loss_list": train_loss_list,
"val_step_list": val_step_list,
"val_loss_list": val_loss_list,
},
checkpoint_save_path,
)
logger.info(f"Saved checkpoint to {checkpoint_save_path}")
plot_loss_chart(
os.path.join(output_dir, f"loss_charts/loss_chart-{global_step}.png"),
("Train loss", train_step_list, train_loss_list),
("Val loss", val_step_list, val_loss_list),
)
progress_bar.set_postfix({"step_loss": global_average_loss})
if global_step >= config.run.max_train_steps:
break
progress_bar.close()
dist.destroy_process_group()
@torch.no_grad()
def validation(val_dataloader, device, syncnet, cosine_loss, latent_space, lower_half, vae, num_val_batches):
syncnet.eval()
losses = []
val_step = 0
while True:
for step, batch in enumerate(val_dataloader):
### >>>> Validation >>>> ###
frames = batch["frames"].to(device, dtype=torch.float16)
audio_samples = batch["audio_samples"].to(device, dtype=torch.float16)
y = batch["y"].to(device, dtype=torch.float32)
if latent_space:
num_frames = frames.shape[1]
frames = rearrange(frames, "b f c h w -> (b f) c h w")
frames = vae.encode(frames).latent_dist.sample() * 0.18215
frames = rearrange(frames, "(b f) c h w -> b (f c) h w", f=num_frames)
else:
frames = rearrange(frames, "b f c h w -> b (f c) h w")
if lower_half:
height = frames.shape[2]
frames = frames[:, :, height // 2 :, :]
with torch.autocast(device_type="cuda", dtype=torch.float16):
vision_embeds, audio_embeds = syncnet(frames, audio_samples)
loss = cosine_loss(vision_embeds.float(), audio_embeds.float(), y).mean()
losses.append(loss.item())
val_step += 1
if val_step > num_val_batches:
syncnet.train()
if len(losses) == 0:
raise RuntimeError("No validation data")
return sum(losses) / len(losses)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Code to train the expert lip-sync discriminator")
parser.add_argument("--config_path", type=str, default="configs/syncnet/syncnet_16_vae.yaml")
args = parser.parse_args()
# Load a configuration file
config = OmegaConf.load(args.config_path)
config.config_path = args.config_path
main(config)