Spaces:
Sleeping
Sleeping
import streamlit as st | |
from fibsem.segmentation.model import load_model | |
from fibsem.structures import FibsemImage | |
import matplotlib.pyplot as plt | |
import glob | |
from PIL import Image | |
import numpy as np | |
st.set_page_config(layout="wide") | |
st.title("Autolamella Demo") | |
st.write("This is a demo space of the Autolamella models. Upload FIBSEM (tiff) image, and select a model to segment them.") | |
# get filenames | |
filenames = st.sidebar.file_uploader("Upload an image", type=["tiff", "tif"], accept_multiple_files=True) | |
# get model | |
checkpoint = st.sidebar.selectbox( | |
"Select a model checkpoint", | |
[ | |
"autolamella-mega-20240107.pt", | |
"autolamella-waffle-20240107.pt", | |
"autolamella-serial-liftout-20240107.pt", | |
],) | |
st.sidebar.header("Available Models") | |
st.sidebar.write("""The following models are available for inference. They are trained on different datasets, and may perform differently on different samples.""") | |
# write a markdown list, listing each of the models | |
st.sidebar.write(""" | |
* autolamella-waffle* | |
* autolamella-serial-liftout* | |
* autolamella-mega*""") | |
st.sidebar.write("If you have a new sample, try all of them and see which one works best.""") | |
# get default data from path | |
if len(filenames) == 0: | |
st.write("No files uploaded, using default data") | |
# get all tiff files in current directory | |
filenames = sorted(glob.glob("example/*.tif")) | |
st.header(f"Segmentation Results") | |
# load model | |
model = load_model(checkpoint) | |
if filenames: | |
cols = st.columns(4) | |
for i, fname in enumerate(filenames): | |
col_id = i % 4 | |
cols[col_id].write(f"#### File: {fname} ({i+1}/{len(filenames)})") | |
cols[col_id].write(f"Running inference... ") | |
# load image, segment, and save | |
image = FibsemImage.load(fname) | |
# resize to 1536 x 1024 | |
image.data = np.asarray(Image.fromarray(image.data).resize((1536, 1024))) | |
mask = model.inference(image.data, rgb=True) | |
# plot | |
fig = plt.figure(figsize=(10, 10)) | |
plt.imshow(image.data, cmap="gray") | |
plt.imshow(mask, alpha=0.5) | |
plt.axis("off") | |
cols[col_id].pyplot(fig, use_container_width=True) |