sanchit-gandhi HF staff commited on
Commit
9f8c873
1 Parent(s): 746f081

Update asr_diarizer.py

Browse files
Files changed (1) hide show
  1. asr_diarizer.py +5 -9
asr_diarizer.py CHANGED
@@ -1,12 +1,10 @@
1
  from typing import List, Optional, Union
2
 
3
  import numpy as np
4
- import torch
5
- from torchaudio import functional as F
6
-
7
  import requests
8
-
9
  from pyannote.audio import Pipeline
 
10
  from transformers import pipeline
11
  from transformers.pipelines.audio_utils import ffmpeg_read
12
 
@@ -19,7 +17,7 @@ class ASRDiarizationPipeline:
19
  ):
20
  self.asr_pipeline = asr_pipeline
21
  self.diarization_pipeline = diarization_pipeline
22
-
23
  self.sampling_rate = self.asr_pipeline.feature_extractor.sampling_rate
24
 
25
  @classmethod
@@ -138,7 +136,7 @@ class ASRDiarizationPipeline:
138
  end_timestamps = end_timestamps[upto_idx + 1 :]
139
 
140
  return segmented_preds
141
-
142
  def preprocess(self, inputs):
143
  if isinstance(inputs, str):
144
  if inputs.startswith("http://") or inputs.startswith("https://"):
@@ -169,9 +167,7 @@ class ASRDiarizationPipeline:
169
  in_sampling_rate = inputs.pop("sampling_rate")
170
  inputs = _inputs
171
  if in_sampling_rate != self.sampling_rate:
172
- inputs = F.resample(
173
- torch.from_numpy(inputs), in_sampling_rate, self.sampling_rate
174
- ).numpy()
175
 
176
  if not isinstance(inputs, np.ndarray):
177
  raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`")
 
1
  from typing import List, Optional, Union
2
 
3
  import numpy as np
 
 
 
4
  import requests
5
+ import torch
6
  from pyannote.audio import Pipeline
7
+ from torchaudio import functional as F
8
  from transformers import pipeline
9
  from transformers.pipelines.audio_utils import ffmpeg_read
10
 
 
17
  ):
18
  self.asr_pipeline = asr_pipeline
19
  self.diarization_pipeline = diarization_pipeline
20
+
21
  self.sampling_rate = self.asr_pipeline.feature_extractor.sampling_rate
22
 
23
  @classmethod
 
136
  end_timestamps = end_timestamps[upto_idx + 1 :]
137
 
138
  return segmented_preds
139
+
140
  def preprocess(self, inputs):
141
  if isinstance(inputs, str):
142
  if inputs.startswith("http://") or inputs.startswith("https://"):
 
167
  in_sampling_rate = inputs.pop("sampling_rate")
168
  inputs = _inputs
169
  if in_sampling_rate != self.sampling_rate:
170
+ inputs = F.resample(torch.from_numpy(inputs), in_sampling_rate, self.sampling_rate).numpy()
 
 
171
 
172
  if not isinstance(inputs, np.ndarray):
173
  raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`")