Xu Xuenan commited on
Commit
676ec69
·
1 Parent(s): 5152717

Transformers MusicGen

Browse files
app.py CHANGED
@@ -97,7 +97,7 @@ def write_story_fn(story_topic, main_role, scene,
97
 
98
  def modality_assets_generation_fn(
99
  height, width, image_seed, sound_guidance_scale, sound_seed,
100
- n_candidate_per_text, music_duration,
101
  config,
102
  story_data):
103
  deep_update(config, {
@@ -117,11 +117,6 @@ def modality_assets_generation_fn(
117
  "n_candidate_per_text": n_candidate_per_text
118
  }
119
  },
120
- "music_generation": {
121
- "call_cfg": {
122
- "duration": music_duration
123
- }
124
- }
125
  })
126
  story_gen_agent = MMStoryAgent()
127
  images = story_gen_agent.generate_modality_assets(config, story_data)
@@ -180,9 +175,6 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
180
  sound_seed = gr.Number(label="Sound Seed", value=default_sound_config["call_cfg"]['seed'])
181
  n_candidate_per_text = gr.Slider(label="Number of Candidates per Text", minimum=0, maximum=5, step=1, value=default_sound_config["call_cfg"]['n_candidate_per_text'])
182
 
183
- with gr.Accordion("Detailed Music Configuration (Optional)", open=False):
184
- music_duration = gr.Number(label="Music Duration", min_width=30.0, maximum=120.0, value=default_music_config["call_cfg"]["duration"])
185
-
186
  with gr.Accordion("Detailed Slideshow Effect (Optional)", open=False):
187
  fade_duration = gr.Slider(label="Fade Duration", minimum=0.1, maximum=1.5, step=0.1, value=default_slideshow_effect['fade_duration'])
188
  slide_duration = gr.Slider(label="Slide Duration", minimum=0.1, maximum=1.0, step=0.1, value=default_slideshow_effect['slide_duration'])
@@ -244,7 +236,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
244
  ).then(
245
  fn=modality_assets_generation_fn,
246
  inputs=[height, width, image_seed, sound_guidance_scale, sound_seed,
247
- n_candidate_per_text, music_duration,
248
  config,
249
  story_data],
250
  outputs=[image_gallery]
 
97
 
98
  def modality_assets_generation_fn(
99
  height, width, image_seed, sound_guidance_scale, sound_seed,
100
+ n_candidate_per_text,
101
  config,
102
  story_data):
103
  deep_update(config, {
 
117
  "n_candidate_per_text": n_candidate_per_text
118
  }
119
  },
 
 
 
 
 
120
  })
121
  story_gen_agent = MMStoryAgent()
122
  images = story_gen_agent.generate_modality_assets(config, story_data)
 
175
  sound_seed = gr.Number(label="Sound Seed", value=default_sound_config["call_cfg"]['seed'])
176
  n_candidate_per_text = gr.Slider(label="Number of Candidates per Text", minimum=0, maximum=5, step=1, value=default_sound_config["call_cfg"]['n_candidate_per_text'])
177
 
 
 
 
178
  with gr.Accordion("Detailed Slideshow Effect (Optional)", open=False):
179
  fade_duration = gr.Slider(label="Fade Duration", minimum=0.1, maximum=1.5, step=0.1, value=default_slideshow_effect['fade_duration'])
180
  slide_duration = gr.Slider(label="Slide Duration", minimum=0.1, maximum=1.0, step=0.1, value=default_slideshow_effect['slide_duration'])
 
236
  ).then(
237
  fn=modality_assets_generation_fn,
238
  inputs=[height, width, image_seed, sound_guidance_scale, sound_seed,
239
+ n_candidate_per_text,
240
  config,
241
  story_data],
242
  outputs=[image_gallery]
configs/mm_story_agent.yaml CHANGED
@@ -56,8 +56,7 @@ music_generation:
56
  revise_cfg:
57
  num_turns: 3
58
  obj_cfg: {}
59
- call_cfg:
60
- duration: 60.0
61
 
62
  slideshow_effect:
63
  fade_duration: 0.8
 
56
  revise_cfg:
57
  num_turns: 3
58
  obj_cfg: {}
59
+ call_cfg: {}
 
60
 
61
  slideshow_effect:
62
  fade_duration: 0.8
mm_story_agent/modality_agents/music_agent.py CHANGED
@@ -2,9 +2,9 @@ from pathlib import Path
2
  import json
3
  from typing import List, Union
4
 
 
5
  import torchaudio
6
- from audiocraft.models import MusicGen
7
- from audiocraft.data.audio import audio_write
8
 
9
  from mm_story_agent.modality_agents.llm import QwenAgent
10
  from mm_story_agent.prompts_en import story_to_music_reviser_system, story_to_music_reviewer_system
@@ -17,19 +17,23 @@ class MusicGenSynthesizer:
17
  device: str = 'cuda',
18
  sample_rate: int = 16000,
19
  ) -> None:
20
- self.model = MusicGen.get_pretrained(model_name, device=device).to(device)
 
 
21
  self.sample_rate = sample_rate
22
 
23
  def call(self,
24
  prompt: Union[str, List[str]],
25
  save_path: Union[str, Path],
26
- duration: float = 60.0,
27
  ):
28
- self.model.set_generation_params(duration=duration)
29
- wav = self.model.generate([prompt], progress=True)[0].cpu()
30
- wav = torchaudio.functional.resample(wav, self.model.sample_rate, self.sample_rate)
31
- save_path = Path(save_path).parent / Path(save_path).stem
32
- audio_write(save_path, wav, self.sample_rate)
 
 
 
33
 
34
 
35
  class MusicGenAgent:
 
2
  import json
3
  from typing import List, Union
4
 
5
+ import soundfile as sf
6
  import torchaudio
7
+ from transformers import AutoProcessor, MusicgenForConditionalGeneration
 
8
 
9
  from mm_story_agent.modality_agents.llm import QwenAgent
10
  from mm_story_agent.prompts_en import story_to_music_reviser_system, story_to_music_reviewer_system
 
17
  device: str = 'cuda',
18
  sample_rate: int = 16000,
19
  ) -> None:
20
+ self.device = device
21
+ self.processor = AutoProcessor.from_pretrained(model_name)
22
+ self.model = MusicgenForConditionalGeneration.from_pretrained(model_name).to(device)
23
  self.sample_rate = sample_rate
24
 
25
  def call(self,
26
  prompt: Union[str, List[str]],
27
  save_path: Union[str, Path],
 
28
  ):
29
+ inputs = self.processor(
30
+ text=[prompt],
31
+ padding=True,
32
+ return_tensors="pt",
33
+ ).to(self.device)
34
+ wav = self.model.generate(**inputs, max_new_tokens=1536)[0, 0].cpu()
35
+ wav = torchaudio.functional.resample(wav, self.model.config.audio_encoder.sampling_rate, self.sample_rate)
36
+ sf.write(save_path, wav.numpy(), self.sample_rate)
37
 
38
 
39
  class MusicGenAgent: