Spaces:
Runtime error
Runtime error
File size: 6,171 Bytes
fd52b7f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import torch.nn as nn
from torch.distributions.normal import Normal
from .constants import *
class Encoder(nn.Module):
'''
Encoder Class
Values:
im_chan: the number of channels of the output image, a scalar
hidden_dim: the inner dimension, a scalar
'''
def __init__(self, im_chan=3, output_chan=Z_DIM, hidden_dim=ENC_HIDDEN_DIM):
super(Encoder, self).__init__()
self.z_dim = output_chan
self.disc = nn.Sequential(
self.make_disc_block(im_chan, hidden_dim),
self.make_disc_block(hidden_dim, hidden_dim * 2),
self.make_disc_block(hidden_dim * 2, hidden_dim * 4),
self.make_disc_block(hidden_dim * 4, hidden_dim * 8),
self.make_disc_block(hidden_dim * 8, output_chan * 2, final_layer=True),
)
def make_disc_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False):
'''
Function to return a sequence of operations corresponding to a encoder block of the VAE,
corresponding to a convolution, a batchnorm (except for in the last layer), and an activation
Parameters:
input_channels: how many channels the input feature representation has
output_channels: how many channels the output feature representation should have
kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
stride: the stride of the convolution
final_layer: whether we're on the final layer (affects activation and batchnorm)
'''
if not final_layer:
return nn.Sequential(
nn.Conv2d(input_channels, output_channels, kernel_size, stride),
nn.BatchNorm2d(output_channels),
nn.LeakyReLU(0.2, inplace=True),
)
else:
return nn.Sequential(
nn.Conv2d(input_channels, output_channels, kernel_size, stride),
)
def forward(self, image):
'''
Function for completing a forward pass of the Encoder: Given an image tensor,
returns a 1-dimension tensor representing fake/real.
Parameters:
image: a flattened image tensor with dimension (im_dim)
'''
disc_pred = self.disc(image)
encoding = disc_pred.view(len(disc_pred), -1)
# The stddev output is treated as the log of the variance of the normal
# distribution by convention and for numerical stability
return encoding[:, :self.z_dim], encoding[:, self.z_dim:].exp()
class Decoder(nn.Module):
'''
Decoder Class
Values:
z_dim: the dimension of the noise vector, a scalar
im_chan: the number of channels of the output image, a scalar
hidden_dim: the inner dimension, a scalar
'''
def __init__(self, z_dim=Z_DIM, im_chan=3, hidden_dim=DEC_HIDDEN_DIM):
super(Decoder, self).__init__()
self.z_dim = z_dim
self.gen = nn.Sequential(
self.make_gen_block(z_dim, hidden_dim * 16),
self.make_gen_block(hidden_dim * 16, hidden_dim * 8, kernel_size=4, stride=1),
self.make_gen_block(hidden_dim * 8, hidden_dim * 4),
self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4),
self.make_gen_block(hidden_dim * 2, hidden_dim, kernel_size=4),
self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),
)
def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):
'''
Function to return a sequence of operations corresponding to a Decoder block of the VAE,
corresponding to a transposed convolution, a batchnorm (except for in the last layer), and an activation
Parameters:
input_channels: how many channels the input feature representation has
output_channels: how many channels the output feature representation should have
kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
stride: the stride of the convolution
final_layer: whether we're on the final layer (affects activation and batchnorm)
'''
if not final_layer:
return nn.Sequential(
nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True),
)
else:
return nn.Sequential(
nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
nn.Sigmoid(),
)
def forward(self, noise):
'''
Function for completing a forward pass of the Decoder: Given a noise vector,
returns a generated image.
Parameters:
noise: a noise tensor with dimensions (batch_size, z_dim)
'''
x = noise.view(len(noise), self.z_dim, 1, 1)
return self.gen(x)
class VAE(nn.Module):
'''
VAE Class
Values:
z_dim: the dimension of the noise vector, a scalar
im_chan: the number of channels of the output image, a scalar
MNIST is black-and-white, so that's our default
hidden_dim: the inner dimension, a scalar
'''
def __init__(self, z_dim=Z_DIM, im_chan=3):
super(VAE, self).__init__()
self.z_dim = z_dim
self.encode = Encoder(im_chan, z_dim)
self.decode = Decoder(z_dim, im_chan)
def forward(self, images):
'''
Function for completing a forward pass of the Decoder: Given a noise vector,
returns a generated image.
Parameters:
images: an image tensor with dimensions (batch_size, im_chan, im_height, im_width)
Returns:
decoding: the autoencoded image
q_dist: the z-distribution of the encoding
'''
q_mean, q_stddev = self.encode(images)
q_dist = Normal(q_mean, q_stddev)
z_sample = q_dist.rsample() # Sample once from each distribution, using the `rsample` notation
decoding = self.decode(z_sample)
return decoding, q_dist |