asahi417 commited on
Commit
9151f3b
·
1 Parent(s): 11553ea

add stability ts

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import time
3
  import tempfile
 
4
  from math import floor
5
  from typing import Optional, List, Dict, Any
6
 
@@ -161,12 +162,11 @@ def get_prediction(inputs, prompt: Optional[str], punctuate_text: bool = True, s
161
  generate_kwargs = {"language": "japanese", "task": "transcribe"}
162
  if prompt:
163
  generate_kwargs['prompt_ids'] = pipe.tokenizer.get_prompt_ids(prompt, return_tensors='pt').to(device)
164
- prediction = pipe(inputs, return_timestamps=True, generate_kwargs=generate_kwargs)
165
  if stabilize_timestamp:
166
  prediction['chunks'] = fix_timestamp(pipeline_output=prediction['chunks'],
167
  audio=inputs["array"],
168
- sample_rate=inputs["sampling_rate"]
169
- )
170
  if punctuate_text:
171
  prediction['chunks'] = PUNCTUATOR.punctuate(prediction['chunks'])
172
  text = "".join([c['text'] for c in prediction['chunks']])
 
1
  import os
2
  import time
3
  import tempfile
4
+ from copy import deepcopy
5
  from math import floor
6
  from typing import Optional, List, Dict, Any
7
 
 
162
  generate_kwargs = {"language": "japanese", "task": "transcribe"}
163
  if prompt:
164
  generate_kwargs['prompt_ids'] = pipe.tokenizer.get_prompt_ids(prompt, return_tensors='pt').to(device)
165
+ prediction = pipe(deepcopy(inputs), return_timestamps=True, generate_kwargs=generate_kwargs)
166
  if stabilize_timestamp:
167
  prediction['chunks'] = fix_timestamp(pipeline_output=prediction['chunks'],
168
  audio=inputs["array"],
169
+ sample_rate=inputs["sampling_rate"])
 
170
  if punctuate_text:
171
  prediction['chunks'] = PUNCTUATOR.punctuate(prediction['chunks'])
172
  text = "".join([c['text'] for c in prediction['chunks']])