Spaces:
Afrinetwork
/
Running on TPU v5e

Afrinetwork7 commited on
Commit
7826b7e
1 Parent(s): 118252d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -22
app.py CHANGED
@@ -50,37 +50,57 @@ logger.debug(f"Compiled in {compile_time}s")
50
 
51
  @app.post("/transcribe_audio")
52
  async def transcribe_chunked_audio(audio_file: UploadFile, task: str = "transcribe", return_timestamps: bool = False):
53
- logger.debug("Loading audio file...")
 
 
 
54
  if not audio_file:
55
  logger.warning("No audio file")
56
  raise HTTPException(status_code=400, detail="No audio file submitted!")
57
- file_size_mb = os.stat(audio_file.filename).st_size / (1024 * 1024)
 
 
 
 
 
 
 
 
 
58
  if file_size_mb > FILE_LIMIT_MB:
59
- logger.warning("Max file size exceeded")
60
  raise HTTPException(status_code=400, detail=f"File size exceeds file size limit. Got file of size {file_size_mb:.2f}MB for a limit of {FILE_LIMIT_MB}MB.")
61
 
62
  try:
63
  logger.debug(f"Opening audio file: {audio_file.filename}")
64
  with open(audio_file.filename, "rb") as f:
65
  inputs = f.read()
 
66
  except Exception as e:
67
- logger.error("Error reading audio file:", exc_info=True)
68
- raise HTTPException(status_code=500, detail="Error reading audio file")
69
 
70
- logger.debug("Performing ffmpeg read on audio file")
71
- inputs = ffmpeg_read(inputs, pipeline.feature_extractor.sampling_rate)
72
- inputs = {"array": inputs, "sampling_rate": pipeline.feature_extractor.sampling_rate}
73
- logger.debug("Done loading audio file")
 
 
 
 
74
 
 
75
  try:
76
- logger.debug("Calling tqdm_generate to transcribe audio")
77
  text, runtime = tqdm_generate(inputs, task=task, return_timestamps=return_timestamps)
 
78
  except Exception as e:
79
- logger.error("Error transcribing audio:", exc_info=True)
80
- raise HTTPException(status_code=500, detail="Error transcribing audio")
81
 
 
82
  return {"text": text, "runtime": runtime}
83
 
 
84
  @app.post("/transcribe_youtube")
85
  async def transcribe_youtube(yt_url: str = Form(...), task: str = "transcribe", return_timestamps: bool = False):
86
  logger.debug("Loading YouTube file...")
@@ -121,29 +141,49 @@ async def transcribe_youtube(yt_url: str = Form(...), task: str = "transcribe",
121
  return {"html_embed": html_embed_str, "text": text, "runtime": runtime}
122
 
123
  def tqdm_generate(inputs: dict, task: str, return_timestamps: bool):
 
 
124
  inputs_len = inputs["array"].shape[0]
 
 
125
  all_chunk_start_idx = np.arange(0, inputs_len, step)
126
  num_samples = len(all_chunk_start_idx)
127
  num_batches = math.ceil(num_samples / BATCH_SIZE)
 
128
 
129
  logger.debug("Preprocessing audio for inference")
130
- dataloader = pipeline.preprocess_batch(inputs, chunk_length_s=CHUNK_LENGTH_S, batch_size=BATCH_SIZE)
 
 
 
 
 
 
131
  model_outputs = []
132
  start_time = time.time()
133
- logger.debug("Transcribing...")
134
- # iterate over our chunked audio samples - always predict timestamps to reduce hallucinations
135
- for batch in dataloader:
136
- logger.debug(f"Processing batch of {len(batch)} samples")
137
- model_outputs.append(pipeline.forward(batch, batch_size=BATCH_SIZE, task=task, return_timestamps=True))
 
 
 
 
 
 
 
138
  runtime = time.time() - start_time
139
- logger.debug("Done transcription")
140
 
141
  logger.debug("Post-processing transcription results")
142
  try:
143
  post_processed = pipeline.postprocess(model_outputs, return_timestamps=True)
 
144
  except Exception as e:
145
- logger.error("Error post-processing transcription:", exc_info=True)
146
- raise HTTPException(status_code=500, detail="Error post-processing transcription")
 
147
  text = post_processed["text"]
148
  if return_timestamps:
149
  timestamps = post_processed.get("chunks")
@@ -152,7 +192,8 @@ def tqdm_generate(inputs: dict, task: str, return_timestamps: bool):
152
  for chunk in timestamps
153
  ]
154
  text = "\n".join(str(feature) for feature in timestamps)
155
- logger.debug("Done post-processing")
 
156
  return text, runtime
157
 
158
  def _return_yt_html_embed(yt_url):
 
50
 
51
  @app.post("/transcribe_audio")
52
  async def transcribe_chunked_audio(audio_file: UploadFile, task: str = "transcribe", return_timestamps: bool = False):
53
+ logger.debug("Starting transcribe_chunked_audio function")
54
+ logger.debug(f"Received parameters - task: {task}, return_timestamps: {return_timestamps}")
55
+
56
+ logger.debug("Checking for audio file...")
57
  if not audio_file:
58
  logger.warning("No audio file")
59
  raise HTTPException(status_code=400, detail="No audio file submitted!")
60
+
61
+ logger.debug(f"Audio file received: {audio_file.filename}")
62
+
63
+ try:
64
+ file_size_mb = os.stat(audio_file.filename).st_size / (1024 * 1024)
65
+ logger.debug(f"File size: {file_size_mb:.2f}MB")
66
+ except Exception as e:
67
+ logger.error(f"Error getting file size: {str(e)}", exc_info=True)
68
+ raise HTTPException(status_code=500, detail="Error checking file size")
69
+
70
  if file_size_mb > FILE_LIMIT_MB:
71
+ logger.warning(f"Max file size exceeded: {file_size_mb:.2f}MB > {FILE_LIMIT_MB}MB")
72
  raise HTTPException(status_code=400, detail=f"File size exceeds file size limit. Got file of size {file_size_mb:.2f}MB for a limit of {FILE_LIMIT_MB}MB.")
73
 
74
  try:
75
  logger.debug(f"Opening audio file: {audio_file.filename}")
76
  with open(audio_file.filename, "rb") as f:
77
  inputs = f.read()
78
+ logger.debug("Audio file read successfully")
79
  except Exception as e:
80
+ logger.error(f"Error reading audio file: {str(e)}", exc_info=True)
81
+ raise HTTPException(status_code=500, detail=f"Error reading audio file: {str(e)}")
82
 
83
+ try:
84
+ logger.debug("Performing ffmpeg read on audio file")
85
+ inputs = ffmpeg_read(inputs, pipeline.feature_extractor.sampling_rate)
86
+ inputs = {"array": inputs, "sampling_rate": pipeline.feature_extractor.sampling_rate}
87
+ logger.debug("ffmpeg read completed successfully")
88
+ except Exception as e:
89
+ logger.error(f"Error in ffmpeg read: {str(e)}", exc_info=True)
90
+ raise HTTPException(status_code=500, detail=f"Error processing audio file: {str(e)}")
91
 
92
+ logger.debug("Calling tqdm_generate to transcribe audio")
93
  try:
 
94
  text, runtime = tqdm_generate(inputs, task=task, return_timestamps=return_timestamps)
95
+ logger.debug(f"Transcription completed. Runtime: {runtime:.2f}s")
96
  except Exception as e:
97
+ logger.error(f"Error in tqdm_generate: {str(e)}", exc_info=True)
98
+ raise HTTPException(status_code=500, detail=f"Error transcribing audio: {str(e)}")
99
 
100
+ logger.debug("Transcribe_chunked_audio function completed successfully")
101
  return {"text": text, "runtime": runtime}
102
 
103
+
104
  @app.post("/transcribe_youtube")
105
  async def transcribe_youtube(yt_url: str = Form(...), task: str = "transcribe", return_timestamps: bool = False):
106
  logger.debug("Loading YouTube file...")
 
141
  return {"html_embed": html_embed_str, "text": text, "runtime": runtime}
142
 
143
  def tqdm_generate(inputs: dict, task: str, return_timestamps: bool):
144
+ logger.debug(f"Starting tqdm_generate - task: {task}, return_timestamps: {return_timestamps}")
145
+
146
  inputs_len = inputs["array"].shape[0]
147
+ logger.debug(f"Input array length: {inputs_len}")
148
+
149
  all_chunk_start_idx = np.arange(0, inputs_len, step)
150
  num_samples = len(all_chunk_start_idx)
151
  num_batches = math.ceil(num_samples / BATCH_SIZE)
152
+ logger.debug(f"Number of samples: {num_samples}, Number of batches: {num_batches}")
153
 
154
  logger.debug("Preprocessing audio for inference")
155
+ try:
156
+ dataloader = pipeline.preprocess_batch(inputs, chunk_length_s=CHUNK_LENGTH_S, batch_size=BATCH_SIZE)
157
+ logger.debug("Preprocessing completed successfully")
158
+ except Exception as e:
159
+ logger.error(f"Error in preprocessing: {str(e)}", exc_info=True)
160
+ raise
161
+
162
  model_outputs = []
163
  start_time = time.time()
164
+ logger.debug("Starting transcription...")
165
+
166
+ try:
167
+ for i, batch in enumerate(dataloader):
168
+ logger.debug(f"Processing batch {i+1}/{num_batches} with {len(batch)} samples")
169
+ batch_output = pipeline.forward(batch, batch_size=BATCH_SIZE, task=task, return_timestamps=True)
170
+ model_outputs.append(batch_output)
171
+ logger.debug(f"Batch {i+1} processed successfully")
172
+ except Exception as e:
173
+ logger.error(f"Error during batch processing: {str(e)}", exc_info=True)
174
+ raise
175
+
176
  runtime = time.time() - start_time
177
+ logger.debug(f"Transcription completed in {runtime:.2f}s")
178
 
179
  logger.debug("Post-processing transcription results")
180
  try:
181
  post_processed = pipeline.postprocess(model_outputs, return_timestamps=True)
182
+ logger.debug("Post-processing completed successfully")
183
  except Exception as e:
184
+ logger.error(f"Error in post-processing: {str(e)}", exc_info=True)
185
+ raise
186
+
187
  text = post_processed["text"]
188
  if return_timestamps:
189
  timestamps = post_processed.get("chunks")
 
192
  for chunk in timestamps
193
  ]
194
  text = "\n".join(str(feature) for feature in timestamps)
195
+
196
+ logger.debug("tqdm_generate function completed successfully")
197
  return text, runtime
198
 
199
  def _return_yt_html_embed(yt_url):