yuancwang commited on
Commit
f3af09b
·
1 Parent(s): 1ee4329
Files changed (1) hide show
  1. app.py +214 -1
app.py CHANGED
@@ -16,4 +16,217 @@ from diffusers import PNDMScheduler
16
  import matplotlib.pyplot as plt
17
  from scipy.io.wavfile import write
18
 
19
- import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  import matplotlib.pyplot as plt
17
  from scipy.io.wavfile import write
18
 
19
+ from utils.util import load_config
20
+ import gradio as gr
21
+
22
+ class AttrDict(dict):
23
+ def __init__(self, *args, **kwargs):
24
+ super(AttrDict, self).__init__(*args, **kwargs)
25
+ self.__dict__ = self
26
+
27
+
28
+ def build_autoencoderkl(cfg, device):
29
+ autoencoderkl = AutoencoderKL(cfg.model.autoencoderkl)
30
+ autoencoder_path = cfg.model.autoencoder_path
31
+ checkpoint = torch.load(autoencoder_path, map_location="cpu")
32
+ autoencoderkl.load_state_dict(checkpoint["model"])
33
+ autoencoderkl = autoencoderkl.to(device=device)
34
+ autoencoderkl.requires_grad_(requires_grad=False)
35
+ autoencoderkl.eval()
36
+ return autoencoderkl
37
+
38
+ def build_textencoder(device):
39
+ # tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=512)
40
+ # text_encoder = T5EncoderModel.from_pretrained("t5-base")
41
+ tokenizer = AutoTokenizer.from_pretrained("ckpts/tta/tokenizer")
42
+ text_encoder = T5EncoderModel.from_pretrained("ckpts/tta/text_encoder")
43
+ text_encoder = text_encoder.to(device=device)
44
+ text_encoder.requires_grad_(requires_grad=False)
45
+ text_encoder.eval()
46
+ return tokenizer, text_encoder
47
+
48
+ def build_vocoder(device):
49
+ config_file = os.path.join("ckpts/tta/hifigan_checkpoints/config.json")
50
+ with open(config_file) as f:
51
+ data = f.read()
52
+ json_config = json.loads(data)
53
+ h = AttrDict(json_config)
54
+ vocoder = Generator(h).to(device)
55
+ checkpoint_dict = torch.load(
56
+ "ckpts/tta/hifigan_checkpoints/g_01250000", map_location=device
57
+ )
58
+ vocoder.load_state_dict(checkpoint_dict["generator"])
59
+ return vocoder
60
+
61
+ def build_model(cfg):
62
+ model = AudioLDM(cfg.model.audioldm)
63
+ return model
64
+
65
+ def get_text_embedding(text, tokenizer, text_encoder, device):
66
+
67
+ prompt = [text]
68
+
69
+ text_input = tokenizer(
70
+ prompt,
71
+ max_length=tokenizer.model_max_length,
72
+ truncation=True,
73
+ padding="do_not_pad",
74
+ return_tensors="pt",
75
+ )
76
+ text_embeddings = text_encoder(
77
+ text_input.input_ids.to(device)
78
+ )[0]
79
+
80
+ max_length = text_input.input_ids.shape[-1]
81
+ uncond_input = tokenizer(
82
+ [""] * 1, padding="max_length", max_length=max_length, return_tensors="pt"
83
+ )
84
+ uncond_embeddings = text_encoder(
85
+ uncond_input.input_ids.to(device)
86
+ )[0]
87
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
88
+
89
+ return text_embeddings
90
+
91
+ def tta_inference(
92
+ text,
93
+ guidance_scale=4,
94
+ diffusion_steps=100,
95
+ ):
96
+
97
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
98
+
99
+ os.environ["WORK_DIR"] = "./"
100
+ cfg = load_config("egs/tta/audioldm/exp_config.json")
101
+
102
+ autoencoderkl = build_autoencoderkl(cfg, device)
103
+ tokenizer, text_encoder = build_textencoder(device)
104
+ vocoder = build_vocoder(device)
105
+ model = build_model(cfg)
106
+
107
+ checkpoint_path = "ckpts/tta/audioldm_debug_latent_size_4_5_39/checkpoints/step-0570000_loss-0.2521.pt"
108
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
109
+ model.load_state_dict(checkpoint["model"])
110
+ model = model.to(device)
111
+
112
+ text_embeddings = get_text_embedding(text, tokenizer, text_encoder, device)
113
+
114
+ num_steps = diffusion_steps
115
+
116
+ noise_scheduler = PNDMScheduler(
117
+ num_train_timesteps=1000,
118
+ beta_start=0.00085,
119
+ beta_end=0.012,
120
+ beta_schedule="scaled_linear",
121
+ skip_prk_steps=True,
122
+ set_alpha_to_one=False,
123
+ steps_offset=1,
124
+ prediction_type="epsilon",
125
+ )
126
+
127
+ noise_scheduler.set_timesteps(num_steps)
128
+
129
+
130
+ latents = torch.randn(
131
+ (
132
+ 1,
133
+ cfg.model.autoencoderkl.z_channels,
134
+ 80 // (2 ** (len(cfg.model.autoencoderkl.ch_mult) - 1)),
135
+ 624 // (2 ** (len(cfg.model.autoencoderkl.ch_mult) - 1)),
136
+ )
137
+ ).to(device)
138
+
139
+ model.eval()
140
+ for t in tqdm(noise_scheduler.timesteps):
141
+ t = t.to(device)
142
+
143
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
144
+ latent_model_input = torch.cat([latents] * 2)
145
+
146
+ latent_model_input = noise_scheduler.scale_model_input(
147
+ latent_model_input, timestep=t
148
+ )
149
+ # print(latent_model_input.shape)
150
+
151
+ # predict the noise residual
152
+ with torch.no_grad():
153
+ noise_pred = model(
154
+ latent_model_input, torch.cat([t.unsqueeze(0)] * 2), text_embeddings
155
+ )
156
+
157
+ # perform guidance
158
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
159
+ print(guidance_scale)
160
+ noise_pred = noise_pred_uncond + guidance_scale * (
161
+ noise_pred_text - noise_pred_uncond
162
+ )
163
+
164
+ # compute the previous noisy sample x_t -> x_t-1
165
+ latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
166
+ # print(latents.shape)
167
+
168
+ latents_out = latents
169
+
170
+ with torch.no_grad():
171
+ mel_out = autoencoderkl.decode(latents_out)
172
+
173
+ melspec = mel_out[0, 0].cpu().detach().numpy()
174
+
175
+ vocoder.eval()
176
+ vocoder.remove_weight_norm()
177
+
178
+ with torch.no_grad():
179
+ melspec = np.expand_dims(melspec, 0)
180
+ melspec = torch.FloatTensor(melspec).to(device)
181
+
182
+ y = vocoder(melspec)
183
+ audio = y.squeeze()
184
+ audio = audio * 32768.0
185
+ audio = audio.cpu().numpy().astype("int16")
186
+
187
+ os.makedirs("result", exist_ok=True)
188
+ write(os.path.join("result", text + ".wav"), 16000, audio)
189
+
190
+ return os.path.join("result", text + ".wav")
191
+
192
+ demo_inputs = [
193
+ gr.Textbox(
194
+ value="birds singing and a man whistling",
195
+ label="Text prompt you want to generate",
196
+ type="text",
197
+ ),
198
+ gr.Slider(
199
+ 1,
200
+ 10,
201
+ value=4,
202
+ step=1,
203
+ label="Classifier free guidance",
204
+ ),
205
+ gr.Slider(
206
+ 50,
207
+ 1000,
208
+ value=100,
209
+ step=1,
210
+ label="Diffusion Inference Steps",
211
+ info="As the step number increases, the synthesis quality will be better while the inference speed will be lower",
212
+ ),
213
+ ]
214
+
215
+ demo_outputs = gr.Audio(label="")
216
+
217
+ demo = gr.Interface(
218
+ fn=tta_inference,
219
+ inputs=demo_inputs,
220
+ outputs=demo_outputs,
221
+ title="Amphion Text to Audio"
222
+ )
223
+
224
+ if __name__ == "__main__":
225
+ demo.launch()
226
+
227
+
228
+
229
+
230
+
231
+
232
+