import os # os.environ['TORCH_LOGS'] = '+dynamic' # os.environ['TORCH_LOGS'] = '+export' # os.environ['TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED']="u0 >= 0" # os.environ['TORCHDYNAMO_EXTENDED_DEBUG_CPP']="1" # os.environ['TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL']="u0" from kokoro import phonemize, tokenize, length_to_mask import torch.nn.functional as F from models_scripting import build_model import torch from typing import Dict device = "cpu" #'cuda' if torch.cuda.is_available() else 'cpu' model = build_model('kokoro-v0_19.pth', device) voicepack = torch.load('voices/af.pt', weights_only=True).to(device) speed = 1. text = "How could I know? It's an unanswerable question. Like asking an unborn child if they'll lead a good life. They haven't even been born." ps = phonemize(text, "a") tokens = tokenize(ps) tokens = torch.LongTensor([[0, *tokens, 0]]).to(device) class StyleTTS2(torch.nn.Module): def __init__(self, model, voicepack): super().__init__() # self.model = model self.bert = model.bert self.bert_encoder = model.bert_encoder self.predictor = model.predictor self.decoder = model.decoder self.text_encoder = model.text_encoder self.voicepack = voicepack def forward(self, tokens : torch.Tensor): speed = 1. # tokens = torch.nn.functional.pad(tokens, (0, 510 - tokens.shape[-1])) device = tokens.device input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device) text_mask = length_to_mask(input_lengths).to(device) bert_dur = self.bert(tokens) d_en = self.bert_encoder(bert_dur).transpose(-1, -2) ref_s = self.voicepack[tokens.shape[1]] s = ref_s[:, 128:] d = self.predictor.text_encoder.inference(d_en, s) x, _ = self.predictor.lstm(d) duration = self.predictor.duration_proj(x) duration = torch.sigmoid(duration).sum(axis=-1) / speed pred_dur = torch.round(duration).clamp(min=1).long() c_start = F.pad(pred_dur,(1,0), "constant").cumsum(dim=1)[0,0:-1] c_end = c_start + pred_dur[0,:] # torch._check(pred_dur.sum().item()>0, lambda: print(f"Got {pred_dur.sum().item()}")) indices = torch.arange(0, pred_dur.sum().item()).long().to(device) pred_aln_trg_list=[] for cs, ce in zip(c_start, c_end): row = torch.where((indices>=cs) & (indices