File size: 3,650 Bytes
a4d0945
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
from modules.wavenet import WN
#
class Redecoder(torch.nn.Module):
    def __init__(self, args):
        super(Redecoder, self).__init__()
        self.n_p_codebooks = args.n_p_codebooks # number of prosody codebooks
        self.n_c_codebooks = args.n_c_codebooks # number of content codebooks
        self.codebook_size = 1024 # codebook size
        self.encoder_type = args.encoder_type
        if args.encoder_type == "wavenet":
            self.embed_dim = args.wavenet_embed_dim
            self.encoder = WN(hidden_channels=self.embed_dim, kernel_size=5, dilation_rate=1, n_layers=16, gin_channels=1024
                              , p_dropout=0.2, causal=args.decoder_causal)
            self.conv_out = torch.nn.Conv1d(self.embed_dim, 1024, 1)
            self.prosody_embed = torch.nn.ModuleList(
                [torch.nn.Embedding(self.codebook_size, self.embed_dim) for _ in range(self.n_p_codebooks)])
            self.content_embed = torch.nn.ModuleList(
                [torch.nn.Embedding(self.codebook_size, self.embed_dim) for _ in range(self.n_c_codebooks)])
        elif args.encoder_type == "mamba":
            from modules.mamba import Mambo
            self.embed_dim = args.mamba_embed_dim
            self.encoder = Mambo(d_model=self.embed_dim, n_layer=24, vocab_size=1024,
                                 prob_random_mask_prosody=args.prob_random_mask_prosody,
                                 prob_random_mask_content=args.prob_random_mask_content,)
            self.conv_out = torch.nn.Linear(self.embed_dim, 1024)
            self.forward = self.forward_v2
            self.prosody_embed = torch.nn.ModuleList(
                [torch.nn.Embedding(self.codebook_size, self.embed_dim) for _ in range(self.n_p_codebooks)])
            self.content_embed = torch.nn.ModuleList(
                [torch.nn.Embedding(self.codebook_size, self.embed_dim) for _ in range(self.n_c_codebooks)])
        else:
            raise NotImplementedError

    def forward(self, p_code, c_code, timbre_vec, use_p_code=True, use_c_code=True, n_c=2):
        B, _, T = p_code.size()
        p_embed = torch.zeros(B, T, self.embed_dim).to(p_code.device)
        c_embed = torch.zeros(B, T, self.embed_dim).to(c_code.device)
        if use_p_code:
            for i in range(self.n_p_codebooks):
                p_embed += self.prosody_embed[i](p_code[:, i, :])
        if use_c_code:
            for i in range(n_c):
                c_embed += self.content_embed[i](c_code[:, i, :])
        x = p_embed + c_embed
        x = self.encoder(x.transpose(1, 2), x_mask=torch.ones(B, 1, T).to(p_code.device), g=timbre_vec.unsqueeze(2))
        x = self.conv_out(x)
        return x
    def forward_v2(self, p_code, c_code, timbre_vec, use_p_code=True, use_c_code=True, n_c=2):
        x = self.encoder(torch.cat([p_code, c_code], dim=1), timbre_vec)
        x = self.conv_out(x).transpose(1, 2)
        return x
    @torch.no_grad()
    def generate(self, prompt_ids, input_ids, prompt_context, timbre, use_p_code=True, use_c_code=True, n_c=2):
        from modules.mamba import InferenceParams
        assert self.encoder_type == "mamba"
        inference_params = InferenceParams(max_seqlen=8192, max_batch_size=1)
        # run once with prompt to initialize memory first
        prompt_out = self.encoder(prompt_ids, prompt_context, timbre, inference_params=inference_params)
        for i in range(input_ids.size(-1)):
            input_id = input_ids[..., i]
            prompt_out = self.encoder(input_id, prompt_out, timbre, inference_params=inference_params)