DiffRhythm-SimpleUI / simpler_app.py
fffiloni's picture
Create simpler_app.py
1f45bd9 verified
raw
history blame
4.46 kB
import gradio as gr
import requests
import json
# from volcenginesdkarkruntime import Ark
import torch
import torchaudio
from einops import rearrange
import argparse
import json
import os
import spaces
from tqdm import tqdm
import random
import numpy as np
import sys
import base64
from diffrhythm.infer.infer_utils import (
get_reference_latent,
get_lrc_token,
get_style_prompt,
prepare_model,
get_negative_style_prompt
)
from diffrhythm.infer.infer import inference
MAX_SEED = np.iinfo(np.int32).max
device='cuda'
cfm, tokenizer, muq, vae = prepare_model(device)
cfm = torch.compile(cfm)
def infer_music(lrc, ref_audio_path, seed=42, randomize_seed=False, steps=32, file_type='wav', max_frames=2048, device='cuda'):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
torch.manual_seed(seed)
sway_sampling_coef = -1 if steps < 32 else None
lrc_prompt, start_time = get_lrc_token(lrc, tokenizer, device)
style_prompt = get_style_prompt(muq, ref_audio_path)
negative_style_prompt = get_negative_style_prompt(device)
latent_prompt = get_reference_latent(device, max_frames)
generated_song = inference(cfm_model=cfm,
vae_model=vae,
cond=latent_prompt,
text=lrc_prompt,
duration=max_frames,
style_prompt=style_prompt,
negative_style_prompt=negative_style_prompt,
steps=steps,
sway_sampling_coef=sway_sampling_coef,
start_time=start_time,
file_type=file_type
)
return generated_song
import re
from transformers import pipeline
zephyr_model = "HuggingFaceH4/zephyr-7b-beta"
mixtral_model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
pipe = pipeline("text-generation", model=zephyr_model, torch_dtype=torch.bfloat16, device_map="auto")
def prepare_lyrics_with_llm(theme, tags, lyrics):
language = "English"
standard_sys = f"""
Please generate a complete song with lyrics in {language}, following the {tags} style and centered around the theme "{theme}".
If {lyrics} is provided, format it accordingly.
If {lyrics} is None, generate original lyrics based on the given theme and style.
Strictly adhere to the following requirements:
### Mandatory Formatting Rules
1. Only output the formatted lyrics—do not include any explanations, introductions, or additional messages.
2. Only include timestamps and lyrics. Do not use brackets, side notes, or section markers (e.g., chorus, instrumental, outro).
3. Each line must follow the format [mm:ss.xx]Lyrics content, with no spaces between the timestamp and lyrics. The lyrics should be continuous and complete.
4. The total song length must not exceed 1 minute 30 seconds.
5. Timestamps should be naturally distributed. The first lyric must not start at [00:00.00]—consider an intro before the lyrics begin.
### Prohibited Examples (Do Not Include)
- Incorrect: [01:30.00](Piano solo)
- Incorrect: [00:45.00][Chorus]
"""
instruction = f"""
<|system|>
{standard_sys}</s>
<|user|>
theme: {theme}
tags: {tags}
lyrics: {lyrics}
"""
prompt = f"{instruction.strip()}</s>"
outputs = pipe(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
pattern = r'\<\|system\|\>(.*?)\<\|assistant\|\>'
cleaned_text = re.sub(pattern, '', outputs[0]["generated_text"], flags=re.DOTALL)
print(f"SUGGESTED Lyrics: {cleaned_text}")
return cleaned_text.lstrip("\n")
def general_process(theme, tags, lyrics):
result = prepare_lyrics_with_llm(theme, tags, lyrics)
return None, result
with gr.Blocks(css=css) as demo:
with gr.Column():
gr.Markdown("# Simpler Diff Rythm")
theme_song = gr.Textbox(label="Theme")
style_tags = gr.Textbox(label="Music style tags")
lyrics = gr.Textbox(label="Lyrics optional")
submit_btn = gr.Button("Submit")
song_result = gr.Audio(label="Song result")
generated_lyrics = gr.Textbox(label="Generated Lyrics")
submit_btn.click(
fn = general_process,
inputs = [theme_song, style_tags, lyrics],
outputs = [song_result, generated_lyrics]
)
demo.queue().launch(show_api=False, show_error=True, ssr_mode=False)