|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from collections import namedtuple |
|
from typing import NamedTuple, Optional, Tuple |
|
import torch |
|
from torch import nn |
|
|
|
|
|
def _run_kernel(x: torch.Tensor, mean: torch.Tensor, tx: torch.Tensor): |
|
if x.ndim <= 3: |
|
x = x - mean |
|
x = x @ tx.T |
|
elif x.ndim == 4: |
|
x = x - mean.reshape(1, -1, 1, 1) |
|
kernel = tx.reshape(*tx.shape, 1, 1) |
|
x = torch.nn.functional.conv2d(x, weight=kernel, bias=None, stride=1, padding=0) |
|
else: |
|
raise ValueError(f'Unsupported input dimension: {x.ndim}, shape: {x.shape}') |
|
return x |
|
|
|
|
|
class FeatureNormalizer(nn.Module): |
|
def __init__(self, embed_dim: int, dtype: torch.dtype = torch.float32): |
|
super().__init__() |
|
|
|
self.register_buffer('mean', torch.zeros(embed_dim, dtype=dtype)) |
|
self.register_buffer('tx', torch.eye(embed_dim, dtype=dtype)) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = _run_kernel(x, self.mean, self.tx) |
|
return x |
|
|
|
|
|
class InterFeatState(NamedTuple): |
|
y: torch.Tensor |
|
alpha: torch.Tensor |
|
|
|
|
|
class IntermediateFeatureNormalizerBase(nn.Module): |
|
def forward(self, x: torch.Tensor, index: int, rot_index: int = None, skip: Optional[int] = None) -> InterFeatState: |
|
raise NotImplementedError() |
|
|
|
|
|
class IntermediateFeatureNormalizer(IntermediateFeatureNormalizerBase): |
|
def __init__(self, num_intermediates: int, embed_dim: int, rot_per_layer: bool = False, dtype: torch.dtype = torch.float32): |
|
super().__init__() |
|
self.register_buffer('alphas', torch.ones(num_intermediates, dtype=dtype)) |
|
|
|
rot = torch.eye(embed_dim, dtype=dtype) |
|
if rot_per_layer: |
|
rot = rot.unsqueeze(0).repeat(num_intermediates, 1, 1) |
|
|
|
self.register_buffer('rotation', rot.contiguous()) |
|
self.register_buffer('means', torch.zeros(num_intermediates, embed_dim, dtype=dtype)) |
|
|
|
def forward(self, x: torch.Tensor, index: int, rot_index: int = None, skip: Optional[int] = None) -> InterFeatState: |
|
if rot_index is None: |
|
rot_index = index |
|
|
|
if skip: |
|
assert x.ndim == 3, f'Cannot use the `skip` parameter when the `x` tensor isn\'t 3-dimensional.' |
|
prefix, x = x[:, :skip], x[:, skip:] |
|
|
|
rotation = self._get_rotation(rot_index) |
|
y = _run_kernel(x, self.means[index], rotation) |
|
|
|
alpha = self.alphas[index] |
|
if skip: |
|
alpha = torch.cat([ |
|
torch.ones(skip, dtype=alpha.dtype, device=alpha.device), |
|
alpha[None].expand(y.shape[1]), |
|
]).reshape(1, -1, 1) |
|
y = torch.cat([prefix, y], dim=1) |
|
else: |
|
if x.ndim == 3: |
|
alpha = alpha.reshape(1, 1, 1).expand(1, y.shape[1], 1) |
|
elif x.ndim == 4: |
|
alpha = alpha.reshape(1, 1, 1, 1).expand(1, 1, *y.shape[2:]) |
|
else: |
|
raise ValueError(f'Unsupported input dimension: {x.ndim}') |
|
|
|
return InterFeatState(y, alpha) |
|
|
|
def _get_rotation(self, rot_index: int) -> torch.Tensor: |
|
if self.rotation.ndim == 2: |
|
return self.rotation |
|
return self.rotation[rot_index] |
|
|
|
|
|
class NullIntermediateFeatureNormalizer(IntermediateFeatureNormalizerBase): |
|
instances = dict() |
|
|
|
def __init__(self, dtype: torch.dtype, device: torch.device): |
|
super().__init__() |
|
self.register_buffer('alpha', torch.tensor(1, dtype=dtype, device=device)) |
|
|
|
@staticmethod |
|
def get_instance(dtype: torch.dtype, device: torch.device): |
|
instance = NullIntermediateFeatureNormalizer.instances.get((dtype, device), None) |
|
if instance is None: |
|
instance = NullIntermediateFeatureNormalizer(dtype, device) |
|
NullIntermediateFeatureNormalizer.instances[(dtype, device)] = instance |
|
return instance |
|
|
|
def forward(self, x: torch.Tensor, index: int, rot_index: int = None, skip: Optional[int] = None) -> InterFeatState: |
|
return InterFeatState(x, self.alpha) |
|
|