Spaces:
Runtime error
Runtime error
File size: 6,500 Bytes
479c88d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import torch
import torch.nn as nn
from constants import *
"""
Class for custom activation.
"""
class SymReLU(nn.Module):
def __init__(self, inplace: bool = False):
super().__init__()
self.inplace = inplace
def forward(self, input):
return torch.min(torch.max(input, -torch.ones_like(input)), torch.ones_like(input))
def extra_repr(self) -> str:
inplace_str = 'inplace=True' if self.inplace else ''
return inplace_str
"""
Class implementing YOLO-Stamp architecture described in https://link.springer.com/article/10.1134/S1054661822040046.
"""
class YOLOStamp(nn.Module):
def __init__(
self,
anchors=ANCHORS,
in_channels=3,
):
super().__init__()
self.register_buffer('anchors', torch.tensor(anchors))
self.act = SymReLU()
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.norm1 = nn.BatchNorm2d(num_features=8)
self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.norm2 = nn.BatchNorm2d(num_features=16)
self.conv3 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.norm3 = nn.BatchNorm2d(num_features=16)
self.conv4 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.norm4 = nn.BatchNorm2d(num_features=16)
self.conv5 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.norm5 = nn.BatchNorm2d(num_features=16)
self.conv6 = nn.Conv2d(in_channels=16, out_channels=24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.norm6 = nn.BatchNorm2d(num_features=24)
self.conv7 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.norm7 = nn.BatchNorm2d(num_features=24)
self.conv8 = nn.Conv2d(in_channels=24, out_channels=48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.norm8 = nn.BatchNorm2d(num_features=48)
self.conv9 = nn.Conv2d(in_channels=48, out_channels=48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.norm9 = nn.BatchNorm2d(num_features=48)
self.conv10 = nn.Conv2d(in_channels=48, out_channels=48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.norm10 = nn.BatchNorm2d(num_features=48)
self.conv11 = nn.Conv2d(in_channels=48, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.norm11 = nn.BatchNorm2d(num_features=64)
self.conv12 = nn.Conv2d(in_channels=64, out_channels=256, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
self.norm12 = nn.BatchNorm2d(num_features=256)
self.conv13 = nn.Conv2d(in_channels=256, out_channels=len(anchors) * 5, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
def forward(self, x, head=True):
x = x.type(self.conv1.weight.dtype)
x = self.act(self.pool(self.norm1(self.conv1(x))))
x = self.act(self.pool(self.norm2(self.conv2(x))))
x = self.act(self.pool(self.norm3(self.conv3(x))))
x = self.act(self.pool(self.norm4(self.conv4(x))))
x = self.act(self.pool(self.norm5(self.conv5(x))))
x = self.act(self.norm6(self.conv6(x)))
x = self.act(self.norm7(self.conv7(x)))
x = self.act(self.pool(self.norm8(self.conv8(x))))
x = self.act(self.norm9(self.conv9(x)))
x = self.act(self.norm10(self.conv10(x)))
x = self.act(self.norm11(self.conv11(x)))
x = self.act(self.norm12(self.conv12(x)))
x = self.conv13(x)
nb, _, nh, nw= x.shape
x = x.permute(0, 2, 3, 1).view(nb, nh, nw, self.anchors.shape[0], 5)
return x
class Encoder(torch.nn.Module):
'''
Encoder Class
Values:
im_chan: the number of channels of the output image, a scalar
hidden_dim: the inner dimension, a scalar
'''
def __init__(self, im_chan=3, output_chan=Z_DIM, hidden_dim=ENC_HIDDEN_DIM):
super(Encoder, self).__init__()
self.z_dim = output_chan
self.disc = torch.nn.Sequential(
self.make_disc_block(im_chan, hidden_dim),
self.make_disc_block(hidden_dim, hidden_dim * 2),
self.make_disc_block(hidden_dim * 2, hidden_dim * 4),
self.make_disc_block(hidden_dim * 4, hidden_dim * 8),
self.make_disc_block(hidden_dim * 8, output_chan * 2, final_layer=True),
)
def make_disc_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False):
'''
Function to return a sequence of operations corresponding to a encoder block of the VAE,
corresponding to a convolution, a batchnorm (except for in the last layer), and an activation
Parameters:
input_channels: how many channels the input feature representation has
output_channels: how many channels the output feature representation should have
kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
stride: the stride of the convolution
final_layer: whether we're on the final layer (affects activation and batchnorm)
'''
if not final_layer:
return torch.nn.Sequential(
torch.nn.Conv2d(input_channels, output_channels, kernel_size, stride),
torch.nn.BatchNorm2d(output_channels),
torch.nn.LeakyReLU(0.2, inplace=True),
)
else:
return torch.nn.Sequential(
torch.nn.Conv2d(input_channels, output_channels, kernel_size, stride),
)
def forward(self, image):
'''
Function for completing a forward pass of the Encoder: Given an image tensor,
returns a 1-dimension tensor representing fake/real.
Parameters:
image: a flattened image tensor with dimension (im_dim)
'''
disc_pred = self.disc(image)
encoding = disc_pred.view(len(disc_pred), -1)
# The stddev output is treated as the log of the variance of the normal
# distribution by convention and for numerical stability
return encoding[:, :self.z_dim], encoding[:, self.z_dim:].exp() |