import numpy as np import torch from torch import nn import torch.nn.functional as F import torchvision.models from .senet import se_resnext50_32x4d, senet154 from .dpn import dpn92 class ConvReluBN(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3): super(ConvReluBN, self).__init__() self.layer = nn.Sequential( nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x): return self.layer(x) class ConvRelu(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3): super(ConvRelu, self).__init__() self.layer = nn.Sequential( nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=1), nn.ReLU(inplace=True) ) def forward(self, x): return self.layer(x) class SCSEModule(nn.Module): # according to https://arxiv.org/pdf/1808.08127.pdf concat is better def __init__(self, channels, reduction=16, concat=False): super(SCSEModule, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, padding=0) self.relu = nn.ReLU(inplace=True) self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, padding=0) self.sigmoid = nn.Sigmoid() self.spatial_se = nn.Sequential(nn.Conv2d(channels, 1, kernel_size=1, stride=1, padding=0, bias=False), nn.Sigmoid()) self.concat = concat def forward(self, x): module_input = x x = self.avg_pool(x) x = self.fc1(x) x = self.relu(x) x = self.fc2(x) chn_se = self.sigmoid(x) chn_se = chn_se * module_input spa_se = self.spatial_se(module_input) spa_se = module_input * spa_se if self.concat: return torch.cat([chn_se, spa_se], dim=1) else: return chn_se + spa_se class SeResNext50_Unet_Loc(nn.Module): def __init__(self, pretrained='imagenet', **kwargs): super(SeResNext50_Unet_Loc, self).__init__() encoder_filters = [64, 256, 512, 1024, 2048] decoder_filters = np.asarray([64, 96, 128, 256, 512]) // 2 self.conv6 = ConvRelu(encoder_filters[-1], decoder_filters[-1]) self.conv6_2 = ConvRelu(decoder_filters[-1] + encoder_filters[-2], decoder_filters[-1]) self.conv7 = ConvRelu(decoder_filters[-1], decoder_filters[-2]) self.conv7_2 = ConvRelu(decoder_filters[-2] + encoder_filters[-3], decoder_filters[-2]) self.conv8 = ConvRelu(decoder_filters[-2], decoder_filters[-3]) self.conv8_2 = ConvRelu(decoder_filters[-3] + encoder_filters[-4], decoder_filters[-3]) self.conv9 = ConvRelu(decoder_filters[-3], decoder_filters[-4]) self.conv9_2 = ConvRelu(decoder_filters[-4] + encoder_filters[-5], decoder_filters[-4]) self.conv10 = ConvRelu(decoder_filters[-4], decoder_filters[-5]) self.res = nn.Conv2d(decoder_filters[-5], 1, 1, stride=1, padding=0) self._initialize_weights() encoder = se_resnext50_32x4d(pretrained=pretrained) # conv1_new = nn.Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) # _w = encoder.layer0.conv1.state_dict() # _w['weight'] = torch.cat([0.5 * _w['weight'], 0.5 * _w['weight']], 1) # conv1_new.load_state_dict(_w) self.conv1 = nn.Sequential(encoder.layer0.conv1, encoder.layer0.bn1, encoder.layer0.relu1) #encoder.layer0.conv1 self.conv2 = nn.Sequential(encoder.pool, encoder.layer1) self.conv3 = encoder.layer2 self.conv4 = encoder.layer3 self.conv5 = encoder.layer4 def forward(self, x): batch_size, C, H, W = x.shape enc1 = self.conv1(x) enc2 = self.conv2(enc1) enc3 = self.conv3(enc2) enc4 = self.conv4(enc3) enc5 = self.conv5(enc4) dec6 = self.conv6(F.interpolate(enc5, scale_factor=2)) dec6 = self.conv6_2(torch.cat([dec6, enc4 ], 1)) dec7 = self.conv7(F.interpolate(dec6, scale_factor=2)) dec7 = self.conv7_2(torch.cat([dec7, enc3 ], 1)) dec8 = self.conv8(F.interpolate(dec7, scale_factor=2)) dec8 = self.conv8_2(torch.cat([dec8, enc2 ], 1)) dec9 = self.conv9(F.interpolate(dec8, scale_factor=2)) dec9 = self.conv9_2(torch.cat([dec9, enc1 ], 1)) dec10 = self.conv10(F.interpolate(dec9, scale_factor=2)) return self.res(dec10) def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear): m.weight.data = nn.init.kaiming_normal_(m.weight.data) if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() class SeResNext50_Unet_Double(nn.Module): def __init__(self, pretrained='imagenet', **kwargs): super(SeResNext50_Unet_Double, self).__init__() encoder_filters = [64, 256, 512, 1024, 2048] decoder_filters = np.asarray([64, 96, 128, 256, 512]) // 2 self.conv6 = ConvRelu(encoder_filters[-1], decoder_filters[-1]) self.conv6_2 = ConvRelu(decoder_filters[-1] + encoder_filters[-2], decoder_filters[-1]) self.conv7 = ConvRelu(decoder_filters[-1], decoder_filters[-2]) self.conv7_2 = ConvRelu(decoder_filters[-2] + encoder_filters[-3], decoder_filters[-2]) self.conv8 = ConvRelu(decoder_filters[-2], decoder_filters[-3]) self.conv8_2 = ConvRelu(decoder_filters[-3] + encoder_filters[-4], decoder_filters[-3]) self.conv9 = ConvRelu(decoder_filters[-3], decoder_filters[-4]) self.conv9_2 = ConvRelu(decoder_filters[-4] + encoder_filters[-5], decoder_filters[-4]) self.conv10 = ConvRelu(decoder_filters[-4], decoder_filters[-5]) self.res = nn.Conv2d(decoder_filters[-5] * 2, 5, 1, stride=1, padding=0) self._initialize_weights() encoder = se_resnext50_32x4d(pretrained=pretrained) # conv1_new = nn.Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) # _w = encoder.layer0.conv1.state_dict() # _w['weight'] = torch.cat([0.5 * _w['weight'], 0.5 * _w['weight']], 1) # conv1_new.load_state_dict(_w) self.conv1 = nn.Sequential(encoder.layer0.conv1, encoder.layer0.bn1, encoder.layer0.relu1) #encoder.layer0.conv1 self.conv2 = nn.Sequential(encoder.pool, encoder.layer1) self.conv3 = encoder.layer2 self.conv4 = encoder.layer3 self.conv5 = encoder.layer4 def forward1(self, x): batch_size, C, H, W = x.shape enc1 = self.conv1(x) enc2 = self.conv2(enc1) enc3 = self.conv3(enc2) enc4 = self.conv4(enc3) enc5 = self.conv5(enc4) dec6 = self.conv6(F.interpolate(enc5, scale_factor=2)) dec6 = self.conv6_2(torch.cat([dec6, enc4 ], 1)) dec7 = self.conv7(F.interpolate(dec6, scale_factor=2)) dec7 = self.conv7_2(torch.cat([dec7, enc3 ], 1)) dec8 = self.conv8(F.interpolate(dec7, scale_factor=2)) dec8 = self.conv8_2(torch.cat([dec8, enc2 ], 1)) dec9 = self.conv9(F.interpolate(dec8, scale_factor=2)) dec9 = self.conv9_2(torch.cat([dec9, enc1 ], 1)) dec10 = self.conv10(F.interpolate(dec9, scale_factor=2)) return dec10 def forward(self, x): dec10_0 = self.forward1(x[:, :3, :, :]) dec10_1 = self.forward1(x[:, 3:, :, :]) dec10 = torch.cat([dec10_0, dec10_1], 1) return self.res(dec10) def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear): m.weight.data = nn.init.kaiming_normal_(m.weight.data) if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() class Dpn92_Unet_Loc(nn.Module): def __init__(self, pretrained='imagenet+5k', **kwargs): super(Dpn92_Unet_Loc, self).__init__() encoder_filters = [64, 336, 704, 1552, 2688] decoder_filters = np.asarray([64, 96, 128, 256, 512]) // 2 self.conv6 = ConvRelu(encoder_filters[-1], decoder_filters[-1]) self.conv6_2 = nn.Sequential(ConvRelu(decoder_filters[-1]+encoder_filters[-2], decoder_filters[-1]), SCSEModule(decoder_filters[-1], reduction=16, concat=True)) self.conv7 = ConvRelu(decoder_filters[-1] * 2, decoder_filters[-2]) self.conv7_2 = nn.Sequential(ConvRelu(decoder_filters[-2]+encoder_filters[-3], decoder_filters[-2]), SCSEModule(decoder_filters[-2], reduction=16, concat=True)) self.conv8 = ConvRelu(decoder_filters[-2] * 2, decoder_filters[-3]) self.conv8_2 = nn.Sequential(ConvRelu(decoder_filters[-3]+encoder_filters[-4], decoder_filters[-3]), SCSEModule(decoder_filters[-3], reduction=16, concat=True)) self.conv9 = ConvRelu(decoder_filters[-3] * 2, decoder_filters[-4]) self.conv9_2 = nn.Sequential(ConvRelu(decoder_filters[-4]+encoder_filters[-5], decoder_filters[-4]), SCSEModule(decoder_filters[-4], reduction=16, concat=True)) self.conv10 = ConvRelu(decoder_filters[-4] * 2, decoder_filters[-5]) self.res = nn.Conv2d(decoder_filters[-5], 1, 1, stride=1, padding=0) self._initialize_weights() encoder = dpn92(pretrained=pretrained) # conv1_new = nn.Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) # _w = encoder.blocks['conv1_1'].conv.state_dict() # _w['weight'] = torch.cat([0.5 * _w['weight'], 0.5 * _w['weight']], 1) # conv1_new.load_state_dict(_w) self.conv1 = nn.Sequential( encoder.blocks['conv1_1'].conv, # conv encoder.blocks['conv1_1'].bn, # bn encoder.blocks['conv1_1'].act, # relu ) self.conv2 = nn.Sequential( encoder.blocks['conv1_1'].pool, # maxpool *[b for k, b in encoder.blocks.items() if k.startswith('conv2_')] ) self.conv3 = nn.Sequential(*[b for k, b in encoder.blocks.items() if k.startswith('conv3_')]) self.conv4 = nn.Sequential(*[b for k, b in encoder.blocks.items() if k.startswith('conv4_')]) self.conv5 = nn.Sequential(*[b for k, b in encoder.blocks.items() if k.startswith('conv5_')]) def forward(self, x): batch_size, C, H, W = x.shape enc1 = self.conv1(x) enc2 = self.conv2(enc1) enc3 = self.conv3(enc2) enc4 = self.conv4(enc3) enc5 = self.conv5(enc4) enc1 = (torch.cat(enc1, dim=1) if isinstance(enc1, tuple) else enc1) enc2 = (torch.cat(enc2, dim=1) if isinstance(enc2, tuple) else enc2) enc3 = (torch.cat(enc3, dim=1) if isinstance(enc3, tuple) else enc3) enc4 = (torch.cat(enc4, dim=1) if isinstance(enc4, tuple) else enc4) enc5 = (torch.cat(enc5, dim=1) if isinstance(enc5, tuple) else enc5) dec6 = self.conv6(F.interpolate(enc5, scale_factor=2)) dec6 = self.conv6_2(torch.cat([dec6, enc4], 1)) dec7 = self.conv7(F.interpolate(dec6, scale_factor=2)) dec7 = self.conv7_2(torch.cat([dec7, enc3], 1)) dec8 = self.conv8(F.interpolate(dec7, scale_factor=2)) dec8 = self.conv8_2(torch.cat([dec8, enc2], 1)) dec9 = self.conv9(F.interpolate(dec8, scale_factor=2)) dec9 = self.conv9_2(torch.cat([dec9, enc1], 1)) dec10 = self.conv10(F.interpolate(dec9, scale_factor=2)) return self.res(dec10) def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear): m.weight.data = nn.init.kaiming_normal_(m.weight.data) if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() class Dpn92_Unet_Double(nn.Module): def __init__(self, pretrained='imagenet+5k', **kwargs): super(Dpn92_Unet_Double, self).__init__() encoder_filters = [64, 336, 704, 1552, 2688] decoder_filters = np.asarray([64, 96, 128, 256, 512]) // 2 self.conv6 = ConvRelu(encoder_filters[-1], decoder_filters[-1]) self.conv6_2 = nn.Sequential(ConvRelu(decoder_filters[-1]+encoder_filters[-2], decoder_filters[-1]), SCSEModule(decoder_filters[-1], reduction=16, concat=True)) self.conv7 = ConvRelu(decoder_filters[-1] * 2, decoder_filters[-2]) self.conv7_2 = nn.Sequential(ConvRelu(decoder_filters[-2]+encoder_filters[-3], decoder_filters[-2]), SCSEModule(decoder_filters[-2], reduction=16, concat=True)) self.conv8 = ConvRelu(decoder_filters[-2] * 2, decoder_filters[-3]) self.conv8_2 = nn.Sequential(ConvRelu(decoder_filters[-3]+encoder_filters[-4], decoder_filters[-3]), SCSEModule(decoder_filters[-3], reduction=16, concat=True)) self.conv9 = ConvRelu(decoder_filters[-3] * 2, decoder_filters[-4]) self.conv9_2 = nn.Sequential(ConvRelu(decoder_filters[-4]+encoder_filters[-5], decoder_filters[-4]), SCSEModule(decoder_filters[-4], reduction=16, concat=True)) self.conv10 = ConvRelu(decoder_filters[-4] * 2, decoder_filters[-5]) self.res = nn.Conv2d(decoder_filters[-5] * 2, 5, 1, stride=1, padding=0) self._initialize_weights() encoder = dpn92(pretrained=pretrained) # conv1_new = nn.Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) # _w = encoder.blocks['conv1_1'].conv.state_dict() # _w['weight'] = torch.cat([0.5 * _w['weight'], 0.5 * _w['weight']], 1) # conv1_new.load_state_dict(_w) self.conv1 = nn.Sequential( encoder.blocks['conv1_1'].conv, # conv encoder.blocks['conv1_1'].bn, # bn encoder.blocks['conv1_1'].act, # relu ) self.conv2 = nn.Sequential( encoder.blocks['conv1_1'].pool, # maxpool *[b for k, b in encoder.blocks.items() if k.startswith('conv2_')] ) self.conv3 = nn.Sequential(*[b for k, b in encoder.blocks.items() if k.startswith('conv3_')]) self.conv4 = nn.Sequential(*[b for k, b in encoder.blocks.items() if k.startswith('conv4_')]) self.conv5 = nn.Sequential(*[b for k, b in encoder.blocks.items() if k.startswith('conv5_')]) def forward1(self, x): batch_size, C, H, W = x.shape enc1 = self.conv1(x) enc2 = self.conv2(enc1) enc3 = self.conv3(enc2) enc4 = self.conv4(enc3) enc5 = self.conv5(enc4) enc1 = (torch.cat(enc1, dim=1) if isinstance(enc1, tuple) else enc1) enc2 = (torch.cat(enc2, dim=1) if isinstance(enc2, tuple) else enc2) enc3 = (torch.cat(enc3, dim=1) if isinstance(enc3, tuple) else enc3) enc4 = (torch.cat(enc4, dim=1) if isinstance(enc4, tuple) else enc4) enc5 = (torch.cat(enc5, dim=1) if isinstance(enc5, tuple) else enc5) dec6 = self.conv6(F.interpolate(enc5, scale_factor=2)) dec6 = self.conv6_2(torch.cat([dec6, enc4], 1)) dec7 = self.conv7(F.interpolate(dec6, scale_factor=2)) dec7 = self.conv7_2(torch.cat([dec7, enc3], 1)) dec8 = self.conv8(F.interpolate(dec7, scale_factor=2)) dec8 = self.conv8_2(torch.cat([dec8, enc2], 1)) dec9 = self.conv9(F.interpolate(dec8, scale_factor=2)) dec9 = self.conv9_2(torch.cat([dec9, enc1], 1)) dec10 = self.conv10(F.interpolate(dec9, scale_factor=2)) return dec10 def forward(self, x): dec10_0 = self.forward1(x[:, :3, :, :]) dec10_1 = self.forward1(x[:, 3:, :, :]) dec10 = torch.cat([dec10_0, dec10_1], 1) return self.res(dec10) def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear): m.weight.data = nn.init.kaiming_normal_(m.weight.data) if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() class Res34_Unet_Loc(nn.Module): def __init__(self, pretrained=True, **kwargs): super(Res34_Unet_Loc, self).__init__() encoder_filters = [64, 64, 128, 256, 512] decoder_filters = np.asarray([48, 64, 96, 160, 320]) self.conv6 = ConvRelu(encoder_filters[-1], decoder_filters[-1]) self.conv6_2 = ConvRelu(decoder_filters[-1] + encoder_filters[-2], decoder_filters[-1]) self.conv7 = ConvRelu(decoder_filters[-1], decoder_filters[-2]) self.conv7_2 = ConvRelu(decoder_filters[-2] + encoder_filters[-3], decoder_filters[-2]) self.conv8 = ConvRelu(decoder_filters[-2], decoder_filters[-3]) self.conv8_2 = ConvRelu(decoder_filters[-3] + encoder_filters[-4], decoder_filters[-3]) self.conv9 = ConvRelu(decoder_filters[-3], decoder_filters[-4]) self.conv9_2 = ConvRelu(decoder_filters[-4] + encoder_filters[-5], decoder_filters[-4]) self.conv10 = ConvRelu(decoder_filters[-4], decoder_filters[-5]) self.res = nn.Conv2d(decoder_filters[-5], 1, 1, stride=1, padding=0) self._initialize_weights() encoder = torchvision.models.resnet34(pretrained=pretrained) self.conv1 = nn.Sequential( encoder.conv1, encoder.bn1, encoder.relu) self.conv2 = nn.Sequential( encoder.maxpool, encoder.layer1) self.conv3 = encoder.layer2 self.conv4 = encoder.layer3 self.conv5 = encoder.layer4 def forward(self, x): batch_size, C, H, W = x.shape enc1 = self.conv1(x) enc2 = self.conv2(enc1) enc3 = self.conv3(enc2) enc4 = self.conv4(enc3) enc5 = self.conv5(enc4) dec6 = self.conv6(F.interpolate(enc5, scale_factor=2)) dec6 = self.conv6_2(torch.cat([dec6, enc4 ], 1)) dec7 = self.conv7(F.interpolate(dec6, scale_factor=2)) dec7 = self.conv7_2(torch.cat([dec7, enc3 ], 1)) dec8 = self.conv8(F.interpolate(dec7, scale_factor=2)) dec8 = self.conv8_2(torch.cat([dec8, enc2 ], 1)) dec9 = self.conv9(F.interpolate(dec8, scale_factor=2)) dec9 = self.conv9_2(torch.cat([dec9, enc1 ], 1)) dec10 = self.conv10(F.interpolate(dec9, scale_factor=2)) return self.res(dec10) def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear): m.weight.data = nn.init.kaiming_normal_(m.weight.data) if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() class Res34_Unet_Double(nn.Module): def __init__(self, pretrained=True, **kwargs): super(Res34_Unet_Double, self).__init__() encoder_filters = [64, 64, 128, 256, 512] decoder_filters = np.asarray([48, 64, 96, 160, 320]) self.conv6 = ConvRelu(encoder_filters[-1], decoder_filters[-1]) self.conv6_2 = ConvRelu(decoder_filters[-1] + encoder_filters[-2], decoder_filters[-1]) self.conv7 = ConvRelu(decoder_filters[-1], decoder_filters[-2]) self.conv7_2 = ConvRelu(decoder_filters[-2] + encoder_filters[-3], decoder_filters[-2]) self.conv8 = ConvRelu(decoder_filters[-2], decoder_filters[-3]) self.conv8_2 = ConvRelu(decoder_filters[-3] + encoder_filters[-4], decoder_filters[-3]) self.conv9 = ConvRelu(decoder_filters[-3], decoder_filters[-4]) self.conv9_2 = ConvRelu(decoder_filters[-4] + encoder_filters[-5], decoder_filters[-4]) self.conv10 = ConvRelu(decoder_filters[-4], decoder_filters[-5]) self.res = nn.Conv2d(decoder_filters[-5] * 2, 5, 1, stride=1, padding=0) self._initialize_weights() encoder = torchvision.models.resnet34(pretrained=pretrained) self.conv1 = nn.Sequential( encoder.conv1, encoder.bn1, encoder.relu) self.conv2 = nn.Sequential( encoder.maxpool, encoder.layer1) self.conv3 = encoder.layer2 self.conv4 = encoder.layer3 self.conv5 = encoder.layer4 def forward1(self, x): batch_size, C, H, W = x.shape enc1 = self.conv1(x) enc2 = self.conv2(enc1) enc3 = self.conv3(enc2) enc4 = self.conv4(enc3) enc5 = self.conv5(enc4) dec6 = self.conv6(F.interpolate(enc5, scale_factor=2)) dec6 = self.conv6_2(torch.cat([dec6, enc4 ], 1)) dec7 = self.conv7(F.interpolate(dec6, scale_factor=2)) dec7 = self.conv7_2(torch.cat([dec7, enc3 ], 1)) dec8 = self.conv8(F.interpolate(dec7, scale_factor=2)) dec8 = self.conv8_2(torch.cat([dec8, enc2 ], 1)) dec9 = self.conv9(F.interpolate(dec8, scale_factor=2)) dec9 = self.conv9_2(torch.cat([dec9, enc1 ], 1)) dec10 = self.conv10(F.interpolate(dec9, scale_factor=2)) return dec10 def forward(self, x): dec10_0 = self.forward1(x[:, :3, :, :]) dec10_1 = self.forward1(x[:, 3:, :, :]) dec10 = torch.cat([dec10_0, dec10_1], 1) return self.res(dec10) def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear): m.weight.data = nn.init.kaiming_normal_(m.weight.data) if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() class SeNet154_Unet_Loc(nn.Module): def __init__(self, pretrained='imagenet', **kwargs): super(SeNet154_Unet_Loc, self).__init__() encoder_filters = [128, 256, 512, 1024, 2048] decoder_filters = np.asarray([48, 64, 96, 160, 320]) self.conv6 = ConvRelu(encoder_filters[-1], decoder_filters[-1]) self.conv6_2 = ConvRelu(decoder_filters[-1] + encoder_filters[-2], decoder_filters[-1]) self.conv7 = ConvRelu(decoder_filters[-1], decoder_filters[-2]) self.conv7_2 = ConvRelu(decoder_filters[-2] + encoder_filters[-3], decoder_filters[-2]) self.conv8 = ConvRelu(decoder_filters[-2], decoder_filters[-3]) self.conv8_2 = ConvRelu(decoder_filters[-3] + encoder_filters[-4], decoder_filters[-3]) self.conv9 = ConvRelu(decoder_filters[-3], decoder_filters[-4]) self.conv9_2 = ConvRelu(decoder_filters[-4] + encoder_filters[-5], decoder_filters[-4]) self.conv10 = ConvRelu(decoder_filters[-4], decoder_filters[-5]) self.res = nn.Conv2d(decoder_filters[-5], 1, 1, stride=1, padding=0) self._initialize_weights() encoder = senet154(pretrained=pretrained) # conv1_new = nn.Conv2d(9, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) # _w = encoder.layer0.conv1.state_dict() # _w['weight'] = torch.cat([0.8 * _w['weight'], 0.1 * _w['weight'], 0.1 * _w['weight']], 1) # conv1_new.load_state_dict(_w) self.conv1 = nn.Sequential(encoder.layer0.conv1, encoder.layer0.bn1, encoder.layer0.relu1, encoder.layer0.conv2, encoder.layer0.bn2, encoder.layer0.relu2, encoder.layer0.conv3, encoder.layer0.bn3, encoder.layer0.relu3) self.conv2 = nn.Sequential(encoder.pool, encoder.layer1) self.conv3 = encoder.layer2 self.conv4 = encoder.layer3 self.conv5 = encoder.layer4 def forward(self, x): batch_size, C, H, W = x.shape enc1 = self.conv1(x) enc2 = self.conv2(enc1) enc3 = self.conv3(enc2) enc4 = self.conv4(enc3) enc5 = self.conv5(enc4) dec6 = self.conv6(F.interpolate(enc5, scale_factor=2)) dec6 = self.conv6_2(torch.cat([dec6, enc4 ], 1)) dec7 = self.conv7(F.interpolate(dec6, scale_factor=2)) dec7 = self.conv7_2(torch.cat([dec7, enc3 ], 1)) dec8 = self.conv8(F.interpolate(dec7, scale_factor=2)) dec8 = self.conv8_2(torch.cat([dec8, enc2 ], 1)) dec9 = self.conv9(F.interpolate(dec8, scale_factor=2)) dec9 = self.conv9_2(torch.cat([dec9, enc1 ], 1)) dec10 = self.conv10(F.interpolate(dec9, scale_factor=2)) return self.res(dec10) def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear): m.weight.data = nn.init.kaiming_normal_(m.weight.data) if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() class SeNet154_Unet_Double(nn.Module): def __init__(self, pretrained='imagenet', **kwargs): super(SeNet154_Unet_Double, self).__init__() encoder_filters = [128, 256, 512, 1024, 2048] decoder_filters = np.asarray([48, 64, 96, 160, 320]) self.conv6 = ConvRelu(encoder_filters[-1], decoder_filters[-1]) self.conv6_2 = ConvRelu(decoder_filters[-1] + encoder_filters[-2], decoder_filters[-1]) self.conv7 = ConvRelu(decoder_filters[-1], decoder_filters[-2]) self.conv7_2 = ConvRelu(decoder_filters[-2] + encoder_filters[-3], decoder_filters[-2]) self.conv8 = ConvRelu(decoder_filters[-2], decoder_filters[-3]) self.conv8_2 = ConvRelu(decoder_filters[-3] + encoder_filters[-4], decoder_filters[-3]) self.conv9 = ConvRelu(decoder_filters[-3], decoder_filters[-4]) self.conv9_2 = ConvRelu(decoder_filters[-4] + encoder_filters[-5], decoder_filters[-4]) self.conv10 = ConvRelu(decoder_filters[-4], decoder_filters[-5]) self.res = nn.Conv2d(decoder_filters[-5] * 2, 5, 1, stride=1, padding=0) self._initialize_weights() encoder = senet154(pretrained=pretrained) # conv1_new = nn.Conv2d(9, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) # _w = encoder.layer0.conv1.state_dict() # _w['weight'] = torch.cat([0.8 * _w['weight'], 0.1 * _w['weight'], 0.1 * _w['weight']], 1) # conv1_new.load_state_dict(_w) self.conv1 = nn.Sequential(encoder.layer0.conv1, encoder.layer0.bn1, encoder.layer0.relu1, encoder.layer0.conv2, encoder.layer0.bn2, encoder.layer0.relu2, encoder.layer0.conv3, encoder.layer0.bn3, encoder.layer0.relu3) self.conv2 = nn.Sequential(encoder.pool, encoder.layer1) self.conv3 = encoder.layer2 self.conv4 = encoder.layer3 self.conv5 = encoder.layer4 def forward1(self, x): batch_size, C, H, W = x.shape enc1 = self.conv1(x) enc2 = self.conv2(enc1) enc3 = self.conv3(enc2) enc4 = self.conv4(enc3) enc5 = self.conv5(enc4) dec6 = self.conv6(F.interpolate(enc5, scale_factor=2)) dec6 = self.conv6_2(torch.cat([dec6, enc4 ], 1)) dec7 = self.conv7(F.interpolate(dec6, scale_factor=2)) dec7 = self.conv7_2(torch.cat([dec7, enc3 ], 1)) dec8 = self.conv8(F.interpolate(dec7, scale_factor=2)) dec8 = self.conv8_2(torch.cat([dec8, enc2 ], 1)) dec9 = self.conv9(F.interpolate(dec8, scale_factor=2)) dec9 = self.conv9_2(torch.cat([dec9, enc1 ], 1)) dec10 = self.conv10(F.interpolate(dec9, scale_factor=2)) return dec10 def forward(self, x): dec10_0 = self.forward1(x[:, :3, :, :]) dec10_1 = self.forward1(x[:, 3:, :, :]) dec10 = torch.cat([dec10_0, dec10_1], 1) return self.res(dec10) def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear): m.weight.data = nn.init.kaiming_normal_(m.weight.data) if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_()