crlandsc commited on
Commit
a4100ac
1 Parent(s): 4bb01b5

added conditional diffusion, descriptions, and examples

Browse files
Files changed (1) hide show
  1. app.py +97 -22
app.py CHANGED
@@ -1,5 +1,6 @@
1
  # Imports
2
  import gradio as gr
 
3
  import matplotlib.pyplot as plt
4
  import torch
5
  import torchaudio
@@ -109,19 +110,40 @@ def load_checkpoint(model, ckpt_path) -> None:
109
 
110
 
111
  # Generate Samples
112
- def generate_samples(model_name, num_samples, num_steps, duration=32768):
113
  # load_checkpoint
114
  ckpt_path = models[model_name]
115
  load_checkpoint(model, ckpt_path)
116
-
117
  if num_samples > 1:
118
- duration = duration / 2
119
 
 
120
  with torch.no_grad():
121
- all_samples = torch.zeros(2, 0) # initialize all samples
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  for i in range(num_samples):
123
- noise = torch.randn((1, 2, int(duration)), device=model.device) # [batch_size, in_channels, length]
124
- generated_sample = model.model_ema.ema_model.sample(noise, num_steps=num_steps).squeeze(0).cpu() # Suggested num_steps 10-100
 
 
 
125
 
126
  # concatenate all samples:
127
  all_samples = torch.concat((all_samples, generated_sample), dim=1)
@@ -133,6 +155,8 @@ def generate_samples(model_name, num_samples, num_steps, duration=32768):
133
 
134
  return (sr, all_samples.cpu().detach().numpy().T), fig # (sample rate, audio), plot
135
 
 
 
136
  # load model & configs
137
  sr = 44100 # sampling rate
138
  config_path = "saved_models/config.yaml" # config path
@@ -147,19 +171,70 @@ models = {
147
  "Percussion": "saved_models/percussion/percussion_v0.ckpt"
148
  }
149
 
150
- demo = gr.Interface(
151
- generate_samples,
152
- inputs=[
153
- gr.Dropdown(choices=list(models.keys()), value=list(models.keys())[3], label="Model"),
154
- gr.Slider(1, 25, step=1, label="Number of Samples to Generate", value=3),
155
- gr.Slider(1, 100, step=1, label="Number of Diffusion Steps", value=15)
156
- ],
157
- outputs=[
158
- gr.Audio(label="Generated Audio Sample"),
159
- gr.Plot(label="Generated Audio Spectrogram")
160
- ]
161
- )
162
-
163
- if __name__ == "__main__":
164
- demo.launch()
165
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Imports
2
  import gradio as gr
3
+ import os
4
  import matplotlib.pyplot as plt
5
  import torch
6
  import torchaudio
 
110
 
111
 
112
  # Generate Samples
113
+ def generate_samples(model_name, num_samples, num_steps, init_audio=None, noise_level=0.7, duration=32768):
114
  # load_checkpoint
115
  ckpt_path = models[model_name]
116
  load_checkpoint(model, ckpt_path)
117
+
118
  if num_samples > 1:
119
+ duration = int(duration / 2)
120
 
121
+ # Generate samples
122
  with torch.no_grad():
123
+ if init_audio:
124
+ # load audio sample
125
+ audio_sample = torch.tensor(init_audio[1].T, dtype=torch.float32).unsqueeze(0).to(model.device)
126
+ audio_sample = audio_sample / torch.max(torch.abs(audio_sample)) # normalize init_audio
127
+
128
+ # Trim audio
129
+ og_shape = audio_sample.shape
130
+ if duration < og_shape[2]:
131
+ audio_sample = audio_sample[:,:,:duration]
132
+ elif duration > og_shape[2]:
133
+ # Pad tensor with zeros to match sample length
134
+ audio_sample = torch.concat((audio_sample, torch.zeros(og_shape[0], og_shape[1], duration - og_shape[2]).to(model.device)), dim=2)
135
+
136
+ else:
137
+ audio_sample = torch.zeros((1, 2, int(duration)), device=model.device)
138
+ noise_level = 1.0
139
+
140
+ all_samples = torch.zeros(2, 0)
141
  for i in range(num_samples):
142
+ noise = torch.randn_like(audio_sample, device=model.device) * noise_level # [batch_size, in_channels, length]
143
+ audio = (audio_sample * abs(1-noise_level)) + noise # add noise
144
+
145
+ # generate samples
146
+ generated_sample = model.model_ema.ema_model.sample(audio, num_steps=num_steps).squeeze(0).cpu() # Suggested num_steps 10-100
147
 
148
  # concatenate all samples:
149
  all_samples = torch.concat((all_samples, generated_sample), dim=1)
 
155
 
156
  return (sr, all_samples.cpu().detach().numpy().T), fig # (sample rate, audio), plot
157
 
158
+
159
+ # Define Constants & initialize model
160
  # load model & configs
161
  sr = 44100 # sampling rate
162
  config_path = "saved_models/config.yaml" # config path
 
171
  "Percussion": "saved_models/percussion/percussion_v0.ckpt"
172
  }
173
 
174
+ intro = """
175
+ <h1 style="font-weight: 1400; text-align: center; margin-bottom: 6px;">
176
+ Tiny Audio Diffusion
177
+ </h1>
178
+ <h3 style="font-weight: 600; text-align: center;">
179
+ Christopher Landschoot - Audio waveform diffusion built to run on consumer-grade hardware (<2GB VRAM)
180
+ </h3>
181
+ <h4 style="text-align: center; margin-bottom: 6px;">
182
+ <a href="https://github.com/crlandsc/tiny-audio-diffusion" style="text-decoration: underline;" target="_blank">GitHub Repo</a>
183
+ | <a href="https://www.youtube.com/watch?v=m6Eh2srtTro&t=3s" style="text-decoration: underline;" target="_blank">Repo Tutorial Video</a>
184
+ | <a href="https://medium.com/towards-data-science/tiny-audio-diffusion-ddc19e90af9b" style="text-decoration: underline;" target="_blank">Towards Data Science Article</a>
185
+ </h4>
186
+ """
187
+
188
+
189
+ with gr.Blocks() as demo:
190
+ # Layout
191
+ gr.HTML(intro)
192
+
193
+ with gr.Row(equal_height=False):
194
+ with gr.Column():
195
+ # Inputs
196
+ model_name = gr.Dropdown(choices=list(models.keys()), value=list(models.keys())[3], label="Model")
197
+ num_samples = gr.Slider(1, 25, step=1, label="Number of Samples to Generate", value=3)
198
+ num_steps = gr.Slider(1, 100, step=1, label="Number of Diffusion Steps", value=15)
199
+
200
+ # Conditioning Audio Input
201
+ with gr.Accordion("Input Audio (optional)", open=False):
202
+ init_audio_description = gr.HTML('Upload an audio file to perform conditional "style transfer" diffusion.<br>Leaving input audio blank results in unconditional generation.')
203
+ init_audio = gr.Audio(label="Input Audio Sample")
204
+ init_audio_noise = gr.Slider(0, 1, step=0.01, label="Noise to add to input audio", value=0.70)#, visible=True)
205
+
206
+ # Examples
207
+ gr.Examples(
208
+ examples=[
209
+ os.path.join(os.path.dirname(__file__), "samples", "guitar.wav"),
210
+ os.path.join(os.path.dirname(__file__), "samples", "snare.wav"),
211
+ os.path.join(os.path.dirname(__file__), "samples", "kick.wav"),
212
+ os.path.join(os.path.dirname(__file__), "samples", "hihat.wav")
213
+ ],
214
+ inputs=init_audio,
215
+ label="Example Audio Inputs"
216
+ )
217
+
218
+ # Buttons
219
+ with gr.Row():
220
+ with gr.Column():
221
+ clear_button = gr.Button(value="Reset All")
222
+ with gr.Column():
223
+ generate_btn = gr.Button("Generate Samples!")
224
+
225
+ with gr.Column():
226
+ # Outputs
227
+ output_audio = gr.Audio(label="Generated Audio Sample")
228
+ output_plot = gr.Plot(label="Generated Audio Spectrogram")
229
+
230
+ # Functionality
231
+ # Generate samples
232
+ generate_btn.click(fn=generate_samples, inputs=[model_name, num_samples, num_steps, init_audio, init_audio_noise], outputs=[output_audio, output_plot])
233
+
234
+ # clear_button button to reset everything
235
+ clear_button.click(fn=lambda: [3, 15, None, 0.70, None, None], outputs=[num_samples, num_steps, init_audio, init_audio_noise, output_audio, output_plot])
236
+
237
+
238
+
239
+ if __name__ == "__main__":
240
+ demo.launch()