File size: 5,000 Bytes
59f625c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
from typing import Dict, List, Any
import logger
import spaces
import gradio as gr
import json
import torch
import wavio
from tqdm import tqdm
from huggingface_hub import snapshot_download
from models import AudioDiffusion, DDPMScheduler
from audioldm.audio.stft import TacotronSTFT
from audioldm.variational_autoencoder import AutoencoderKL
from pydub import AudioSegment
from gradio import Markdown
import torch
#from diffusers.models.autoencoder_kl import AutoencoderKL
from diffusers.models.unet_2d_condition import UNet2DConditionModel
from diffusers import DiffusionPipeline,AudioPipelineOutput
from transformers import CLIPTextModel, T5EncoderModel, AutoModel, T5Tokenizer, T5TokenizerFast
from typing import Union
from diffusers.utils.torch_utils import randn_tensor
from tqdm import tqdm
class Tango:
def __init__(self, name="declare-lab/tango2", device=device_selection):
path = snapshot_download(repo_id=name)
vae_config = json.load(open("{}/vae_config.json".format(path)))
stft_config = json.load(open("{}/stft_config.json".format(path)))
main_config = json.load(open("{}/main_config.json".format(path)))
self.vae = AutoencoderKL(**vae_config).to(device)
self.stft = TacotronSTFT(**stft_config).to(device)
self.model = AudioDiffusion(**main_config).to(device)
vae_weights = torch.load("{}/pytorch_model_vae.bin".format(path), map_location=device)
stft_weights = torch.load("{}/pytorch_model_stft.bin".format(path), map_location=device)
main_weights = torch.load("{}/pytorch_model_main.bin".format(path), map_location=device)
self.vae.load_state_dict(vae_weights)
self.stft.load_state_dict(stft_weights)
self.model.load_state_dict(main_weights)
print ("Successfully loaded checkpoint from:", name)
self.vae.eval()
self.stft.eval()
self.model.eval()
self.scheduler = DDPMScheduler.from_pretrained(main_config["scheduler_name"], subfolder="scheduler")
def chunks(self, lst, n):
""" Yield successive n-sized chunks from a list. """
for i in range(0, len(lst), n):
yield lst[i:i + n]
def generate(self, prompt, steps=100, guidance=3, samples=1, disable_progress=True):
""" Genrate audio for a single prompt string. """
with torch.no_grad():
latents = self.model.inference([prompt], self.scheduler, steps, guidance, samples, disable_progress=disable_progress)
mel = self.vae.decode_first_stage(latents)
wave = self.vae.decode_to_waveform(mel)
return wave[0]
def generate_for_batch(self, prompts, steps=200, guidance=3, samples=1, batch_size=8, disable_progress=True):
""" Genrate audio for a list of prompt strings. """
outputs = []
for k in tqdm(range(0, len(prompts), batch_size)):
batch = prompts[k: k+batch_size]
with torch.no_grad():
latents = self.model.inference(batch, self.scheduler, steps, guidance, samples, disable_progress=disable_progress)
mel = self.vae.decode_first_stage(latents)
wave = self.vae.decode_to_waveform(mel)
outputs += [item for item in wave]
if samples == 1:
return outputs
else:
return list(self.chunks(outputs, samples))
# Initialize TANGO
class EndpointHandler():
def __init__(self, path=""):
# Preload all the elements you are going to need at inference.
# pseudo:
self.model= tango(device='cuda')
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
# pseudo
# self.model(input)
inputs = data.pop("inputs", data)
logger.info(f"Received incoming request with {data=}")
if "inputs" in data and isinstance(data["inputs"], str):
prompt = data.pop("inputs")
elif "prompt" in data and isinstance(data["prompt"], str):
prompt = data.pop("prompt")
else:
raise ValueError(
"Provided input body must contain either the key `inputs` or `prompt` with the"
" prompt to use for the image generation, and it needs to be a non-empty string."
)
parameters = data.pop("parameters", {})
num_inference_steps = parameters.get("num_inference_steps", 30)
width = parameters.get("width", 1024)
height = parameters.get("height", 768)
guidance_scale = parameters.get("guidance_scale", 3.5)
# seed generator (seed cannot be provided as is but via a generator)
seed = parameters.get("seed", 0)
generator = torch.manual_seed(seed)
|