Spaces:
Running
on
T4
Running
on
T4
File size: 1,638 Bytes
561c629 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import torch.nn as nn
class ResBlock(nn.Module):
"""Residual block without BN.
It has a style of:
::
---Conv-ReLU-Conv-+-
|________________|
Args:
num_feats (int): Channel number of intermediate features.
Default: 64.
res_scale (float): Used to scale the residual before addition.
Default: 1.0.
"""
def __init__(self, num_feats=64, res_scale=1.0, bias=True, shortcut=True):
super().__init__()
self.res_scale = res_scale
self.shortcut = shortcut
self.conv1 = nn.Conv2d(num_feats, num_feats, 3, 1, 1, bias=bias)
self.conv2 = nn.Conv2d(num_feats, num_feats, 3, 1, 1, bias=bias)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
identity = x
out = self.conv2(self.relu(self.conv1(x)))
if self.shortcut:
return identity + out * self.res_scale
else:
return out * self.res_scale
class ResBlockWrapper(ResBlock):
"Used for transformers"
def __init__(self, num_feats, bias=True, shortcut=True):
super(ResBlockWrapper, self).__init__(
num_feats=num_feats, bias=bias, shortcut=shortcut
)
def forward(self, x, x_size):
H, W = x_size
B, L, C = x.shape
x = x.view(B, H, W, C).permute(0, 3, 1, 2)
x = super(ResBlockWrapper, self).forward(x)
x = x.flatten(2).permute(0, 2, 1)
return x
|