|
|
|
|
|
import os, sys |
|
|
|
now_dir = os.getcwd() |
|
sys.path.append(now_dir) |
|
from typing import Dict |
|
|
|
import torch |
|
from pytorch_lightning import LightningModule |
|
from AR.models.t2s_model import Text2SemanticDecoder |
|
from AR.modules.lr_schedulers import WarmupCosineLRSchedule |
|
from AR.modules.optim import ScaledAdam |
|
|
|
class Text2SemanticLightningModule(LightningModule): |
|
def __init__(self, config, output_dir, is_train=True, flash_attn_enabled:bool = False): |
|
super().__init__() |
|
self.config = config |
|
self.top_k = 3 |
|
self.model = Text2SemanticDecoder(config=config, top_k=self.top_k,flash_attn_enabled=flash_attn_enabled) |
|
pretrained_s1 = config.get("pretrained_s1") |
|
if pretrained_s1 and is_train: |
|
|
|
print( |
|
self.load_state_dict( |
|
torch.load(pretrained_s1, map_location="cpu")["weight"] |
|
) |
|
) |
|
if is_train: |
|
self.automatic_optimization = False |
|
self.save_hyperparameters() |
|
self.eval_dir = output_dir / "eval" |
|
self.eval_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
def training_step(self, batch: Dict, batch_idx: int): |
|
opt = self.optimizers() |
|
scheduler = self.lr_schedulers() |
|
forward=self.model.forward if self.config["train"].get("if_dpo",False)==True else self.model.forward_old |
|
loss, acc = forward( |
|
batch["phoneme_ids"], |
|
batch["phoneme_ids_len"], |
|
batch["semantic_ids"], |
|
batch["semantic_ids_len"], |
|
batch["bert_feature"], |
|
) |
|
self.manual_backward(loss) |
|
if batch_idx > 0 and batch_idx % 4 == 0: |
|
opt.step() |
|
opt.zero_grad() |
|
scheduler.step() |
|
|
|
self.log( |
|
"total_loss", |
|
loss, |
|
on_step=True, |
|
on_epoch=True, |
|
prog_bar=True, |
|
sync_dist=True, |
|
) |
|
self.log( |
|
"lr", |
|
scheduler.get_last_lr()[0], |
|
on_epoch=True, |
|
prog_bar=True, |
|
sync_dist=True, |
|
) |
|
self.log( |
|
f"top_{self.top_k}_acc", |
|
acc, |
|
on_step=True, |
|
on_epoch=True, |
|
prog_bar=True, |
|
sync_dist=True, |
|
) |
|
|
|
def validation_step(self, batch: Dict, batch_idx: int): |
|
return |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def configure_optimizers(self): |
|
model_parameters = self.model.parameters() |
|
parameters_names = [] |
|
parameters_names.append( |
|
[name_param_pair[0] for name_param_pair in self.model.named_parameters()] |
|
) |
|
lm_opt = ScaledAdam( |
|
model_parameters, |
|
lr=0.01, |
|
betas=(0.9, 0.95), |
|
clipping_scale=2.0, |
|
parameters_names=parameters_names, |
|
show_dominant_parameters=False, |
|
clipping_update_period=1000, |
|
) |
|
|
|
return { |
|
"optimizer": lm_opt, |
|
"lr_scheduler": { |
|
"scheduler": WarmupCosineLRSchedule( |
|
lm_opt, |
|
init_lr=self.config["optimizer"]["lr_init"], |
|
peak_lr=self.config["optimizer"]["lr"], |
|
end_lr=self.config["optimizer"]["lr_end"], |
|
warmup_steps=self.config["optimizer"]["warmup_steps"], |
|
total_steps=self.config["optimizer"]["decay_steps"], |
|
) |
|
}, |
|
} |
|
|