Spaces:
Running
Running
File size: 5,546 Bytes
d3378e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
# Imports
import gradio as gr
import matplotlib.pyplot as plt
import torch
import torchaudio
from torch import nn
import pytorch_lightning as pl
from ema_pytorch import EMA
import yaml
from audio_diffusion_pytorch import DiffusionModel, UNetV0, VDiffusion, VSampler
# Load configs
def load_configs(config_path):
with open(config_path, 'r') as file:
config = yaml.safe_load(file)
pl_configs = config['model']
model_configs = config['model']['model']
return pl_configs, model_configs
# plot mel spectrogram
def plot_mel_spectrogram(sample, sr):
transform = torchaudio.transforms.MelSpectrogram(
sample_rate=sr,
n_fft=1024,
hop_length=512,
n_mels=80,
center=True,
norm="slaney",
)
spectrogram = transform(torch.mean(sample, dim=0)) # downmix and cal spectrogram
spectrogram = torchaudio.functional.amplitude_to_DB(spectrogram, 1.0, 1e-10, 80.0)
# Plot the Mel spectrogram
fig = plt.figure(figsize=(7, 4))
plt.imshow(spectrogram, aspect='auto', origin='lower')
plt.colorbar(format='%+2.0f dB')
plt.xlabel('Frame')
plt.ylabel('Mel Bin')
plt.title('Mel Spectrogram')
plt.tight_layout()
return fig
# Define PyTorch Lightning model
class Model(pl.LightningModule):
def __init__(
self,
lr: float,
lr_beta1: float,
lr_beta2: float,
lr_eps: float,
lr_weight_decay: float,
ema_beta: float,
ema_power: float,
model: nn.Module,
):
super().__init__()
self.lr = lr
self.lr_beta1 = lr_beta1
self.lr_beta2 = lr_beta2
self.lr_eps = lr_eps
self.lr_weight_decay = lr_weight_decay
self.model = model
self.model_ema = EMA(self.model, beta=ema_beta, power=ema_power)
# Instantiate model (must match model that was trained)
def load_model(model_configs, pl_configs) -> nn.Module:
# Diffusion model
model = DiffusionModel(
net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case)
in_channels=model_configs['in_channels'], # U-Net: number of input/output (audio) channels
channels=model_configs['channels'], # U-Net: channels at each layer
factors=model_configs['factors'], # U-Net: downsampling and upsampling factors at each layer
items=model_configs['items'], # U-Net: number of repeating items at each layer
attentions=model_configs['attentions'], # U-Net: attention enabled/disabled at each layer
attention_heads=model_configs['attention_heads'], # U-Net: number of attention heads per attention item
attention_features=model_configs['attention_features'], # U-Net: number of attention features per attention item
diffusion_t=VDiffusion, # The diffusion method used
sampler_t=VSampler # The diffusion sampler used
)
# pl model
model = Model(
lr=pl_configs['lr'],
lr_beta1=pl_configs['lr_beta1'],
lr_beta2=pl_configs['lr_beta2'],
lr_eps=pl_configs['lr_eps'],
lr_weight_decay=pl_configs['lr_weight_decay'],
ema_beta=pl_configs['ema_beta'],
ema_power=pl_configs['ema_power'],
model=model
)
return model
# Assign to GPU
def assign_to_gpu(model):
if torch.cuda.is_available():
model = model.to('cuda')
print(f"Device: {model.device}")
return model
# Load model checkpoint
def load_checkpoint(model, ckpt_path) -> None:
checkpoint = torch.load(ckpt_path, map_location='cpu')['state_dict']
model.load_state_dict(checkpoint) # should output "<All keys matched successfully>"
# Generate Samples
def generate_samples(model_name, num_samples, num_steps, duration=32768):
# load_checkpoint
ckpt_path = models[model_name]
load_checkpoint(model, ckpt_path)
with torch.no_grad():
all_samples = torch.zeros(2, 0) # initialize all samples
for i in range(num_samples):
noise = torch.randn((1, 2, int(duration)), device=model.device) # [batch_size, in_channels, length]
generated_sample = model.model_ema.ema_model.sample(noise, num_steps=num_steps).squeeze(0).cpu() # Suggested num_steps 10-100
# concatenate all samples:
all_samples = torch.concat((all_samples, generated_sample), dim=1)
torch.cuda.empty_cache()
fig = plot_mel_spectrogram(all_samples, sr)
plt.title(f"{model_name} Mel Spectrogram")
return (sr, all_samples.cpu().detach().numpy().T), fig # (sample rate, audio), plot
# load model & configs
sr = 44100 # sampling rate
config_path = "saved_models/config.yaml" # config path
pl_configs, model_configs = load_configs(config_path)
model = load_model(model_configs, pl_configs)
model = assign_to_gpu(model)
models = {
"Kicks": "saved_models/kicks/kicks_v7.ckpt",
"Snares": "saved_models/snares/snares_v0.ckpt",
"Hi-hats": "saved_models/hihats/hihats_v2.ckpt",
"Percussion": "saved_models/percussion/percussion_v0.ckpt"
}
demo = gr.Interface(
generate_samples,
inputs=[
gr.Dropdown(choices=list(models.keys()), value=list(models.keys())[0], label="Model"),
gr.Slider(1, 25, step=1, label="Number of Samples to Generate", value=1),
gr.Slider(1, 100, step=1, label="Number of Diffusion Steps", value=10)
],
outputs=[
gr.Audio(label="Generated Audio Sample"),
gr.Plot(label="Generated Audio Spectrogram")
]
)
if __name__ == "__main__":
demo.launch()
|