Spaces:
Running
Running
File size: 5,270 Bytes
2235ef5 fc75fbd 2235ef5 fc75fbd 54770f1 2235ef5 54770f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import os
import torch
from torch import nn
from PIL import Image
from torchvision.transforms import ToTensor
NUM_RESIDUAL_GROUPS = 8
NUM_RESIDUAL_BLOCKS = 16
KERNEL_SIZE = 3
REDUCTION_RATIO = 16
NUM_CHANNELS = 64
UPSCALE_FACTOR = 4
class ResidualChannelAttentionBlock(nn.Module):
def __init__(self, num_channels=NUM_CHANNELS, reduction_ratio=REDUCTION_RATIO, kernel_size=KERNEL_SIZE):
super(ResidualChannelAttentionBlock, self).__init__()
self.feature_extractor = nn.Sequential(
nn.Conv2d(num_channels, num_channels, kernel_size=kernel_size, stride=1, padding=kernel_size//2),
nn.ReLU(),
nn.Conv2d(num_channels, num_channels, kernel_size=kernel_size, stride=1, padding=kernel_size//2)
)
self.channel_attention = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(num_channels, num_channels//reduction_ratio, kernel_size=1, stride=1),
nn.ReLU(),
# nn.BatchNorm2d(num_channels//reduction_ratio),
nn.Conv2d(num_channels//reduction_ratio, num_channels, kernel_size=1, stride=1),
nn.Sigmoid()
)
def forward(self, x):
block_input = x.clone()
residual = self.feature_extractor(x) # Feature extraction
rescale = self.channel_attention(residual) # Rescaling vector
block_output = block_input + (residual * rescale)
return block_output
class ResidualGroup(nn.Module):
def __init__(self, num_residual_blocks=NUM_RESIDUAL_BLOCKS,
num_channels=NUM_CHANNELS, reduction_ratio=REDUCTION_RATIO, kernel_size=KERNEL_SIZE):
super(ResidualGroup, self).__init__()
self.residual_blocks = nn.Sequential(
*[ResidualChannelAttentionBlock(num_channels=num_channels, reduction_ratio=reduction_ratio, kernel_size=kernel_size)
for _ in range(num_residual_blocks)]
)
self.final_conv = nn.Conv2d(num_channels, num_channels, kernel_size=kernel_size, stride=1, padding=kernel_size//2)
def forward(self, x):
group_input = x.clone()
residual = self.residual_blocks(x) # Residual blocks
residual = self.final_conv(residual) # Final convolution
group_output = group_input + residual
return group_output
class ResidualInResidual(nn.Module):
def __init__(self, num_residual_groups=NUM_RESIDUAL_GROUPS, num_residual_blocks=NUM_RESIDUAL_BLOCKS,
num_channels=NUM_CHANNELS, reduction_ratio=REDUCTION_RATIO, kernel_size=KERNEL_SIZE):
super(ResidualInResidual, self).__init__()
self.residual_groups = nn.Sequential(
*[ResidualGroup(num_residual_blocks=num_residual_blocks,
num_channels=num_channels, reduction_ratio=reduction_ratio, kernel_size=kernel_size)
for _ in range(num_residual_groups)]
)
self.final_conv = nn.Conv2d(num_channels, num_channels, kernel_size=kernel_size, stride=1, padding=kernel_size//2)
def forward(self, x):
shallow_feature = x.clone()
residual = self.residual_groups(x) # Residual groups
residual = self.final_conv(residual) # Final convolution
deep_feature = shallow_feature + residual
return deep_feature
class RCAN(nn.Module):
def __init__(self, num_residual_groups=NUM_RESIDUAL_GROUPS, num_residual_blocks=NUM_RESIDUAL_BLOCKS,
num_channels=NUM_CHANNELS, reduction_ratio=REDUCTION_RATIO, kernel_size=KERNEL_SIZE):
super(RCAN, self).__init__()
self.shallow_conv = nn.Conv2d(3, num_channels, kernel_size=kernel_size, stride=1, padding=kernel_size//2)
self.residual_in_residual = ResidualInResidual(num_residual_groups=num_residual_groups, num_residual_blocks=num_residual_blocks,
num_channels=num_channels, reduction_ratio=reduction_ratio, kernel_size=kernel_size)
self.upscaling_module = nn.PixelShuffle(upscale_factor=UPSCALE_FACTOR)
self.reconstruction_conv = nn.Conv2d(num_channels // (UPSCALE_FACTOR ** 2), 3, kernel_size=kernel_size, stride=1, padding=kernel_size//2)
def forward(self, x):
shallow_feature = self.shallow_conv(x) # Initial convolution
deep_feature = self.residual_in_residual(shallow_feature) # Residual in Residual
upscaled_image = self.upscaling_module(deep_feature) # Upscaling module
reconstructed_image = self.reconstruction_conv(upscaled_image) # Reconstruction
return reconstructed_image
def inference(self, x):
"""
x is a PIL image
"""
x = ToTensor()(x).unsqueeze(0)
x = self.forward(x)
x = Image.fromarray((x.squeeze(0).permute(1, 2, 0).detach().numpy() * 255).astype('uint8'))
return x
if __name__ == '__main__':
current_dir = os.path.dirname(os.path.realpath(__file__))
model = RCAN()
model.load_state_dict(torch.load(current_dir + '/rcan_checkpoint.pth', map_location=torch.device('cpu')))
model.eval()
with torch.no_grad():
input_image = Image.open('images/demo.png')
output_image = model.inference(input_image) |