Niki Zhang commited on
Commit
cf1091a
·
verified ·
1 Parent(s): 9434e0e

Update app.py

Browse files

Combine with TTS module

Files changed (1) hide show
  1. app.py +86 -21
app.py CHANGED
@@ -18,6 +18,16 @@ from caption_anything.segmenter import build_segmenter
18
  from caption_anything.utils.chatbot import ConversationBot, build_chatbot_tools, get_new_image_name
19
  from segment_anything import sam_model_registry
20
  import easyocr
 
 
 
 
 
 
 
 
 
 
21
 
22
  args = parse_augment()
23
  args.segmenter = "huge"
@@ -102,12 +112,12 @@ def init_openai_api_key(api_key=""):
102
  print(text_refiner)
103
  openai_available = text_refiner is not None
104
  if openai_available:
105
- return [gr.update(visible=True)]*6 + [gr.update(visible=False)]*2 + [text_refiner, visual_chatgpt, None]
106
  else:
107
- return [gr.update(visible=False)]*6 + [gr.update(visible=True)]*2 + [text_refiner, visual_chatgpt, 'Your OpenAI API Key is not available']
108
 
109
  def init_wo_openai_api_key():
110
- return [gr.update(visible=False)]*4 + [gr.update(visible=True)]*2 + [gr.update(visible=False)]*2 + [None, None, None]
111
 
112
  def get_click_prompt(chat_input, click_state, click_mode):
113
  inputs = json.loads(chat_input)
@@ -256,7 +266,8 @@ def inference_click(image_input, point_prompt, click_mode, enable_wiki, language
256
 
257
 
258
  def submit_caption(image_input, state, generated_caption, text_refiner, visual_chatgpt, enable_wiki, length, sentiment, factuality, language,
259
- out_state, click_index_state, input_mask_state, input_points_state, input_labels_state):
 
260
  print("state",state)
261
 
262
  click_index = click_index_state
@@ -291,13 +302,23 @@ def submit_caption(image_input, state, generated_caption, text_refiner, visual_c
291
  print("new_cap",new_cap)
292
  refined_image_input = create_bubble_frame(np.array(origin_image_input), new_cap, click_index, input_mask,
293
  input_points=input_points, input_labels=input_labels)
294
- txt2speech(new_cap)
295
- yield state, state, refined_image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state
296
-
297
- else:
298
- txt2speech(generated_caption)
299
- yield state, state, image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state
 
 
300
 
 
 
 
 
 
 
 
 
301
 
302
 
303
 
@@ -531,6 +552,7 @@ def create_ui():
531
  interactive=True,
532
  label="Generated Caption Length",
533
  )
 
534
  enable_wiki = gr.Radio(
535
  choices=["Yes", "No"],
536
  value="No",
@@ -541,6 +563,7 @@ def create_ui():
541
  examples=examples,
542
  inputs=[example_image],
543
  )
 
544
  with gr.Column(scale=0.5):
545
  with gr.Column(visible=True) as module_key_input:
546
  openai_api_key = gr.Textbox(
@@ -567,18 +590,52 @@ def create_ui():
567
  with gr.Row():
568
  clear_button_text = gr.Button(value="Clear Text", interactive=True)
569
  submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
570
 
 
571
  openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key],
572
  outputs=[modules_need_gpt0, modules_need_gpt1, modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt,
573
- modules_not_need_gpt2, module_key_input, module_notification_box, text_refiner, visual_chatgpt, notification_box])
574
  enable_chatGPT_button.click(init_openai_api_key, inputs=[openai_api_key],
575
  outputs=[modules_need_gpt0, modules_need_gpt1, modules_need_gpt2, modules_need_gpt3,
576
  modules_not_need_gpt,
577
- modules_not_need_gpt2, module_key_input, module_notification_box, text_refiner, visual_chatgpt, notification_box])
578
  disable_chatGPT_button.click(init_wo_openai_api_key,
579
  outputs=[modules_need_gpt0, modules_need_gpt1, modules_need_gpt2, modules_need_gpt3,
580
  modules_not_need_gpt,
581
- modules_not_need_gpt2, module_key_input, module_notification_box, text_refiner, visual_chatgpt, notification_box])
582
 
583
  enable_chatGPT_button.click(
584
  lambda: (None, [], [], [[], [], []], "", "", ""),
@@ -663,13 +720,19 @@ def create_ui():
663
 
664
 
665
  submit_button_click.click(
666
- submit_caption,
667
- inputs=[image_input, state, generated_caption, text_refiner, visual_chatgpt, enable_wiki, length, sentiment, factuality, language,
668
- out_state, click_index_state, input_mask_state, input_points_state, input_labels_state],
669
- outputs=[chatbot, state, image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state],
670
- show_progress=True, queue=True
671
- )
672
-
 
 
 
 
 
 
673
 
674
 
675
 
@@ -683,6 +746,9 @@ def create_ui():
683
  show_progress=False, queue=True
684
  )
685
 
 
 
 
686
  return iface
687
 
688
 
@@ -690,4 +756,3 @@ if __name__ == '__main__':
690
  iface = create_ui()
691
  iface.queue(concurrency_count=5, api_open=False, max_size=10)
692
  iface.launch(server_name="0.0.0.0", enable_queue=True)
693
-
 
18
  from caption_anything.utils.chatbot import ConversationBot, build_chatbot_tools, get_new_image_name
19
  from segment_anything import sam_model_registry
20
  import easyocr
21
+ import tts
22
+
23
+
24
+
25
+
26
+ article = """
27
+ <div style='margin:20px auto;'>
28
+ <p>By using this demo you agree to the terms of the Coqui Public Model License at https://coqui.ai/cpml</p>
29
+ </div>
30
+ """
31
 
32
  args = parse_augment()
33
  args.segmenter = "huge"
 
112
  print(text_refiner)
113
  openai_available = text_refiner is not None
114
  if openai_available:
115
+ return [gr.update(visible=True)]*7 + [gr.update(visible=False)]*2 + [text_refiner, visual_chatgpt, None]
116
  else:
117
+ return [gr.update(visible=False)]*7 + [gr.update(visible=True)]*2 + [text_refiner, visual_chatgpt, 'Your OpenAI API Key is not available']
118
 
119
  def init_wo_openai_api_key():
120
+ return [gr.update(visible=False)]*4 + [gr.update(visible=True)]*3 + [gr.update(visible=False)]*2 + [None, None, None]
121
 
122
  def get_click_prompt(chat_input, click_state, click_mode):
123
  inputs = json.loads(chat_input)
 
266
 
267
 
268
  def submit_caption(image_input, state, generated_caption, text_refiner, visual_chatgpt, enable_wiki, length, sentiment, factuality, language,
269
+ out_state, click_index_state, input_mask_state, input_points_state, input_labels_state,
270
+ input_text, input_language, input_audio, input_mic, use_mic, agree):
271
  print("state",state)
272
 
273
  click_index = click_index_state
 
302
  print("new_cap",new_cap)
303
  refined_image_input = create_bubble_frame(np.array(origin_image_input), new_cap, click_index, input_mask,
304
  input_points=input_points, input_labels=input_labels)
305
+ try:
306
+ waveform_visual, audio_output = tts.predict(new_cap, input_language, input_audio, input_mic, use_mic, agree)
307
+ print("error tts")
308
+ yield state, state, refined_image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, waveform_visual, audio_output
309
+ except Exception as e:
310
+ state = state + [(None, f"Error during TTS prediction: {str(e)}")]
311
+ print(f"Error during TTS prediction: {str(e)}")
312
+ yield state, state, refined_image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, None, None
313
 
314
+ else:
315
+ try:
316
+ waveform_visual, audio_output = tts.predict(generated_caption, input_language, input_audio, input_mic, use_mic, agree)
317
+ yield state, state, image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, waveform_visual, audio_output
318
+ except Exception as e:
319
+ state = state + [(None, f"Error during TTS prediction: {str(e)}")]
320
+ print(f"Error during TTS prediction: {str(e)}")
321
+ yield state, state, image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state, None, None
322
 
323
 
324
 
 
552
  interactive=True,
553
  label="Generated Caption Length",
554
  )
555
+ # 是否启用wiki内容整合到caption中
556
  enable_wiki = gr.Radio(
557
  choices=["Yes", "No"],
558
  value="No",
 
563
  examples=examples,
564
  inputs=[example_image],
565
  )
566
+
567
  with gr.Column(scale=0.5):
568
  with gr.Column(visible=True) as module_key_input:
569
  openai_api_key = gr.Textbox(
 
590
  with gr.Row():
591
  clear_button_text = gr.Button(value="Clear Text", interactive=True)
592
  submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
593
+
594
+ # TTS interface hidden initially
595
+ with gr.Column(visible=False) as tts_interface:
596
+ input_text = gr.Textbox(label="Text Prompt", value="Hello, World !, here is an example of light voice cloning. Try to upload your best audio samples quality")
597
+ input_language = gr.Dropdown(label="Language", choices=["en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh-cn"], value="en")
598
+ input_audio = gr.Audio(label="Reference Audio", type="filepath", value="examples/female.wav")
599
+ input_mic = gr.Audio(source="microphone", type="filepath", label="Use Microphone for Reference")
600
+ use_mic = gr.Checkbox(label="Check to use Microphone as Reference", value=False)
601
+ agree = gr.Checkbox(label="Agree", value=True)
602
+ output_waveform = gr.Video(label="Waveform Visual")
603
+ output_audio = gr.Audio(label="Synthesised Audio")
604
+
605
+ with gr.Row():
606
+ submit_tts = gr.Button(value="Submit", interactive=True)
607
+ clear_tts = gr.Button(value="Clear", interactive=True)
608
+
609
+
610
+ def clear_tts_fields():
611
+ return [gr.update(value=""), gr.update(value=""), None, None, gr.update(value=False), gr.update(value=True), None, None]
612
+
613
+ submit_tts.click(
614
+ tts.predict,
615
+ inputs=[input_text, input_language, input_audio, input_mic, use_mic, agree],
616
+ outputs=[output_waveform, output_audio],
617
+ queue=True
618
+ )
619
+
620
+ clear_tts.click(
621
+ clear_tts_fields,
622
+ inputs=None,
623
+ outputs=[input_text, input_language, input_audio, input_mic, use_mic, agree, output_waveform, output_audio],
624
+ queue=False
625
+ )
626
 
627
+
628
  openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key],
629
  outputs=[modules_need_gpt0, modules_need_gpt1, modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt,
630
+ modules_not_need_gpt2, tts_interface,module_key_input ,module_notification_box, text_refiner, visual_chatgpt, notification_box])
631
  enable_chatGPT_button.click(init_openai_api_key, inputs=[openai_api_key],
632
  outputs=[modules_need_gpt0, modules_need_gpt1, modules_need_gpt2, modules_need_gpt3,
633
  modules_not_need_gpt,
634
+ modules_not_need_gpt2, tts_interface,module_key_input,module_notification_box, text_refiner, visual_chatgpt, notification_box])
635
  disable_chatGPT_button.click(init_wo_openai_api_key,
636
  outputs=[modules_need_gpt0, modules_need_gpt1, modules_need_gpt2, modules_need_gpt3,
637
  modules_not_need_gpt,
638
+ modules_not_need_gpt2, tts_interface,module_key_input, module_notification_box, text_refiner, visual_chatgpt, notification_box])
639
 
640
  enable_chatGPT_button.click(
641
  lambda: (None, [], [], [[], [], []], "", "", ""),
 
720
 
721
 
722
  submit_button_click.click(
723
+ submit_caption,
724
+ inputs=[
725
+ image_input, state, generated_caption, text_refiner, visual_chatgpt, enable_wiki, length, sentiment, factuality, language,
726
+ out_state, click_index_state, input_mask_state, input_points_state, input_labels_state,
727
+ input_text, input_language, input_audio, input_mic, use_mic, agree
728
+ ],
729
+ outputs=[
730
+ chatbot, state, image_input, click_index_state, input_mask_state, input_points_state, input_labels_state, out_state,
731
+ output_waveform, output_audio
732
+ ],
733
+ show_progress=True,
734
+ queue=True
735
+ )
736
 
737
 
738
 
 
746
  show_progress=False, queue=True
747
  )
748
 
749
+
750
+
751
+
752
  return iface
753
 
754
 
 
756
  iface = create_ui()
757
  iface.queue(concurrency_count=5, api_open=False, max_size=10)
758
  iface.launch(server_name="0.0.0.0", enable_queue=True)