Spaces:
Running
Running
added conditional diffusion, descriptions, and examples
Browse files
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
# Imports
|
2 |
import gradio as gr
|
|
|
3 |
import matplotlib.pyplot as plt
|
4 |
import torch
|
5 |
import torchaudio
|
@@ -109,19 +110,40 @@ def load_checkpoint(model, ckpt_path) -> None:
|
|
109 |
|
110 |
|
111 |
# Generate Samples
|
112 |
-
def generate_samples(model_name, num_samples, num_steps, duration=32768):
|
113 |
# load_checkpoint
|
114 |
ckpt_path = models[model_name]
|
115 |
load_checkpoint(model, ckpt_path)
|
116 |
-
|
117 |
if num_samples > 1:
|
118 |
-
duration = duration / 2
|
119 |
|
|
|
120 |
with torch.no_grad():
|
121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
for i in range(num_samples):
|
123 |
-
noise = torch.
|
124 |
-
|
|
|
|
|
|
|
125 |
|
126 |
# concatenate all samples:
|
127 |
all_samples = torch.concat((all_samples, generated_sample), dim=1)
|
@@ -133,6 +155,8 @@ def generate_samples(model_name, num_samples, num_steps, duration=32768):
|
|
133 |
|
134 |
return (sr, all_samples.cpu().detach().numpy().T), fig # (sample rate, audio), plot
|
135 |
|
|
|
|
|
136 |
# load model & configs
|
137 |
sr = 44100 # sampling rate
|
138 |
config_path = "saved_models/config.yaml" # config path
|
@@ -147,19 +171,70 @@ models = {
|
|
147 |
"Percussion": "saved_models/percussion/percussion_v0.ckpt"
|
148 |
}
|
149 |
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# Imports
|
2 |
import gradio as gr
|
3 |
+
import os
|
4 |
import matplotlib.pyplot as plt
|
5 |
import torch
|
6 |
import torchaudio
|
|
|
110 |
|
111 |
|
112 |
# Generate Samples
|
113 |
+
def generate_samples(model_name, num_samples, num_steps, init_audio=None, noise_level=0.7, duration=32768):
|
114 |
# load_checkpoint
|
115 |
ckpt_path = models[model_name]
|
116 |
load_checkpoint(model, ckpt_path)
|
117 |
+
|
118 |
if num_samples > 1:
|
119 |
+
duration = int(duration / 2)
|
120 |
|
121 |
+
# Generate samples
|
122 |
with torch.no_grad():
|
123 |
+
if init_audio:
|
124 |
+
# load audio sample
|
125 |
+
audio_sample = torch.tensor(init_audio[1].T, dtype=torch.float32).unsqueeze(0).to(model.device)
|
126 |
+
audio_sample = audio_sample / torch.max(torch.abs(audio_sample)) # normalize init_audio
|
127 |
+
|
128 |
+
# Trim audio
|
129 |
+
og_shape = audio_sample.shape
|
130 |
+
if duration < og_shape[2]:
|
131 |
+
audio_sample = audio_sample[:,:,:duration]
|
132 |
+
elif duration > og_shape[2]:
|
133 |
+
# Pad tensor with zeros to match sample length
|
134 |
+
audio_sample = torch.concat((audio_sample, torch.zeros(og_shape[0], og_shape[1], duration - og_shape[2]).to(model.device)), dim=2)
|
135 |
+
|
136 |
+
else:
|
137 |
+
audio_sample = torch.zeros((1, 2, int(duration)), device=model.device)
|
138 |
+
noise_level = 1.0
|
139 |
+
|
140 |
+
all_samples = torch.zeros(2, 0)
|
141 |
for i in range(num_samples):
|
142 |
+
noise = torch.randn_like(audio_sample, device=model.device) * noise_level # [batch_size, in_channels, length]
|
143 |
+
audio = (audio_sample * abs(1-noise_level)) + noise # add noise
|
144 |
+
|
145 |
+
# generate samples
|
146 |
+
generated_sample = model.model_ema.ema_model.sample(audio, num_steps=num_steps).squeeze(0).cpu() # Suggested num_steps 10-100
|
147 |
|
148 |
# concatenate all samples:
|
149 |
all_samples = torch.concat((all_samples, generated_sample), dim=1)
|
|
|
155 |
|
156 |
return (sr, all_samples.cpu().detach().numpy().T), fig # (sample rate, audio), plot
|
157 |
|
158 |
+
|
159 |
+
# Define Constants & initialize model
|
160 |
# load model & configs
|
161 |
sr = 44100 # sampling rate
|
162 |
config_path = "saved_models/config.yaml" # config path
|
|
|
171 |
"Percussion": "saved_models/percussion/percussion_v0.ckpt"
|
172 |
}
|
173 |
|
174 |
+
intro = """
|
175 |
+
<h1 style="font-weight: 1400; text-align: center; margin-bottom: 6px;">
|
176 |
+
Tiny Audio Diffusion
|
177 |
+
</h1>
|
178 |
+
<h3 style="font-weight: 600; text-align: center;">
|
179 |
+
Christopher Landschoot - Audio waveform diffusion built to run on consumer-grade hardware (<2GB VRAM)
|
180 |
+
</h3>
|
181 |
+
<h4 style="text-align: center; margin-bottom: 6px;">
|
182 |
+
<a href="https://github.com/crlandsc/tiny-audio-diffusion" style="text-decoration: underline;" target="_blank">GitHub Repo</a>
|
183 |
+
| <a href="https://www.youtube.com/watch?v=m6Eh2srtTro&t=3s" style="text-decoration: underline;" target="_blank">Repo Tutorial Video</a>
|
184 |
+
| <a href="https://medium.com/towards-data-science/tiny-audio-diffusion-ddc19e90af9b" style="text-decoration: underline;" target="_blank">Towards Data Science Article</a>
|
185 |
+
</h4>
|
186 |
+
"""
|
187 |
+
|
188 |
+
|
189 |
+
with gr.Blocks() as demo:
|
190 |
+
# Layout
|
191 |
+
gr.HTML(intro)
|
192 |
+
|
193 |
+
with gr.Row(equal_height=False):
|
194 |
+
with gr.Column():
|
195 |
+
# Inputs
|
196 |
+
model_name = gr.Dropdown(choices=list(models.keys()), value=list(models.keys())[3], label="Model")
|
197 |
+
num_samples = gr.Slider(1, 25, step=1, label="Number of Samples to Generate", value=3)
|
198 |
+
num_steps = gr.Slider(1, 100, step=1, label="Number of Diffusion Steps", value=15)
|
199 |
+
|
200 |
+
# Conditioning Audio Input
|
201 |
+
with gr.Accordion("Input Audio (optional)", open=False):
|
202 |
+
init_audio_description = gr.HTML('Upload an audio file to perform conditional "style transfer" diffusion.<br>Leaving input audio blank results in unconditional generation.')
|
203 |
+
init_audio = gr.Audio(label="Input Audio Sample")
|
204 |
+
init_audio_noise = gr.Slider(0, 1, step=0.01, label="Noise to add to input audio", value=0.70)#, visible=True)
|
205 |
+
|
206 |
+
# Examples
|
207 |
+
gr.Examples(
|
208 |
+
examples=[
|
209 |
+
os.path.join(os.path.dirname(__file__), "samples", "guitar.wav"),
|
210 |
+
os.path.join(os.path.dirname(__file__), "samples", "snare.wav"),
|
211 |
+
os.path.join(os.path.dirname(__file__), "samples", "kick.wav"),
|
212 |
+
os.path.join(os.path.dirname(__file__), "samples", "hihat.wav")
|
213 |
+
],
|
214 |
+
inputs=init_audio,
|
215 |
+
label="Example Audio Inputs"
|
216 |
+
)
|
217 |
+
|
218 |
+
# Buttons
|
219 |
+
with gr.Row():
|
220 |
+
with gr.Column():
|
221 |
+
clear_button = gr.Button(value="Reset All")
|
222 |
+
with gr.Column():
|
223 |
+
generate_btn = gr.Button("Generate Samples!")
|
224 |
+
|
225 |
+
with gr.Column():
|
226 |
+
# Outputs
|
227 |
+
output_audio = gr.Audio(label="Generated Audio Sample")
|
228 |
+
output_plot = gr.Plot(label="Generated Audio Spectrogram")
|
229 |
+
|
230 |
+
# Functionality
|
231 |
+
# Generate samples
|
232 |
+
generate_btn.click(fn=generate_samples, inputs=[model_name, num_samples, num_steps, init_audio, init_audio_noise], outputs=[output_audio, output_plot])
|
233 |
+
|
234 |
+
# clear_button button to reset everything
|
235 |
+
clear_button.click(fn=lambda: [3, 15, None, 0.70, None, None], outputs=[num_samples, num_steps, init_audio, init_audio_noise, output_audio, output_plot])
|
236 |
+
|
237 |
+
|
238 |
+
|
239 |
+
if __name__ == "__main__":
|
240 |
+
demo.launch()
|