Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import logging | |
import os | |
import pickle | |
from concurrent.futures import ProcessPoolExecutor | |
from pathlib import Path | |
from tempfile import NamedTemporaryFile | |
import time | |
import typing as tp | |
import subprocess as sp | |
import torch | |
import gradio as gr | |
from audiocraft.data.audio_utils import f32_pcm, normalize_audio | |
from audiocraft.data.audio import audio_write | |
from audiocraft.models import JASCO | |
import os | |
from huggingface_hub import login | |
hf_token = os.environ.get('HFTOKEN') | |
if hf_token: | |
login(token=hf_token) | |
MODEL = None | |
MAX_BATCH_SIZE = 12 | |
INTERRUPTING = False | |
os.makedirs(os.path.join(os.path.dirname(__file__), "models"), exist_ok=True) | |
def generate_chord_mappings(): | |
# Define basic chord mappings | |
basic_chords = ['N', 'C', 'Dm7', 'Am', 'F', 'D', 'Ab', 'Bb'] + ['UNK'] | |
chord_to_index = {chord: idx for idx, chord in enumerate(basic_chords)} | |
# Save the mapping | |
mapping_path = os.path.join(os.path.dirname(__file__), "models", "chord_to_index_mapping.pkl") | |
os.makedirs(os.path.dirname(mapping_path), exist_ok=True) | |
with open(mapping_path, "wb") as f: | |
pickle.dump(chord_to_index, f) | |
return mapping_path | |
def create_default_chord_mapping(): | |
"""Create a basic chord-to-index mapping with common chords""" | |
basic_chords = [ | |
'N', 'C', 'Cm', 'C7', 'Cmaj7', 'Cm7', | |
'D', 'Dm', 'D7', 'Dmaj7', 'Dm7', | |
'E', 'Em', 'E7', 'Emaj7', 'Em7', | |
'F', 'Fm', 'F7', 'Fmaj7', 'Fm7', | |
'G', 'Gm', 'G7', 'Gmaj7', 'Gm7', | |
'A', 'Am', 'A7', 'Amaj7', 'Am7', | |
'B', 'Bm', 'B7', 'Bmaj7', 'Bm7', | |
'Ab', 'Abm', 'Ab7', 'Abmaj7', 'Abm7', | |
'Bb', 'Bbm', 'Bb7', 'Bbmaj7', 'Bbm7', | |
'UNK' | |
] | |
return {chord: idx for idx, chord in enumerate(basic_chords)} | |
def initialize_chord_mapping(): | |
"""Initialize chord mapping file if it doesn't exist""" | |
mapping_dir = os.path.join(os.path.dirname(__file__), "models") | |
os.makedirs(mapping_dir, exist_ok=True) | |
mapping_file = os.path.join(mapping_dir, "chord_to_index_mapping.pkl") | |
if not os.path.exists(mapping_file): | |
chord_to_index = create_default_chord_mapping() | |
with open(mapping_file, "wb") as f: | |
pickle.dump(chord_to_index, f) | |
return mapping_file | |
def validate_chord(chord, chord_mapping): | |
if chord not in chord_mapping: | |
return 'UNK' | |
return chord | |
mapping_file = initialize_chord_mapping() | |
os.environ['AUDIOCRAFT_CHORD_MAPPING'] = mapping_file | |
def chords_string_to_list(chords: str): | |
if chords == '': | |
return [] | |
chords = chords.replace('[', '').replace(']', '').replace(' ', '') | |
chrd_times = [x.split(',') for x in chords[1:-1].split('),(')] | |
# Load chord mapping | |
mapping_path = os.path.join(os.path.dirname(__file__), "models", "chord_to_index_mapping.pkl") | |
with open(mapping_path, 'rb') as f: | |
chord_mapping = pickle.load(f) | |
return [(validate_chord(x[0], chord_mapping), float(x[1])) for x in chrd_times] | |
# Wrap subprocess call to clean logs | |
_old_call = sp.call | |
def _call_nostderr(*args, **kwargs): | |
kwargs['stderr'] = sp.DEVNULL | |
kwargs['stdout'] = sp.DEVNULL | |
_old_call(*args, **kwargs) | |
sp.call = _call_nostderr | |
# Preallocate process pool | |
pool = ProcessPoolExecutor(4) | |
pool.__enter__() | |
def interrupt(): | |
global INTERRUPTING | |
INTERRUPTING = True | |
class FileCleaner: | |
def __init__(self, file_lifetime: float = 3600): | |
self.file_lifetime = file_lifetime | |
self.files = [] | |
def add(self, path: tp.Union[str, Path]): | |
self._cleanup() | |
self.files.append((time.time(), Path(path))) | |
def _cleanup(self): | |
now = time.time() | |
for time_added, path in list(self.files): | |
if now - time_added > self.file_lifetime: | |
if path.exists(): | |
path.unlink() | |
self.files.pop(0) | |
else: | |
break | |
file_cleaner = FileCleaner() | |
def chords_string_to_list(chords: str): | |
if chords == '': | |
return [] | |
chords = chords.replace('[', '').replace(']', '').replace(' ', '') | |
chrd_times = [x.split(',') for x in chords[1:-1].split('),(')] | |
return [(x[0], float(x[1])) for x in chrd_times] | |
# Create necessary directories | |
os.makedirs("models", exist_ok=True) | |
def load_model(version='facebook/jasco-chords-drums-400M'): | |
global MODEL | |
print("Loading model", version) | |
if MODEL is None or MODEL.name != version: | |
MODEL = None | |
# Setup model directory | |
model_dir = os.path.join(os.path.dirname(__file__), "models") | |
os.makedirs(model_dir, exist_ok=True) | |
# Generate and save chord mappings | |
chord_mapping_path = os.path.join(model_dir, "chord_to_index_mapping.pkl") | |
if not os.path.exists(chord_mapping_path): | |
chord_mapping_path = generate_chord_mappings() | |
try: | |
# Initialize JASCO with the chord mapping path | |
MODEL = JASCO.get_pretrained( | |
version, | |
device='cuda', | |
chords_mapping_path=chord_mapping_path | |
) | |
MODEL.name = version | |
except Exception as e: | |
raise gr.Error(f"Error loading model: {str(e)}") | |
if MODEL is None: | |
raise gr.Error("Failed to load model") | |
return MODEL | |
def _do_predictions(texts, chords, melody_matrix, drum_prompt, progress=False, gradio_progress=None, **gen_kwargs): | |
MODEL.set_generation_params(**gen_kwargs) | |
be = time.time() | |
chords = chords_string_to_list(chords) | |
if melody_matrix is not None: | |
melody_matrix = torch.load(melody_matrix.name, weights_only=True) | |
if len(melody_matrix.shape) != 2: | |
raise gr.Error(f"Melody matrix should be a torch tensor of shape [n_melody_bins, T]; got: {melody_matrix.shape}") | |
if melody_matrix.shape[0] > melody_matrix.shape[1]: | |
melody_matrix = melody_matrix.permute(1, 0) | |
if drum_prompt is None: | |
preprocessed_drums_wav = None | |
drums_sr = 32000 | |
else: | |
drums_sr, drums = drum_prompt[0], f32_pcm(torch.from_numpy(drum_prompt[1])).t() | |
if drums.dim() == 1: | |
drums = drums[None] | |
drums = normalize_audio(drums, strategy="loudness", loudness_headroom_db=16, sample_rate=drums_sr) | |
preprocessed_drums_wav = drums | |
try: | |
outputs = MODEL.generate_music(descriptions=texts, chords=chords, | |
drums_wav=preprocessed_drums_wav, | |
melody_salience_matrix=melody_matrix, | |
drums_sample_rate=drums_sr, progress=progress) | |
except RuntimeError as e: | |
raise gr.Error("Error while generating " + e.args[0]) | |
outputs = outputs.detach().cpu().float() | |
out_wavs = [] | |
for output in outputs: | |
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file: | |
audio_write( | |
file.name, output, MODEL.sample_rate, strategy="loudness", | |
loudness_headroom_db=16, loudness_compressor=True, add_suffix=False) | |
out_wavs.append(file.name) | |
file_cleaner.add(file.name) | |
return out_wavs | |
def predict_full(model, text, chords_sym, melody_file, | |
drums_file, drums_mic, drum_input_src, | |
cfg_coef_all, cfg_coef_txt, | |
ode_rtol, ode_atol, | |
ode_solver, ode_steps, | |
progress=gr.Progress()): | |
global INTERRUPTING | |
INTERRUPTING = False | |
progress(0, desc="Loading model...") | |
load_model(model) | |
max_generated = 0 | |
def _progress(generated, to_generate): | |
nonlocal max_generated | |
max_generated = max(generated, max_generated) | |
progress((min(max_generated, to_generate), to_generate)) | |
if INTERRUPTING: | |
raise gr.Error("Interrupted.") | |
MODEL.set_custom_progress_callback(_progress) | |
drums = drums_mic if drum_input_src == "mic" else drums_file | |
wavs = _do_predictions( | |
texts=[text] * 2, | |
chords=chords_sym, | |
drum_prompt=drums, | |
melody_matrix=melody_file, | |
progress=True, | |
gradio_progress=progress, | |
cfg_coef_all=cfg_coef_all, | |
cfg_coef_txt=cfg_coef_txt, | |
ode_rtol=ode_rtol, | |
ode_atol=ode_atol, | |
euler=ode_solver == 'euler', | |
euler_steps=ode_steps) | |
return wavs | |
with gr.Blocks() as demo: | |
gr.Markdown(""" | |
# JASCO - Text-to-Music Generation with Temporal Control | |
Generate 10-second music clips using text descriptions and temporal controls (chords, drums, melody). | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
submit = gr.Button("Generate") | |
interrupt_btn = gr.Button("Interrupt") | |
with gr.Column(): | |
audio_output_0 = gr.Audio(label="Generated Audio 1", type='filepath') | |
audio_output_1 = gr.Audio(label="Generated Audio 2", type='filepath') | |
with gr.Row(): | |
with gr.Column(): | |
text = gr.Text(label="Input Text", | |
value="Strings, woodwind, orchestral, symphony.", | |
interactive=True) | |
with gr.Column(): | |
model = gr.Radio([ | |
'facebook/jasco-chords-drums-400M', | |
'facebook/jasco-chords-drums-1B', | |
'facebook/jasco-chords-drums-melody-400M', | |
'facebook/jasco-chords-drums-melody-1B' | |
], label="Model", value='facebook/jasco-chords-drums-melody-400M') | |
gr.Markdown("### Chords Conditions") | |
chords_sym = gr.Text( | |
label="Chord Progression", | |
value="(C, 0.0), (D, 2.0), (F, 4.0), (Ab, 6.0), (Bb, 7.0), (C, 8.0)", | |
interactive=True | |
) | |
gr.Markdown("### Drums Conditions") | |
with gr.Row(): | |
drum_input_src = gr.Radio(["file", "mic"], value="file", label="Drums Input Source") | |
drums_file = gr.Audio(sources=["upload"], type="numpy", label="Drums File") | |
drums_mic = gr.Audio(sources=["microphone"], type="numpy", label="Drums Mic") | |
gr.Markdown("### Melody Conditions") | |
melody_file = gr.File(label="Melody File") | |
with gr.Row(): | |
cfg_coef_all = gr.Number(label="CFG ALL", value=1.25, step=0.25) | |
cfg_coef_txt = gr.Number(label="CFG TEXT", value=2.5, step=0.25) | |
ode_tol = gr.Number(label="ODE Tolerance", value=1e-4, step=1e-5) | |
ode_solver = gr.Radio(['euler', 'dopri5'], label="ODE Solver", value='euler') | |
ode_steps = gr.Number(label="Euler Steps", value=10, step=1) | |
submit.click( | |
fn=predict_full, | |
inputs=[ | |
model, text, chords_sym, melody_file, | |
drums_file, drums_mic, drum_input_src, | |
cfg_coef_all, cfg_coef_txt, | |
ode_tol, ode_tol, ode_solver, ode_steps | |
], | |
outputs=[audio_output_0, audio_output_1] | |
) | |
interrupt_btn.click(fn=interrupt, queue=False) | |
gr.Examples( | |
examples=[ | |
[ | |
"80s pop with groovy synth bass and electric piano", | |
"(N, 0.0), (C, 0.32), (Dm7, 3.456), (Am, 4.608), (F, 8.32), (C, 9.216)", | |
None, | |
None, | |
], | |
[ | |
"Strings, woodwind, orchestral, symphony.", | |
"(C, 0.0), (D, 2.0), (F, 4.0), (Ab, 6.0), (Bb, 7.0), (C, 8.0)", | |
None, | |
None, | |
], | |
], | |
inputs=[text, chords_sym, melody_file, drums_file], | |
outputs=[audio_output_0, audio_output_1] | |
) | |
demo.queue().launch(ssr_mode=False) |