sadjava's picture
changed to pipelines
fd52b7f
raw
history blame
No virus
6.17 kB
import torch.nn as nn
from torch.distributions.normal import Normal
from .constants import *
class Encoder(nn.Module):
'''
Encoder Class
Values:
im_chan: the number of channels of the output image, a scalar
hidden_dim: the inner dimension, a scalar
'''
def __init__(self, im_chan=3, output_chan=Z_DIM, hidden_dim=ENC_HIDDEN_DIM):
super(Encoder, self).__init__()
self.z_dim = output_chan
self.disc = nn.Sequential(
self.make_disc_block(im_chan, hidden_dim),
self.make_disc_block(hidden_dim, hidden_dim * 2),
self.make_disc_block(hidden_dim * 2, hidden_dim * 4),
self.make_disc_block(hidden_dim * 4, hidden_dim * 8),
self.make_disc_block(hidden_dim * 8, output_chan * 2, final_layer=True),
)
def make_disc_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False):
'''
Function to return a sequence of operations corresponding to a encoder block of the VAE,
corresponding to a convolution, a batchnorm (except for in the last layer), and an activation
Parameters:
input_channels: how many channels the input feature representation has
output_channels: how many channels the output feature representation should have
kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
stride: the stride of the convolution
final_layer: whether we're on the final layer (affects activation and batchnorm)
'''
if not final_layer:
return nn.Sequential(
nn.Conv2d(input_channels, output_channels, kernel_size, stride),
nn.BatchNorm2d(output_channels),
nn.LeakyReLU(0.2, inplace=True),
)
else:
return nn.Sequential(
nn.Conv2d(input_channels, output_channels, kernel_size, stride),
)
def forward(self, image):
'''
Function for completing a forward pass of the Encoder: Given an image tensor,
returns a 1-dimension tensor representing fake/real.
Parameters:
image: a flattened image tensor with dimension (im_dim)
'''
disc_pred = self.disc(image)
encoding = disc_pred.view(len(disc_pred), -1)
# The stddev output is treated as the log of the variance of the normal
# distribution by convention and for numerical stability
return encoding[:, :self.z_dim], encoding[:, self.z_dim:].exp()
class Decoder(nn.Module):
'''
Decoder Class
Values:
z_dim: the dimension of the noise vector, a scalar
im_chan: the number of channels of the output image, a scalar
hidden_dim: the inner dimension, a scalar
'''
def __init__(self, z_dim=Z_DIM, im_chan=3, hidden_dim=DEC_HIDDEN_DIM):
super(Decoder, self).__init__()
self.z_dim = z_dim
self.gen = nn.Sequential(
self.make_gen_block(z_dim, hidden_dim * 16),
self.make_gen_block(hidden_dim * 16, hidden_dim * 8, kernel_size=4, stride=1),
self.make_gen_block(hidden_dim * 8, hidden_dim * 4),
self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4),
self.make_gen_block(hidden_dim * 2, hidden_dim, kernel_size=4),
self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),
)
def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):
'''
Function to return a sequence of operations corresponding to a Decoder block of the VAE,
corresponding to a transposed convolution, a batchnorm (except for in the last layer), and an activation
Parameters:
input_channels: how many channels the input feature representation has
output_channels: how many channels the output feature representation should have
kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
stride: the stride of the convolution
final_layer: whether we're on the final layer (affects activation and batchnorm)
'''
if not final_layer:
return nn.Sequential(
nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True),
)
else:
return nn.Sequential(
nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
nn.Sigmoid(),
)
def forward(self, noise):
'''
Function for completing a forward pass of the Decoder: Given a noise vector,
returns a generated image.
Parameters:
noise: a noise tensor with dimensions (batch_size, z_dim)
'''
x = noise.view(len(noise), self.z_dim, 1, 1)
return self.gen(x)
class VAE(nn.Module):
'''
VAE Class
Values:
z_dim: the dimension of the noise vector, a scalar
im_chan: the number of channels of the output image, a scalar
MNIST is black-and-white, so that's our default
hidden_dim: the inner dimension, a scalar
'''
def __init__(self, z_dim=Z_DIM, im_chan=3):
super(VAE, self).__init__()
self.z_dim = z_dim
self.encode = Encoder(im_chan, z_dim)
self.decode = Decoder(z_dim, im_chan)
def forward(self, images):
'''
Function for completing a forward pass of the Decoder: Given a noise vector,
returns a generated image.
Parameters:
images: an image tensor with dimensions (batch_size, im_chan, im_height, im_width)
Returns:
decoding: the autoencoded image
q_dist: the z-distribution of the encoding
'''
q_mean, q_stddev = self.encode(images)
q_dist = Normal(q_mean, q_stddev)
z_sample = q_dist.rsample() # Sample once from each distribution, using the `rsample` notation
decoding = self.decode(z_sample)
return decoding, q_dist