Xu Xuenan commited on
Commit
5152717
·
1 Parent(s): 6331da0

Multi-GPUs

Browse files
app.py CHANGED
@@ -1,8 +1,8 @@
1
- import spaces
2
-
3
  from pathlib import Path
4
- import argparse
5
  import shutil
 
 
6
  import time
7
  import uuid
8
  import subprocess
@@ -22,7 +22,6 @@ except FileNotFoundError:
22
  imagemagick_installed = False
23
 
24
  if not imagemagick_installed:
25
- import os
26
  os.system("apt update -y")
27
  os.system("apt install -y imagemagick")
28
  os.system("cp policy.xml /etc/ImageMagick-6/")
@@ -41,7 +40,7 @@ default_music_config = config["music_generation"]
41
 
42
 
43
  def set_generating_progress_text(text):
44
- return gr.update(visible=True, value=f"<h3>{text} ...</h3>")
45
 
46
  def set_text_invisible():
47
  return gr.update(visible=False)
@@ -67,8 +66,19 @@ def update_page(direction, page, story_data):
67
  def write_story_fn(story_topic, main_role, scene,
68
  num_outline, temperature,
69
  current_page,
 
70
  progress=gr.Progress(track_tqdm=True)):
71
  config["story_dir"] = f"generated_stories/{time.strftime('%Y%m%d-%H%M%S') + '-' + str(uuid.uuid1().hex)}"
 
 
 
 
 
 
 
 
 
 
72
  deep_update(config, {
73
  "story_setting": {
74
  "story_topic": story_topic,
@@ -85,12 +95,11 @@ def write_story_fn(story_topic, main_role, scene,
85
  # story_data, story_accordion, story_content
86
  return pages, gr.update(visible=True), pages[current_page], gr.update()
87
 
88
- @spaces.GPU()
89
  def modality_assets_generation_fn(
90
  height, width, image_seed, sound_guidance_scale, sound_seed,
91
  n_candidate_per_text, music_duration,
92
- story_data,
93
- progress=gr.Progress(track_tqdm=True)):
94
  deep_update(config, {
95
  "image_generation": {
96
  "obj_cfg": {
@@ -119,60 +128,10 @@ def modality_assets_generation_fn(
119
  # image gallery
120
  return gr.update(visible=True, value=images, columns=[len(images)], rows=[1], height="auto")
121
 
122
- def speech_generation_fn(story_data):
123
- story_gen_agent = MMStoryAgent()
124
- story_gen_agent.generate_speech(config, story_data)
125
-
126
- @spaces.GPU(duration=60)
127
- def sound_generation_fn(sound_guidance_scale, sound_seed, n_candidate_per_text,
128
- story_data, progress=gr.Progress(track_tqdm=True)):
129
- deep_update(config, {
130
- "sound_generation": {
131
- "call_cfg": {
132
- "guidance_scale": sound_guidance_scale,
133
- "seed": sound_seed,
134
- "n_candidate_per_text": n_candidate_per_text
135
- }
136
- }
137
- })
138
- story_gen_agent = MMStoryAgent()
139
- story_gen_agent.generate_sound(config, story_data)
140
-
141
- @spaces.GPU(duration=120)
142
- def music_generation_fn(music_duration,
143
- story_data, progress=gr.Progress(track_tqdm=True)):
144
- deep_update(config, {
145
- "music_generation": {
146
- "call_cfg": {
147
- "duration": music_duration
148
- }
149
- }
150
- })
151
- story_gen_agent = MMStoryAgent()
152
- story_gen_agent.generate_music(config, story_data)
153
-
154
- @spaces.GPU(duration=120)
155
- def image_generation_fn(height, width, image_seed,
156
- story_data, progress=gr.Progress(track_tqdm=True)):
157
- deep_update(config, {
158
- "image_generation": {
159
- "obj_cfg": {
160
- "height": height,
161
- "width": width,
162
- },
163
- "call_cfg": {
164
- "seed": image_seed
165
- }
166
- },
167
- })
168
- story_gen_agent = MMStoryAgent()
169
- result = story_gen_agent.generate_image(config, story_data)
170
- images = result["images"]
171
- return gr.update(visible=True, value=images, columns=[len(images)], rows=[1], height="auto")
172
-
173
  def compose_storytelling_video_fn(
174
  fade_duration, slide_duration, zoom_speed, move_ratio,
175
- sound_volume, music_volume, bg_speech_ratio, fps,
 
176
  story_data,
177
  progress=gr.Progress(track_tqdm=True)):
178
  deep_update(config, {
@@ -194,121 +153,122 @@ def compose_storytelling_video_fn(
194
  return Path(config["story_dir"]) / "output.mp4"
195
 
196
 
197
- if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
200
-
201
- gr.HTML("""
202
- <h1 style="text-align: center;">MM-StoryAgent</h1>
203
- <p style="font-size: 16px;">This is a demo for generating attractive storytelling videos based on the given story setting.</p>
204
- """)
205
-
206
- with gr.Row():
207
- with gr.Column():
208
- story_topic = gr.Textbox(label="Story Topic", value=default_story_setting["story_topic"])
209
- main_role = gr.Textbox(label="Main Role", value=default_story_setting["main_role"])
210
- scene = gr.Textbox(label="Scene", value=default_story_setting["scene"])
211
- chapter_num = gr.Number(label="Chapter Number", value=default_story_gen_config["num_outline"])
212
- temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Temperature", value=default_story_gen_config["temperature"])
213
-
214
- with gr.Accordion("Detailed Image Configuration (Optional)", open=False):
215
- height = gr.Slider(label="Height", minimum=256, maximum=1024, step=32, value=default_image_config["obj_cfg"]['height'])
216
- width = gr.Slider(label="Width", minimum=256, maximum=1024, step=32, value=default_image_config["obj_cfg"]['width'])
217
- image_seed = gr.Number(label="Image Seed", value=default_image_config["call_cfg"]['seed'])
218
-
219
- with gr.Accordion("Detailed Sound Configuration (Optional)", open=False):
220
- sound_guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=7.0, step=0.5, value=default_sound_config["call_cfg"]['guidance_scale'])
221
- sound_seed = gr.Number(label="Sound Seed", value=default_sound_config["call_cfg"]['seed'])
222
- n_candidate_per_text = gr.Slider(label="Number of Candidates per Text", minimum=0, maximum=5, step=1, value=default_sound_config["call_cfg"]['n_candidate_per_text'])
223
-
224
- with gr.Accordion("Detailed Music Configuration (Optional)", open=False):
225
- music_duration = gr.Number(label="Music Duration", min_width=30.0, maximum=120.0, value=default_music_config["call_cfg"]["duration"])
226
-
227
- with gr.Accordion("Detailed Slideshow Effect (Optional)", open=False):
228
- fade_duration = gr.Slider(label="Fade Duration", minimum=0.1, maximum=1.5, step=0.1, value=default_slideshow_effect['fade_duration'])
229
- slide_duration = gr.Slider(label="Slide Duration", minimum=0.1, maximum=1.0, step=0.1, value=default_slideshow_effect['slide_duration'])
230
- zoom_speed = gr.Slider(label="Zoom Speed", minimum=0.1, maximum=2.0, step=0.1, value=default_slideshow_effect['zoom_speed'])
231
- move_ratio = gr.Slider(label="Move Ratio", minimum=0.8, maximum=1.0, step=0.05, value=default_slideshow_effect['move_ratio'])
232
- sound_volume = gr.Slider(label="Sound Volume", minimum=0.0, maximum=1.0, step=0.1, value=default_slideshow_effect['sound_volume'])
233
- music_volume = gr.Slider(label="Music Volume", minimum=0.0, maximum=1.0, step=0.1, value=default_slideshow_effect['music_volume'])
234
- bg_speech_ratio = gr.Slider(label="Background / Speech Ratio", minimum=0.0, maximum=1.0, step=0.1, value=default_slideshow_effect['bg_speech_ratio'])
235
- fps = gr.Slider(label="FPS", minimum=1, maximum=30, step=1, value=default_slideshow_effect['fps'])
236
-
237
-
238
- with gr.Column():
239
- story_data = gr.State([])
240
-
241
- story_generation_information = gr.Markdown(
242
- label="Story Generation Status",
243
- value="<h3>Generating Story Script ......</h3>",
244
- visible=False)
245
- with gr.Accordion(label="Story Content", open=False, visible=False) as story_accordion:
246
- with gr.Row():
247
- prev_button = gr.Button("Previous Page",)
248
- next_button = gr.Button("Next Page",)
249
- story_content = gr.Textbox(label="Page Content")
250
- video_generation_information = gr.Markdown(label="Generation Status", value="<h3>Generating Video ......</h3>", visible=False)
251
- image_gallery = gr.Gallery(label="Images", show_label=False, visible=False)
252
- video_generation_btn = gr.Button("Generate Video")
253
- video_output = gr.Video(label="Generated Story", interactive=False)
254
-
255
- current_page = gr.State(0)
256
-
257
- prev_button.click(
258
- fn=update_page,
259
- inputs=[gr.State("prev"), current_page, story_data],
260
- outputs=[current_page, story_content]
261
- )
262
- next_button.click(
263
- fn=update_page,
264
- inputs=[gr.State("next"), current_page, story_data],
265
- outputs=[current_page, story_content,])
266
-
267
- video_generation_btn.click(
268
- fn=set_generating_progress_text,
269
- inputs=[gr.State("Generating Story")],
270
- outputs=video_generation_information
271
- ).then(
272
- fn=write_story_fn,
273
- inputs=[story_topic, main_role, scene,
274
- chapter_num, temperature,
275
- current_page],
276
- outputs=[story_data, story_accordion, story_content, video_output]
277
- ).then(
278
- fn=set_generating_progress_text,
279
- inputs=[gr.State("Generating Modality Assets")],
280
- outputs=video_generation_information
281
- ).then(
282
- fn=speech_generation_fn,
283
- inputs=[story_data]
284
- ).then(
285
- fn=sound_generation_fn,
286
- inputs=[sound_guidance_scale, sound_seed, n_candidate_per_text, story_data]
287
- ).then(
288
- fn=music_generation_fn,
289
- inputs=[music_duration, story_data]
290
- ).then(
291
- fn=image_generation_fn,
292
- inputs=[height, width, image_seed, story_data],
293
- outputs=[image_gallery]
294
- ).then(
295
- fn=set_generating_progress_text,
296
- inputs=[gr.State("Composing Video")],
297
- outputs=video_generation_information
298
- ).then(
299
- fn=compose_storytelling_video_fn,
300
- inputs=[fade_duration, slide_duration, zoom_speed, move_ratio,
301
- sound_volume, music_volume, bg_speech_ratio, fps,
302
- story_data],
303
- outputs=[video_output]
304
- ).then(
305
- fn=lambda : gr.update(visible=False),
306
- inputs=[],
307
- outputs=[image_gallery]
308
- ).then(
309
- fn=set_generating_progress_text,
310
- inputs=[gr.State("Generation Finished")],
311
- outputs=video_generation_information
312
- )
313
-
314
- demo.launch()
 
 
 
1
  from pathlib import Path
2
+ from copy import deepcopy
3
  import shutil
4
+ import os
5
+ from datetime import datetime
6
  import time
7
  import uuid
8
  import subprocess
 
22
  imagemagick_installed = False
23
 
24
  if not imagemagick_installed:
 
25
  os.system("apt update -y")
26
  os.system("apt install -y imagemagick")
27
  os.system("cp policy.xml /etc/ImageMagick-6/")
 
40
 
41
 
42
  def set_generating_progress_text(text):
43
+ return gr.update(visible=True, value=f"<h3>{text}</h3>")
44
 
45
  def set_text_invisible():
46
  return gr.update(visible=False)
 
66
  def write_story_fn(story_topic, main_role, scene,
67
  num_outline, temperature,
68
  current_page,
69
+ config,
70
  progress=gr.Progress(track_tqdm=True)):
71
  config["story_dir"] = f"generated_stories/{time.strftime('%Y%m%d-%H%M%S') + '-' + str(uuid.uuid1().hex)}"
72
+ current_date = datetime.now()
73
+
74
+ if Path("generated_stories").exists():
75
+ for story_dir in Path("generated_stories").iterdir():
76
+ story_date = story_dir.name[:8]
77
+ story_date = datetime.strptime(story_date, '%Y%m%d')
78
+ date_difference = current_date - story_date
79
+ if date_difference.days >= 2:
80
+ shutil.rmtree(story_dir)
81
+
82
  deep_update(config, {
83
  "story_setting": {
84
  "story_topic": story_topic,
 
95
  # story_data, story_accordion, story_content
96
  return pages, gr.update(visible=True), pages[current_page], gr.update()
97
 
 
98
  def modality_assets_generation_fn(
99
  height, width, image_seed, sound_guidance_scale, sound_seed,
100
  n_candidate_per_text, music_duration,
101
+ config,
102
+ story_data):
103
  deep_update(config, {
104
  "image_generation": {
105
  "obj_cfg": {
 
128
  # image gallery
129
  return gr.update(visible=True, value=images, columns=[len(images)], rows=[1], height="auto")
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  def compose_storytelling_video_fn(
132
  fade_duration, slide_duration, zoom_speed, move_ratio,
133
+ sound_volume, music_volume, bg_speech_ratio, fps,
134
+ config,
135
  story_data,
136
  progress=gr.Progress(track_tqdm=True)):
137
  deep_update(config, {
 
153
  return Path(config["story_dir"]) / "output.mp4"
154
 
155
 
156
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
157
+
158
+ gr.HTML("""
159
+ <h1 style="text-align: center;">MM-StoryAgent</h1>
160
+ <p style="font-size: 16px;">This is a demo for generating attractive storytelling videos based on the given story setting.</p>
161
+ """)
162
+
163
+ config = gr.State(deepcopy(config))
164
+
165
+ with gr.Row():
166
+ with gr.Column():
167
+ story_topic = gr.Textbox(label="Story Topic", value=default_story_setting["story_topic"])
168
+ main_role = gr.Textbox(label="Main Role", value=default_story_setting["main_role"])
169
+ scene = gr.Textbox(label="Scene", value=default_story_setting["scene"])
170
+ chapter_num = gr.Number(label="Chapter Number", value=default_story_gen_config["num_outline"])
171
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Temperature", value=default_story_gen_config["temperature"])
172
+
173
+ with gr.Accordion("Detailed Image Configuration (Optional)", open=False):
174
+ height = gr.Slider(label="Height", minimum=256, maximum=1024, step=32, value=default_image_config["obj_cfg"]['height'])
175
+ width = gr.Slider(label="Width", minimum=256, maximum=1024, step=32, value=default_image_config["obj_cfg"]['width'])
176
+ image_seed = gr.Number(label="Image Seed", value=default_image_config["call_cfg"]['seed'])
177
+
178
+ with gr.Accordion("Detailed Sound Configuration (Optional)", open=False):
179
+ sound_guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=7.0, step=0.5, value=default_sound_config["call_cfg"]['guidance_scale'])
180
+ sound_seed = gr.Number(label="Sound Seed", value=default_sound_config["call_cfg"]['seed'])
181
+ n_candidate_per_text = gr.Slider(label="Number of Candidates per Text", minimum=0, maximum=5, step=1, value=default_sound_config["call_cfg"]['n_candidate_per_text'])
182
+
183
+ with gr.Accordion("Detailed Music Configuration (Optional)", open=False):
184
+ music_duration = gr.Number(label="Music Duration", min_width=30.0, maximum=120.0, value=default_music_config["call_cfg"]["duration"])
185
+
186
+ with gr.Accordion("Detailed Slideshow Effect (Optional)", open=False):
187
+ fade_duration = gr.Slider(label="Fade Duration", minimum=0.1, maximum=1.5, step=0.1, value=default_slideshow_effect['fade_duration'])
188
+ slide_duration = gr.Slider(label="Slide Duration", minimum=0.1, maximum=1.0, step=0.1, value=default_slideshow_effect['slide_duration'])
189
+ zoom_speed = gr.Slider(label="Zoom Speed", minimum=0.1, maximum=2.0, step=0.1, value=default_slideshow_effect['zoom_speed'])
190
+ move_ratio = gr.Slider(label="Move Ratio", minimum=0.8, maximum=1.0, step=0.05, value=default_slideshow_effect['move_ratio'])
191
+ sound_volume = gr.Slider(label="Sound Volume", minimum=0.0, maximum=1.0, step=0.1, value=default_slideshow_effect['sound_volume'])
192
+ music_volume = gr.Slider(label="Music Volume", minimum=0.0, maximum=1.0, step=0.1, value=default_slideshow_effect['music_volume'])
193
+ bg_speech_ratio = gr.Slider(label="Background / Speech Ratio", minimum=0.0, maximum=1.0, step=0.1, value=default_slideshow_effect['bg_speech_ratio'])
194
+ fps = gr.Slider(label="FPS", minimum=1, maximum=30, step=1, value=default_slideshow_effect['fps'])
195
+
196
+
197
+ with gr.Column():
198
+ story_data = gr.State([])
199
+
200
+ story_generation_information = gr.Markdown(
201
+ label="Story Generation Status",
202
+ value="<h3>Generating Story Script ......</h3>",
203
+ visible=False)
204
+ with gr.Accordion(label="Story Content", open=False, visible=False) as story_accordion:
205
+ with gr.Row():
206
+ prev_button = gr.Button("Previous Page",)
207
+ next_button = gr.Button("Next Page",)
208
+ story_content = gr.Textbox(label="Page Content")
209
+ video_generation_information = gr.Markdown(label="Generation Status", value="<h3>Generating Video ......</h3>", visible=False)
210
+ image_gallery = gr.Gallery(label="Images", show_label=False, visible=False)
211
+ video_generation_btn = gr.Button("Generate Video")
212
+ video_output = gr.Video(label="Generated Story", interactive=False)
213
+
214
+ current_page = gr.State(0)
215
+
216
+ prev_button.click(
217
+ fn=update_page,
218
+ inputs=[gr.State("prev"), current_page, story_data],
219
+ outputs=[current_page, story_content]
220
+ )
221
+ next_button.click(
222
+ fn=update_page,
223
+ inputs=[gr.State("next"), current_page, story_data],
224
+ outputs=[current_page, story_content,])
225
+
226
+ # (possibly) update role description and scripts
227
+
228
+ video_generation_btn.click(
229
+ fn=set_generating_progress_text,
230
+ inputs=[gr.State("Generating Story ...")],
231
+ outputs=video_generation_information
232
+ ).then(
233
+ fn=write_story_fn,
234
+ inputs=[story_topic, main_role, scene,
235
+ chapter_num, temperature,
236
+ current_page,
237
+ config
238
+ ],
239
+ outputs=[story_data, story_accordion, story_content, video_output]
240
+ ).then(
241
+ fn=set_generating_progress_text,
242
+ inputs=[gr.State("Generating Modality Assets ...")],
243
+ outputs=video_generation_information
244
+ ).then(
245
+ fn=modality_assets_generation_fn,
246
+ inputs=[height, width, image_seed, sound_guidance_scale, sound_seed,
247
+ n_candidate_per_text, music_duration,
248
+ config,
249
+ story_data],
250
+ outputs=[image_gallery]
251
+ ).then(
252
+ fn=set_generating_progress_text,
253
+ inputs=[gr.State("Composing Video ...")],
254
+ outputs=video_generation_information
255
+ ).then(
256
+ fn=compose_storytelling_video_fn,
257
+ inputs=[fade_duration, slide_duration, zoom_speed, move_ratio,
258
+ sound_volume, music_volume, bg_speech_ratio, fps,
259
+ config,
260
+ story_data],
261
+ outputs=[video_output]
262
+ ).then(
263
+ fn=lambda : gr.update(visible=False),
264
+ inputs=[],
265
+ outputs=[image_gallery]
266
+ ).then(
267
+ fn=set_generating_progress_text,
268
+ inputs=[gr.State("Generation Finished!")],
269
+ outputs=video_generation_information
270
+ )
271
+
272
 
273
+ if __name__ == "__main__":
274
+ demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/mm_story_agent.yaml CHANGED
@@ -1,10 +1,9 @@
1
- story_dir: generated_stories/20240808_1130
2
  audio_sample_rate: &audio_sample_rate 16000
3
  audio_codec: mp3 # [mp3, aac, ...]
4
 
5
 
6
  story_setting:
7
- story_topic: "Time Management: A child learning how to manage their time effectively."
8
  main_role: "(no main role specified)"
9
  scene: "(no scene specified)"
10
 
@@ -26,7 +25,7 @@ sound_generation:
26
  call_cfg:
27
  guidance_scale: 3.5
28
  seed: 0
29
- ddim_steps: 200
30
  n_candidate_per_text: 3
31
  revise_cfg:
32
  num_turns: 3
@@ -44,7 +43,7 @@ image_generation:
44
  num_turns: 3
45
  obj_cfg:
46
  model_name: stabilityai/stable-diffusion-xl-base-1.0
47
- id_length: 2
48
  height: 512
49
  width: 1024
50
  call_cfg:
@@ -56,6 +55,7 @@ image_generation:
56
  music_generation:
57
  revise_cfg:
58
  num_turns: 3
 
59
  call_cfg:
60
  duration: 60.0
61
 
 
 
1
  audio_sample_rate: &audio_sample_rate 16000
2
  audio_codec: mp3 # [mp3, aac, ...]
3
 
4
 
5
  story_setting:
6
+ story_topic: "learn to use computer"
7
  main_role: "(no main role specified)"
8
  scene: "(no scene specified)"
9
 
 
25
  call_cfg:
26
  guidance_scale: 3.5
27
  seed: 0
28
+ ddim_steps: 100
29
  n_candidate_per_text: 3
30
  revise_cfg:
31
  num_turns: 3
 
43
  num_turns: 3
44
  obj_cfg:
45
  model_name: stabilityai/stable-diffusion-xl-base-1.0
46
+ id_length: 1
47
  height: 512
48
  width: 1024
49
  call_cfg:
 
55
  music_generation:
56
  revise_cfg:
57
  num_turns: 3
58
+ obj_cfg: {}
59
  call_cfg:
60
  duration: 60.0
61
 
mm_story_agent/__init__.py CHANGED
@@ -22,10 +22,16 @@ class MMStoryAgent:
22
  "speech": CosyVoiceAgent,
23
  "music": MusicGenAgent
24
  }
 
 
 
 
 
 
25
  self.agents = {}
26
 
27
- def call_modality_agent(self, agent, pages, save_path, return_dict):
28
- result = agent.call(pages, save_path)
29
  modality = result["modality"]
30
  return_dict[modality] = result
31
 
@@ -73,7 +79,7 @@ class MMStoryAgent:
73
  return_dict = mp.Manager().dict()
74
 
75
  for modality in self.modalities:
76
- p = mp.Process(target=self.call_modality_agent, args=(agents[modality], pages, story_dir / modality, return_dict), daemon=False)
77
  processes.append(p)
78
  p.start()
79
 
 
22
  "speech": CosyVoiceAgent,
23
  "music": MusicGenAgent
24
  }
25
+ self.modality_devices = {
26
+ "image": "cuda:0",
27
+ "sound": "cuda:1",
28
+ "music": "cuda:2",
29
+ "speech": "cuda:3"
30
+ }
31
  self.agents = {}
32
 
33
+ def call_modality_agent(self, agent, device, pages, save_path, return_dict):
34
+ result = agent.call(pages, device, save_path)
35
  modality = result["modality"]
36
  return_dict[modality] = result
37
 
 
79
  return_dict = mp.Manager().dict()
80
 
81
  for modality in self.modalities:
82
+ p = mp.Process(target=self.call_modality_agent, args=(agents[modality], self.modality_devices[modality], pages, story_dir / modality, return_dict), daemon=False)
83
  processes.append(p)
84
  p.start()
85
 
mm_story_agent/modality_agents/image_agent.py CHANGED
@@ -389,6 +389,7 @@ class StoryDiffusionSynthesizer:
389
  num_pages: int,
390
  height: int,
391
  width: int,
 
392
  model_name: str = "stabilityai/stable-diffusion-xl-base-1.0",
393
  model_path: str = None,
394
  id_length: int = 4,
@@ -404,7 +405,7 @@ class StoryDiffusionSynthesizer:
404
  self.total_length = num_pages
405
  self.height = height
406
  self.width = width
407
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
408
  self.dtype = torch.float16
409
  self.num_steps = num_steps
410
  self.styles = {
@@ -525,7 +526,7 @@ class StoryDiffusionSynthesizer:
525
  return p.replace("{prompt}", positive)
526
 
527
  def call(self,
528
- prompts: List[str],
529
  input_id_images = None,
530
  start_merge_step = None,
531
  style_name: str = "Pixar/Disney Character",
@@ -581,7 +582,7 @@ class StoryDiffusionAgent:
581
  if llm_type == "qwen2":
582
  self.LLM = QwenAgent
583
 
584
- def call(self, pages: List, save_path: str):
585
  role_dict = self.extract_role_from_story(pages, **self.config["revise_cfg"])
586
  image_prompts = self.generate_image_prompt_from_story(pages, **self.config["revise_cfg"])
587
  image_prompts_with_role_desc = []
@@ -592,6 +593,7 @@ class StoryDiffusionAgent:
592
  image_prompts_with_role_desc.append(image_prompt)
593
  generation_agent = StoryDiffusionSynthesizer(
594
  num_pages=len(pages),
 
595
  **self.config["obj_cfg"]
596
  )
597
  images = generation_agent.call(
 
389
  num_pages: int,
390
  height: int,
391
  width: int,
392
+ device: str,
393
  model_name: str = "stabilityai/stable-diffusion-xl-base-1.0",
394
  model_path: str = None,
395
  id_length: int = 4,
 
405
  self.total_length = num_pages
406
  self.height = height
407
  self.width = width
408
+ self.device = device
409
  self.dtype = torch.float16
410
  self.num_steps = num_steps
411
  self.styles = {
 
526
  return p.replace("{prompt}", positive)
527
 
528
  def call(self,
529
+ prompts: List[str],
530
  input_id_images = None,
531
  start_merge_step = None,
532
  style_name: str = "Pixar/Disney Character",
 
582
  if llm_type == "qwen2":
583
  self.LLM = QwenAgent
584
 
585
+ def call(self, pages: List, device: str, save_path: str):
586
  role_dict = self.extract_role_from_story(pages, **self.config["revise_cfg"])
587
  image_prompts = self.generate_image_prompt_from_story(pages, **self.config["revise_cfg"])
588
  image_prompts_with_role_desc = []
 
593
  image_prompts_with_role_desc.append(image_prompt)
594
  generation_agent = StoryDiffusionSynthesizer(
595
  num_pages=len(pages),
596
+ device=device,
597
  **self.config["obj_cfg"]
598
  )
599
  images = generation_agent.call(
mm_story_agent/modality_agents/music_agent.py CHANGED
@@ -14,9 +14,10 @@ class MusicGenSynthesizer:
14
 
15
  def __init__(self,
16
  model_name: str = 'facebook/musicgen-medium',
 
17
  sample_rate: int = 16000,
18
  ) -> None:
19
- self.model = MusicGen.get_pretrained(model_name)
20
  self.sample_rate = sample_rate
21
 
22
  def call(self,
@@ -63,10 +64,10 @@ class MusicGenAgent:
63
 
64
  return music_prompt
65
 
66
- def call(self, pages: List, save_path: str):
67
  save_path = Path(save_path)
68
  music_prompt = self.generate_music_prompt_from_story(pages, **self.config["revise_cfg"])
69
- generation_agent = MusicGenSynthesizer()
70
  generation_agent.call(
71
  prompt=music_prompt,
72
  save_path=save_path / "music.wav",
 
14
 
15
  def __init__(self,
16
  model_name: str = 'facebook/musicgen-medium',
17
+ device: str = 'cuda',
18
  sample_rate: int = 16000,
19
  ) -> None:
20
+ self.model = MusicGen.get_pretrained(model_name, device=device).to(device)
21
  self.sample_rate = sample_rate
22
 
23
  def call(self,
 
64
 
65
  return music_prompt
66
 
67
+ def call(self, pages: List, device: str, save_path: str):
68
  save_path = Path(save_path)
69
  music_prompt = self.generate_music_prompt_from_story(pages, **self.config["revise_cfg"])
70
+ generation_agent = MusicGenSynthesizer(device=device)
71
  generation_agent.call(
72
  prompt=music_prompt,
73
  save_path=save_path / "music.wav",
mm_story_agent/modality_agents/sound_agent.py CHANGED
@@ -14,8 +14,9 @@ class AudioLDM2Synthesizer:
14
 
15
  def __init__(self,
16
  model_path: str = None,
 
17
  ) -> None:
18
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
  self.pipe = AudioLDM2Pipeline.from_pretrained(
20
  model_path if model_path is not None else "cvssp/audioldm2",
21
  torch_dtype=torch.float16
@@ -49,7 +50,7 @@ class AudioLDM2Agent:
49
  if llm_type == "qwen2":
50
  self.LLM = QwenAgent
51
 
52
- def call(self, pages: List, save_path: str):
53
  sound_prompts = self.generate_sound_prompt_from_story(pages, **self.config["revise_cfg"])
54
  save_paths = []
55
  forward_prompts = []
@@ -59,7 +60,7 @@ class AudioLDM2Agent:
59
  save_paths.append(save_path / f"p{idx + 1}.wav")
60
  forward_prompts.append(sound_prompts[idx])
61
 
62
- generation_agent = AudioLDM2Synthesizer()
63
  if len(forward_prompts) > 0:
64
  sounds = generation_agent.call(
65
  forward_prompts,
 
14
 
15
  def __init__(self,
16
  model_path: str = None,
17
+ device: str = "cuda"
18
  ) -> None:
19
+ self.device = device
20
  self.pipe = AudioLDM2Pipeline.from_pretrained(
21
  model_path if model_path is not None else "cvssp/audioldm2",
22
  torch_dtype=torch.float16
 
50
  if llm_type == "qwen2":
51
  self.LLM = QwenAgent
52
 
53
+ def call(self, pages: List, device: str, save_path: str):
54
  sound_prompts = self.generate_sound_prompt_from_story(pages, **self.config["revise_cfg"])
55
  save_paths = []
56
  forward_prompts = []
 
60
  save_paths.append(save_path / f"p{idx + 1}.wav")
61
  forward_prompts.append(sound_prompts[idx])
62
 
63
+ generation_agent = AudioLDM2Synthesizer(device=device)
64
  if len(forward_prompts) > 0:
65
  sounds = generation_agent.call(
66
  forward_prompts,
mm_story_agent/modality_agents/speech_agent.py CHANGED
@@ -74,7 +74,7 @@ class CosyVoiceAgent:
74
  def __init__(self, config) -> None:
75
  self.config = config
76
 
77
- def call(self, pages: List, save_path: str):
78
  save_path = Path(save_path)
79
  generation_agent = CosyVoiceSynthesizer()
80
 
 
74
  def __init__(self, config) -> None:
75
  self.config = config
76
 
77
+ def call(self, pages: List, device: str, save_path: str):
78
  save_path = Path(save_path)
79
  generation_agent = CosyVoiceSynthesizer()
80
 
mm_story_agent/prompts_en.py CHANGED
@@ -89,10 +89,11 @@ The input consists of already written story content and the current chapter that
89
  Output the expanded story content for the current chapter. The result should be a list where each element corresponds to the plot of one page of the storybook.
90
 
91
  ## Notes
92
- 1. Only expand the current chapter; do not overwrite content from other chapters.
93
- 2. The expanded content should not be too lengthy, with a maximum of 3 pages and no more than 2 sentences per page.
94
- 3. Maintain the tone of the story; do not add extra annotations, explanations, settings, or comments.
95
- 4. If the story is already complete, no further writing is necessary.
 
96
  """.strip()
97
 
98
 
 
89
  Output the expanded story content for the current chapter. The result should be a list where each element corresponds to the plot of one page of the storybook.
90
 
91
  ## Notes
92
+ 1. Only expand the current chapter. Do not overwrite content from other chapters.
93
+ 2. The expanded story content should not be too long, with a maximum of 3 pages. Each page contains only 1 sentence.
94
+ 3. Maintain the tone of the story. Do not add extra annotations, explanations, settings, or comments.
95
+ 4. Use simple and straightforward language suitable for children's stories.
96
+ 5. If the story is already complete, no further writing is necessary.
97
  """.strip()
98
 
99
 
nls-1.0.0-py3-none-any.whl DELETED
Binary file (47 kB)