import torch.nn as nn class ResBlock(nn.Module): """Residual block without BN. It has a style of: :: ---Conv-ReLU-Conv-+- |________________| Args: num_feats (int): Channel number of intermediate features. Default: 64. res_scale (float): Used to scale the residual before addition. Default: 1.0. """ def __init__(self, num_feats=64, res_scale=1.0, bias=True, shortcut=True): super().__init__() self.res_scale = res_scale self.shortcut = shortcut self.conv1 = nn.Conv2d(num_feats, num_feats, 3, 1, 1, bias=bias) self.conv2 = nn.Conv2d(num_feats, num_feats, 3, 1, 1, bias=bias) self.relu = nn.ReLU(inplace=True) def forward(self, x): """Forward function. Args: x (Tensor): Input tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ identity = x out = self.conv2(self.relu(self.conv1(x))) if self.shortcut: return identity + out * self.res_scale else: return out * self.res_scale class ResBlockWrapper(ResBlock): "Used for transformers" def __init__(self, num_feats, bias=True, shortcut=True): super(ResBlockWrapper, self).__init__( num_feats=num_feats, bias=bias, shortcut=shortcut ) def forward(self, x, x_size): H, W = x_size B, L, C = x.shape x = x.view(B, H, W, C).permute(0, 3, 1, 2) x = super(ResBlockWrapper, self).forward(x) x = x.flatten(2).permute(0, 2, 1) return x