medical_imaging_segmentation / apps /project_model2.py
Kukulauren's picture
models and proprocessing
d679c76 verified
raw
history blame
3.05 kB
import torch
class DoubleConv(torch.nn.Module):
"""
Helper Class which implements the intermediate Convolutions
"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.step = torch.nn.Sequential(torch.nn.Conv3d(in_channels, out_channels, 3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv3d(out_channels, out_channels, 3, padding=1),
torch.nn.ReLU())
def forward(self, X):
return self.step(X)
class UNet(torch.nn.Module):
"""
This class implements a UNet for the Segmentation
We use 3 down- and 3 UpConvolutions and two Convolutions in each step
"""
def __init__(self):
"""Sets up the U-Net Structure
"""
super().__init__()
############# DOWN SAMPLING #####################
self.layer1 = DoubleConv(1, 32)
self.layer2 = DoubleConv(32, 64)
self.layer3 = DoubleConv(64, 128)
self.layer4 = DoubleConv(128, 256)
#########################################
############## UP SAMPLING #######################
self.layer5 = DoubleConv(256 + 128, 128)
self.layer6 = DoubleConv(128+64, 64)
self.layer7 = DoubleConv(64+32, 32)
self.layer8 = torch.nn.Conv3d(32, 6, 1) # Output: 5 values -> background, upper jaw, lower jaw,upper teeth, lower teeth, artery
#########################################
self.maxpool = torch.nn.MaxPool3d(2)
def forward(self, x):
####### DownConv 1#########
x1 = self.layer1(x)
x1m = self.maxpool(x1)
###########################
####### DownConv 2#########
x2 = self.layer2(x1m)
x2m = self.maxpool(x2)
###########################
####### DownConv 3#########
x3 = self.layer3(x2m)
x3m = self.maxpool(x3)
###########################
##### Intermediate Layer ##
x4 = self.layer4(x3m)
###########################
####### UpCONV 1#########
x5 = torch.nn.Upsample(scale_factor=2, mode="trilinear")(x4) # Upsample with a factor of 2
x5 = torch.cat([x5, x3], dim=1) # Skip-Connection
x5 = self.layer5(x5)
###########################
####### UpCONV 2#########
x6 = torch.nn.Upsample(scale_factor=2, mode="trilinear")(x5)
x6 = torch.cat([x6, x2], dim=1) # Skip-Connection AKA downsampling
x6 = self.layer6(x6)
###########################
####### UpCONV 3#########
x7 = torch.nn.Upsample(scale_factor=2, mode="trilinear")(x6)
x7 = torch.cat([x7, x1], dim=1)
x7 = self.layer7(x7)
###########################
####### Predicted segmentation#########
ret = self.layer8(x7)
return ret