File size: 6,672 Bytes
574a515
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from configuration_bigcodec import BigCodecConfig

# 请确保这些模块路径是正确的
from vq.codec_encoder import CodecEncoder_Transformer
from vq.codec_decoder_vocos import CodecDecoderVocos
from vq.module import SemanticEncoder
from transformers import AutoFeatureExtractor, Wav2Vec2BertModel

class XCodec2Model(PreTrainedModel):
    config_class = BigCodecConfig

    def __init__(self, config: BigCodecConfig):
        super().__init__(config)

        # 1) 语义模型
        self.semantic_model = Wav2Vec2BertModel.from_pretrained(
            "facebook/w2v-bert-2.0",
            output_hidden_states=True
        )
        self.semantic_model.eval()

        self.SemanticEncoder_module = SemanticEncoder(
             config.semantic_hidden_size,
             config.semantic_hidden_size,
             config.semantic_hidden_size
        )

        # 2) Codec Encoder
        self.CodecEnc = CodecEncoder_Transformer()

        # 3) Codec Decoder
        self.generator = CodecDecoderVocos()  

        # 4) 两个全连接层
        self.fc_prior = nn.Linear(2048, 2048)
        self.fc_post_a = nn.Linear(2048, 1024)
        feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
        self.feature_extractor = feature_extractor

    def forward(self, input_waveform, sample_rate=16000):
        """
        这里的 forward 不一定要叫 forward,也可以拆成别的方法;
        但是如果想兼容 pipeline,需要在 forward 里给出核心逻辑。
        
        参数:
          input_waveform: [batch_size, waveform_length]
          sample_rate: 默认 16000
        返回:
          重构后的语音音频 (Tensor)
        """
        # 1) 特征提取
        # 如果需要 padding,可以在这里做
        input_features = self.feature_extractor(
            input_waveform, 
            sampling_rate=sample_rate, 
            return_tensors="pt"
        ).input_features.to(self.device)  # [batch, frames, feat_dim]

        # 2) 语义层
        semantic_output = self.semantic_model(input_features)
        semantic_hidden_16 = semantic_output.hidden_states[16]  # 取第16层
        semantic_hidden_16 = semantic_hidden_16.transpose(1, 2)  # [batch, hidden_dim, frames]
        semantic_encoded = self.SemanticEncoder_module(semantic_hidden_16)

        # 3) codec encoder
        wav = input_waveform.unsqueeze(1).to(self.device)  # shape: [batch, 1, time]
        vq_emb = self.CodecEnc(wav)  # [batch, time//down, 1024] 只是示例
        vq_emb = vq_emb.transpose(1, 2)  # -> [batch, 1024, frames]

        # 对齐语义向量的时间帧数,这里只做示例处理
        # 真实做法里可能要先对齐维度
        if vq_emb.shape[-1] != semantic_encoded.shape[-1]:
            # 简单强行截断或补零都行,需要你自己决定
            min_len = min(vq_emb.shape[-1], semantic_encoded.shape[-1])
            vq_emb = vq_emb[:, :, :min_len]
            semantic_encoded = semantic_encoded[:, :, :min_len]

        # 4) 拼接
        concat_emb = torch.cat([semantic_encoded, vq_emb], dim=1)  # [batch, 1024 + 1024, frames]

        # 5) fc_prior
        concat_emb = self.fc_prior(concat_emb.transpose(1, 2)).transpose(1, 2)

        # 6) decoder 的量化部分
        _, vq_code, _ = self.generator(concat_emb, vq=True)
        vq_post_emb = self.generator.quantizer.get_output_from_indices(vq_code.transpose(1, 2))
        vq_post_emb = vq_post_emb.transpose(1, 2)

        # 7) fc_post_a
        vq_post_emb = self.fc_post_a(vq_post_emb.transpose(1, 2)).transpose(1, 2)

        # 8) 最后解码成波形
        recon_audio = self.generator(vq_post_emb.transpose(1, 2), vq=False)[0]
        # recon_audio: [batch, time]
        return recon_audio

    def encode_code(self, input_waveform, sample_rate=16000):
        """
        将输入的音频编码为代码表示。

        参数:
          input_waveform: [batch_size, waveform_length]
          sample_rate: 默认 16000
        返回:
          编码后的代码 (Tensor)
        """
        with torch.no_grad():
            # 1) 特征提取
            input_features = self.feature_extractor(
                input_waveform, 
                sampling_rate=sample_rate, 
                return_tensors="pt"
            ).input_features.to(self.device)  # [batch, frames, feat_dim]

            # 2) 语义层
            semantic_output = self.semantic_model(input_features)
            semantic_hidden_16 = semantic_output.hidden_states[16]  # 取第16层
            semantic_hidden_16 = semantic_hidden_16.transpose(1, 2)  # [batch, hidden_dim, frames]
            semantic_encoded = self.SemanticEncoder_module(semantic_hidden_16)

            # 3) codec encoder
            wav = input_waveform.unsqueeze(1).to(self.device)  # shape: [batch, 1, time]
            vq_emb = self.CodecEnc(wav)  # [batch, time//down, 1024] 只是示例
            vq_emb = vq_emb.transpose(1, 2)  # -> [batch, 1024, frames]

            # 对齐语义向量的时间帧数,这里只做示例处理
            if vq_emb.shape[-1] != semantic_encoded.shape[-1]:
                min_len = min(vq_emb.shape[-1], semantic_encoded.shape[-1])
                vq_emb = vq_emb[:, :, :min_len]
                semantic_encoded = semantic_encoded[:, :, :min_len]

            # 4) 拼接
            concat_emb = torch.cat([semantic_encoded, vq_emb], dim=1)  # [batch, 2048, frames]

            # 5) fc_prior
            concat_emb = self.fc_prior(concat_emb.transpose(1, 2)).transpose(1, 2)

            # 6) decoder 的量化部分,获取code
            _, vq_code, _ = self.generator(concat_emb, vq=True)
            # vq_code: [batch, frames]
            return vq_code

    def decode_code(self, vq_code):
        """
        将编码后的代码解码回音频。

        参数:
          vq_code: 编码后的代码 (Tensor) [batch, frames]
        返回:
          解码后的音频 (Tensor) [batch, waveform_length]
        """
        with torch.no_grad():
            # 获取量化后的嵌入
            vq_post_emb = self.generator.quantizer.get_output_from_indices(vq_code.transpose(1, 2))
            vq_post_emb = vq_post_emb.transpose(1, 2)  # [batch, 1024, frames]

            # 7) fc_post_a
            vq_post_emb = self.fc_post_a(vq_post_emb.transpose(1, 2)).transpose(1, 2)  # [batch, 1024, frames]

            # 8) 最后解码成波形
            recon_audio = self.generator(vq_post_emb.transpose(1, 2), vq=False)[0]  # [batch, time]
            return recon_audio