File size: 1,667 Bytes
574a515
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import math
import torch
from torch import nn
from .factorized_vector_quantize import FactorizedVectorQuantize

class ResidualVQ(nn.Module):
    def __init__(

        self,

        *,

        num_quantizers,

        codebook_size,

        **kwargs

    ):
        super().__init__()
        VQ = FactorizedVectorQuantize
        if type(codebook_size) == int:
            codebook_size = [codebook_size] * num_quantizers
        self.layers = nn.ModuleList([VQ(codebook_size=size, **kwargs) for size in codebook_size])
        self.num_quantizers = num_quantizers

    def forward(self, x):
        quantized_out = 0.
        residual = x

        all_losses = []
        all_indices = []
        
        for idx, layer in enumerate(self.layers):
            quantized, indices, loss = layer(residual)
            
            residual = residual - quantized
            
            quantized_out = quantized_out + quantized

            loss = loss.mean()

            all_indices.append(indices)
            all_losses.append(loss)
        all_losses, all_indices = map(torch.stack, (all_losses, all_indices))
        return quantized_out, all_indices, all_losses

    def vq2emb(self, vq, proj=True):
        # [B, T, num_quantizers]
        quantized_out = 0.
        for idx, layer in enumerate(self.layers):
            quantized = layer.vq2emb(vq[:, :, idx], proj=proj)
            quantized_out = quantized_out + quantized
        return quantized_out
    def get_emb(self):
        embs = [] 
        for idx, layer in enumerate(self.layers):
            embs.append(layer.get_emb())
        return embs