Spaces:
Sleeping
Sleeping
File size: 2,653 Bytes
06f26d7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Power by Zongsheng Yue 2022-07-12 20:35:28
import math
from torch import nn
import torch.nn.functional as F
class SRCNN(nn.Module):
def __init__(self, in_chns, out_chns=None, num_chns=64, depth=8, sf=4):
super().__init__()
self.sf = sf
out_chns = in_chns if out_chns is None else out_chns
self.head = nn.Conv2d(in_chns, num_chns, kernel_size=5, padding=2)
body = []
for _ in range(depth-1):
body.append(nn.Conv2d(num_chns, num_chns, kernel_size=5, padding=2))
body.append(nn.LeakyReLU(0.2, inplace=True))
self.body = nn.Sequential(*body)
tail = []
for _ in range(int(math.log(sf, 2))):
tail.append(nn.Conv2d(num_chns, num_chns*4, kernel_size=3, padding=1))
tail.append(nn.LeakyReLU(0.2, inplace=True))
tail.append(nn.PixelShuffle(2))
tail.append(nn.Conv2d(num_chns, out_chns, kernel_size=5, padding=2))
self.tail = nn.Sequential(*tail)
def forward(self, x):
y = self.head(x)
y = self.body(y)
y = self.tail(y)
return y
class SRCNNFSR(nn.Module):
def __init__(self, in_chns, down_scale_factor=2, num_chns=64, depth=8, sf=4):
super().__init__()
self.sf = sf
head = []
in_chns_shuffle = in_chns * 4
assert num_chns % 4 == 0
for ii in range(int(math.log(down_scale_factor, 2))):
head.append(nn.PixelUnshuffle(2))
head.append(nn.Conv2d(in_chns_shuffle, num_chns, kernel_size=3, padding=1))
if ii + 1 < int(math.log(down_scale_factor, 2)):
head.append(nn.Conv2d(num_chns, num_chns//4, kernel_size=5, padding=2))
head.append(nn.LeakyReLU(0.2, inplace=True))
in_chns_shuffle = num_chns
self.head = nn.Sequential(*head)
body = []
for _ in range(depth-1):
body.append(nn.Conv2d(num_chns, num_chns, kernel_size=5, padding=2))
body.append(nn.LeakyReLU(0.2, inplace=True))
self.body = nn.Sequential(*body)
tail = []
for _ in range(int(math.log(down_scale_factor, 2))):
tail.append(nn.Conv2d(num_chns, num_chns, kernel_size=3, padding=1))
tail.append(nn.LeakyReLU(0.2, inplace=True))
tail.append(nn.PixelShuffle(2))
num_chns //= 4
tail.append(nn.Conv2d(num_chns, in_chns, kernel_size=5, padding=2))
self.tail = nn.Sequential(*tail)
def forward(self, x):
y = self.head(x)
y = self.body(y)
y = self.tail(y)
return y
|