umarigan commited on
Commit
1a55abe
·
verified ·
1 Parent(s): 6dc214e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +148 -75
app.py CHANGED
@@ -1,87 +1,160 @@
1
  import gradio as gr
2
- import numpy as np
3
  import torch
4
- from datasets import load_dataset
5
- from transformers import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor, pipeline
6
-
7
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
8
-
9
- # Load Whisper large-v2 model for multilingual speech translation
10
- asr_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-large-v2", device=device)
11
-
12
- # Load MMS TTS model for multilingual text-to-speech (using German model as base)
13
- processor = SpeechT5Processor.from_pretrained("facebook/s2t-medium-mustc-multilingual-st")
14
- model = SpeechT5ForTextToSpeech.from_pretrained("facebook/s2t-medium-mustc-multilingual-st").to(device)
15
- vocoder = SpeechT5HifiGan.from_pretrained("facebook/s2t-medium-mustc-multilingual-st").to(device)
16
-
17
- # Define supported languages (adjust based on the languages supported by the model)
18
- LANGUAGES = {
19
- "German": "deu", "English": "eng", "French": "fra", "Spanish": "spa",
20
- "Italian": "ita", "Portuguese": "por", "Polish": "pol", "Turkish": "tur"
21
- }
22
-
23
- def translate(audio, source_lang, target_lang):
24
- outputs = asr_pipe(audio, max_new_tokens=256, generate_kwargs={
25
- "task": "transcribe",
26
- "language": source_lang,
27
- })
28
- transcription = outputs["text"]
29
-
30
- # Use Whisper for translation
31
- translation = asr_pipe(transcription, max_new_tokens=256, generate_kwargs={
32
- "task": "translate",
33
- "language": target_lang,
34
- })["text"]
35
 
36
- return translation
 
 
 
 
 
 
 
37
 
38
- def synthesise(text, target_lang):
39
- inputs = processor(text=text, return_tensors="pt")
40
- speech = model.generate_speech(inputs["input_ids"].to(device), vocoder=vocoder, language=LANGUAGES[target_lang])
41
- return speech.cpu()
 
 
 
 
 
 
 
 
 
 
42
 
43
- def speech_to_speech_translation(audio, source_lang, target_lang):
44
- translated_text = translate(audio, LANGUAGES[source_lang], LANGUAGES[target_lang])
45
- synthesised_speech = synthesise(translated_text, target_lang)
46
- synthesised_speech = (synthesised_speech.numpy() * 32767).astype(np.int16)
47
- return 16000, synthesised_speech
 
 
 
 
 
 
 
 
48
 
49
- title = "Multilingual Speech-to-Speech Translation"
50
- description = """
51
- Demo for multilingual speech-to-speech translation (STST), mapping from source speech in any supported language to target speech in any other supported language.
52
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- demo = gr.Blocks()
 
 
 
 
 
 
 
 
55
 
56
- with demo:
57
- gr.Markdown(f"# {title}")
58
- gr.Markdown(description)
59
 
 
 
 
 
 
60
  with gr.Row():
61
- source_lang = gr.Dropdown(choices=list(LANGUAGES.keys()), label="Source Language")
62
- target_lang = gr.Dropdown(choices=list(LANGUAGES.keys()), label="Target Language")
63
-
64
- with gr.Tabs():
65
- with gr.TabItem("Microphone"):
66
- mic_input = gr.Audio(source="microphone", type="filepath")
67
- mic_output = gr.Audio(label="Generated Speech", type="numpy")
68
- mic_button = gr.Button("Translate")
69
 
70
- with gr.TabItem("Audio File"):
71
- file_input = gr.Audio(source="upload", type="filepath")
72
- file_output = gr.Audio(label="Generated Speech", type="numpy")
73
- file_button = gr.Button("Translate")
74
-
75
- mic_button.click(
76
- speech_to_speech_translation,
77
- inputs=[mic_input, source_lang, target_lang],
78
- outputs=mic_output
79
- )
80
-
81
- file_button.click(
82
- speech_to_speech_translation,
83
- inputs=[file_input, source_lang, target_lang],
84
- outputs=file_output
85
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- demo.launch()
 
1
  import gradio as gr
 
2
  import torch
3
+ import uuid
4
+ import json
5
+ import librosa
6
+ import os
7
+ import tempfile
8
+ import soundfile as sf
9
+ import scipy.io.wavfile as wav
10
+
11
+ from transformers import pipeline, VitsModel, AutoTokenizer, set_seed
12
+ from nemo.collections.asr.models import EncDecMultiTaskModel
13
+
14
+ # Constants
15
+ SAMPLE_RATE = 16000 # Hz
16
+
17
+ # load ASR model
18
+ canary_model = EncDecMultiTaskModel.from_pretrained('nvidia/canary-1b')
19
+ decode_cfg = canary_model.cfg.decoding
20
+ decode_cfg.beam.beam_size = 1
21
+ canary_model.change_decoding_strategy(decode_cfg)
22
+
23
+ # Function to convert audio to text using ASR
24
+ def gen_text(audio_filepath, action, source_lang, target_lang):
25
+ if audio_filepath is None:
26
+ raise gr.Error("Please provide some input audio.")
 
 
 
 
 
 
 
27
 
28
+ utt_id = uuid.uuid4()
29
+ with tempfile.TemporaryDirectory() as tmpdir:
30
+ # Convert to 16 kHz
31
+ data, sr = librosa.load(audio_filepath, sr=None, mono=True)
32
+ if sr != SAMPLE_RATE:
33
+ data = librosa.resample(data, orig_sr=sr, target_sr=SAMPLE_RATE)
34
+ converted_audio_filepath = os.path.join(tmpdir, f"{utt_id}.wav")
35
+ sf.write(converted_audio_filepath, data, SAMPLE_RATE)
36
 
37
+ # Transcribe audio
38
+ duration = len(data) / SAMPLE_RATE
39
+ manifest_data = {
40
+ "audio_filepath": converted_audio_filepath,
41
+ "taskname": action,
42
+ "source_lang": source_lang,
43
+ "target_lang": source_lang if action=="asr" else target_lang,
44
+ "pnc": "no",
45
+ "answer": "predict",
46
+ "duration": str(duration),
47
+ }
48
+ manifest_filepath = os.path.join(tmpdir, f"{utt_id}.json")
49
+ with open(manifest_filepath, 'w') as fout:
50
+ fout.write(json.dumps(manifest_data))
51
 
52
+ predicted_text = canary_model.transcribe(manifest_filepath)[0]
53
+ # if duration < 40:
54
+ # predicted_text = canary_model.transcribe(manifest_filepath)[0]
55
+ # else:
56
+ # predicted_text = get_buffered_pred_feat_multitaskAED(
57
+ # frame_asr,
58
+ # canary_model.cfg.preprocessor,
59
+ # model_stride_in_secs,
60
+ # canary_model.device,
61
+ # manifest=manifest_filepath,
62
+ # )[0].text
63
+
64
+ return predicted_text
65
 
66
+ # Function to convert text to speech using TTS
67
+ def gen_speech(text, lang):
68
+ set_seed(555) # Make it deterministic
69
+ match lang:
70
+ case "en":
71
+ model = "facebook/mms-tts-eng"
72
+ case "fr":
73
+ model = "facebook/mms-tts-fra"
74
+ case "de":
75
+ model = "facebook/mms-tts-deu"
76
+ case "es":
77
+ model = "facebook/mms-tts-spa"
78
+ case _:
79
+ model = "facebook/mms-tts"
80
+
81
+ # load TTS model
82
+ tts_model = VitsModel.from_pretrained(model)
83
+ tts_tokenizer = AutoTokenizer.from_pretrained(model)
84
+
85
+ input_text = tts_tokenizer(text, return_tensors="pt")
86
+ with torch.no_grad():
87
+ outputs = tts_model(**input_text)
88
+ waveform_np = outputs.waveform[0].cpu().numpy()
89
+ output_file = f"{str(uuid.uuid4())}.wav"
90
+ wav.write(output_file, rate=tts_model.config.sampling_rate, data=waveform_np)
91
+ return output_file
92
 
93
+ # Root function for Gradio interface
94
+ def start_process(audio_filepath, source_lang, target_lang):
95
+ transcription = gen_text(audio_filepath, "asr", source_lang, target_lang)
96
+ print("Done transcribing")
97
+ translation = gen_text(audio_filepath, "s2t_translation", source_lang, target_lang)
98
+ print("Done translation")
99
+ audio_output_filepath = gen_speech(translation, target_lang)
100
+ print("Done speaking")
101
+ return transcription, translation, audio_output_filepath
102
 
 
 
 
103
 
104
+ # Create Gradio interface
105
+ playground = gr.Blocks()
106
+
107
+ with playground:
108
+
109
  with gr.Row():
110
+ gr.Markdown("""
111
+ ## Your AI Translate Assistant
112
+ ### Gets input audio from user, transcribe and translate it. Convert back to speech.
113
+ - category: [Automatic Speech Recognition](https://huggingface.co/models?pipeline_tag=automatic-speech-recognition), model: [nvidia/canary-1b](https://huggingface.co/nvidia/canary-1b)
114
+ - category: [Text-to-Speech](https://huggingface.co/models?pipeline_tag=text-to-speech), model: [facebook/mms-tts](https://huggingface.co/facebook/mms-tts)
115
+ """)
 
 
116
 
117
+ with gr.Row():
118
+ with gr.Column():
119
+ source_lang = gr.Dropdown(
120
+ choices=["en", "de", "es", "fr"], value="en", label="Source Language"
121
+ )
122
+ with gr.Column():
123
+ target_lang = gr.Dropdown(
124
+ choices=["en", "de", "es", "fr"], value="fr", label="Target Language"
125
+ )
126
+
127
+ with gr.Row():
128
+ with gr.Column():
129
+ input_audio = gr.Audio(sources=["microphone"], type="filepath", label="Input Audio")
130
+ with gr.Column():
131
+ translated_speech = gr.Audio(type="filepath", label="Generated Speech")
132
+
133
+ with gr.Row():
134
+ with gr.Column():
135
+ transcipted_text = gr.Textbox(label="Transcription")
136
+ with gr.Column():
137
+ translated_text = gr.Textbox(label="Translation")
138
+
139
+ with gr.Row():
140
+ with gr.Column():
141
+ submit_button = gr.Button(value="Start Process", variant="primary")
142
+ with gr.Column():
143
+ clear_button = gr.ClearButton(components=[input_audio, source_lang, target_lang, transcipted_text, translated_text, translated_speech], value="Clear")
144
+
145
+ with gr.Row():
146
+ gr.Examples(
147
+ examples=[
148
+ ["sample_en.wav","en","fr"],
149
+ ["sample_fr.wav","fr","de"],
150
+ ["sample_de.wav","de","es"],
151
+ ["sample_es.wav","es","en"]
152
+ ],
153
+ inputs=[input_audio, source_lang, target_lang],
154
+ outputs=[transcipted_text, translated_text, translated_speech],
155
+ run_on_click=True, cache_examples=True, fn=start_process
156
+ )
157
+
158
+ submit_button.click(start_process, inputs=[input_audio, source_lang, target_lang], outputs=[transcipted_text, translated_text, translated_speech])
159
 
160
+ playground.launch()