from pathlib import Path from typing import List import json import torch import soundfile as sf from diffusers import AudioLDM2Pipeline from mm_story_agent.prompts_en import story_to_sound_reviser_system, story_to_sound_review_system from mm_story_agent.modality_agents.llm import QwenAgent class AudioLDM2Synthesizer: def __init__(self, model_path: str = None, device: str = "cuda" ) -> None: self.device = device self.pipe = AudioLDM2Pipeline.from_pretrained( model_path if model_path is not None else "cvssp/audioldm2", torch_dtype=torch.float16 ).to(self.device) def call(self, prompts: List[str], n_candidate_per_text: int = 3, seed: int = 0, guidance_scale: float = 3.5, ddim_steps: int = 100, ): generator = torch.Generator(device=self.device).manual_seed(seed) audios = self.pipe( prompts, num_inference_steps=ddim_steps, audio_length_in_s=10.0, guidance_scale=guidance_scale, generator=generator, num_waveforms_per_prompt=n_candidate_per_text).audios audios = audios[::n_candidate_per_text] return audios class AudioLDM2Agent: def __init__(self, config, llm_type="qwen2") -> None: self.config = config if llm_type == "qwen2": self.LLM = QwenAgent def call(self, pages: List, device: str, save_path: str): sound_prompts = self.generate_sound_prompt_from_story(pages, **self.config["revise_cfg"]) save_paths = [] forward_prompts = [] save_path = Path(save_path) for idx in range(len(pages)): if sound_prompts[idx] != "No sounds.": save_paths.append(save_path / f"p{idx + 1}.wav") forward_prompts.append(sound_prompts[idx]) generation_agent = AudioLDM2Synthesizer(device=device) if len(forward_prompts) > 0: sounds = generation_agent.call( forward_prompts, **self.config["call_cfg"] ) for sound, path in zip(sounds, save_paths): sf.write(path.__str__(), sound, self.config["sample_rate"]) return { "prompts": sound_prompts, "modality": "sound" } def generate_sound_prompt_from_story( self, pages: List, num_turns: int = 3 ): sound_prompt_reviser = self.LLM(story_to_sound_reviser_system, track_history=False) sound_prompt_reviewer = self.LLM(story_to_sound_review_system, track_history=False) sound_prompts = [] for page in pages: review = "" sound_prompt = "" for turn in range(num_turns): sound_prompt, success = sound_prompt_reviser.run(json.dumps({ "story": page, "previous_result": sound_prompt, "improvement_suggestions": review, }, ensure_ascii=False)) if sound_prompt.startswith("Sound description:"): sound_prompt = sound_prompt[len("Sound description:"):] review, success = sound_prompt_reviewer.run(json.dumps({ "story": page, "sound_description": sound_prompt }, ensure_ascii=False)) if review == "Check passed.": break # else: # print(review) sound_prompts.append(sound_prompt) return sound_prompts