File size: 3,706 Bytes
a121edc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5152717
a121edc
5152717
a121edc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5152717
a121edc
 
 
 
 
 
 
 
 
5152717
a121edc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from pathlib import Path
from typing import List
import json

import torch
import soundfile as sf
from diffusers import AudioLDM2Pipeline

from mm_story_agent.prompts_en import story_to_sound_reviser_system, story_to_sound_review_system
from mm_story_agent.modality_agents.llm import QwenAgent


class AudioLDM2Synthesizer:

    def __init__(self,
                 model_path: str = None,
                 device: str = "cuda"
                 ) -> None:
        self.device = device
        self.pipe = AudioLDM2Pipeline.from_pretrained(
            model_path if model_path is not None else "cvssp/audioldm2",
            torch_dtype=torch.float16
        ).to(self.device)
    
    def call(self,
             prompts: List[str],
             n_candidate_per_text: int = 3,
             seed: int = 0,
             guidance_scale: float = 3.5,
             ddim_steps: int = 100,
    ):
        generator = torch.Generator(device=self.device).manual_seed(seed)
        audios = self.pipe(
            prompts, 
            num_inference_steps=ddim_steps, 
            audio_length_in_s=10.0,
            guidance_scale=guidance_scale,
            generator=generator,
            num_waveforms_per_prompt=n_candidate_per_text).audios
        
        audios = audios[::n_candidate_per_text]

        return audios


class AudioLDM2Agent:

    def __init__(self, config, llm_type="qwen2") -> None:
        self.config = config
        if llm_type == "qwen2":
            self.LLM = QwenAgent

    def call(self, pages: List, device: str, save_path: str):
        sound_prompts = self.generate_sound_prompt_from_story(pages, **self.config["revise_cfg"])
        save_paths = []
        forward_prompts = []
        save_path = Path(save_path)
        for idx in range(len(pages)):
            if sound_prompts[idx] != "No sounds.":
                save_paths.append(save_path / f"p{idx + 1}.wav")
                forward_prompts.append(sound_prompts[idx])
        
        generation_agent = AudioLDM2Synthesizer(device=device)
        if len(forward_prompts) > 0:
            sounds = generation_agent.call(
                forward_prompts,
                **self.config["call_cfg"]
            )
            for sound, path in zip(sounds, save_paths):
                sf.write(path.__str__(), sound, self.config["sample_rate"])
        return {
            "prompts": sound_prompts,
            "modality": "sound"
        }

    def generate_sound_prompt_from_story(
            self,
            pages: List,
            num_turns: int = 3
        ):
        sound_prompt_reviser = self.LLM(story_to_sound_reviser_system, track_history=False)
        sound_prompt_reviewer = self.LLM(story_to_sound_review_system, track_history=False)

        sound_prompts = []
        for page in pages:
            review = ""
            sound_prompt = ""
            for turn in range(num_turns):
                sound_prompt, success = sound_prompt_reviser.run(json.dumps({
                    "story": page,
                    "previous_result": sound_prompt,
                    "improvement_suggestions": review,
                }, ensure_ascii=False))
                if sound_prompt.startswith("Sound description:"):
                    sound_prompt = sound_prompt[len("Sound description:"):]
                review, success = sound_prompt_reviewer.run(json.dumps({
                    "story": page,
                    "sound_description": sound_prompt
                }, ensure_ascii=False))
                if review == "Check passed.":
                    break
                # else:
                    # print(review)
            sound_prompts.append(sound_prompt)

        return sound_prompts