|
|
|
|
|
|
|
|
|
|
|
|
|
"""Residual vector quantizer implementation.""" |
|
|
|
from dataclasses import dataclass, field |
|
import math |
|
import typing as tp |
|
|
|
import torch |
|
from torch import nn |
|
|
|
from module.core_vq import ResidualVectorQuantization |
|
|
|
|
|
@dataclass |
|
class QuantizedResult: |
|
quantized: torch.Tensor |
|
codes: torch.Tensor |
|
bandwidth: torch.Tensor |
|
penalty: tp.Optional[torch.Tensor] = None |
|
metrics: dict = field(default_factory=dict) |
|
|
|
|
|
class ResidualVectorQuantizer(nn.Module): |
|
"""Residual Vector Quantizer. |
|
Args: |
|
dimension (int): Dimension of the codebooks. |
|
n_q (int): Number of residual vector quantizers used. |
|
bins (int): Codebook size. |
|
decay (float): Decay for exponential moving average over the codebooks. |
|
kmeans_init (bool): Whether to use kmeans to initialize the codebooks. |
|
kmeans_iters (int): Number of iterations used for kmeans initialization. |
|
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes |
|
that have an exponential moving average cluster size less than the specified threshold with |
|
randomly selected vector from the current batch. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dimension: int = 256, |
|
n_q: int = 8, |
|
bins: int = 1024, |
|
decay: float = 0.99, |
|
kmeans_init: bool = True, |
|
kmeans_iters: int = 50, |
|
threshold_ema_dead_code: int = 2, |
|
): |
|
super().__init__() |
|
self.n_q = n_q |
|
self.dimension = dimension |
|
self.bins = bins |
|
self.decay = decay |
|
self.kmeans_init = kmeans_init |
|
self.kmeans_iters = kmeans_iters |
|
self.threshold_ema_dead_code = threshold_ema_dead_code |
|
self.vq = ResidualVectorQuantization( |
|
dim=self.dimension, |
|
codebook_size=self.bins, |
|
num_quantizers=self.n_q, |
|
decay=self.decay, |
|
kmeans_init=self.kmeans_init, |
|
kmeans_iters=self.kmeans_iters, |
|
threshold_ema_dead_code=self.threshold_ema_dead_code, |
|
) |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
n_q: tp.Optional[int] = None, |
|
layers: tp.Optional[list] = None, |
|
) -> QuantizedResult: |
|
"""Residual vector quantization on the given input tensor. |
|
Args: |
|
x (torch.Tensor): Input tensor. |
|
n_q (int): Number of quantizer used to quantize. Default: All quantizers. |
|
layers (list): Layer that need to return quantized. Defalt: None. |
|
Returns: |
|
QuantizedResult: |
|
The quantized (or approximately quantized) representation with |
|
the associated numbert quantizers and layer quantized required to return. |
|
""" |
|
n_q = n_q if n_q else self.n_q |
|
if layers and max(layers) >= n_q: |
|
raise ValueError( |
|
f"Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B." |
|
) |
|
quantized, codes, commit_loss, quantized_list = self.vq( |
|
x, n_q=n_q, layers=layers |
|
) |
|
return quantized, codes, torch.mean(commit_loss), quantized_list |
|
|
|
def encode( |
|
self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None |
|
) -> torch.Tensor: |
|
"""Encode a given input tensor with the specified sample rate at the given bandwidth. |
|
The RVQ encode method sets the appropriate number of quantizer to use |
|
and returns indices for each quantizer. |
|
Args: |
|
x (torch.Tensor): Input tensor. |
|
n_q (int): Number of quantizer used to quantize. Default: All quantizers. |
|
st (int): Start to encode input from which layers. Default: 0. |
|
""" |
|
n_q = n_q if n_q else self.n_q |
|
st = st or 0 |
|
codes = self.vq.encode(x, n_q=n_q, st=st) |
|
return codes |
|
|
|
def decode(self, codes: torch.Tensor, st: int = 0) -> torch.Tensor: |
|
"""Decode the given codes to the quantized representation. |
|
Args: |
|
codes (torch.Tensor): Input indices for each quantizer. |
|
st (int): Start to decode input codes from which layers. Default: 0. |
|
""" |
|
quantized = self.vq.decode(codes, st=st) |
|
return quantized |
|
|