import spaces import logging import os import pickle from concurrent.futures import ProcessPoolExecutor from pathlib import Path from tempfile import NamedTemporaryFile import time import typing as tp import subprocess as sp import torch import gradio as gr from audiocraft.data.audio_utils import f32_pcm, normalize_audio from audiocraft.data.audio import audio_write from audiocraft.models import JASCO import os from huggingface_hub import login title = """# 🙋🏻‍♂️Welcome to 🌟Tonic's 🎼Jasco🎶AudioCraft Demo""" description = """Facebook presents JASCO, a temporally controlled text-to-music generation model utilizing both symbolic and audio-based conditions. JASCO can generate high-quality music samples conditioned on global text descriptions along with fine-grained local controls. JASCO is based on the Flow Matching modeling paradigm together with a novel conditioning method, allowing for music generation controlled both locally (e.g., chords) and globally (text description).""" join_us = """ ## Join us: 🌟TeamTonic🌟 is always making cool demos! Join our active builder's 🛠️community 👻 [![Join us on Discord](https://img.shields.io/discord/1109943800132010065?label=Discord&logo=discord&style=flat-square)](https://discord.gg/qdfnvSPcqP) On 🤗Huggingface: [MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Tonic-AI](https://github.com/tonic-ai) & contribute to🌟 [MultiTonic](https://github.com/MultiTonic/thinking-dataset) 🤗Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant 🤗 """ hf_token = os.environ.get('HFTOKEN') if hf_token: login(token=hf_token) MODEL = None MAX_BATCH_SIZE = 12 INTERRUPTING = False os.makedirs(os.path.join(os.path.dirname(__file__), "models"), exist_ok=True) def generate_chord_mappings(): # Define basic chord mappings basic_chords = ['N', 'C', 'Dm7', 'Am', 'F', 'D', 'Ab', 'Bb'] + ['UNK'] chord_to_index = {chord: idx for idx, chord in enumerate(basic_chords)} # Save the mapping mapping_path = os.path.join(os.path.dirname(__file__), "models", "chord_to_index_mapping.pkl") os.makedirs(os.path.dirname(mapping_path), exist_ok=True) with open(mapping_path, "wb") as f: pickle.dump(chord_to_index, f) return mapping_path def create_default_chord_mapping(): """Create a basic chord-to-index mapping with common chords""" basic_chords = [ 'N', 'C', 'Cm', 'C7', 'Cmaj7', 'Cm7', 'D', 'Dm', 'D7', 'Dmaj7', 'Dm7', 'E', 'Em', 'E7', 'Emaj7', 'Em7', 'F', 'Fm', 'F7', 'Fmaj7', 'Fm7', 'G', 'Gm', 'G7', 'Gmaj7', 'Gm7', 'A', 'Am', 'A7', 'Amaj7', 'Am7', 'B', 'Bm', 'B7', 'Bmaj7', 'Bm7', 'Ab', 'Abm', 'Ab7', 'Abmaj7', 'Abm7', 'Bb', 'Bbm', 'Bb7', 'Bbmaj7', 'Bbm7', 'UNK' ] return {chord: idx for idx, chord in enumerate(basic_chords)} def initialize_chord_mapping(): """Initialize chord mapping file if it doesn't exist""" mapping_dir = os.path.join(os.path.dirname(__file__), "models") os.makedirs(mapping_dir, exist_ok=True) mapping_file = os.path.join(mapping_dir, "chord_to_index_mapping.pkl") if not os.path.exists(mapping_file): chord_to_index = create_default_chord_mapping() with open(mapping_file, "wb") as f: pickle.dump(chord_to_index, f) return mapping_file def validate_chord(chord, chord_mapping): if chord not in chord_mapping: return 'UNK' return chord mapping_file = initialize_chord_mapping() os.environ['AUDIOCRAFT_CHORD_MAPPING'] = mapping_file def chords_string_to_list(chords: str): if chords == '': return [] chords = chords.replace('[', '').replace(']', '').replace(' ', '') chrd_times = [x.split(',') for x in chords[1:-1].split('),(')] # Load chord mapping mapping_path = os.path.join(os.path.dirname(__file__), "models", "chord_to_index_mapping.pkl") with open(mapping_path, 'rb') as f: chord_mapping = pickle.load(f) return [(validate_chord(x[0], chord_mapping), float(x[1])) for x in chrd_times] # Wrap subprocess call to clean logs _old_call = sp.call def _call_nostderr(*args, **kwargs): kwargs['stderr'] = sp.DEVNULL kwargs['stdout'] = sp.DEVNULL _old_call(*args, **kwargs) sp.call = _call_nostderr # Preallocate process pool pool = ProcessPoolExecutor(4) pool.__enter__() def interrupt(): global INTERRUPTING INTERRUPTING = True class FileCleaner: def __init__(self, file_lifetime: float = 3600): self.file_lifetime = file_lifetime self.files = [] def add(self, path: tp.Union[str, Path]): self._cleanup() self.files.append((time.time(), Path(path))) def _cleanup(self): now = time.time() for time_added, path in list(self.files): if now - time_added > self.file_lifetime: if path.exists(): path.unlink() self.files.pop(0) else: break file_cleaner = FileCleaner() def chords_string_to_list(chords: str): if chords == '': return [] chords = chords.replace('[', '').replace(']', '').replace(' ', '') chrd_times = [x.split(',') for x in chords[1:-1].split('),(')] return [(x[0], float(x[1])) for x in chrd_times] # Create necessary directories os.makedirs("models", exist_ok=True) @spaces.GPU def load_model(version='facebook/jasco-chords-drums-400M'): global MODEL print("Loading model", version) if MODEL is None or MODEL.name != version: MODEL = None # Setup model directory model_dir = os.path.join(os.path.dirname(__file__), "models") os.makedirs(model_dir, exist_ok=True) # Generate and save chord mappings chord_mapping_path = os.path.join(model_dir, "chord_to_index_mapping.pkl") if not os.path.exists(chord_mapping_path): chord_mapping_path = generate_chord_mappings() try: # Initialize JASCO with the chord mapping path MODEL = JASCO.get_pretrained( version, device='cuda', chords_mapping_path=chord_mapping_path ) MODEL.name = version except Exception as e: raise gr.Error(f"Error loading model: {str(e)}") if MODEL is None: raise gr.Error("Failed to load model") return MODEL @spaces.GPU def _do_predictions(texts, chords, melody_matrix, drum_prompt, progress=False, gradio_progress=None, **gen_kwargs): MODEL.set_generation_params(**gen_kwargs) be = time.time() chords = chords_string_to_list(chords) if melody_matrix is not None: melody_matrix = torch.load(melody_matrix.name, weights_only=True) if len(melody_matrix.shape) != 2: raise gr.Error(f"Melody matrix should be a torch tensor of shape [n_melody_bins, T]; got: {melody_matrix.shape}") if melody_matrix.shape[0] > melody_matrix.shape[1]: melody_matrix = melody_matrix.permute(1, 0) if drum_prompt is None: preprocessed_drums_wav = None drums_sr = 32000 else: drums_sr, drums = drum_prompt[0], f32_pcm(torch.from_numpy(drum_prompt[1])).t() if drums.dim() == 1: drums = drums[None] drums = normalize_audio(drums, strategy="loudness", loudness_headroom_db=16, sample_rate=drums_sr) preprocessed_drums_wav = drums try: outputs = MODEL.generate_music(descriptions=texts, chords=chords, drums_wav=preprocessed_drums_wav, melody_salience_matrix=melody_matrix, drums_sample_rate=drums_sr, progress=progress) except RuntimeError as e: raise gr.Error("Error while generating " + e.args[0]) outputs = outputs.detach().cpu().float() out_wavs = [] for output in outputs: with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file: audio_write( file.name, output, MODEL.sample_rate, strategy="loudness", loudness_headroom_db=16, loudness_compressor=True, add_suffix=False) out_wavs.append(file.name) file_cleaner.add(file.name) return out_wavs @spaces.GPU def predict_full(model, text, chords_sym, melody_file, drums_file, drums_mic, drum_input_src, cfg_coef_all, cfg_coef_txt, ode_rtol, ode_atol, ode_solver, ode_steps, progress=gr.Progress()): global INTERRUPTING INTERRUPTING = False progress(0, desc="Loading model...") load_model(model) max_generated = 0 def _progress(generated, to_generate): nonlocal max_generated max_generated = max(generated, max_generated) progress((min(max_generated, to_generate), to_generate)) if INTERRUPTING: raise gr.Error("Interrupted.") MODEL.set_custom_progress_callback(_progress) drums = drums_mic if drum_input_src == "mic" else drums_file wavs = _do_predictions( texts=[text] * 2, chords=chords_sym, drum_prompt=drums, melody_matrix=melody_file, progress=True, gradio_progress=progress, cfg_coef_all=cfg_coef_all, cfg_coef_txt=cfg_coef_txt, ode_rtol=ode_rtol, ode_atol=ode_atol, euler=ode_solver == 'euler', euler_steps=ode_steps) return wavs with gr.Blocks() as demo: gr.Markdown(title) with gr.Row(): with gr.Column(): with gr.Group(): gr.Markdown(description) with gr.Column(): with gr.Group(): gr.Markdown(join_us) with gr.Row(): with gr.Column(): submit = gr.Button("Generate") interrupt_btn = gr.Button("Interrupt") with gr.Column(): audio_output_0 = gr.Audio(label="Generated Audio 1", type='filepath') audio_output_1 = gr.Audio(label="Generated Audio 2", type='filepath') with gr.Row(): with gr.Column(): text = gr.Text(label="Input Text", value="Strings, woodwind, orchestral, symphony.", interactive=True) with gr.Column(): model = gr.Radio([ 'facebook/jasco-chords-drums-400M', 'facebook/jasco-chords-drums-1B', 'facebook/jasco-chords-drums-melody-400M', 'facebook/jasco-chords-drums-melody-1B' ], label="Model", value='facebook/jasco-chords-drums-melody-400M') gr.Markdown("### Chords Conditions") chords_sym = gr.Text( label="Chord Progression", value="(C, 0.0), (D, 2.0), (F, 4.0), (Ab, 6.0), (Bb, 7.0), (C, 8.0)", interactive=True ) gr.Markdown("### Drums Conditions") with gr.Row(): drum_input_src = gr.Radio(["file", "mic"], value="file", label="Drums Input Source") drums_file = gr.Audio(sources=["upload"], type="numpy", label="Drums File") drums_mic = gr.Audio(sources=["microphone"], type="numpy", label="Drums Mic") gr.Markdown("### Melody Conditions") melody_file = gr.File(label="Melody File") with gr.Row(): cfg_coef_all = gr.Number(label="CFG ALL", value=1.25, step=0.25) cfg_coef_txt = gr.Number(label="CFG TEXT", value=2.5, step=0.25) ode_tol = gr.Number(label="ODE Tolerance", value=1e-4, step=1e-5) ode_solver = gr.Radio(['euler', 'dopri5'], label="ODE Solver", value='euler') ode_steps = gr.Number(label="Euler Steps", value=10, step=1) submit.click( fn=predict_full, inputs=[ model, text, chords_sym, melody_file, drums_file, drums_mic, drum_input_src, cfg_coef_all, cfg_coef_txt, ode_tol, ode_tol, ode_solver, ode_steps ], outputs=[audio_output_0, audio_output_1] ) interrupt_btn.click(fn=interrupt, queue=False) gr.Examples( examples=[ [ "80s pop with groovy synth bass and electric piano", "(N, 0.0), (C, 0.32), (Dm7, 3.456), (Am, 4.608), (F, 8.32), (C, 9.216)", None, None, ], [ "Strings, woodwind, orchestral, symphony.", "(C, 0.0), (D, 2.0), (F, 4.0), (Ab, 6.0), (Bb, 7.0), (C, 8.0)", None, None, ], ], inputs=[text, chords_sym, melody_file, drums_file], outputs=[audio_output_0, audio_output_1] ) demo.queue().launch(ssr_mode=False)