crlandsc's picture
Fixed YouTube link
18d2991
raw
history blame
9.34 kB
# Imports
import gradio as gr
import os
import matplotlib.pyplot as plt
import torch
import torchaudio
from torch import nn
import pytorch_lightning as pl
from ema_pytorch import EMA
import yaml
from audio_diffusion_pytorch import DiffusionModel, UNetV0, VDiffusion, VSampler
# Load configs
def load_configs(config_path):
with open(config_path, 'r') as file:
config = yaml.safe_load(file)
pl_configs = config['model']
model_configs = config['model']['model']
return pl_configs, model_configs
# plot mel spectrogram
def plot_mel_spectrogram(sample, sr):
transform = torchaudio.transforms.MelSpectrogram(
sample_rate=sr,
n_fft=1024,
hop_length=512,
n_mels=80,
center=True,
norm="slaney",
)
spectrogram = transform(torch.mean(sample, dim=0)) # downmix and cal spectrogram
spectrogram = torchaudio.functional.amplitude_to_DB(spectrogram, 1.0, 1e-10, 80.0)
# Plot the Mel spectrogram
fig = plt.figure(figsize=(7, 4))
plt.imshow(spectrogram, aspect='auto', origin='lower')
plt.colorbar(format='%+2.0f dB')
plt.xlabel('Frame')
plt.ylabel('Mel Bin')
plt.title('Mel Spectrogram')
plt.tight_layout()
return fig
# Define PyTorch Lightning model
class Model(pl.LightningModule):
def __init__(
self,
lr: float,
lr_beta1: float,
lr_beta2: float,
lr_eps: float,
lr_weight_decay: float,
ema_beta: float,
ema_power: float,
model: nn.Module,
):
super().__init__()
self.lr = lr
self.lr_beta1 = lr_beta1
self.lr_beta2 = lr_beta2
self.lr_eps = lr_eps
self.lr_weight_decay = lr_weight_decay
self.model = model
self.model_ema = EMA(self.model, beta=ema_beta, power=ema_power)
# Instantiate model (must match model that was trained)
def load_model(model_configs, pl_configs) -> nn.Module:
# Diffusion model
model = DiffusionModel(
net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case)
in_channels=model_configs['in_channels'], # U-Net: number of input/output (audio) channels
channels=model_configs['channels'], # U-Net: channels at each layer
factors=model_configs['factors'], # U-Net: downsampling and upsampling factors at each layer
items=model_configs['items'], # U-Net: number of repeating items at each layer
attentions=model_configs['attentions'], # U-Net: attention enabled/disabled at each layer
attention_heads=model_configs['attention_heads'], # U-Net: number of attention heads per attention item
attention_features=model_configs['attention_features'], # U-Net: number of attention features per attention item
diffusion_t=VDiffusion, # The diffusion method used
sampler_t=VSampler # The diffusion sampler used
)
# pl model
model = Model(
lr=pl_configs['lr'],
lr_beta1=pl_configs['lr_beta1'],
lr_beta2=pl_configs['lr_beta2'],
lr_eps=pl_configs['lr_eps'],
lr_weight_decay=pl_configs['lr_weight_decay'],
ema_beta=pl_configs['ema_beta'],
ema_power=pl_configs['ema_power'],
model=model
)
return model
# Assign to GPU
def assign_to_gpu(model):
if torch.cuda.is_available():
model = model.to('cuda')
print(f"Device: {model.device}")
return model
# Load model checkpoint
def load_checkpoint(model, ckpt_path) -> None:
checkpoint = torch.load(ckpt_path, map_location='cpu')['state_dict']
model.load_state_dict(checkpoint) # should output "<All keys matched successfully>"
# Generate Samples
def generate_samples(model_name, num_samples, num_steps, init_audio=None, noise_level=0.7, duration=32768):
# load_checkpoint
ckpt_path = models[model_name]
load_checkpoint(model, ckpt_path)
if num_samples > 1:
duration = int(duration / 2)
# Generate samples
with torch.no_grad():
if init_audio:
# load audio sample
audio_sample = torch.tensor(init_audio[1].T, dtype=torch.float32).unsqueeze(0).to(model.device)
audio_sample = audio_sample / torch.max(torch.abs(audio_sample)) # normalize init_audio
# Trim audio
og_shape = audio_sample.shape
if duration < og_shape[2]:
audio_sample = audio_sample[:,:,:duration]
elif duration > og_shape[2]:
# Pad tensor with zeros to match sample length
audio_sample = torch.concat((audio_sample, torch.zeros(og_shape[0], og_shape[1], duration - og_shape[2]).to(model.device)), dim=2)
else:
audio_sample = torch.zeros((1, 2, int(duration)), device=model.device)
noise_level = 1.0
all_samples = torch.zeros(2, 0)
for i in range(num_samples):
noise = torch.randn_like(audio_sample, device=model.device) * noise_level # [batch_size, in_channels, length]
audio = (audio_sample * abs(1-noise_level)) + noise # add noise
# generate samples
generated_sample = model.model_ema.ema_model.sample(audio, num_steps=num_steps).squeeze(0).cpu() # Suggested num_steps 10-100
# concatenate all samples:
all_samples = torch.concat((all_samples, generated_sample), dim=1)
torch.cuda.empty_cache()
fig = plot_mel_spectrogram(all_samples, sr)
plt.title(f"{model_name} Mel Spectrogram")
return (sr, all_samples.cpu().detach().numpy().T), fig # (sample rate, audio), plot
# Define Constants & initialize model
# load model & configs
sr = 44100 # sampling rate
config_path = "saved_models/config.yaml" # config path
pl_configs, model_configs = load_configs(config_path)
model = load_model(model_configs, pl_configs)
model = assign_to_gpu(model)
models = {
"Kicks": "saved_models/kicks/kicks_v7.ckpt",
"Snares": "saved_models/snares/snares_v0.ckpt",
"Hi-hats": "saved_models/hihats/hihats_v2.ckpt",
"Percussion": "saved_models/percussion/percussion_v0.ckpt"
}
intro = """
<h1 style="font-weight: 1400; text-align: center; margin-bottom: 6px;">
Tiny Audio Diffusion
</h1>
<h3 style="font-weight: 600; text-align: center;">
Christopher Landschoot - Audio waveform diffusion built to run on consumer-grade hardware (<2GB VRAM)
</h3>
<h4 style="text-align: center; margin-bottom: 6px;">
<a href="https://github.com/crlandsc/tiny-audio-diffusion" style="text-decoration: underline;" target="_blank">GitHub Repo</a>
| <a href="https://youtu.be/m6Eh2srtTro" style="text-decoration: underline;" target="_blank">Repo Tutorial Video</a>
| <a href="https://medium.com/towards-data-science/tiny-audio-diffusion-ddc19e90af9b" style="text-decoration: underline;" target="_blank">Towards Data Science Article</a>
</h4>
"""
with gr.Blocks() as demo:
# Layout
gr.HTML(intro)
with gr.Row(equal_height=False):
with gr.Column():
# Inputs
model_name = gr.Dropdown(choices=list(models.keys()), value=list(models.keys())[3], label="Model")
num_samples = gr.Slider(1, 25, step=1, label="Number of Samples to Generate", value=3)
num_steps = gr.Slider(1, 100, step=1, label="Number of Diffusion Steps", value=15)
# Conditioning Audio Input
with gr.Accordion("Input Audio (optional)", open=False):
init_audio_description = gr.HTML('Upload an audio file to perform conditional "style transfer" diffusion.<br>Leaving input audio blank results in unconditional generation.')
init_audio = gr.Audio(label="Input Audio Sample")
init_audio_noise = gr.Slider(0, 1, step=0.01, label="Noise to add to input audio", value=0.70)#, visible=True)
# Examples
gr.Examples(
examples=[
os.path.join(os.path.dirname(__file__), "samples", "guitar.wav"),
os.path.join(os.path.dirname(__file__), "samples", "snare.wav"),
os.path.join(os.path.dirname(__file__), "samples", "kick.wav"),
os.path.join(os.path.dirname(__file__), "samples", "hihat.wav")
],
inputs=init_audio,
label="Example Audio Inputs"
)
# Buttons
with gr.Row():
with gr.Column():
clear_button = gr.Button(value="Reset All")
with gr.Column():
generate_btn = gr.Button("Generate Samples!")
with gr.Column():
# Outputs
output_audio = gr.Audio(label="Generated Audio Sample")
output_plot = gr.Plot(label="Generated Audio Spectrogram")
# Functionality
# Generate samples
generate_btn.click(fn=generate_samples, inputs=[model_name, num_samples, num_steps, init_audio, init_audio_noise], outputs=[output_audio, output_plot])
# clear_button button to reset everything
clear_button.click(fn=lambda: [3, 15, None, 0.70, None, None], outputs=[num_samples, num_steps, init_audio, init_audio_noise, output_audio, output_plot])
if __name__ == "__main__":
demo.launch()