jhj0517 commited on
Commit
e4c9d55
·
unverified ·
2 Parent(s): 79933ea 072ec01

Merge pull request #208 from jhj0517/fix/diarization-type

Browse files
modules/diarize/audio_loader.py CHANGED
@@ -2,6 +2,8 @@ import os
2
  import subprocess
3
  from functools import lru_cache
4
  from typing import Optional, Union
 
 
5
 
6
  import numpy as np
7
  import torch
@@ -24,32 +26,43 @@ FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
24
  TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
25
 
26
 
27
- def load_audio(file: str, sr: int = SAMPLE_RATE):
28
  """
29
- Open an audio file and read as mono waveform, resampling as necessary
30
 
31
  Parameters
32
  ----------
33
- file: str
34
- The audio file to open
35
 
36
  sr: int
37
- The sample rate to resample the audio if necessary
38
 
39
  Returns
40
  -------
41
  A NumPy array containing the audio waveform, in float32 dtype.
42
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  try:
44
- # Launches a subprocess to decode audio while down-mixing and resampling as necessary.
45
- # Requires the ffmpeg CLI to be installed.
46
  cmd = [
47
  "ffmpeg",
48
  "-nostdin",
49
  "-threads",
50
  "0",
51
  "-i",
52
- file,
53
  "-f",
54
  "s16le",
55
  "-ac",
@@ -63,6 +76,9 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
63
  out = subprocess.run(cmd, capture_output=True, check=True).stdout
64
  except subprocess.CalledProcessError as e:
65
  raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
 
 
 
66
 
67
  return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
68
 
 
2
  import subprocess
3
  from functools import lru_cache
4
  from typing import Optional, Union
5
+ from scipy.io.wavfile import write
6
+ import tempfile
7
 
8
  import numpy as np
9
  import torch
 
26
  TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
27
 
28
 
29
+ def load_audio(file: Union[str, np.ndarray], sr: int = SAMPLE_RATE) -> np.ndarray:
30
  """
31
+ Open an audio file or process a numpy array containing audio data as mono waveform, resampling as necessary.
32
 
33
  Parameters
34
  ----------
35
+ file: Union[str, np.ndarray]
36
+ The audio file to open or a numpy array containing the audio data.
37
 
38
  sr: int
39
+ The sample rate to resample the audio if necessary.
40
 
41
  Returns
42
  -------
43
  A NumPy array containing the audio waveform, in float32 dtype.
44
  """
45
+ if isinstance(file, np.ndarray):
46
+ if file.dtype != np.float32:
47
+ file = file.astype(np.float32)
48
+ if file.ndim > 1:
49
+ file = np.mean(file, axis=1)
50
+
51
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
52
+ write(temp_file.name, SAMPLE_RATE, (file * 32768).astype(np.int16))
53
+ temp_file_path = temp_file.name
54
+ temp_file.close()
55
+ else:
56
+ temp_file_path = file
57
+
58
  try:
 
 
59
  cmd = [
60
  "ffmpeg",
61
  "-nostdin",
62
  "-threads",
63
  "0",
64
  "-i",
65
+ temp_file_path,
66
  "-f",
67
  "s16le",
68
  "-ac",
 
76
  out = subprocess.run(cmd, capture_output=True, check=True).stdout
77
  except subprocess.CalledProcessError as e:
78
  raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
79
+ finally:
80
+ if isinstance(file, np.ndarray):
81
+ os.remove(temp_file_path)
82
 
83
  return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
84
 
modules/diarize/diarizer.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import torch
3
- from typing import List
 
4
  import time
5
  import logging
6
 
@@ -20,7 +21,7 @@ class Diarizer:
20
  self.pipe = None
21
 
22
  def run(self,
23
- audio: str,
24
  transcribed_result: List[dict],
25
  use_auth_token: str,
26
  device: str
 
1
  import os
2
  import torch
3
+ from typing import List, Union, BinaryIO
4
+ import numpy as np
5
  import time
6
  import logging
7
 
 
21
  self.pipe = None
22
 
23
  def run(self,
24
+ audio: Union[str, BinaryIO, np.ndarray],
25
  transcribed_result: List[dict],
26
  use_auth_token: str,
27
  device: str