import torch class DoubleConv(torch.nn.Module): """ Helper Class which implements the intermediate Convolutions """ def __init__(self, in_channels, out_channels): super().__init__() self.step = torch.nn.Sequential(torch.nn.Conv3d(in_channels, out_channels, 3, padding=1), torch.nn.ReLU(), torch.nn.Conv3d(out_channels, out_channels, 3, padding=1), torch.nn.ReLU()) def forward(self, X): return self.step(X) class UNet(torch.nn.Module): """ This class implements a UNet for the Segmentation We use 3 down- and 3 UpConvolutions and two Convolutions in each step """ def __init__(self): """Sets up the U-Net Structure """ super().__init__() ############# DOWN SAMPLING ##################### self.layer1 = DoubleConv(1, 32) self.layer2 = DoubleConv(32, 64) self.layer3 = DoubleConv(64, 128) self.layer4 = DoubleConv(128, 256) ######################################### ############## UP SAMPLING ####################### self.layer5 = DoubleConv(256 + 128, 128) self.layer6 = DoubleConv(128+64, 64) self.layer7 = DoubleConv(64+32, 32) self.layer8 = torch.nn.Conv3d(32, 6, 1) # Output: 5 values -> background, upper jaw, lower jaw,upper teeth, lower teeth, artery ######################################### self.maxpool = torch.nn.MaxPool3d(2) def forward(self, x): ####### DownConv 1######### x1 = self.layer1(x) x1m = self.maxpool(x1) ########################### ####### DownConv 2######### x2 = self.layer2(x1m) x2m = self.maxpool(x2) ########################### ####### DownConv 3######### x3 = self.layer3(x2m) x3m = self.maxpool(x3) ########################### ##### Intermediate Layer ## x4 = self.layer4(x3m) ########################### ####### UpCONV 1######### x5 = torch.nn.Upsample(scale_factor=2, mode="trilinear")(x4) # Upsample with a factor of 2 x5 = torch.cat([x5, x3], dim=1) # Skip-Connection x5 = self.layer5(x5) ########################### ####### UpCONV 2######### x6 = torch.nn.Upsample(scale_factor=2, mode="trilinear")(x5) x6 = torch.cat([x6, x2], dim=1) # Skip-Connection AKA downsampling x6 = self.layer6(x6) ########################### ####### UpCONV 3######### x7 = torch.nn.Upsample(scale_factor=2, mode="trilinear")(x6) x7 = torch.cat([x7, x1], dim=1) x7 = self.layer7(x7) ########################### ####### Predicted segmentation######### ret = self.layer8(x7) return ret