Spaces:
Runtime error
Runtime error
Commit
·
f61800c
1
Parent(s):
a98ddcd
Create new file
Browse files
UNet.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
height, width = 512, 512
|
2 |
+
class Block(Module):
|
3 |
+
def __init__(self, inChannels, outChannels):
|
4 |
+
super().__init__()
|
5 |
+
# store the convolution and RELU layers
|
6 |
+
self.conv1 = Conv2d(inChannels, outChannels, 3)
|
7 |
+
self.relu = ReLU()
|
8 |
+
self.conv2 = Conv2d(outChannels, outChannels, 3)
|
9 |
+
def forward(self, x):
|
10 |
+
# apply CONV => RELU => CONV block to the inputs and return it
|
11 |
+
return self.conv2(self.relu(self.conv1(x)))
|
12 |
+
|
13 |
+
class Encoder(Module):
|
14 |
+
def __init__(self, channels=(3, 16, 32, 64)):
|
15 |
+
super().__init__()
|
16 |
+
# store the encoder blocks and maxpooling layer
|
17 |
+
self.encBlocks = ModuleList([Block(channels[i], channels[i + 1]) for i in range(len(channels) - 1)])
|
18 |
+
self.pool = MaxPool2d(2)
|
19 |
+
def forward(self, x):
|
20 |
+
# initialize an empty list to store the intermediate outputs
|
21 |
+
blockOutputs = []
|
22 |
+
# loop through the encoder blocks
|
23 |
+
for block in self.encBlocks:
|
24 |
+
# pass the inputs through the current encoder block, store
|
25 |
+
# the outputs, and then apply maxpooling on the output
|
26 |
+
x = block(x)
|
27 |
+
blockOutputs.append(x)
|
28 |
+
x = self.pool(x)
|
29 |
+
# return the list containing the intermediate outputs
|
30 |
+
return blockOutputs
|
31 |
+
|
32 |
+
class Decoder(Module):
|
33 |
+
def __init__(self, channels=(64, 32, 16)):
|
34 |
+
super().__init__()
|
35 |
+
# initialize the number of channels, upsampler blocks, and
|
36 |
+
# decoder blocks
|
37 |
+
self.channels = channels
|
38 |
+
self.upconvs = ModuleList([ConvTranspose2d(channels[i], channels[i + 1], 2, 2) for i in range(len(channels) - 1)])
|
39 |
+
self.dec_blocks = ModuleList([Block(channels[i], channels[i + 1]) for i in range(len(channels) - 1)])
|
40 |
+
def forward(self, x, encFeatures):
|
41 |
+
# loop through the number of channels
|
42 |
+
for i in range(len(self.channels) - 1):
|
43 |
+
# pass the inputs through the upsampler blocks
|
44 |
+
x = self.upconvs[i](x)
|
45 |
+
# crop the current features from the encoder blocks,
|
46 |
+
# concatenate them with the current upsampled features,
|
47 |
+
# and pass the concatenated output through the current
|
48 |
+
# decoder block
|
49 |
+
encFeat = self.crop(encFeatures[i], x)
|
50 |
+
x = torch.cat([x, encFeat], dim=1)
|
51 |
+
x = self.dec_blocks[i](x)
|
52 |
+
# return the final decoder output
|
53 |
+
return x
|
54 |
+
def crop(self, encFeatures, x):
|
55 |
+
# grab the dimensions of the inputs, and crop the encoder
|
56 |
+
# features to match the dimensions
|
57 |
+
(_, _, H, W) = x.shape
|
58 |
+
encFeatures = CenterCrop([H, W])(encFeatures)
|
59 |
+
# return the cropped features
|
60 |
+
return encFeatures
|
61 |
+
|
62 |
+
class UNet(Module):
|
63 |
+
def __init__(self, encChannels=(3, 64, 128, 256, 512, 1024), decChannels=(1024, 512, 256, 128, 64),
|
64 |
+
nbClasses=1, retainDim=True, outSize=(height, width)):
|
65 |
+
super().__init__()
|
66 |
+
# initialize the encoder and decoder
|
67 |
+
self.encoder = Encoder(encChannels)
|
68 |
+
self.decoder = Decoder(decChannels)
|
69 |
+
# initialize the regression head and store the class variables
|
70 |
+
self.head = Conv2d(decChannels[-1], nbClasses, 1)
|
71 |
+
self.retainDim = retainDim
|
72 |
+
self.outSize = outSize
|
73 |
+
def forward(self, x):
|
74 |
+
# grab the features from the encoder
|
75 |
+
encFeatures = self.encoder(x)
|
76 |
+
# pass the encoder features through decoder making sure that
|
77 |
+
# their dimensions are suited for concatenation
|
78 |
+
decFeatures = self.decoder(encFeatures[::-1][0], encFeatures[::-1][1:])
|
79 |
+
# pass the decoder features through the regression head to
|
80 |
+
# obtain the segmentation mask
|
81 |
+
map_ = self.head(decFeatures)
|
82 |
+
# check to see if we are retaining the original output
|
83 |
+
# dimensions and if so, then resize the output to match them
|
84 |
+
if self.retainDim:
|
85 |
+
map_ = F.interpolate(map_, self.outSize)
|
86 |
+
# return the segmentation map
|
87 |
+
return map_
|