import torch import torch.nn as nn from transformers import PreTrainedModel from configuration_bigcodec import BigCodecConfig # 请确保这些模块路径是正确的 from vq.codec_encoder import CodecEncoder_Transformer from vq.codec_decoder_vocos import CodecDecoderVocos from vq.module import SemanticEncoder from transformers import AutoFeatureExtractor, Wav2Vec2BertModel class XCodec2Model(PreTrainedModel): config_class = BigCodecConfig def __init__(self, config: BigCodecConfig): super().__init__(config) # 1) 语义模型 self.semantic_model = Wav2Vec2BertModel.from_pretrained( "facebook/w2v-bert-2.0", output_hidden_states=True ) self.semantic_model.eval() self.SemanticEncoder_module = SemanticEncoder( config.semantic_hidden_size, config.semantic_hidden_size, config.semantic_hidden_size ) # 2) Codec Encoder self.CodecEnc = CodecEncoder_Transformer() # 3) Codec Decoder self.generator = CodecDecoderVocos() # 4) 两个全连接层 self.fc_prior = nn.Linear(2048, 2048) self.fc_post_a = nn.Linear(2048, 1024) feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0") self.feature_extractor = feature_extractor def forward(self, input_waveform, sample_rate=16000): """ 这里的 forward 不一定要叫 forward,也可以拆成别的方法; 但是如果想兼容 pipeline,需要在 forward 里给出核心逻辑。 参数: input_waveform: [batch_size, waveform_length] sample_rate: 默认 16000 返回: 重构后的语音音频 (Tensor) """ # 1) 特征提取 # 如果需要 padding,可以在这里做 input_features = self.feature_extractor( input_waveform, sampling_rate=sample_rate, return_tensors="pt" ).input_features.to(self.device) # [batch, frames, feat_dim] # 2) 语义层 semantic_output = self.semantic_model(input_features) semantic_hidden_16 = semantic_output.hidden_states[16] # 取第16层 semantic_hidden_16 = semantic_hidden_16.transpose(1, 2) # [batch, hidden_dim, frames] semantic_encoded = self.SemanticEncoder_module(semantic_hidden_16) # 3) codec encoder wav = input_waveform.unsqueeze(1).to(self.device) # shape: [batch, 1, time] vq_emb = self.CodecEnc(wav) # [batch, time//down, 1024] 只是示例 vq_emb = vq_emb.transpose(1, 2) # -> [batch, 1024, frames] # 对齐语义向量的时间帧数,这里只做示例处理 # 真实做法里可能要先对齐维度 if vq_emb.shape[-1] != semantic_encoded.shape[-1]: # 简单强行截断或补零都行,需要你自己决定 min_len = min(vq_emb.shape[-1], semantic_encoded.shape[-1]) vq_emb = vq_emb[:, :, :min_len] semantic_encoded = semantic_encoded[:, :, :min_len] # 4) 拼接 concat_emb = torch.cat([semantic_encoded, vq_emb], dim=1) # [batch, 1024 + 1024, frames] # 5) fc_prior concat_emb = self.fc_prior(concat_emb.transpose(1, 2)).transpose(1, 2) # 6) decoder 的量化部分 _, vq_code, _ = self.generator(concat_emb, vq=True) vq_post_emb = self.generator.quantizer.get_output_from_indices(vq_code.transpose(1, 2)) vq_post_emb = vq_post_emb.transpose(1, 2) # 7) fc_post_a vq_post_emb = self.fc_post_a(vq_post_emb.transpose(1, 2)).transpose(1, 2) # 8) 最后解码成波形 recon_audio = self.generator(vq_post_emb.transpose(1, 2), vq=False)[0] # recon_audio: [batch, time] return recon_audio def encode_code(self, input_waveform, sample_rate=16000): """ 将输入的音频编码为代码表示。 参数: input_waveform: [batch_size, waveform_length] sample_rate: 默认 16000 返回: 编码后的代码 (Tensor) """ with torch.no_grad(): # 1) 特征提取 input_features = self.feature_extractor( input_waveform, sampling_rate=sample_rate, return_tensors="pt" ).input_features.to(self.device) # [batch, frames, feat_dim] # 2) 语义层 semantic_output = self.semantic_model(input_features) semantic_hidden_16 = semantic_output.hidden_states[16] # 取第16层 semantic_hidden_16 = semantic_hidden_16.transpose(1, 2) # [batch, hidden_dim, frames] semantic_encoded = self.SemanticEncoder_module(semantic_hidden_16) # 3) codec encoder wav = input_waveform.unsqueeze(1).to(self.device) # shape: [batch, 1, time] vq_emb = self.CodecEnc(wav) # [batch, time//down, 1024] 只是示例 vq_emb = vq_emb.transpose(1, 2) # -> [batch, 1024, frames] # 对齐语义向量的时间帧数,这里只做示例处理 if vq_emb.shape[-1] != semantic_encoded.shape[-1]: min_len = min(vq_emb.shape[-1], semantic_encoded.shape[-1]) vq_emb = vq_emb[:, :, :min_len] semantic_encoded = semantic_encoded[:, :, :min_len] # 4) 拼接 concat_emb = torch.cat([semantic_encoded, vq_emb], dim=1) # [batch, 2048, frames] # 5) fc_prior concat_emb = self.fc_prior(concat_emb.transpose(1, 2)).transpose(1, 2) # 6) decoder 的量化部分,获取code _, vq_code, _ = self.generator(concat_emb, vq=True) # vq_code: [batch, frames] return vq_code def decode_code(self, vq_code): """ 将编码后的代码解码回音频。 参数: vq_code: 编码后的代码 (Tensor) [batch, frames] 返回: 解码后的音频 (Tensor) [batch, waveform_length] """ with torch.no_grad(): # 获取量化后的嵌入 vq_post_emb = self.generator.quantizer.get_output_from_indices(vq_code.transpose(1, 2)) vq_post_emb = vq_post_emb.transpose(1, 2) # [batch, 1024, frames] # 7) fc_post_a vq_post_emb = self.fc_post_a(vq_post_emb.transpose(1, 2)).transpose(1, 2) # [batch, 1024, frames] # 8) 最后解码成波形 recon_audio = self.generator(vq_post_emb.transpose(1, 2), vq=False)[0] # [batch, time] return recon_audio