minhalvp's picture
Update app.py
import streamlit as st
import torch.nn as nn
import torchvision.transforms as T
from torchvision.utils import make_grid
import torch
device = "cpu"
def inference_gan():
generator = torch.jit.load("mnist-G-torchscript.pt").to(device)
x = torch.randn(30, 256, device=device)
y = generator(x)
y = y.view(-1, 1, 28, 28) # reshape y to have 1 channel
grid = make_grid(y.detach(), nrow=8)
img = T.functional.to_pil_image(grid)
return img
def inference_dcgan():
generator = torch.jit.load("animefacedataset-G2-torchscript.pt").to(device)
def denorm(img_tensors):
stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
return img_tensors * stats[1][0] + stats[0][0]
x = torch.randn(64, 128, 1, 1, device=device)
y = generator(x)
y = y.view(-1, 3, 64, 64) # reshape y to have 3 channels
grid = make_grid(denorm(y.detach()), nrow=8)
img = T.functional.to_pil_image(grid)
return img
def inference_both():
st.markdown("# Image Generation with GANs and DCGANs")
st.button("Generate Images", on_click=inference_both)
st.image(inference_dcgan(), caption="", use_column_width=True)
st.image(inference_gan(), caption="", use_column_width=True)