|
import streamlit as st |
|
from streamlit_drawable_canvas import st_canvas |
|
import os |
|
import utils |
|
from PIL import Image |
|
|
|
|
|
st.set_page_config("VAE MNIST Pytorch Lightning") |
|
st.title("VAE Playground") |
|
|
|
|
|
|
|
st.markdown( |
|
"This is a simple streamlit app to showcase how a simple VAEs." |
|
) |
|
|
|
def load_model_files(): |
|
files = os.listdir("./models/") |
|
|
|
files = [i for i in files if ".ckpt" in i] |
|
clean_names = [utils.parse_model_file_name(name) for name in files] |
|
return {k: v for k, v in zip(clean_names, files)} |
|
|
|
|
|
file_name_map = load_model_files() |
|
files = list(file_name_map.keys()) |
|
|
|
st.header("🖼️ Image Reconstruction", "recon") |
|
|
|
with st.form("reconstruction"): |
|
model_name = st.selectbox("Choose Model:", files, |
|
key="recon_model_select") or "conv_vae" |
|
recon_model_name = file_name_map[model_name] |
|
recon_canvas = st_canvas( |
|
|
|
fill_color="rgba(255, 165, 0, 0.3)", |
|
stroke_width=8, |
|
stroke_color="#FFFFFF", |
|
background_color="#000000", |
|
update_streamlit=True, |
|
height=150, |
|
width=150, |
|
drawing_mode="freedraw", |
|
key="recon_canvas", |
|
) |
|
submit = st.form_submit_button("Perform Reconstruction") |
|
if submit: |
|
recon_model = utils.load_model(recon_model_name) |
|
inp_tens = utils.canvas_to_tensor(recon_canvas) |
|
_, _, out = recon_model(inp_tens) |
|
out = (out+1)/2 |
|
out_img = utils.resize_img(utils.tensor_to_img(out), 150, 150) |
|
if submit: |
|
st.image(out_img) |
|
|
|
|