|
import torch |
|
import torch.nn as nn |
|
from transformers import PreTrainedModel |
|
from configuration_bigcodec import BigCodecConfig |
|
|
|
|
|
from vq.codec_encoder import CodecEncoder_Transformer |
|
from vq.codec_decoder_vocos import CodecDecoderVocos |
|
from vq.module import SemanticEncoder |
|
from transformers import AutoFeatureExtractor, Wav2Vec2BertModel |
|
|
|
class XCodec2Model(PreTrainedModel): |
|
config_class = BigCodecConfig |
|
|
|
def __init__(self, config: BigCodecConfig): |
|
super().__init__(config) |
|
|
|
|
|
self.semantic_model = Wav2Vec2BertModel.from_pretrained( |
|
"facebook/w2v-bert-2.0", |
|
output_hidden_states=True |
|
) |
|
self.semantic_model.eval() |
|
|
|
self.SemanticEncoder_module = SemanticEncoder( |
|
config.semantic_hidden_size, |
|
config.semantic_hidden_size, |
|
config.semantic_hidden_size |
|
) |
|
|
|
|
|
self.CodecEnc = CodecEncoder_Transformer() |
|
|
|
|
|
self.generator = CodecDecoderVocos() |
|
|
|
|
|
self.fc_prior = nn.Linear(2048, 2048) |
|
self.fc_post_a = nn.Linear(2048, 1024) |
|
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0") |
|
self.feature_extractor = feature_extractor |
|
|
|
def forward(self, input_waveform, sample_rate=16000): |
|
""" |
|
这里的 forward 不一定要叫 forward,也可以拆成别的方法; |
|
但是如果想兼容 pipeline,需要在 forward 里给出核心逻辑。 |
|
|
|
参数: |
|
input_waveform: [batch_size, waveform_length] |
|
sample_rate: 默认 16000 |
|
返回: |
|
重构后的语音音频 (Tensor) |
|
""" |
|
|
|
|
|
input_features = self.feature_extractor( |
|
input_waveform, |
|
sampling_rate=sample_rate, |
|
return_tensors="pt" |
|
).input_features.to(self.device) |
|
|
|
|
|
semantic_output = self.semantic_model(input_features) |
|
semantic_hidden_16 = semantic_output.hidden_states[16] |
|
semantic_hidden_16 = semantic_hidden_16.transpose(1, 2) |
|
semantic_encoded = self.SemanticEncoder_module(semantic_hidden_16) |
|
|
|
|
|
wav = input_waveform.unsqueeze(1).to(self.device) |
|
vq_emb = self.CodecEnc(wav) |
|
vq_emb = vq_emb.transpose(1, 2) |
|
|
|
|
|
|
|
if vq_emb.shape[-1] != semantic_encoded.shape[-1]: |
|
|
|
min_len = min(vq_emb.shape[-1], semantic_encoded.shape[-1]) |
|
vq_emb = vq_emb[:, :, :min_len] |
|
semantic_encoded = semantic_encoded[:, :, :min_len] |
|
|
|
|
|
concat_emb = torch.cat([semantic_encoded, vq_emb], dim=1) |
|
|
|
|
|
concat_emb = self.fc_prior(concat_emb.transpose(1, 2)).transpose(1, 2) |
|
|
|
|
|
_, vq_code, _ = self.generator(concat_emb, vq=True) |
|
vq_post_emb = self.generator.quantizer.get_output_from_indices(vq_code.transpose(1, 2)) |
|
vq_post_emb = vq_post_emb.transpose(1, 2) |
|
|
|
|
|
vq_post_emb = self.fc_post_a(vq_post_emb.transpose(1, 2)).transpose(1, 2) |
|
|
|
|
|
recon_audio = self.generator(vq_post_emb.transpose(1, 2), vq=False)[0] |
|
|
|
return recon_audio |
|
|
|
def encode_code(self, input_waveform, sample_rate=16000): |
|
""" |
|
将输入的音频编码为代码表示。 |
|
|
|
参数: |
|
input_waveform: [batch_size, waveform_length] |
|
sample_rate: 默认 16000 |
|
返回: |
|
编码后的代码 (Tensor) |
|
""" |
|
with torch.no_grad(): |
|
|
|
input_features = self.feature_extractor( |
|
input_waveform, |
|
sampling_rate=sample_rate, |
|
return_tensors="pt" |
|
).input_features.to(self.device) |
|
|
|
|
|
semantic_output = self.semantic_model(input_features) |
|
semantic_hidden_16 = semantic_output.hidden_states[16] |
|
semantic_hidden_16 = semantic_hidden_16.transpose(1, 2) |
|
semantic_encoded = self.SemanticEncoder_module(semantic_hidden_16) |
|
|
|
|
|
wav = input_waveform.unsqueeze(1).to(self.device) |
|
vq_emb = self.CodecEnc(wav) |
|
vq_emb = vq_emb.transpose(1, 2) |
|
|
|
|
|
if vq_emb.shape[-1] != semantic_encoded.shape[-1]: |
|
min_len = min(vq_emb.shape[-1], semantic_encoded.shape[-1]) |
|
vq_emb = vq_emb[:, :, :min_len] |
|
semantic_encoded = semantic_encoded[:, :, :min_len] |
|
|
|
|
|
concat_emb = torch.cat([semantic_encoded, vq_emb], dim=1) |
|
|
|
|
|
concat_emb = self.fc_prior(concat_emb.transpose(1, 2)).transpose(1, 2) |
|
|
|
|
|
_, vq_code, _ = self.generator(concat_emb, vq=True) |
|
|
|
return vq_code |
|
|
|
def decode_code(self, vq_code): |
|
""" |
|
将编码后的代码解码回音频。 |
|
|
|
参数: |
|
vq_code: 编码后的代码 (Tensor) [batch, frames] |
|
返回: |
|
解码后的音频 (Tensor) [batch, waveform_length] |
|
""" |
|
with torch.no_grad(): |
|
|
|
vq_post_emb = self.generator.quantizer.get_output_from_indices(vq_code.transpose(1, 2)) |
|
vq_post_emb = vq_post_emb.transpose(1, 2) |
|
|
|
|
|
vq_post_emb = self.fc_post_a(vq_post_emb.transpose(1, 2)).transpose(1, 2) |
|
|
|
|
|
recon_audio = self.generator(vq_post_emb.transpose(1, 2), vq=False)[0] |
|
return recon_audio |
|
|