patrickcleeve's picture
Update app.py
dbb8c20 verified
import streamlit as st
from fibsem.segmentation.model import load_model
from fibsem.structures import FibsemImage
import matplotlib.pyplot as plt
import glob
from PIL import Image
import numpy as np
st.set_page_config(layout="wide")
st.title("Autolamella Demo")
st.write("This is a demo space of the Autolamella models. Upload FIBSEM (tiff) image, and select a model to segment them.")
# get filenames
filenames = st.sidebar.file_uploader("Upload an image", type=["tiff", "tif"], accept_multiple_files=True)
# get model
checkpoint = st.sidebar.selectbox(
"Select a model checkpoint",
[
"autolamella-mega-20240107.pt",
"autolamella-waffle-20240107.pt",
"autolamella-serial-liftout-20240107.pt",
],)
st.sidebar.header("Available Models")
st.sidebar.write("""The following models are available for inference. They are trained on different datasets, and may perform differently on different samples.""")
# write a markdown list, listing each of the models
st.sidebar.write("""
* autolamella-waffle*
* autolamella-serial-liftout*
* autolamella-mega*""")
st.sidebar.write("If you have a new sample, try all of them and see which one works best.""")
# get default data from path
if len(filenames) == 0:
st.write("No files uploaded, using default data")
# get all tiff files in current directory
filenames = sorted(glob.glob("example/*.tif"))
st.header(f"Segmentation Results")
# load model
model = load_model(checkpoint)
if filenames:
cols = st.columns(4)
for i, fname in enumerate(filenames):
col_id = i % 4
cols[col_id].write(f"#### File: {fname} ({i+1}/{len(filenames)})")
cols[col_id].write(f"Running inference... ")
# load image, segment, and save
image = FibsemImage.load(fname)
# resize to 1536 x 1024
image.data = np.asarray(Image.fromarray(image.data).resize((1536, 1024)))
mask = model.inference(image.data, rgb=True)
# plot
fig = plt.figure(figsize=(10, 10))
plt.imshow(image.data, cmap="gray")
plt.imshow(mask, alpha=0.5)
plt.axis("off")
cols[col_id].pyplot(fig, use_container_width=True)