Spaces:
Sleeping
Sleeping
yuancwang
commited on
Commit
·
f3af09b
1
Parent(s):
1ee4329
commit
Browse files
app.py
CHANGED
@@ -16,4 +16,217 @@ from diffusers import PNDMScheduler
|
|
16 |
import matplotlib.pyplot as plt
|
17 |
from scipy.io.wavfile import write
|
18 |
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
import matplotlib.pyplot as plt
|
17 |
from scipy.io.wavfile import write
|
18 |
|
19 |
+
from utils.util import load_config
|
20 |
+
import gradio as gr
|
21 |
+
|
22 |
+
class AttrDict(dict):
|
23 |
+
def __init__(self, *args, **kwargs):
|
24 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
25 |
+
self.__dict__ = self
|
26 |
+
|
27 |
+
|
28 |
+
def build_autoencoderkl(cfg, device):
|
29 |
+
autoencoderkl = AutoencoderKL(cfg.model.autoencoderkl)
|
30 |
+
autoencoder_path = cfg.model.autoencoder_path
|
31 |
+
checkpoint = torch.load(autoencoder_path, map_location="cpu")
|
32 |
+
autoencoderkl.load_state_dict(checkpoint["model"])
|
33 |
+
autoencoderkl = autoencoderkl.to(device=device)
|
34 |
+
autoencoderkl.requires_grad_(requires_grad=False)
|
35 |
+
autoencoderkl.eval()
|
36 |
+
return autoencoderkl
|
37 |
+
|
38 |
+
def build_textencoder(device):
|
39 |
+
# tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=512)
|
40 |
+
# text_encoder = T5EncoderModel.from_pretrained("t5-base")
|
41 |
+
tokenizer = AutoTokenizer.from_pretrained("ckpts/tta/tokenizer")
|
42 |
+
text_encoder = T5EncoderModel.from_pretrained("ckpts/tta/text_encoder")
|
43 |
+
text_encoder = text_encoder.to(device=device)
|
44 |
+
text_encoder.requires_grad_(requires_grad=False)
|
45 |
+
text_encoder.eval()
|
46 |
+
return tokenizer, text_encoder
|
47 |
+
|
48 |
+
def build_vocoder(device):
|
49 |
+
config_file = os.path.join("ckpts/tta/hifigan_checkpoints/config.json")
|
50 |
+
with open(config_file) as f:
|
51 |
+
data = f.read()
|
52 |
+
json_config = json.loads(data)
|
53 |
+
h = AttrDict(json_config)
|
54 |
+
vocoder = Generator(h).to(device)
|
55 |
+
checkpoint_dict = torch.load(
|
56 |
+
"ckpts/tta/hifigan_checkpoints/g_01250000", map_location=device
|
57 |
+
)
|
58 |
+
vocoder.load_state_dict(checkpoint_dict["generator"])
|
59 |
+
return vocoder
|
60 |
+
|
61 |
+
def build_model(cfg):
|
62 |
+
model = AudioLDM(cfg.model.audioldm)
|
63 |
+
return model
|
64 |
+
|
65 |
+
def get_text_embedding(text, tokenizer, text_encoder, device):
|
66 |
+
|
67 |
+
prompt = [text]
|
68 |
+
|
69 |
+
text_input = tokenizer(
|
70 |
+
prompt,
|
71 |
+
max_length=tokenizer.model_max_length,
|
72 |
+
truncation=True,
|
73 |
+
padding="do_not_pad",
|
74 |
+
return_tensors="pt",
|
75 |
+
)
|
76 |
+
text_embeddings = text_encoder(
|
77 |
+
text_input.input_ids.to(device)
|
78 |
+
)[0]
|
79 |
+
|
80 |
+
max_length = text_input.input_ids.shape[-1]
|
81 |
+
uncond_input = tokenizer(
|
82 |
+
[""] * 1, padding="max_length", max_length=max_length, return_tensors="pt"
|
83 |
+
)
|
84 |
+
uncond_embeddings = text_encoder(
|
85 |
+
uncond_input.input_ids.to(device)
|
86 |
+
)[0]
|
87 |
+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
88 |
+
|
89 |
+
return text_embeddings
|
90 |
+
|
91 |
+
def tta_inference(
|
92 |
+
text,
|
93 |
+
guidance_scale=4,
|
94 |
+
diffusion_steps=100,
|
95 |
+
):
|
96 |
+
|
97 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
98 |
+
|
99 |
+
os.environ["WORK_DIR"] = "./"
|
100 |
+
cfg = load_config("egs/tta/audioldm/exp_config.json")
|
101 |
+
|
102 |
+
autoencoderkl = build_autoencoderkl(cfg, device)
|
103 |
+
tokenizer, text_encoder = build_textencoder(device)
|
104 |
+
vocoder = build_vocoder(device)
|
105 |
+
model = build_model(cfg)
|
106 |
+
|
107 |
+
checkpoint_path = "ckpts/tta/audioldm_debug_latent_size_4_5_39/checkpoints/step-0570000_loss-0.2521.pt"
|
108 |
+
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
109 |
+
model.load_state_dict(checkpoint["model"])
|
110 |
+
model = model.to(device)
|
111 |
+
|
112 |
+
text_embeddings = get_text_embedding(text, tokenizer, text_encoder, device)
|
113 |
+
|
114 |
+
num_steps = diffusion_steps
|
115 |
+
|
116 |
+
noise_scheduler = PNDMScheduler(
|
117 |
+
num_train_timesteps=1000,
|
118 |
+
beta_start=0.00085,
|
119 |
+
beta_end=0.012,
|
120 |
+
beta_schedule="scaled_linear",
|
121 |
+
skip_prk_steps=True,
|
122 |
+
set_alpha_to_one=False,
|
123 |
+
steps_offset=1,
|
124 |
+
prediction_type="epsilon",
|
125 |
+
)
|
126 |
+
|
127 |
+
noise_scheduler.set_timesteps(num_steps)
|
128 |
+
|
129 |
+
|
130 |
+
latents = torch.randn(
|
131 |
+
(
|
132 |
+
1,
|
133 |
+
cfg.model.autoencoderkl.z_channels,
|
134 |
+
80 // (2 ** (len(cfg.model.autoencoderkl.ch_mult) - 1)),
|
135 |
+
624 // (2 ** (len(cfg.model.autoencoderkl.ch_mult) - 1)),
|
136 |
+
)
|
137 |
+
).to(device)
|
138 |
+
|
139 |
+
model.eval()
|
140 |
+
for t in tqdm(noise_scheduler.timesteps):
|
141 |
+
t = t.to(device)
|
142 |
+
|
143 |
+
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
144 |
+
latent_model_input = torch.cat([latents] * 2)
|
145 |
+
|
146 |
+
latent_model_input = noise_scheduler.scale_model_input(
|
147 |
+
latent_model_input, timestep=t
|
148 |
+
)
|
149 |
+
# print(latent_model_input.shape)
|
150 |
+
|
151 |
+
# predict the noise residual
|
152 |
+
with torch.no_grad():
|
153 |
+
noise_pred = model(
|
154 |
+
latent_model_input, torch.cat([t.unsqueeze(0)] * 2), text_embeddings
|
155 |
+
)
|
156 |
+
|
157 |
+
# perform guidance
|
158 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
159 |
+
print(guidance_scale)
|
160 |
+
noise_pred = noise_pred_uncond + guidance_scale * (
|
161 |
+
noise_pred_text - noise_pred_uncond
|
162 |
+
)
|
163 |
+
|
164 |
+
# compute the previous noisy sample x_t -> x_t-1
|
165 |
+
latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
|
166 |
+
# print(latents.shape)
|
167 |
+
|
168 |
+
latents_out = latents
|
169 |
+
|
170 |
+
with torch.no_grad():
|
171 |
+
mel_out = autoencoderkl.decode(latents_out)
|
172 |
+
|
173 |
+
melspec = mel_out[0, 0].cpu().detach().numpy()
|
174 |
+
|
175 |
+
vocoder.eval()
|
176 |
+
vocoder.remove_weight_norm()
|
177 |
+
|
178 |
+
with torch.no_grad():
|
179 |
+
melspec = np.expand_dims(melspec, 0)
|
180 |
+
melspec = torch.FloatTensor(melspec).to(device)
|
181 |
+
|
182 |
+
y = vocoder(melspec)
|
183 |
+
audio = y.squeeze()
|
184 |
+
audio = audio * 32768.0
|
185 |
+
audio = audio.cpu().numpy().astype("int16")
|
186 |
+
|
187 |
+
os.makedirs("result", exist_ok=True)
|
188 |
+
write(os.path.join("result", text + ".wav"), 16000, audio)
|
189 |
+
|
190 |
+
return os.path.join("result", text + ".wav")
|
191 |
+
|
192 |
+
demo_inputs = [
|
193 |
+
gr.Textbox(
|
194 |
+
value="birds singing and a man whistling",
|
195 |
+
label="Text prompt you want to generate",
|
196 |
+
type="text",
|
197 |
+
),
|
198 |
+
gr.Slider(
|
199 |
+
1,
|
200 |
+
10,
|
201 |
+
value=4,
|
202 |
+
step=1,
|
203 |
+
label="Classifier free guidance",
|
204 |
+
),
|
205 |
+
gr.Slider(
|
206 |
+
50,
|
207 |
+
1000,
|
208 |
+
value=100,
|
209 |
+
step=1,
|
210 |
+
label="Diffusion Inference Steps",
|
211 |
+
info="As the step number increases, the synthesis quality will be better while the inference speed will be lower",
|
212 |
+
),
|
213 |
+
]
|
214 |
+
|
215 |
+
demo_outputs = gr.Audio(label="")
|
216 |
+
|
217 |
+
demo = gr.Interface(
|
218 |
+
fn=tta_inference,
|
219 |
+
inputs=demo_inputs,
|
220 |
+
outputs=demo_outputs,
|
221 |
+
title="Amphion Text to Audio"
|
222 |
+
)
|
223 |
+
|
224 |
+
if __name__ == "__main__":
|
225 |
+
demo.launch()
|
226 |
+
|
227 |
+
|
228 |
+
|
229 |
+
|
230 |
+
|
231 |
+
|
232 |
+
|