52Hz's picture
Update model/CMFNet.py
b97e969
raw
history blame
9.82 kB
import torch
import torch.nn as nn
from model.block import SAB, CAB, PAB, conv, SAM, conv3x3, conv_down
##########################################################################
## U-Net
bn = 2 # block number-1
class Encoder(nn.Module):
def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats, block):
super(Encoder, self).__init__()
if block == 'CAB':
self.encoder_level1 = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
self.encoder_level2 = [CAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
self.encoder_level3 = [CAB(n_feat + (scale_unetfeats * 2), kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
elif block == 'PAB':
self.encoder_level1 = [PAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
self.encoder_level2 = [PAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
self.encoder_level3 = [PAB(n_feat + (scale_unetfeats * 2), kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
elif block == 'SAB':
self.encoder_level1 = [SAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
self.encoder_level2 = [SAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
self.encoder_level3 = [SAB(n_feat + (scale_unetfeats * 2), kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
self.encoder_level1 = nn.Sequential(*self.encoder_level1)
self.encoder_level2 = nn.Sequential(*self.encoder_level2)
self.encoder_level3 = nn.Sequential(*self.encoder_level3)
self.down12 = DownSample(n_feat, scale_unetfeats)
self.down23 = DownSample(n_feat + scale_unetfeats, scale_unetfeats)
def forward(self, x):
enc1 = self.encoder_level1(x)
x = self.down12(enc1)
enc2 = self.encoder_level2(x)
x = self.down23(enc2)
enc3 = self.encoder_level3(x)
return [enc1, enc2, enc3]
class Decoder(nn.Module):
def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats, block):
super(Decoder, self).__init__()
if block == 'CAB':
self.decoder_level1 = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
self.decoder_level2 = [CAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
self.decoder_level3 = [CAB(n_feat + (scale_unetfeats * 2), kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
elif block == 'PAB':
self.decoder_level1 = [PAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
self.decoder_level2 = [PAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
self.decoder_level3 = [PAB(n_feat + (scale_unetfeats * 2), kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
elif block == 'SAB':
self.decoder_level1 = [SAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
self.decoder_level2 = [SAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
self.decoder_level3 = [SAB(n_feat + (scale_unetfeats * 2), kernel_size, reduction, bias=bias, act=act) for _ in range(bn)]
self.decoder_level1 = nn.Sequential(*self.decoder_level1)
self.decoder_level2 = nn.Sequential(*self.decoder_level2)
self.decoder_level3 = nn.Sequential(*self.decoder_level3)
if block == 'CAB':
self.skip_attn1 = CAB(n_feat, kernel_size, reduction, bias=bias, act=act)
self.skip_attn2 = CAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act)
if block == 'PAB':
self.skip_attn1 = PAB(n_feat, kernel_size, reduction, bias=bias, act=act)
self.skip_attn2 = PAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act)
if block == 'SAB':
self.skip_attn1 = SAB(n_feat, kernel_size, reduction, bias=bias, act=act)
self.skip_attn2 = SAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act)
self.up21 = SkipUpSample(n_feat, scale_unetfeats)
self.up32 = SkipUpSample(n_feat + scale_unetfeats, scale_unetfeats)
def forward(self, outs):
enc1, enc2, enc3 = outs
dec3 = self.decoder_level3(enc3)
x = self.up32(dec3, self.skip_attn2(enc2))
dec2 = self.decoder_level2(x)
x = self.up21(dec2, self.skip_attn1(enc1))
dec1 = self.decoder_level1(x)
return [dec1, dec2, dec3]
##########################################################################
##---------- Resizing Modules ----------
class DownSample(nn.Module):
def __init__(self, in_channels, s_factor):
super(DownSample, self).__init__()
self.down = nn.Sequential(nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False),
nn.Conv2d(in_channels, in_channels + s_factor, 1, stride=1, padding=0, bias=False))
def forward(self, x):
x = self.down(x)
return x
class UpSample(nn.Module):
def __init__(self, in_channels, s_factor):
super(UpSample, self).__init__()
self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(in_channels + s_factor, in_channels, 1, stride=1, padding=0, bias=False))
def forward(self, x):
x = self.up(x)
return x
class SkipUpSample(nn.Module):
def __init__(self, in_channels, s_factor):
super(SkipUpSample, self).__init__()
self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(in_channels + s_factor, in_channels, 1, stride=1, padding=0, bias=False))
def forward(self, x, y):
x = self.up(x)
x = x + y
return x
##########################################################################
# Mixed Residual Module
class Mix(nn.Module):
def __init__(self, m=1):
super(Mix, self).__init__()
w = nn.Parameter(torch.FloatTensor([m]), requires_grad=True)
w = nn.Parameter(w, requires_grad=True)
self.w = w
self.mix_block = nn.Sigmoid()
def forward(self, fea1, fea2, feat3):
factor = self.mix_block(self.w)
other = (1 - factor)/2
output = fea1 * other.expand_as(fea1) + fea2 * factor.expand_as(fea2) + feat3 * other.expand_as(feat3)
return output, factor
##########################################################################
# Architecture
class CMFNet(nn.Module):
def __init__(self, in_c=3, out_c=3, n_feat=96, scale_unetfeats=48, kernel_size=3, reduction=4, bias=False):
super(CMFNet, self).__init__()
p_act = nn.PReLU()
self.shallow_feat1 = nn.Sequential(conv(in_c, n_feat // 2, kernel_size, bias=bias), p_act,
conv(n_feat // 2, n_feat, kernel_size, bias=bias))
self.shallow_feat2 = nn.Sequential(conv(in_c, n_feat // 2, kernel_size, bias=bias), p_act,
conv(n_feat // 2, n_feat, kernel_size, bias=bias))
self.shallow_feat3 = nn.Sequential(conv(in_c, n_feat // 2, kernel_size, bias=bias), p_act,
conv(n_feat // 2, n_feat, kernel_size, bias=bias))
self.stage1_encoder = Encoder(n_feat, kernel_size, reduction, p_act, bias, scale_unetfeats, 'CAB')
self.stage1_decoder = Decoder(n_feat, kernel_size, reduction, p_act, bias, scale_unetfeats, 'CAB')
self.stage2_encoder = Encoder(n_feat, kernel_size, reduction, p_act, bias, scale_unetfeats, 'PAB')
self.stage2_decoder = Decoder(n_feat, kernel_size, reduction, p_act, bias, scale_unetfeats, 'PAB')
self.stage3_encoder = Encoder(n_feat, kernel_size, reduction, p_act, bias, scale_unetfeats, 'SAB')
self.stage3_decoder = Decoder(n_feat, kernel_size, reduction, p_act, bias, scale_unetfeats, 'SAB')
self.sam1o = SAM(n_feat, kernel_size=3, bias=bias)
self.sam2o = SAM(n_feat, kernel_size=3, bias=bias)
self.sam3o = SAM(n_feat, kernel_size=3, bias=bias)
self.mix = Mix(1)
self.add123 = conv(out_c, out_c, kernel_size, bias=bias)
self.concat123 = conv(n_feat*3, n_feat, kernel_size, bias=bias)
self.tail = conv(n_feat, out_c, kernel_size, bias=bias)
def forward(self, x):
## Compute Shallow Features
shallow1 = self.shallow_feat1(x)
shallow2 = self.shallow_feat2(x)
shallow3 = self.shallow_feat3(x)
## Enter the UNet-CAB
x1 = self.stage1_encoder(shallow1)
x1_D = self.stage1_decoder(x1)
## Apply SAM
x1_out, x1_img = self.sam1o(x1_D[0], x)
## Enter the UNet-PAB
x2 = self.stage2_encoder(shallow2)
x2_D = self.stage2_decoder(x2)
## Apply SAM
x2_out, x2_img = self.sam2o(x2_D[0], x)
## Enter the UNet-SAB
x3 = self.stage3_encoder(shallow3)
x3_D = self.stage3_decoder(x3)
## Apply SAM
x3_out, x3_img = self.sam3o(x3_D[0], x)
## Aggregate SAM features of Stage 1, Stage 2 and Stage 3
mix_r = self.mix(x1_img, x2_img, x3_img)
mixed_img = self.add123(mix_r[0])
## Concat SAM features of Stage 1, Stage 2 and Stage 3
concat_feat = self.concat123(torch.cat([x1_out, x2_out, x3_out], 1))
x_final = self.tail(concat_feat)
return x_final + mixed_img