import streamlit as st import torch.nn as nn import torchvision.transforms as T from torchvision.utils import make_grid import torch device = "cpu" @torch.inference_mode() def inference_gan(): generator = torch.jit.load("mnist-G-torchscript.pt").to(device) x = torch.randn(30, 256, device=device) y = generator(x) y = y.view(-1, 1, 28, 28) # reshape y to have 1 channel grid = make_grid(y.detach(), nrow=8) img = T.functional.to_pil_image(grid) return img @torch.inference_mode() def inference_dcgan(): generator = torch.jit.load("animefacedataset-G2-torchscript.pt").to(device) def denorm(img_tensors): stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5) return img_tensors * stats[1][0] + stats[0][0] x = torch.randn(64, 128, 1, 1, device=device) y = generator(x) y = y.view(-1, 3, 64, 64) # reshape y to have 3 channels grid = make_grid(denorm(y.detach()), nrow=8) img = T.functional.to_pil_image(grid) return img def inference_both(): inference_gan() inference_dcgan() st.markdown("# Image Generation with GANs and DCGANs") st.button("Generate Images", on_click=inference_both) st.image(inference_dcgan(), caption="", use_column_width=True) st.image(inference_gan(), caption="", use_column_width=True)