Spaces:
Sleeping
Sleeping
from pathlib import Path | |
from typing import List | |
import json | |
import torch | |
import soundfile as sf | |
from diffusers import AudioLDM2Pipeline | |
from mm_story_agent.prompts_en import story_to_sound_reviser_system, story_to_sound_review_system | |
from mm_story_agent.modality_agents.llm import QwenAgent | |
class AudioLDM2Synthesizer: | |
def __init__(self, | |
model_path: str = None, | |
device: str = "cuda" | |
) -> None: | |
self.device = device | |
self.pipe = AudioLDM2Pipeline.from_pretrained( | |
model_path if model_path is not None else "cvssp/audioldm2", | |
torch_dtype=torch.float16 | |
).to(self.device) | |
def call(self, | |
prompts: List[str], | |
n_candidate_per_text: int = 3, | |
seed: int = 0, | |
guidance_scale: float = 3.5, | |
ddim_steps: int = 100, | |
): | |
generator = torch.Generator(device=self.device).manual_seed(seed) | |
audios = self.pipe( | |
prompts, | |
num_inference_steps=ddim_steps, | |
audio_length_in_s=10.0, | |
guidance_scale=guidance_scale, | |
generator=generator, | |
num_waveforms_per_prompt=n_candidate_per_text).audios | |
audios = audios[::n_candidate_per_text] | |
return audios | |
class AudioLDM2Agent: | |
def __init__(self, config, llm_type="qwen2") -> None: | |
self.config = config | |
if llm_type == "qwen2": | |
self.LLM = QwenAgent | |
def call(self, pages: List, device: str, save_path: str): | |
sound_prompts = self.generate_sound_prompt_from_story(pages, **self.config["revise_cfg"]) | |
save_paths = [] | |
forward_prompts = [] | |
save_path = Path(save_path) | |
for idx in range(len(pages)): | |
if sound_prompts[idx] != "No sounds.": | |
save_paths.append(save_path / f"p{idx + 1}.wav") | |
forward_prompts.append(sound_prompts[idx]) | |
generation_agent = AudioLDM2Synthesizer(device=device) | |
if len(forward_prompts) > 0: | |
sounds = generation_agent.call( | |
forward_prompts, | |
**self.config["call_cfg"] | |
) | |
for sound, path in zip(sounds, save_paths): | |
sf.write(path.__str__(), sound, self.config["sample_rate"]) | |
return { | |
"prompts": sound_prompts, | |
"modality": "sound" | |
} | |
def generate_sound_prompt_from_story( | |
self, | |
pages: List, | |
num_turns: int = 3 | |
): | |
sound_prompt_reviser = self.LLM(story_to_sound_reviser_system, track_history=False) | |
sound_prompt_reviewer = self.LLM(story_to_sound_review_system, track_history=False) | |
sound_prompts = [] | |
for page in pages: | |
review = "" | |
sound_prompt = "" | |
for turn in range(num_turns): | |
sound_prompt, success = sound_prompt_reviser.run(json.dumps({ | |
"story": page, | |
"previous_result": sound_prompt, | |
"improvement_suggestions": review, | |
}, ensure_ascii=False)) | |
if sound_prompt.startswith("Sound description:"): | |
sound_prompt = sound_prompt[len("Sound description:"):] | |
review, success = sound_prompt_reviewer.run(json.dumps({ | |
"story": page, | |
"sound_description": sound_prompt | |
}, ensure_ascii=False)) | |
if review == "Check passed.": | |
break | |
# else: | |
# print(review) | |
sound_prompts.append(sound_prompt) | |
return sound_prompts | |