File size: 6,672 Bytes
574a515 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from configuration_bigcodec import BigCodecConfig
# 请确保这些模块路径是正确的
from vq.codec_encoder import CodecEncoder_Transformer
from vq.codec_decoder_vocos import CodecDecoderVocos
from vq.module import SemanticEncoder
from transformers import AutoFeatureExtractor, Wav2Vec2BertModel
class XCodec2Model(PreTrainedModel):
config_class = BigCodecConfig
def __init__(self, config: BigCodecConfig):
super().__init__(config)
# 1) 语义模型
self.semantic_model = Wav2Vec2BertModel.from_pretrained(
"facebook/w2v-bert-2.0",
output_hidden_states=True
)
self.semantic_model.eval()
self.SemanticEncoder_module = SemanticEncoder(
config.semantic_hidden_size,
config.semantic_hidden_size,
config.semantic_hidden_size
)
# 2) Codec Encoder
self.CodecEnc = CodecEncoder_Transformer()
# 3) Codec Decoder
self.generator = CodecDecoderVocos()
# 4) 两个全连接层
self.fc_prior = nn.Linear(2048, 2048)
self.fc_post_a = nn.Linear(2048, 1024)
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
self.feature_extractor = feature_extractor
def forward(self, input_waveform, sample_rate=16000):
"""
这里的 forward 不一定要叫 forward,也可以拆成别的方法;
但是如果想兼容 pipeline,需要在 forward 里给出核心逻辑。
参数:
input_waveform: [batch_size, waveform_length]
sample_rate: 默认 16000
返回:
重构后的语音音频 (Tensor)
"""
# 1) 特征提取
# 如果需要 padding,可以在这里做
input_features = self.feature_extractor(
input_waveform,
sampling_rate=sample_rate,
return_tensors="pt"
).input_features.to(self.device) # [batch, frames, feat_dim]
# 2) 语义层
semantic_output = self.semantic_model(input_features)
semantic_hidden_16 = semantic_output.hidden_states[16] # 取第16层
semantic_hidden_16 = semantic_hidden_16.transpose(1, 2) # [batch, hidden_dim, frames]
semantic_encoded = self.SemanticEncoder_module(semantic_hidden_16)
# 3) codec encoder
wav = input_waveform.unsqueeze(1).to(self.device) # shape: [batch, 1, time]
vq_emb = self.CodecEnc(wav) # [batch, time//down, 1024] 只是示例
vq_emb = vq_emb.transpose(1, 2) # -> [batch, 1024, frames]
# 对齐语义向量的时间帧数,这里只做示例处理
# 真实做法里可能要先对齐维度
if vq_emb.shape[-1] != semantic_encoded.shape[-1]:
# 简单强行截断或补零都行,需要你自己决定
min_len = min(vq_emb.shape[-1], semantic_encoded.shape[-1])
vq_emb = vq_emb[:, :, :min_len]
semantic_encoded = semantic_encoded[:, :, :min_len]
# 4) 拼接
concat_emb = torch.cat([semantic_encoded, vq_emb], dim=1) # [batch, 1024 + 1024, frames]
# 5) fc_prior
concat_emb = self.fc_prior(concat_emb.transpose(1, 2)).transpose(1, 2)
# 6) decoder 的量化部分
_, vq_code, _ = self.generator(concat_emb, vq=True)
vq_post_emb = self.generator.quantizer.get_output_from_indices(vq_code.transpose(1, 2))
vq_post_emb = vq_post_emb.transpose(1, 2)
# 7) fc_post_a
vq_post_emb = self.fc_post_a(vq_post_emb.transpose(1, 2)).transpose(1, 2)
# 8) 最后解码成波形
recon_audio = self.generator(vq_post_emb.transpose(1, 2), vq=False)[0]
# recon_audio: [batch, time]
return recon_audio
def encode_code(self, input_waveform, sample_rate=16000):
"""
将输入的音频编码为代码表示。
参数:
input_waveform: [batch_size, waveform_length]
sample_rate: 默认 16000
返回:
编码后的代码 (Tensor)
"""
with torch.no_grad():
# 1) 特征提取
input_features = self.feature_extractor(
input_waveform,
sampling_rate=sample_rate,
return_tensors="pt"
).input_features.to(self.device) # [batch, frames, feat_dim]
# 2) 语义层
semantic_output = self.semantic_model(input_features)
semantic_hidden_16 = semantic_output.hidden_states[16] # 取第16层
semantic_hidden_16 = semantic_hidden_16.transpose(1, 2) # [batch, hidden_dim, frames]
semantic_encoded = self.SemanticEncoder_module(semantic_hidden_16)
# 3) codec encoder
wav = input_waveform.unsqueeze(1).to(self.device) # shape: [batch, 1, time]
vq_emb = self.CodecEnc(wav) # [batch, time//down, 1024] 只是示例
vq_emb = vq_emb.transpose(1, 2) # -> [batch, 1024, frames]
# 对齐语义向量的时间帧数,这里只做示例处理
if vq_emb.shape[-1] != semantic_encoded.shape[-1]:
min_len = min(vq_emb.shape[-1], semantic_encoded.shape[-1])
vq_emb = vq_emb[:, :, :min_len]
semantic_encoded = semantic_encoded[:, :, :min_len]
# 4) 拼接
concat_emb = torch.cat([semantic_encoded, vq_emb], dim=1) # [batch, 2048, frames]
# 5) fc_prior
concat_emb = self.fc_prior(concat_emb.transpose(1, 2)).transpose(1, 2)
# 6) decoder 的量化部分,获取code
_, vq_code, _ = self.generator(concat_emb, vq=True)
# vq_code: [batch, frames]
return vq_code
def decode_code(self, vq_code):
"""
将编码后的代码解码回音频。
参数:
vq_code: 编码后的代码 (Tensor) [batch, frames]
返回:
解码后的音频 (Tensor) [batch, waveform_length]
"""
with torch.no_grad():
# 获取量化后的嵌入
vq_post_emb = self.generator.quantizer.get_output_from_indices(vq_code.transpose(1, 2))
vq_post_emb = vq_post_emb.transpose(1, 2) # [batch, 1024, frames]
# 7) fc_post_a
vq_post_emb = self.fc_post_a(vq_post_emb.transpose(1, 2)).transpose(1, 2) # [batch, 1024, frames]
# 8) 最后解码成波形
recon_audio = self.generator(vq_post_emb.transpose(1, 2), vq=False)[0] # [batch, time]
return recon_audio
|