Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,735 Bytes
113884e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import torch
from torchvision import transforms
import torch.nn.functional as F
import random
from utils.lora import extract_lora_child_module
from utils.func_utils import tensor_to_vae_latent, sample_noise
def DebiasedHybridLoss(
train_loss_temporal,
accelerator,
optimizers,
lr_schedulers,
unet,
vae,
text_encoder,
noise_scheduler,
batch,
step,
config,
random_hflip_img=False,
spatial_lora_num=1
):
mask_spatial_lora = random.uniform(0, 1) < 0.2
cache_latents = config.train.cache_latents
if not cache_latents:
latents = tensor_to_vae_latent(batch["pixel_values"], vae)
else:
latents = batch["latents"]
# Sample noise that we'll add to the latents
# use_offset_noise = use_offset_noise and not rescale_schedule
noise = sample_noise(latents, 0.1, False)
bsz = latents.shape[0]
# Sample a random timestep for each video
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# *Potentially* Fixes gradient checkpointing training.
# See: https://github.com/prigoyal/pytorch_memonger/blob/master/tutorial/Checkpointing_for_PyTorch_models.ipynb
# if kwargs.get('eval_train', False):
# unet.eval()
# text_encoder.eval()
# Encode text embeddings
token_ids = batch['prompt_ids']
encoder_hidden_states = text_encoder(token_ids)[0]
detached_encoder_state = encoder_hidden_states.clone().detach()
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
encoder_hidden_states = detached_encoder_state
# optimization
if mask_spatial_lora:
loras = extract_lora_child_module(unet, target_replace_module=["Transformer2DModel"])
for lora_i in loras:
lora_i.scale = 0.
loss_spatial = None
else:
loras = extract_lora_child_module(unet, target_replace_module=["Transformer2DModel"])
if spatial_lora_num == 1:
for lora_i in loras:
lora_i.scale = 1.
else:
for lora_i in loras:
lora_i.scale = 0.
for lora_idx in range(0, len(loras), spatial_lora_num):
loras[lora_idx + step].scale = 1.
loras = extract_lora_child_module(unet, target_replace_module=["TransformerTemporalModel"])
if len(loras) > 0:
for lora_i in loras:
lora_i.scale = 0.
ran_idx = torch.randint(0, noisy_latents.shape[2], (1,)).item()
if random.uniform(0, 1) < random_hflip_img:
pixel_values_spatial = transforms.functional.hflip(
batch["pixel_values"][:, ran_idx, :, :, :]).unsqueeze(1)
latents_spatial = tensor_to_vae_latent(pixel_values_spatial, vae)
noise_spatial = sample_noise(latents_spatial, 0.1, False)
noisy_latents_input = noise_scheduler.add_noise(latents_spatial, noise_spatial, timesteps)
target_spatial = noise_spatial
model_pred_spatial = unet(noisy_latents_input, timesteps,
encoder_hidden_states=encoder_hidden_states).sample
loss_spatial = F.mse_loss(model_pred_spatial[:, :, 0, :, :].float(),
target_spatial[:, :, 0, :, :].float(), reduction="mean")
else:
noisy_latents_input = noisy_latents[:, :, ran_idx, :, :]
target_spatial = target[:, :, ran_idx, :, :]
model_pred_spatial = unet(noisy_latents_input.unsqueeze(2), timesteps,
encoder_hidden_states=encoder_hidden_states).sample
loss_spatial = F.mse_loss(model_pred_spatial[:, :, 0, :, :].float(),
target_spatial.float(), reduction="mean")
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
loss_temporal = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
beta = 1
alpha = (beta ** 2 + 1) ** 0.5
ran_idx = torch.randint(0, model_pred.shape[2], (1,)).item()
model_pred_decent = alpha * model_pred - beta * model_pred[:, :, ran_idx, :, :].unsqueeze(2)
target_decent = alpha * target - beta * target[:, :, ran_idx, :, :].unsqueeze(2)
loss_ad_temporal = F.mse_loss(model_pred_decent.float(), target_decent.float(), reduction="mean")
loss_temporal = loss_temporal + loss_ad_temporal
avg_loss_temporal = accelerator.gather(loss_temporal.repeat(config.train.train_batch_size)).mean()
train_loss_temporal += avg_loss_temporal.item() / config.train.gradient_accumulation_steps
if not mask_spatial_lora:
accelerator.backward(loss_spatial, retain_graph=True)
if spatial_lora_num == 1:
optimizers[1].step()
else:
optimizers[step+1].step()
accelerator.backward(loss_temporal)
optimizers[0].step()
if spatial_lora_num == 1:
lr_schedulers[1].step()
else:
lr_schedulers[1 + step].step()
lr_schedulers[0].step()
return loss_temporal, train_loss_temporal |