Spaces:
Runtime error
Runtime error
Commit
•
9f8c873
1
Parent(s):
746f081
Update asr_diarizer.py
Browse files- asr_diarizer.py +5 -9
asr_diarizer.py
CHANGED
@@ -1,12 +1,10 @@
|
|
1 |
from typing import List, Optional, Union
|
2 |
|
3 |
import numpy as np
|
4 |
-
import torch
|
5 |
-
from torchaudio import functional as F
|
6 |
-
|
7 |
import requests
|
8 |
-
|
9 |
from pyannote.audio import Pipeline
|
|
|
10 |
from transformers import pipeline
|
11 |
from transformers.pipelines.audio_utils import ffmpeg_read
|
12 |
|
@@ -19,7 +17,7 @@ class ASRDiarizationPipeline:
|
|
19 |
):
|
20 |
self.asr_pipeline = asr_pipeline
|
21 |
self.diarization_pipeline = diarization_pipeline
|
22 |
-
|
23 |
self.sampling_rate = self.asr_pipeline.feature_extractor.sampling_rate
|
24 |
|
25 |
@classmethod
|
@@ -138,7 +136,7 @@ class ASRDiarizationPipeline:
|
|
138 |
end_timestamps = end_timestamps[upto_idx + 1 :]
|
139 |
|
140 |
return segmented_preds
|
141 |
-
|
142 |
def preprocess(self, inputs):
|
143 |
if isinstance(inputs, str):
|
144 |
if inputs.startswith("http://") or inputs.startswith("https://"):
|
@@ -169,9 +167,7 @@ class ASRDiarizationPipeline:
|
|
169 |
in_sampling_rate = inputs.pop("sampling_rate")
|
170 |
inputs = _inputs
|
171 |
if in_sampling_rate != self.sampling_rate:
|
172 |
-
inputs = F.resample(
|
173 |
-
torch.from_numpy(inputs), in_sampling_rate, self.sampling_rate
|
174 |
-
).numpy()
|
175 |
|
176 |
if not isinstance(inputs, np.ndarray):
|
177 |
raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`")
|
|
|
1 |
from typing import List, Optional, Union
|
2 |
|
3 |
import numpy as np
|
|
|
|
|
|
|
4 |
import requests
|
5 |
+
import torch
|
6 |
from pyannote.audio import Pipeline
|
7 |
+
from torchaudio import functional as F
|
8 |
from transformers import pipeline
|
9 |
from transformers.pipelines.audio_utils import ffmpeg_read
|
10 |
|
|
|
17 |
):
|
18 |
self.asr_pipeline = asr_pipeline
|
19 |
self.diarization_pipeline = diarization_pipeline
|
20 |
+
|
21 |
self.sampling_rate = self.asr_pipeline.feature_extractor.sampling_rate
|
22 |
|
23 |
@classmethod
|
|
|
136 |
end_timestamps = end_timestamps[upto_idx + 1 :]
|
137 |
|
138 |
return segmented_preds
|
139 |
+
|
140 |
def preprocess(self, inputs):
|
141 |
if isinstance(inputs, str):
|
142 |
if inputs.startswith("http://") or inputs.startswith("https://"):
|
|
|
167 |
in_sampling_rate = inputs.pop("sampling_rate")
|
168 |
inputs = _inputs
|
169 |
if in_sampling_rate != self.sampling_rate:
|
170 |
+
inputs = F.resample(torch.from_numpy(inputs), in_sampling_rate, self.sampling_rate).numpy()
|
|
|
|
|
171 |
|
172 |
if not isinstance(inputs, np.ndarray):
|
173 |
raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`")
|