lewistape commited on
Commit
4b20bb0
·
verified ·
1 Parent(s): 5154443

Update tts.py

Browse files
Files changed (1) hide show
  1. tts.py +164 -406
tts.py CHANGED
@@ -1,416 +1,174 @@
1
  import os
2
- from fastapi import FastAPI, HTTPException, File, UploadFile, Depends, Security, Form
3
- from fastapi.security.api_key import APIKeyHeader, APIKey
4
- from fastapi.responses import JSONResponse
5
- from pydantic import BaseModel
6
- from typing import Optional
7
- import numpy as np
8
- import io
9
- import soundfile as sf
10
- import base64
11
- import logging
12
- import torch
13
- import librosa
14
- from pathlib import Path
15
- from pydub import AudioSegment
16
- from moviepy.editor import VideoFileClip
17
- import traceback
18
- from logging.handlers import RotatingFileHandler
19
- import boto3
20
- from botocore.exceptions import NoCredentialsError
21
- import time
22
  import tempfile
23
- import magic
24
-
25
- # Import functions from other modules
26
- from asr import transcribe, ASR_LANGUAGES, ASR_SAMPLING_RATE
27
- from tts import synthesize, TTS_LANGUAGES
28
- from lid import identify
29
-
30
- # Configure logging
31
- logging.basicConfig(level=logging.INFO)
32
- logger = logging.getLogger(__name__)
33
-
34
- # Add a file handler
35
- file_handler = RotatingFileHandler('app.log', maxBytes=10000000, backupCount=5)
36
- file_handler.setLevel(logging.INFO)
37
- formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
38
- file_handler.setFormatter(formatter)
39
- logger.addHandler(file_handler)
40
-
41
- app = FastAPI(title="MMS: Scaling Speech Technology to 1000+ languages")
42
-
43
- # S3 Configuration
44
- S3_BUCKET = os.environ.get("S3_BUCKET")
45
- S3_REGION = os.environ.get("S3_REGION")
46
- S3_ACCESS_KEY_ID = os.environ.get("AWS_ACCESS_KEY_ID")
47
- S3_SECRET_ACCESS_KEY = os.environ.get("AWS_SECRET_ACCESS_KEY")
48
-
49
- # API Key Configuration
50
- API_KEY = os.environ.get("API_KEY")
51
- api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
52
-
53
- # Initialize S3 client
54
- s3_client = boto3.client(
55
- 's3',
56
- aws_access_key_id=S3_ACCESS_KEY_ID,
57
- aws_secret_access_key=S3_SECRET_ACCESS_KEY,
58
- region_name=S3_REGION
59
- )
60
-
61
- # Define request models
62
- class AudioRequest(BaseModel):
63
- audio: str # Base64 encoded audio or video data
64
- language: Optional[str] = None
65
-
66
- class TTSRequest(BaseModel):
67
- text: str
68
- language: Optional[str] = None
69
- speed: float = 1.0
70
-
71
- class LanguageRequest(BaseModel):
72
- language: Optional[str] = None
73
-
74
- async def get_api_key(api_key_header: str = Security(api_key_header)):
75
- if api_key_header == API_KEY:
76
- return api_key_header
77
- raise HTTPException(status_code=403, detail="Could not validate credentials")
78
-
79
- def extract_audio_from_file(input_bytes):
80
- with tempfile.NamedTemporaryFile(delete=False, suffix='.tmp') as temp_file:
81
- temp_file.write(input_bytes)
82
- temp_file_path = temp_file.name
83
-
84
- try:
85
- # Log file info
86
- file_info = magic.from_file(temp_file_path, mime=True)
87
- logger.info(f"Received file of type: {file_info}")
88
-
89
- # Try reading with soundfile first
90
- try:
91
- audio_array, sample_rate = sf.read(temp_file_path)
92
- logger.info(f"Successfully read audio with soundfile. Shape: {audio_array.shape}, Sample rate: {sample_rate}")
93
- return audio_array, sample_rate
94
- except Exception as e:
95
- logger.info(f"Could not read with soundfile: {str(e)}")
96
-
97
- # Try reading as video
98
- try:
99
- video = VideoFileClip(temp_file_path)
100
- audio = video.audio
101
- if audio is not None:
102
- audio_array = audio.to_soundarray()
103
- sample_rate = audio.fps
104
- audio_array = audio_array.mean(axis=1) if len(audio_array.shape) > 1 and audio_array.shape[1] > 1 else audio_array
105
- audio_array = audio_array.astype(np.float32)
106
- audio_array /= np.max(np.abs(audio_array))
107
- video.close()
108
- logger.info(f"Successfully extracted audio from video. Shape: {audio_array.shape}, Sample rate: {sample_rate}")
109
- return audio_array, sample_rate
110
- else:
111
- logger.info("Video file contains no audio")
112
- except Exception as e:
113
- logger.info(f"Could not read as video: {str(e)}")
114
-
115
- # Try reading with pydub
116
- try:
117
- audio = AudioSegment.from_file(temp_file_path)
118
- audio_array = np.array(audio.get_array_of_samples())
119
- audio_array = audio_array.astype(np.float32) / (2**15 if audio.sample_width == 2 else 2**7)
120
- audio_array = audio_array.reshape((-1, 2)).mean(axis=1) if audio.channels == 2 else audio_array
121
- logger.info(f"Successfully read audio with pydub. Shape: {audio_array.shape}, Sample rate: {audio.frame_rate}")
122
- return audio_array, audio.frame_rate
123
- except Exception as e:
124
- logger.info(f"Could not read with pydub: {str(e)}")
125
-
126
- raise ValueError(f"Unsupported file format: {file_info}")
127
- finally:
128
- os.unlink(temp_file_path)
129
-
130
- @app.post("/transcribe")
131
- async def transcribe_audio(request: AudioRequest, api_key: APIKey = Depends(get_api_key)):
132
- start_time = time.time()
133
- try:
134
- input_bytes = base64.b64decode(request.audio)
135
- audio_array, sample_rate = extract_audio_from_file(input_bytes)
136
-
137
- # Ensure audio_array is float32
138
- audio_array = audio_array.astype(np.float32)
139
-
140
- # Resample if necessary
141
- if sample_rate != ASR_SAMPLING_RATE:
142
- audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=ASR_SAMPLING_RATE)
143
-
144
- if request.language is None:
145
- # If no language is provided, use language identification
146
- identified_language = identify(audio_array)
147
- result = transcribe(audio_array, identified_language)
148
- else:
149
- result = transcribe(audio_array, request.language)
150
-
151
- processing_time = time.time() - start_time
152
- return JSONResponse(content={"transcription": result, "processing_time_seconds": processing_time})
153
- except Exception as e:
154
- logger.error(f"Error in transcribe_audio: {str(e)}", exc_info=True)
155
- error_details = {
156
- "error": str(e),
157
- "traceback": traceback.format_exc()
158
- }
159
- processing_time = time.time() - start_time
160
- return JSONResponse(
161
- status_code=500,
162
- content={"message": "An error occurred during transcription", "details": error_details, "processing_time_seconds": processing_time}
163
- )
164
-
165
- @app.post("/transcribe_file")
166
- async def transcribe_audio_file(
167
- file: UploadFile = File(...),
168
- language: Optional[str] = Form(None),
169
- api_key: APIKey = Depends(get_api_key)
170
- ):
171
- start_time = time.time()
172
- try:
173
- contents = await file.read()
174
- audio_array, sample_rate = extract_audio_from_file(contents)
175
 
176
- # Ensure audio_array is float32
177
- audio_array = audio_array.astype(np.float32)
178
 
179
- # Resample if necessary
180
- if sample_rate != ASR_SAMPLING_RATE:
181
- audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=ASR_SAMPLING_RATE)
182
 
183
- if language is None:
184
- # If no language is provided, use language identification
185
- identified_language = identify(audio_array)
186
- result = transcribe(audio_array, identified_language)
187
- else:
188
- result = transcribe(audio_array, language)
189
 
190
- processing_time = time.time() - start_time
191
- return JSONResponse(content={"transcription": result, "processing_time_seconds": processing_time})
192
- except Exception as e:
193
- logger.error(f"Error in transcribe_audio_file: {str(e)}", exc_info=True)
194
- error_details = {
195
- "error": str(e),
196
- "traceback": traceback.format_exc()
197
- }
198
- processing_time = time.time() - start_time
199
- return JSONResponse(
200
- status_code=500,
201
- content={"message": "An error occurred during transcription", "details": error_details, "processing_time_seconds": processing_time}
202
- )
203
 
204
- @app.post("/synthesize")
205
- async def synthesize_speech(request: TTSRequest, api_key: APIKey = Depends(get_api_key)):
206
- start_time = time.time()
207
- logger.info(f"Synthesize request received: text='{request.text}', language='{request.language}', speed={request.speed}")
208
- try:
209
- if request.language is None:
210
- # If no language is provided, default to English
211
- lang_code = "eng"
212
- else:
213
- # Extract the ISO code from the full language name
214
- lang_code = request.language.split()[0].strip()
215
-
216
- # Input validation
217
- if not request.text:
218
- raise ValueError("Text cannot be empty")
219
- if lang_code not in TTS_LANGUAGES:
220
- raise ValueError(f"Unsupported language: {lang_code}")
221
- if not 0.5 <= request.speed <= 2.0:
222
- raise ValueError(f"Speed must be between 0.5 and 2.0, got {request.speed}")
223
-
224
- logger.info(f"Calling synthesize function with lang_code: {lang_code}")
225
- result, filtered_text = synthesize(request.text, lang_code, request.speed)
226
- logger.info(f"Synthesize function completed. Filtered text: '{filtered_text}'")
227
-
228
- if result is None:
229
- logger.error("Synthesize function returned None")
230
- raise ValueError("Synthesis failed to produce audio")
231
-
232
- sample_rate, audio = result
233
- logger.info(f"Synthesis result: sample_rate={sample_rate}, audio_shape={audio.shape if isinstance(audio, np.ndarray) else 'not numpy array'}, audio_dtype={audio.dtype if isinstance(audio, np.ndarray) else type(audio)}")
234
-
235
- logger.info("Converting audio to numpy array")
236
- audio = np.array(audio, dtype=np.float32)
237
- logger.info(f"Converted audio shape: {audio.shape}, dtype: {audio.dtype}")
238
-
239
- logger.info("Normalizing audio")
240
- max_value = np.max(np.abs(audio))
241
- if max_value == 0:
242
- logger.warning("Audio array is all zeros")
243
- raise ValueError("Generated audio is silent (all zeros)")
244
- audio = audio / max_value
245
- logger.info(f"Normalized audio range: [{audio.min()}, {audio.max()}]")
246
-
247
- logger.info("Converting to int16")
248
- audio = (audio * 32767).astype(np.int16)
249
- logger.info(f"Int16 audio shape: {audio.shape}, dtype: {audio.dtype}")
250
-
251
- logger.info("Writing audio to buffer")
252
- buffer = io.BytesIO()
253
- sf.write(buffer, audio, sample_rate, format='wav')
254
- buffer.seek(0)
255
- logger.info(f"Buffer size: {buffer.getbuffer().nbytes} bytes")
256
-
257
- # Generate a unique filename
258
- filename = f"synthesized_audio_{int(time.time())}.wav"
259
-
260
- # Upload to S3 without ACL
261
- try:
262
- s3_client.upload_fileobj(
263
- buffer,
264
- S3_BUCKET,
265
- filename,
266
- ExtraArgs={'ContentType': 'audio/wav'}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  )
268
- logger.info(f"File uploaded successfully to S3: {filename}")
269
-
270
- # Generate the public URL with the correct format
271
- url = f"https://s3.{S3_REGION}.amazonaws.com/{S3_BUCKET}/{filename}"
272
- logger.info(f"Public URL generated: {url}")
273
-
274
- processing_time = time.time() - start_time
275
- return JSONResponse(content={"audio_url": url, "processing_time_seconds": processing_time})
276
-
277
- except NoCredentialsError:
278
- logger.error("AWS credentials not available or invalid")
279
- raise HTTPException(status_code=500, detail="Could not upload file to S3: Missing or invalid credentials")
280
- except Exception as e:
281
- logger.error(f"Failed to upload to S3: {str(e)}")
282
- raise HTTPException(status_code=500, detail=f"Could not upload file to S3: {str(e)}")
283
-
284
- except ValueError as ve:
285
- logger.error(f"ValueError in synthesize_speech: {str(ve)}", exc_info=True)
286
- processing_time = time.time() - start_time
287
- return JSONResponse(
288
- status_code=400,
289
- content={"message": "Invalid input", "details": str(ve), "processing_time_seconds": processing_time}
290
- )
291
- except Exception as e:
292
- logger.error(f"Unexpected error in synthesize_speech: {str(e)}", exc_info=True)
293
- error_details = {
294
- "error": str(e),
295
- "type": type(e).__name__,
296
- "traceback": traceback.format_exc()
297
- }
298
- processing_time = time.time() - start_time
299
- return JSONResponse(
300
- status_code=500,
301
- content={"message": "An unexpected error occurred during speech synthesis", "details": error_details, "processing_time_seconds": processing_time}
302
- )
303
-
304
- @app.post("/identify")
305
- async def identify_language(request: AudioRequest, api_key: APIKey = Depends(get_api_key)):
306
- start_time = time.time()
307
- try:
308
- input_bytes = base64.b64decode(request.audio)
309
- audio_array, sample_rate = extract_audio_from_file(input_bytes)
310
- result = identify(audio_array)
311
- processing_time = time.time() - start_time
312
- return JSONResponse(content={"language_identification": result, "processing_time_seconds": processing_time})
313
- except Exception as e:
314
- logger.error(f"Error in identify_language: {str(e)}", exc_info=True)
315
- error_details = {
316
- "error": str(e),
317
- "traceback": traceback.format_exc()
318
- }
319
- processing_time = time.time() - start_time
320
- return JSONResponse(
321
- status_code=500,
322
- content={"message": "An error occurred during language identification", "details": error_details, "processing_time_seconds": processing_time}
323
- )
324
 
325
- @app.post("/identify_file")
326
- async def identify_language_file(
327
- file: UploadFile = File(...),
328
- api_key: APIKey = Depends(get_api_key)
329
- ):
330
- start_time = time.time()
331
- try:
332
- contents = await file.read()
333
- audio_array, sample_rate = extract_audio_from_file(contents)
334
- result = identify(audio_array)
335
- processing_time = time.time() - start_time
336
- return JSONResponse(content={"language_identification": result, "processing_time_seconds": processing_time})
337
- except Exception as e:
338
- logger.error(f"Error in identify_language_file: {str(e)}", exc_info=True)
339
- error_details = {
340
- "error": str(e),
341
- "traceback": traceback.format_exc()
342
- }
343
- processing_time = time.time() - start_time
344
- return JSONResponse(
345
- status_code=500,
346
- content={"message": "An error occurred during language identification", "details": error_details, "processing_time_seconds": processing_time}
347
- )
348
 
349
- @app.post("/asr_languages")
350
- async def get_asr_languages(request: LanguageRequest, api_key: APIKey = Depends(get_api_key)):
351
- start_time = time.time()
352
- try:
353
- if request.language is None or request.language == "":
354
- # If no language is provided, return all languages
355
- matching_languages = ASR_LANGUAGES
356
- else:
357
- matching_languages = [lang for lang in ASR_LANGUAGES if lang.lower().startswith(request.language.lower())]
358
-
359
- processing_time = time.time() - start_time
360
- return JSONResponse(content={"languages": matching_languages, "processing_time_seconds": processing_time})
361
- except Exception as e:
362
- logger.error(f"Error in get_asr_languages: {str(e)}", exc_info=True)
363
- error_details = {
364
- "error": str(e),
365
- "traceback": traceback.format_exc()
366
- }
367
- processing_time = time.time() - start_time
368
- return JSONResponse(
369
- status_code=500,
370
- content={"message": "An error occurred while fetching ASR languages", "details": error_details, "processing_time_seconds": processing_time}
371
- )
372
-
373
- @app.post("/tts_languages")
374
- async def get_tts_languages(request: LanguageRequest, api_key: APIKey = Depends(get_api_key)):
375
- start_time = time.time()
376
- try:
377
- if request.language is None or request.language == "":
378
- # If no language is provided, return all languages
379
- matching_languages = TTS_LANGUAGES
380
- else:
381
- matching_languages = [lang for lang in TTS_LANGUAGES if lang.lower().startswith(request.language.lower())]
382
-
383
- processing_time = time.time() - start_time
384
- return JSONResponse(content={"languages": matching_languages, "processing_time_seconds": processing_time})
385
- except Exception as e:
386
- logger.error(f"Error in get_tts_languages: {str(e)}", exc_info=True)
387
- error_details = {
388
- "error": str(e),
389
- "traceback": traceback.format_exc()
390
- }
391
- processing_time = time.time() - start_time
392
- return JSONResponse(
393
- status_code=500,
394
- content={"message": "An error occurred while fetching TTS languages", "details": error_details, "processing_time_seconds": processing_time}
395
- )
396
-
397
- @app.get("/health")
398
- async def health_check():
399
- return {"status": "ok"}
400
-
401
- @app.get("/")
402
- async def root():
403
- return {
404
- "message": "Welcome to the MMS Speech Technology API",
405
- "version": "1.0",
406
- "endpoints": [
407
- "/transcribe",
408
- "/transcribe_file",
409
- "/synthesize",
410
- "/identify",
411
- "/identify_file",
412
- "/asr_languages",
413
- "/tts_languages",
414
- "/health"
415
- ]
416
- }
 
1
  import os
2
+ import re
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import tempfile
4
+ import torch
5
+ import sys
6
+ import gradio as gr
7
+ import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ from huggingface_hub import hf_hub_download
 
10
 
11
+ # Setup TTS env
12
+ if "vits" not in sys.path:
13
+ sys.path.append("vits")
14
 
15
+ from vits import commons, utils
16
+ from vits.models import SynthesizerTrn
 
 
 
 
17
 
18
+ TTS_LANGUAGES = {}
19
+ with open(f"data/tts/all_langs.tsv") as f:
20
+ for line in f:
21
+ iso, name = line.split(" ", 1)
22
+ TTS_LANGUAGES[iso.strip()] = name.strip()
 
 
 
 
 
 
 
 
23
 
24
+ class TextMapper(object):
25
+ def __init__(self, vocab_file):
26
+ self.symbols = [
27
+ x.replace("\n", "") for x in open(vocab_file, encoding="utf-8").readlines()
28
+ ]
29
+ self.SPACE_ID = self.symbols.index(" ")
30
+ self._symbol_to_id = {s: i for i, s in enumerate(self.symbols)}
31
+ self._id_to_symbol = {i: s for i, s in enumerate(self.symbols)}
32
+
33
+ def text_to_sequence(self, text, cleaner_names):
34
+ """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
35
+ Args:
36
+ text: string to convert to a sequence
37
+ cleaner_names: names of the cleaner functions to run the text through
38
+ Returns:
39
+ List of integers corresponding to the symbols in the text
40
+ """
41
+ sequence = []
42
+ clean_text = text.strip()
43
+ for symbol in clean_text:
44
+ symbol_id = self._symbol_to_id[symbol]
45
+ sequence += [symbol_id]
46
+ return sequence
47
+
48
+ def uromanize(self, text, uroman_pl):
49
+ iso = "xxx"
50
+ with tempfile.NamedTemporaryFile() as tf, tempfile.NamedTemporaryFile() as tf2:
51
+ with open(tf.name, "w") as f:
52
+ f.write("\n".join([text]))
53
+ cmd = f"perl " + uroman_pl
54
+ cmd += f" -l {iso} "
55
+ cmd += f" < {tf.name} > {tf2.name}"
56
+ os.system(cmd)
57
+ outtexts = []
58
+ with open(tf2.name) as f:
59
+ for line in f:
60
+ line = re.sub(r"\s+", " ", line).strip()
61
+ outtexts.append(line)
62
+ outtext = outtexts[0]
63
+ return outtext
64
+
65
+ def get_text(self, text, hps):
66
+ text_norm = self.text_to_sequence(text, hps.data.text_cleaners)
67
+ if hps.data.add_blank:
68
+ text_norm = commons.intersperse(text_norm, 0)
69
+ text_norm = torch.LongTensor(text_norm)
70
+ return text_norm
71
+
72
+ def filter_oov(self, text, lang=None):
73
+ text = self.preprocess_char(text, lang=lang)
74
+ val_chars = self._symbol_to_id
75
+ txt_filt = "".join(list(filter(lambda x: x in val_chars, text)))
76
+ return txt_filt
77
+
78
+ def preprocess_char(self, text, lang=None):
79
+ """
80
+ Special treatement of characters in certain languages
81
+ """
82
+ if lang == "ron":
83
+ text = text.replace("ț", "ţ")
84
+ print(f"{lang} (ț -> ţ): {text}")
85
+ return text
86
+
87
+ def synthesize(text=None, lang=None, speed=None):
88
+ if speed is None:
89
+ speed = 1.0
90
+
91
+ lang_code = lang.split()[0].strip()
92
+
93
+ vocab_file = hf_hub_download(
94
+ repo_id="facebook/mms-tts",
95
+ filename="vocab.txt",
96
+ subfolder=f"models/{lang_code}",
97
+ )
98
+ config_file = hf_hub_download(
99
+ repo_id="facebook/mms-tts",
100
+ filename="config.json",
101
+ subfolder=f"models/{lang_code}",
102
+ )
103
+ g_pth = hf_hub_download(
104
+ repo_id="facebook/mms-tts",
105
+ filename="G_100000.pth",
106
+ subfolder=f"models/{lang_code}",
107
+ )
108
+
109
+ if torch.cuda.is_available():
110
+ device = torch.device("cuda")
111
+ elif (
112
+ hasattr(torch.backends, "mps")
113
+ and torch.backends.mps.is_available()
114
+ and torch.backends.mps.is_built()
115
+ ):
116
+ device = torch.device("mps")
117
+ else:
118
+ device = torch.device("cpu")
119
+
120
+ print(f"Run inference with {device}")
121
+
122
+ assert os.path.isfile(config_file), f"{config_file} doesn't exist"
123
+ hps = utils.get_hparams_from_file(config_file)
124
+ text_mapper = TextMapper(vocab_file)
125
+
126
+ net_g = SynthesizerTrn(
127
+ len(text_mapper.symbols),
128
+ hps.data.filter_length // 2 + 1,
129
+ hps.train.segment_size // hps.data.hop_length,
130
+ **hps.model,
131
+ ).to(device)
132
+ net_g.eval()
133
+
134
+ _ = utils.load_checkpoint(g_pth, net_g, None)
135
+
136
+ is_uroman = hps.data.training_files.split(".")[-1] == "uroman"
137
+
138
+ if is_uroman:
139
+ uroman_dir = "uroman"
140
+ assert os.path.exists(uroman_dir)
141
+ uroman_pl = os.path.join(uroman_dir, "bin", "uroman.pl")
142
+ text = text_mapper.uromanize(text, uroman_pl)
143
+
144
+ text = text.lower()
145
+ text = text_mapper.filter_oov(text, lang=lang)
146
+ stn_tst = text_mapper.get_text(text, hps).to(device)
147
+
148
+ # Use autocast for mixed-precision inference
149
+ with torch.cuda.amp.autocast(enabled=True):
150
+ with torch.no_grad():
151
+ x_tst = stn_tst.unsqueeze(0)
152
+ x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
153
+ hyp = (
154
+ net_g.infer(
155
+ x_tst,
156
+ x_tst_lengths,
157
+ noise_scale=0.667,
158
+ noise_scale_w=0.8,
159
+ length_scale=1.0 / speed,
160
+ )[0][0, 0]
161
+ .cpu()
162
+ .float() # Convert to float32 for numpy
163
+ .numpy()
164
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
+ return (hps.data.sampling_rate, hyp), text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
+ TTS_EXAMPLES = [
169
+ ["I am going to the store.", "eng (English)", 1.0],
170
+ ["안녕하세요.", "kor (Korean)", 1.0],
171
+ ["क्या मुझे पीने का पानी मिल सकता है?", "hin (Hindi)", 1.0],
172
+ ["Tanış olmağıma çox şadam", "azj-script_latin (Azerbaijani, North)", 1.0],
173
+ ["Mu zo murna a cikin ƙasar.", "hau (Hausa)", 1.0],
174
+ ]