import spaces import logging import os import pickle from concurrent.futures import ProcessPoolExecutor from pathlib import Path from tempfile import NamedTemporaryFile import time import typing as tp import subprocess as sp import torch import gradio as gr from audiocraft.data.audio_utils import f32_pcm, normalize_audio from audiocraft.data.audio import audio_write from audiocraft.models import JASCO import os from huggingface_hub import login hf_token = os.environ.get('HFTOKEN') if hf_token: login(token=hf_token) MODEL = None MAX_BATCH_SIZE = 12 INTERRUPTING = False os.makedirs(os.path.join(os.path.dirname(__file__), "models"), exist_ok=True) def generate_chord_mappings(): # Define basic chord mappings basic_chords = ['N', 'C', 'Dm7', 'Am', 'F', 'D', 'Ab', 'Bb'] + ['UNK'] chord_to_index = {chord: idx for idx, chord in enumerate(basic_chords)} # Save the mapping mapping_path = os.path.join(os.path.dirname(__file__), "models", "chord_to_index_mapping.pkl") os.makedirs(os.path.dirname(mapping_path), exist_ok=True) with open(mapping_path, "wb") as f: pickle.dump(chord_to_index, f) return mapping_path def create_default_chord_mapping(): """Create a basic chord-to-index mapping with common chords""" basic_chords = [ 'N', 'C', 'Cm', 'C7', 'Cmaj7', 'Cm7', 'D', 'Dm', 'D7', 'Dmaj7', 'Dm7', 'E', 'Em', 'E7', 'Emaj7', 'Em7', 'F', 'Fm', 'F7', 'Fmaj7', 'Fm7', 'G', 'Gm', 'G7', 'Gmaj7', 'Gm7', 'A', 'Am', 'A7', 'Amaj7', 'Am7', 'B', 'Bm', 'B7', 'Bmaj7', 'Bm7', 'Ab', 'Abm', 'Ab7', 'Abmaj7', 'Abm7', 'Bb', 'Bbm', 'Bb7', 'Bbmaj7', 'Bbm7', 'UNK' ] return {chord: idx for idx, chord in enumerate(basic_chords)} def initialize_chord_mapping(): """Initialize chord mapping file if it doesn't exist""" mapping_dir = os.path.join(os.path.dirname(__file__), "models") os.makedirs(mapping_dir, exist_ok=True) mapping_file = os.path.join(mapping_dir, "chord_to_index_mapping.pkl") if not os.path.exists(mapping_file): chord_to_index = create_default_chord_mapping() with open(mapping_file, "wb") as f: pickle.dump(chord_to_index, f) return mapping_file def validate_chord(chord, chord_mapping): if chord not in chord_mapping: return 'UNK' return chord mapping_file = initialize_chord_mapping() os.environ['AUDIOCRAFT_CHORD_MAPPING'] = mapping_file def chords_string_to_list(chords: str): if chords == '': return [] chords = chords.replace('[', '').replace(']', '').replace(' ', '') chrd_times = [x.split(',') for x in chords[1:-1].split('),(')] # Load chord mapping mapping_path = os.path.join(os.path.dirname(__file__), "models", "chord_to_index_mapping.pkl") with open(mapping_path, 'rb') as f: chord_mapping = pickle.load(f) return [(validate_chord(x[0], chord_mapping), float(x[1])) for x in chrd_times] # Wrap subprocess call to clean logs _old_call = sp.call def _call_nostderr(*args, **kwargs): kwargs['stderr'] = sp.DEVNULL kwargs['stdout'] = sp.DEVNULL _old_call(*args, **kwargs) sp.call = _call_nostderr # Preallocate process pool pool = ProcessPoolExecutor(4) pool.__enter__() def interrupt(): global INTERRUPTING INTERRUPTING = True class FileCleaner: def __init__(self, file_lifetime: float = 3600): self.file_lifetime = file_lifetime self.files = [] def add(self, path: tp.Union[str, Path]): self._cleanup() self.files.append((time.time(), Path(path))) def _cleanup(self): now = time.time() for time_added, path in list(self.files): if now - time_added > self.file_lifetime: if path.exists(): path.unlink() self.files.pop(0) else: break file_cleaner = FileCleaner() def chords_string_to_list(chords: str): if chords == '': return [] chords = chords.replace('[', '').replace(']', '').replace(' ', '') chrd_times = [x.split(',') for x in chords[1:-1].split('),(')] return [(x[0], float(x[1])) for x in chrd_times] # Create necessary directories os.makedirs("models", exist_ok=True) @spaces.GPU def load_model(version='facebook/jasco-chords-drums-400M'): global MODEL print("Loading model", version) if MODEL is None or MODEL.name != version: MODEL = None # Setup model directory model_dir = os.path.join(os.path.dirname(__file__), "models") os.makedirs(model_dir, exist_ok=True) # Generate and save chord mappings chord_mapping_path = os.path.join(model_dir, "chord_to_index_mapping.pkl") if not os.path.exists(chord_mapping_path): chord_mapping_path = generate_chord_mappings() try: # Initialize JASCO with the chord mapping path MODEL = JASCO.get_pretrained( version, device='cuda', chords_mapping_path=chord_mapping_path ) MODEL.name = version except Exception as e: raise gr.Error(f"Error loading model: {str(e)}") if MODEL is None: raise gr.Error("Failed to load model") return MODEL @spaces.GPU def _do_predictions(texts, chords, melody_matrix, drum_prompt, progress=False, gradio_progress=None, **gen_kwargs): MODEL.set_generation_params(**gen_kwargs) be = time.time() chords = chords_string_to_list(chords) if melody_matrix is not None: melody_matrix = torch.load(melody_matrix.name, weights_only=True) if len(melody_matrix.shape) != 2: raise gr.Error(f"Melody matrix should be a torch tensor of shape [n_melody_bins, T]; got: {melody_matrix.shape}") if melody_matrix.shape[0] > melody_matrix.shape[1]: melody_matrix = melody_matrix.permute(1, 0) if drum_prompt is None: preprocessed_drums_wav = None drums_sr = 32000 else: drums_sr, drums = drum_prompt[0], f32_pcm(torch.from_numpy(drum_prompt[1])).t() if drums.dim() == 1: drums = drums[None] drums = normalize_audio(drums, strategy="loudness", loudness_headroom_db=16, sample_rate=drums_sr) preprocessed_drums_wav = drums try: outputs = MODEL.generate_music(descriptions=texts, chords=chords, drums_wav=preprocessed_drums_wav, melody_salience_matrix=melody_matrix, drums_sample_rate=drums_sr, progress=progress) except RuntimeError as e: raise gr.Error("Error while generating " + e.args[0]) outputs = outputs.detach().cpu().float() out_wavs = [] for output in outputs: with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file: audio_write( file.name, output, MODEL.sample_rate, strategy="loudness", loudness_headroom_db=16, loudness_compressor=True, add_suffix=False) out_wavs.append(file.name) file_cleaner.add(file.name) return out_wavs @spaces.GPU def predict_full(model, text, chords_sym, melody_file, drums_file, drums_mic, drum_input_src, cfg_coef_all, cfg_coef_txt, ode_rtol, ode_atol, ode_solver, ode_steps, progress=gr.Progress()): global INTERRUPTING INTERRUPTING = False progress(0, desc="Loading model...") load_model(model) max_generated = 0 def _progress(generated, to_generate): nonlocal max_generated max_generated = max(generated, max_generated) progress((min(max_generated, to_generate), to_generate)) if INTERRUPTING: raise gr.Error("Interrupted.") MODEL.set_custom_progress_callback(_progress) drums = drums_mic if drum_input_src == "mic" else drums_file wavs = _do_predictions( texts=[text] * 2, chords=chords_sym, drum_prompt=drums, melody_matrix=melody_file, progress=True, gradio_progress=progress, cfg_coef_all=cfg_coef_all, cfg_coef_txt=cfg_coef_txt, ode_rtol=ode_rtol, ode_atol=ode_atol, euler=ode_solver == 'euler', euler_steps=ode_steps) return wavs with gr.Blocks() as demo: gr.Markdown(""" # JASCO - Text-to-Music Generation with Temporal Control Generate 10-second music clips using text descriptions and temporal controls (chords, drums, melody). """) with gr.Row(): with gr.Column(): submit = gr.Button("Generate") interrupt_btn = gr.Button("Interrupt") with gr.Column(): audio_output_0 = gr.Audio(label="Generated Audio 1", type='filepath') audio_output_1 = gr.Audio(label="Generated Audio 2", type='filepath') with gr.Row(): with gr.Column(): text = gr.Text(label="Input Text", value="Strings, woodwind, orchestral, symphony.", interactive=True) with gr.Column(): model = gr.Radio([ 'facebook/jasco-chords-drums-400M', 'facebook/jasco-chords-drums-1B', 'facebook/jasco-chords-drums-melody-400M', 'facebook/jasco-chords-drums-melody-1B' ], label="Model", value='facebook/jasco-chords-drums-melody-400M') gr.Markdown("### Chords Conditions") chords_sym = gr.Text( label="Chord Progression", value="(C, 0.0), (D, 2.0), (F, 4.0), (Ab, 6.0), (Bb, 7.0), (C, 8.0)", interactive=True ) gr.Markdown("### Drums Conditions") with gr.Row(): drum_input_src = gr.Radio(["file", "mic"], value="file", label="Drums Input Source") drums_file = gr.Audio(sources=["upload"], type="numpy", label="Drums File") drums_mic = gr.Audio(sources=["microphone"], type="numpy", label="Drums Mic") gr.Markdown("### Melody Conditions") melody_file = gr.File(label="Melody File") with gr.Row(): cfg_coef_all = gr.Number(label="CFG ALL", value=1.25, step=0.25) cfg_coef_txt = gr.Number(label="CFG TEXT", value=2.5, step=0.25) ode_tol = gr.Number(label="ODE Tolerance", value=1e-4, step=1e-5) ode_solver = gr.Radio(['euler', 'dopri5'], label="ODE Solver", value='euler') ode_steps = gr.Number(label="Euler Steps", value=10, step=1) submit.click( fn=predict_full, inputs=[ model, text, chords_sym, melody_file, drums_file, drums_mic, drum_input_src, cfg_coef_all, cfg_coef_txt, ode_tol, ode_tol, ode_solver, ode_steps ], outputs=[audio_output_0, audio_output_1] ) interrupt_btn.click(fn=interrupt, queue=False) gr.Examples( examples=[ [ "80s pop with groovy synth bass and electric piano", "(N, 0.0), (C, 0.32), (Dm7, 3.456), (Am, 4.608), (F, 8.32), (C, 9.216)", None, None, ], [ "Strings, woodwind, orchestral, symphony.", "(C, 0.0), (D, 2.0), (F, 4.0), (Ab, 6.0), (Bb, 7.0), (C, 8.0)", None, None, ], ], inputs=[text, chords_sym, melody_file, drums_file], outputs=[audio_output_0, audio_output_1] ) demo.queue().launch(ssr_mode=False)