import spaces import logging import os from concurrent.futures import ProcessPoolExecutor from pathlib import Path from tempfile import NamedTemporaryFile import time import typing as tp import subprocess as sp import torch import gradio as gr from audiocraft.data.audio_utils import f32_pcm, normalize_audio from audiocraft.data.audio import audio_write from audiocraft.models import JASCO import os from huggingface_hub import login MODEL = None MAX_BATCH_SIZE = 12 INTERRUPTING = False hf_token = os.environ.get('HFTOKEN') if hf_token: login(token=hf_token) # Wrap subprocess call to clean logs _old_call = sp.call def _call_nostderr(*args, **kwargs): kwargs['stderr'] = sp.DEVNULL kwargs['stdout'] = sp.DEVNULL _old_call(*args, **kwargs) sp.call = _call_nostderr # Preallocate process pool pool = ProcessPoolExecutor(4) pool.__enter__() def interrupt(): global INTERRUPTING INTERRUPTING = True class FileCleaner: def __init__(self, file_lifetime: float = 3600): self.file_lifetime = file_lifetime self.files = [] def add(self, path: tp.Union[str, Path]): self._cleanup() self.files.append((time.time(), Path(path))) def _cleanup(self): now = time.time() for time_added, path in list(self.files): if now - time_added > self.file_lifetime: if path.exists(): path.unlink() self.files.pop(0) else: break file_cleaner = FileCleaner() def chords_string_to_list(chords: str): if chords == '': return [] chords = chords.replace('[', '').replace(']', '').replace(' ', '') chrd_times = [x.split(',') for x in chords[1:-1].split('),(')] return [(x[0], float(x[1])) for x in chrd_times] def load_model(version='facebook/jasco-chords-drums-400M'): global MODEL print("Loading model", version) if MODEL is None or MODEL.name != version: MODEL = None MODEL = JASCO.get_pretrained(version, token=hf_token) @spaces.GPU def _do_predictions(texts, chords, melody_matrix, drum_prompt, progress=False, gradio_progress=None, **gen_kwargs): MODEL.set_generation_params(**gen_kwargs) be = time.time() chords = chords_string_to_list(chords) if melody_matrix is not None: melody_matrix = torch.load(melody_matrix.name, weights_only=True) if len(melody_matrix.shape) != 2: raise gr.Error(f"Melody matrix should be a torch tensor of shape [n_melody_bins, T]; got: {melody_matrix.shape}") if melody_matrix.shape[0] > melody_matrix.shape[1]: melody_matrix = melody_matrix.permute(1, 0) if drum_prompt is None: preprocessed_drums_wav = None drums_sr = 32000 else: drums_sr, drums = drum_prompt[0], f32_pcm(torch.from_numpy(drum_prompt[1])).t() if drums.dim() == 1: drums = drums[None] drums = normalize_audio(drums, strategy="loudness", loudness_headroom_db=16, sample_rate=drums_sr) preprocessed_drums_wav = drums try: outputs = MODEL.generate_music(descriptions=texts, chords=chords, drums_wav=preprocessed_drums_wav, melody_salience_matrix=melody_matrix, drums_sample_rate=drums_sr, progress=progress) except RuntimeError as e: raise gr.Error("Error while generating " + e.args[0]) outputs = outputs.detach().cpu().float() out_wavs = [] for output in outputs: with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file: audio_write( file.name, output, MODEL.sample_rate, strategy="loudness", loudness_headroom_db=16, loudness_compressor=True, add_suffix=False) out_wavs.append(file.name) file_cleaner.add(file.name) return out_wavs @spaces.GPU def predict_full(model, text, chords_sym, melody_file, drums_file, drums_mic, drum_input_src, cfg_coef_all, cfg_coef_txt, ode_rtol, ode_atol, ode_solver, ode_steps, progress=gr.Progress()): global INTERRUPTING INTERRUPTING = False progress(0, desc="Loading model...") load_model(model) max_generated = 0 def _progress(generated, to_generate): nonlocal max_generated max_generated = max(generated, max_generated) progress((min(max_generated, to_generate), to_generate)) if INTERRUPTING: raise gr.Error("Interrupted.") MODEL.set_custom_progress_callback(_progress) drums = drums_mic if drum_input_src == "mic" else drums_file wavs = _do_predictions( texts=[text] * 2, chords=chords_sym, drum_prompt=drums, melody_matrix=melody_file, progress=True, gradio_progress=progress, cfg_coef_all=cfg_coef_all, cfg_coef_txt=cfg_coef_txt, ode_rtol=ode_rtol, ode_atol=ode_atol, euler=ode_solver == 'euler', euler_steps=ode_steps) return wavs with gr.Blocks() as demo: gr.Markdown(""" # JASCO - Text-to-Music Generation with Temporal Control Generate 10-second music clips using text descriptions and temporal controls (chords, drums, melody). """) with gr.Row(): with gr.Column(): submit = gr.Button("Generate") interrupt_btn = gr.Button("Interrupt") with gr.Column(): audio_output_0 = gr.Audio(label="Generated Audio 1", type='filepath') audio_output_1 = gr.Audio(label="Generated Audio 2", type='filepath') with gr.Row(): with gr.Column(): text = gr.Text(label="Input Text", value="Strings, woodwind, orchestral, symphony.", interactive=True) with gr.Column(): model = gr.Radio([ 'facebook/jasco-chords-drums-400M', 'facebook/jasco-chords-drums-1B', 'facebook/jasco-chords-drums-melody-400M', 'facebook/jasco-chords-drums-melody-1B' ], label="Model", value='facebook/jasco-chords-drums-melody-400M') gr.Markdown("### Chords Conditions") chords_sym = gr.Text( label="Chord Progression", value="(C, 0.0), (D, 2.0), (F, 4.0), (Ab, 6.0), (Bb, 7.0), (C, 8.0)", interactive=True ) gr.Markdown("### Drums Conditions") with gr.Row(): drum_input_src = gr.Radio(["file", "mic"], value="file", label="Drums Input Source") drums_file = gr.Audio(sources=["upload"], type="numpy", label="Drums File") drums_mic = gr.Audio(sources=["microphone"], type="numpy", label="Drums Mic") gr.Markdown("### Melody Conditions") melody_file = gr.File(label="Melody File") with gr.Row(): cfg_coef_all = gr.Number(label="CFG ALL", value=1.25, step=0.25) cfg_coef_txt = gr.Number(label="CFG TEXT", value=2.5, step=0.25) ode_tol = gr.Number(label="ODE Tolerance", value=1e-4, step=1e-5) ode_solver = gr.Radio(['euler', 'dopri5'], label="ODE Solver", value='euler') ode_steps = gr.Number(label="Euler Steps", value=10, step=1) submit.click( fn=predict_full, inputs=[ model, text, chords_sym, melody_file, drums_file, drums_mic, drum_input_src, cfg_coef_all, cfg_coef_txt, ode_tol, ode_tol, ode_solver, ode_steps ], outputs=[audio_output_0, audio_output_1] ) interrupt_btn.click(fn=interrupt, queue=False) gr.Examples( examples=[ [ "80s pop with groovy synth bass and electric piano", "(N, 0.0), (C, 0.32), (Dm7, 3.456), (Am, 4.608), (F, 8.32), (C, 9.216)", None, None, ], [ "Strings, woodwind, orchestral, symphony.", "(C, 0.0), (D, 2.0), (F, 4.0), (Ab, 6.0), (Bb, 7.0), (C, 8.0)", None, None, ], ], inputs=[text, chords_sym, melody_file, drums_file], outputs=[audio_output_0, audio_output_1] ) demo.queue().launch(ssr_mode=False)