Spaces:
Runtime error
Runtime error
import streamlit as st | |
import torch.nn as nn | |
import torchvision.transforms as T | |
from torchvision.utils import make_grid | |
import torch | |
device = "cpu" | |
def inference_gan(): | |
generator = torch.jit.load("mnist-G-torchscript.pt").to(device) | |
x = torch.randn(30, 256, device=device) | |
y = generator(x) | |
y = y.view(-1, 1, 28, 28) # reshape y to have 1 channel | |
grid = make_grid(y.detach(), nrow=8) | |
img = T.functional.to_pil_image(grid) | |
return img | |
def inference_dcgan(): | |
generator = torch.jit.load("animefacedataset-G2-torchscript.pt").to(device) | |
def denorm(img_tensors): | |
stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5) | |
return img_tensors * stats[1][0] + stats[0][0] | |
x = torch.randn(64, 128, 1, 1, device=device) | |
y = generator(x) | |
y = y.view(-1, 3, 64, 64) # reshape y to have 3 channels | |
grid = make_grid(denorm(y.detach()), nrow=8) | |
img = T.functional.to_pil_image(grid) | |
return img | |
def inference_both(): | |
inference_gan() | |
inference_dcgan() | |
st.markdown("# Image Generation with GANs and DCGANs") | |
st.button("Generate Images", on_click=inference_both) | |
st.image(inference_dcgan(), caption="", use_column_width=True) | |
st.image(inference_gan(), caption="", use_column_width=True) |