|
import torch |
|
import torch.nn as nn |
|
from .block import Block |
|
|
|
class GroupNorm(Block): |
|
def __init__(self, num_channels, num_groups=32, eps=1e-6, *args, **kwargs) -> None: |
|
super().__init__(*args, **kwargs) |
|
self.norm = torch.nn.GroupNorm( |
|
num_groups=num_groups, num_channels=num_channels, eps=1e-6, affine=True |
|
) |
|
def forward(self, x): |
|
return self.norm(x) |
|
|
|
def Normalize(in_channels, num_groups=32): |
|
return torch.nn.GroupNorm( |
|
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True |
|
) |
|
|
|
class ActNorm(nn.Module): |
|
def __init__(self, num_features, logdet=False, affine=True, |
|
allow_reverse_init=False): |
|
assert affine |
|
super().__init__() |
|
self.logdet = logdet |
|
self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) |
|
self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) |
|
self.allow_reverse_init = allow_reverse_init |
|
|
|
self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) |
|
|
|
def initialize(self, input): |
|
with torch.no_grad(): |
|
flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) |
|
mean = ( |
|
flatten.mean(1) |
|
.unsqueeze(1) |
|
.unsqueeze(2) |
|
.unsqueeze(3) |
|
.permute(1, 0, 2, 3) |
|
) |
|
std = ( |
|
flatten.std(1) |
|
.unsqueeze(1) |
|
.unsqueeze(2) |
|
.unsqueeze(3) |
|
.permute(1, 0, 2, 3) |
|
) |
|
|
|
self.loc.data.copy_(-mean) |
|
self.scale.data.copy_(1 / (std + 1e-6)) |
|
|
|
def forward(self, input, reverse=False): |
|
if reverse: |
|
return self.reverse(input) |
|
if len(input.shape) == 2: |
|
input = input[:,:,None,None] |
|
squeeze = True |
|
else: |
|
squeeze = False |
|
|
|
_, _, height, width = input.shape |
|
|
|
if self.training and self.initialized.item() == 0: |
|
self.initialize(input) |
|
self.initialized.fill_(1) |
|
|
|
h = self.scale * (input + self.loc) |
|
|
|
if squeeze: |
|
h = h.squeeze(-1).squeeze(-1) |
|
|
|
if self.logdet: |
|
log_abs = torch.log(torch.abs(self.scale)) |
|
logdet = height*width*torch.sum(log_abs) |
|
logdet = logdet * torch.ones(input.shape[0]).to(input) |
|
return h, logdet |
|
|
|
return h |
|
|
|
def reverse(self, output): |
|
if self.training and self.initialized.item() == 0: |
|
if not self.allow_reverse_init: |
|
raise RuntimeError( |
|
"Initializing ActNorm in reverse direction is " |
|
"disabled by default. Use allow_reverse_init=True to enable." |
|
) |
|
else: |
|
self.initialize(output) |
|
self.initialized.fill_(1) |
|
|
|
if len(output.shape) == 2: |
|
output = output[:,:,None,None] |
|
squeeze = True |
|
else: |
|
squeeze = False |
|
|
|
h = output / self.scale - self.loc |
|
|
|
if squeeze: |
|
h = h.squeeze(-1).squeeze(-1) |
|
return h |
|
|