Spaces:
Runtime error
Runtime error
import torch.nn as nn | |
from torch.distributions.normal import Normal | |
from .constants import * | |
class Encoder(nn.Module): | |
''' | |
Encoder Class | |
Values: | |
im_chan: the number of channels of the output image, a scalar | |
hidden_dim: the inner dimension, a scalar | |
''' | |
def __init__(self, im_chan=3, output_chan=Z_DIM, hidden_dim=ENC_HIDDEN_DIM): | |
super(Encoder, self).__init__() | |
self.z_dim = output_chan | |
self.disc = nn.Sequential( | |
self.make_disc_block(im_chan, hidden_dim), | |
self.make_disc_block(hidden_dim, hidden_dim * 2), | |
self.make_disc_block(hidden_dim * 2, hidden_dim * 4), | |
self.make_disc_block(hidden_dim * 4, hidden_dim * 8), | |
self.make_disc_block(hidden_dim * 8, output_chan * 2, final_layer=True), | |
) | |
def make_disc_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False): | |
''' | |
Function to return a sequence of operations corresponding to a encoder block of the VAE, | |
corresponding to a convolution, a batchnorm (except for in the last layer), and an activation | |
Parameters: | |
input_channels: how many channels the input feature representation has | |
output_channels: how many channels the output feature representation should have | |
kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size) | |
stride: the stride of the convolution | |
final_layer: whether we're on the final layer (affects activation and batchnorm) | |
''' | |
if not final_layer: | |
return nn.Sequential( | |
nn.Conv2d(input_channels, output_channels, kernel_size, stride), | |
nn.BatchNorm2d(output_channels), | |
nn.LeakyReLU(0.2, inplace=True), | |
) | |
else: | |
return nn.Sequential( | |
nn.Conv2d(input_channels, output_channels, kernel_size, stride), | |
) | |
def forward(self, image): | |
''' | |
Function for completing a forward pass of the Encoder: Given an image tensor, | |
returns a 1-dimension tensor representing fake/real. | |
Parameters: | |
image: a flattened image tensor with dimension (im_dim) | |
''' | |
disc_pred = self.disc(image) | |
encoding = disc_pred.view(len(disc_pred), -1) | |
# The stddev output is treated as the log of the variance of the normal | |
# distribution by convention and for numerical stability | |
return encoding[:, :self.z_dim], encoding[:, self.z_dim:].exp() | |
class Decoder(nn.Module): | |
''' | |
Decoder Class | |
Values: | |
z_dim: the dimension of the noise vector, a scalar | |
im_chan: the number of channels of the output image, a scalar | |
hidden_dim: the inner dimension, a scalar | |
''' | |
def __init__(self, z_dim=Z_DIM, im_chan=3, hidden_dim=DEC_HIDDEN_DIM): | |
super(Decoder, self).__init__() | |
self.z_dim = z_dim | |
self.gen = nn.Sequential( | |
self.make_gen_block(z_dim, hidden_dim * 16), | |
self.make_gen_block(hidden_dim * 16, hidden_dim * 8, kernel_size=4, stride=1), | |
self.make_gen_block(hidden_dim * 8, hidden_dim * 4), | |
self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4), | |
self.make_gen_block(hidden_dim * 2, hidden_dim, kernel_size=4), | |
self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True), | |
) | |
def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False): | |
''' | |
Function to return a sequence of operations corresponding to a Decoder block of the VAE, | |
corresponding to a transposed convolution, a batchnorm (except for in the last layer), and an activation | |
Parameters: | |
input_channels: how many channels the input feature representation has | |
output_channels: how many channels the output feature representation should have | |
kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size) | |
stride: the stride of the convolution | |
final_layer: whether we're on the final layer (affects activation and batchnorm) | |
''' | |
if not final_layer: | |
return nn.Sequential( | |
nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride), | |
nn.BatchNorm2d(output_channels), | |
nn.ReLU(inplace=True), | |
) | |
else: | |
return nn.Sequential( | |
nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride), | |
nn.Sigmoid(), | |
) | |
def forward(self, noise): | |
''' | |
Function for completing a forward pass of the Decoder: Given a noise vector, | |
returns a generated image. | |
Parameters: | |
noise: a noise tensor with dimensions (batch_size, z_dim) | |
''' | |
x = noise.view(len(noise), self.z_dim, 1, 1) | |
return self.gen(x) | |
class VAE(nn.Module): | |
''' | |
VAE Class | |
Values: | |
z_dim: the dimension of the noise vector, a scalar | |
im_chan: the number of channels of the output image, a scalar | |
MNIST is black-and-white, so that's our default | |
hidden_dim: the inner dimension, a scalar | |
''' | |
def __init__(self, z_dim=Z_DIM, im_chan=3): | |
super(VAE, self).__init__() | |
self.z_dim = z_dim | |
self.encode = Encoder(im_chan, z_dim) | |
self.decode = Decoder(z_dim, im_chan) | |
def forward(self, images): | |
''' | |
Function for completing a forward pass of the Decoder: Given a noise vector, | |
returns a generated image. | |
Parameters: | |
images: an image tensor with dimensions (batch_size, im_chan, im_height, im_width) | |
Returns: | |
decoding: the autoencoded image | |
q_dist: the z-distribution of the encoding | |
''' | |
q_mean, q_stddev = self.encode(images) | |
q_dist = Normal(q_mean, q_stddev) | |
z_sample = q_dist.rsample() # Sample once from each distribution, using the `rsample` notation | |
decoding = self.decode(z_sample) | |
return decoding, q_dist |