hengjie yang commited on
Commit
a03d4c1
·
1 Parent(s): 4ecc033

Complete overhaul of audio processing and embedding extraction

Browse files
Files changed (1) hide show
  1. src/deploy/voice_clone.py +100 -47
src/deploy/voice_clone.py CHANGED
@@ -41,6 +41,40 @@ class VoiceCloneSystem:
41
 
42
  print("模型加载完成!")
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def extract_speaker_embedding(
45
  self,
46
  audio_paths: List[Union[str, Path]]
@@ -57,33 +91,42 @@ class VoiceCloneSystem:
57
  embeddings = []
58
 
59
  for audio_path in audio_paths:
60
- # 加载音频
61
- waveform, sr = torchaudio.load(str(audio_path))
62
-
63
- # 重采样到16kHz
64
- if sr != 16000:
65
- waveform = torchaudio.functional.resample(waveform, sr, 16000)
66
-
67
- # 确保音频是单声道
68
- if waveform.shape[0] > 1:
69
- waveform = torch.mean(waveform, dim=0, keepdim=True)
70
-
71
- # 提取特征
72
- with torch.no_grad():
73
- embedding = self.speaker_encoder.encode_batch(waveform.to(self.device))
74
- # 调整维度:从 [1, 1, 1, 512] 转换为 [1, 512]
75
- embedding = embedding.squeeze() # 移除所有维度为1的维度
76
- if embedding.dim() == 1:
77
- embedding = embedding.unsqueeze(0) # 确保是 [1, 512]
78
- embeddings.append(embedding)
 
 
 
 
 
 
79
 
80
  # 计算平均特征
81
- mean_embedding = torch.mean(torch.stack(embeddings), dim=0)
 
 
82
  if mean_embedding.dim() == 1:
83
- mean_embedding = mean_embedding.unsqueeze(0) # 确保是 [1, 512]
84
 
85
- # 打印维度信息以便调试
86
  print(f"Final embedding shape: {mean_embedding.shape}")
 
87
  return mean_embedding
88
 
89
  def generate_speech(
@@ -101,21 +144,26 @@ class VoiceCloneSystem:
101
  Returns:
102
  生成的语音波形
103
  """
104
- # 处理输入文本
105
- inputs = self.processor(text=text, return_tensors="pt")
106
-
107
- # 确保说话人特征维度正确
108
- if speaker_embedding.dim() != 2 or speaker_embedding.size(1) != 512:
109
- raise ValueError(f"Speaker embedding should have shape [1, 512], but got {speaker_embedding.shape}")
110
-
111
- # 生成语音
112
- speech = self.tts_model.generate_speech(
113
- inputs["input_ids"].to(self.device),
114
- speaker_embedding.to(self.device),
115
- vocoder=self.vocoder
116
- )
117
-
118
- return speech
 
 
 
 
 
119
 
120
  def clone_voice(
121
  self,
@@ -140,6 +188,7 @@ class VoiceCloneSystem:
140
  speech = self.generate_speech(text, speaker_embedding)
141
 
142
  return speech
 
143
  except Exception as e:
144
  print(f"Error in clone_voice: {str(e)}")
145
  raise
@@ -158,13 +207,17 @@ class VoiceCloneSystem:
158
  output_path: 输出文件路径
159
  sample_rate: 采样率
160
  """
161
- # 确保输出目录存在
162
- output_path = Path(output_path)
163
- output_path.parent.mkdir(parents=True, exist_ok=True)
164
-
165
- # 保存音频
166
- torchaudio.save(
167
- str(output_path),
168
- waveform.unsqueeze(0).cpu(),
169
- sample_rate
170
- )
 
 
 
 
 
41
 
42
  print("模型加载完成!")
43
 
44
+ def process_audio(self, waveform: torch.Tensor, sr: int) -> torch.Tensor:
45
+ """
46
+ 处理音频:重采样和转换为单声道
47
+
48
+ Args:
49
+ waveform: 输入音频波形
50
+ sr: 采样率
51
+
52
+ Returns:
53
+ 处理后的音频波形
54
+ """
55
+ # 重采样到16kHz
56
+ if sr != 16000:
57
+ waveform = torchaudio.functional.resample(waveform, sr, 16000)
58
+
59
+ # 确保音频是单声道
60
+ if waveform.shape[0] > 1:
61
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
62
+
63
+ # 标准化音频长度(3秒)
64
+ target_length = 16000 * 3
65
+ current_length = waveform.shape[1]
66
+
67
+ if current_length > target_length:
68
+ # 如果太长,截取中间部分
69
+ start = (current_length - target_length) // 2
70
+ waveform = waveform[:, start:start + target_length]
71
+ elif current_length < target_length:
72
+ # 如果太短,用0填充
73
+ padding = torch.zeros(1, target_length - current_length)
74
+ waveform = torch.cat([waveform, padding], dim=1)
75
+
76
+ return waveform
77
+
78
  def extract_speaker_embedding(
79
  self,
80
  audio_paths: List[Union[str, Path]]
 
91
  embeddings = []
92
 
93
  for audio_path in audio_paths:
94
+ try:
95
+ # 加载音频
96
+ waveform, sr = torchaudio.load(str(audio_path))
97
+
98
+ # 处理音频
99
+ waveform = self.process_audio(waveform, sr)
100
+
101
+ # 提取特征
102
+ with torch.no_grad():
103
+ # 确保输入维度正确 [batch, time]
104
+ if waveform.dim() == 2:
105
+ waveform = waveform.squeeze(0)
106
+
107
+ # 提取特征并处理维度
108
+ embedding = self.speaker_encoder.encode_batch(waveform.unsqueeze(0).to(self.device))
109
+ embedding = embedding.squeeze() # 移除所有维度为1的维度
110
+
111
+ # 打印中间结果
112
+ print(f"Raw embedding shape: {embedding.shape}")
113
+
114
+ embeddings.append(embedding)
115
+
116
+ except Exception as e:
117
+ print(f"Error processing audio {audio_path}: {str(e)}")
118
+ raise
119
 
120
  # 计算平均特征
121
+ mean_embedding = torch.stack(embeddings).mean(dim=0)
122
+
123
+ # 确保最终维度正确 [1, 512]
124
  if mean_embedding.dim() == 1:
125
+ mean_embedding = mean_embedding.unsqueeze(0)
126
 
127
+ # 打印最终维度
128
  print(f"Final embedding shape: {mean_embedding.shape}")
129
+
130
  return mean_embedding
131
 
132
  def generate_speech(
 
144
  Returns:
145
  生成的语音波形
146
  """
147
+ try:
148
+ # 处理输入文本
149
+ inputs = self.processor(text=text, return_tensors="pt")
150
+
151
+ # 确保说话人特征维度正确
152
+ if speaker_embedding.dim() != 2 or speaker_embedding.size(1) != 512:
153
+ raise ValueError(f"Speaker embedding should have shape [1, 512], but got {speaker_embedding.shape}")
154
+
155
+ # 生成语音
156
+ speech = self.tts_model.generate_speech(
157
+ inputs["input_ids"].to(self.device),
158
+ speaker_embedding.to(self.device),
159
+ vocoder=self.vocoder
160
+ )
161
+
162
+ return speech
163
+
164
+ except Exception as e:
165
+ print(f"Error in generate_speech: {str(e)}")
166
+ raise
167
 
168
  def clone_voice(
169
  self,
 
188
  speech = self.generate_speech(text, speaker_embedding)
189
 
190
  return speech
191
+
192
  except Exception as e:
193
  print(f"Error in clone_voice: {str(e)}")
194
  raise
 
207
  output_path: 输出文件路径
208
  sample_rate: 采样率
209
  """
210
+ try:
211
+ # 确保输出目录存在
212
+ output_path = Path(output_path)
213
+ output_path.parent.mkdir(parents=True, exist_ok=True)
214
+
215
+ # 保存音频
216
+ torchaudio.save(
217
+ str(output_path),
218
+ waveform.unsqueeze(0).cpu(),
219
+ sample_rate
220
+ )
221
+ except Exception as e:
222
+ print(f"Error saving audio: {str(e)}")
223
+ raise