Spaces:
Sleeping
Sleeping
import os | |
import time | |
import numpy as np | |
import torch | |
from tqdm import tqdm | |
import torch.nn as nn | |
from collections import OrderedDict | |
import json | |
from models.tta.autoencoder.autoencoder import AutoencoderKL | |
from models.tta.ldm.inference_utils.vocoder import Generator | |
from models.tta.ldm.audioldm import AudioLDM | |
from transformers import T5EncoderModel, AutoTokenizer | |
from diffusers import PNDMScheduler | |
import matplotlib.pyplot as plt | |
from scipy.io.wavfile import write | |
from utils.util import load_config | |
import gradio as gr | |
class AttrDict(dict): | |
def __init__(self, *args, **kwargs): | |
super(AttrDict, self).__init__(*args, **kwargs) | |
self.__dict__ = self | |
def build_autoencoderkl(cfg, device): | |
autoencoderkl = AutoencoderKL(cfg.model.autoencoderkl) | |
autoencoder_path = cfg.model.autoencoder_path | |
checkpoint = torch.load(autoencoder_path, map_location="cpu") | |
autoencoderkl.load_state_dict(checkpoint["model"]) | |
autoencoderkl = autoencoderkl.to(device=device) | |
autoencoderkl.requires_grad_(requires_grad=False) | |
autoencoderkl.eval() | |
return autoencoderkl | |
def build_textencoder(device): | |
try: | |
tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=512) | |
text_encoder = T5EncoderModel.from_pretrained("t5-base") | |
except: | |
tokenizer = AutoTokenizer.from_pretrained("ckpts/tta/tokenizer") | |
text_encoder = T5EncoderModel.from_pretrained("ckpts/tta/text_encoder") | |
text_encoder = text_encoder.to(device=device) | |
text_encoder.requires_grad_(requires_grad=False) | |
text_encoder.eval() | |
return tokenizer, text_encoder | |
def build_vocoder(device): | |
config_file = os.path.join("ckpts/tta/hifigan_checkpoints/config.json") | |
with open(config_file) as f: | |
data = f.read() | |
json_config = json.loads(data) | |
h = AttrDict(json_config) | |
vocoder = Generator(h).to(device) | |
checkpoint_dict = torch.load( | |
"ckpts/tta/hifigan_checkpoints/g_01250000", map_location=device | |
) | |
vocoder.load_state_dict(checkpoint_dict["generator"]) | |
return vocoder | |
def build_model(cfg): | |
model = AudioLDM(cfg.model.audioldm) | |
return model | |
def get_text_embedding(text, tokenizer, text_encoder, device): | |
prompt = [text] | |
text_input = tokenizer( | |
prompt, | |
max_length=tokenizer.model_max_length, | |
truncation=True, | |
padding="do_not_pad", | |
return_tensors="pt", | |
) | |
text_embeddings = text_encoder(text_input.input_ids.to(device))[0] | |
max_length = text_input.input_ids.shape[-1] | |
uncond_input = tokenizer( | |
[""] * 1, padding="max_length", max_length=max_length, return_tensors="pt" | |
) | |
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0] | |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | |
return text_embeddings | |
def tta_inference( | |
text, | |
guidance_scale=4, | |
diffusion_steps=100, | |
): | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
os.environ["WORK_DIR"] = "./" | |
cfg = load_config("egs/tta/audioldm/exp_config.json") | |
autoencoderkl = build_autoencoderkl(cfg, device) | |
tokenizer, text_encoder = build_textencoder(device) | |
vocoder = build_vocoder(device) | |
model = build_model(cfg) | |
checkpoint_path = "ckpts/tta/audioldm_debug_latent_size_4_5_39/checkpoints/step-0570000_loss-0.2521.pt" | |
checkpoint = torch.load(checkpoint_path, map_location="cpu") | |
model.load_state_dict(checkpoint["model"]) | |
model = model.to(device) | |
text_embeddings = get_text_embedding(text, tokenizer, text_encoder, device) | |
num_steps = diffusion_steps | |
noise_scheduler = PNDMScheduler( | |
num_train_timesteps=1000, | |
beta_start=0.00085, | |
beta_end=0.012, | |
beta_schedule="scaled_linear", | |
skip_prk_steps=True, | |
set_alpha_to_one=False, | |
steps_offset=1, | |
prediction_type="epsilon", | |
) | |
noise_scheduler.set_timesteps(num_steps) | |
latents = torch.randn( | |
( | |
1, | |
cfg.model.autoencoderkl.z_channels, | |
80 // (2 ** (len(cfg.model.autoencoderkl.ch_mult) - 1)), | |
624 // (2 ** (len(cfg.model.autoencoderkl.ch_mult) - 1)), | |
) | |
).to(device) | |
model.eval() | |
for t in tqdm(noise_scheduler.timesteps): | |
t = t.to(device) | |
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. | |
latent_model_input = torch.cat([latents] * 2) | |
latent_model_input = noise_scheduler.scale_model_input( | |
latent_model_input, timestep=t | |
) | |
# print(latent_model_input.shape) | |
# predict the noise residual | |
with torch.no_grad(): | |
noise_pred = model( | |
latent_model_input, torch.cat([t.unsqueeze(0)] * 2), text_embeddings | |
) | |
# perform guidance | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
print(guidance_scale) | |
noise_pred = noise_pred_uncond + guidance_scale * ( | |
noise_pred_text - noise_pred_uncond | |
) | |
# compute the previous noisy sample x_t -> x_t-1 | |
latents = noise_scheduler.step(noise_pred, t, latents).prev_sample | |
# print(latents.shape) | |
latents_out = latents | |
with torch.no_grad(): | |
mel_out = autoencoderkl.decode(latents_out) | |
melspec = mel_out[0, 0].cpu().detach().numpy() | |
vocoder.eval() | |
vocoder.remove_weight_norm() | |
with torch.no_grad(): | |
melspec = np.expand_dims(melspec, 0) | |
melspec = torch.FloatTensor(melspec).to(device) | |
y = vocoder(melspec) | |
audio = y.squeeze() | |
audio = audio * 32768.0 | |
audio = audio.cpu().numpy().astype("int16") | |
os.makedirs("result", exist_ok=True) | |
write(os.path.join("result", text + ".wav"), 16000, audio) | |
return os.path.join("result", text + ".wav") | |
demo_inputs = [ | |
gr.Textbox( | |
value="birds singing and a man whistling", | |
label="Text prompt you want to generate", | |
type="text", | |
), | |
gr.Slider( | |
1, | |
10, | |
value=4, | |
step=1, | |
label="Classifier free guidance", | |
), | |
gr.Slider( | |
50, | |
1000, | |
value=100, | |
step=1, | |
label="Diffusion Inference Steps", | |
info="As the step number increases, the synthesis quality will be better while the inference speed will be lower", | |
), | |
] | |
demo_outputs = gr.Audio(label="") | |
demo = gr.Interface( | |
fn=tta_inference, | |
inputs=demo_inputs, | |
outputs=demo_outputs, | |
title="Amphion Text to Audio", | |
) | |
if __name__ == "__main__": | |
demo.launch() | |