patrickcleeve commited on
Commit
dbb8c20
·
verified ·
1 Parent(s): 3350d1d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -42
app.py CHANGED
@@ -14,22 +14,9 @@ st.set_page_config(layout="wide")
14
  st.title("Autolamella Demo")
15
  st.write("This is a demo space of the Autolamella models. Upload FIBSEM (tiff) image, and select a model to segment them.")
16
 
17
- st.header(f"Models")
18
- st.write("""The following models are available for inference. They are trained on different datasets, and may perform differently on different samples.
19
- \nautolamella-waffle* is trained on waffle method data.
20
- \nautolamella-serial-liftout* is trained on serial liftout data.
21
- \nautolamella-mega* is trained on a combination of waffle, autoliftout and serial liftout data.
22
- If you have a new sample, try all of them and see which one works best.""")
23
-
24
  # get filenames
25
  filenames = st.sidebar.file_uploader("Upload an image", type=["tiff", "tif"], accept_multiple_files=True)
26
 
27
- # get default data from path
28
- if len(filenames) == 0:
29
- st.write("No files uploaded, using default data")
30
- # get all tiff files in current directory
31
- filenames = sorted(glob.glob("*.tif"))
32
-
33
  # get model
34
  checkpoint = st.sidebar.selectbox(
35
  "Select a model checkpoint",
@@ -37,8 +24,24 @@ checkpoint = st.sidebar.selectbox(
37
  "autolamella-mega-20240107.pt",
38
  "autolamella-waffle-20240107.pt",
39
  "autolamella-serial-liftout-20240107.pt",
40
- ],
41
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  st.header(f"Segmentation Results")
44
 
@@ -46,34 +49,27 @@ st.header(f"Segmentation Results")
46
  model = load_model(checkpoint)
47
 
48
  if filenames:
49
- # if there are multiple files, show them in a grid of 4 images
50
- if len(filenames) > 1:
51
- # create a grid of 4 images
52
- n_rows = len(filenames) // 4 + 1
53
- if len(filenames) % 4 == 0:
54
- n_rows -= 1
55
- fig, ax = plt.subplots(n_rows, 4, figsize=(20, 4 * n_rows))
56
- ax = ax.flatten()
57
- for i, fname in enumerate(filenames):
58
-
59
- # load image, segment, and save
60
- image = FibsemImage.load(fname)
61
-
62
- # resize to 1536 x 1024
63
- image.data = np.asarray(Image.fromarray(image.data).resize((1536, 1024)))
64
 
65
- mask = model.inference(image.data, rgb=True)
66
 
67
- # plot
68
- ax[i].imshow(image.data, cmap="gray")
69
- ax[i].imshow(mask, alpha=0.5)
70
- # ax[i].set_title(fname.name)
71
- ax[i].axis("off")
72
-
73
- # remove extra axes
74
- for i in range(len(filenames), len(ax)):
75
- ax[i].axis("off")
76
 
77
- plt.subplots_adjust(wspace=0, hspace=0)
78
- st.pyplot(fig, use_container_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
 
 
14
  st.title("Autolamella Demo")
15
  st.write("This is a demo space of the Autolamella models. Upload FIBSEM (tiff) image, and select a model to segment them.")
16
 
 
 
 
 
 
 
 
17
  # get filenames
18
  filenames = st.sidebar.file_uploader("Upload an image", type=["tiff", "tif"], accept_multiple_files=True)
19
 
 
 
 
 
 
 
20
  # get model
21
  checkpoint = st.sidebar.selectbox(
22
  "Select a model checkpoint",
 
24
  "autolamella-mega-20240107.pt",
25
  "autolamella-waffle-20240107.pt",
26
  "autolamella-serial-liftout-20240107.pt",
27
+ ],)
28
+
29
+ st.sidebar.header("Available Models")
30
+ st.sidebar.write("""The following models are available for inference. They are trained on different datasets, and may perform differently on different samples.""")
31
+
32
+ # write a markdown list, listing each of the models
33
+ st.sidebar.write("""
34
+ * autolamella-waffle*
35
+ * autolamella-serial-liftout*
36
+ * autolamella-mega*""")
37
+ st.sidebar.write("If you have a new sample, try all of them and see which one works best.""")
38
+
39
+ # get default data from path
40
+ if len(filenames) == 0:
41
+ st.write("No files uploaded, using default data")
42
+ # get all tiff files in current directory
43
+ filenames = sorted(glob.glob("example/*.tif"))
44
+
45
 
46
  st.header(f"Segmentation Results")
47
 
 
49
  model = load_model(checkpoint)
50
 
51
  if filenames:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ cols = st.columns(4)
54
 
55
+ for i, fname in enumerate(filenames):
 
 
 
 
 
 
 
 
56
 
57
+ col_id = i % 4
58
+ cols[col_id].write(f"#### File: {fname} ({i+1}/{len(filenames)})")
59
+ cols[col_id].write(f"Running inference... ")
60
+ # load image, segment, and save
61
+ image = FibsemImage.load(fname)
62
+
63
+ # resize to 1536 x 1024
64
+ image.data = np.asarray(Image.fromarray(image.data).resize((1536, 1024)))
65
+
66
+ mask = model.inference(image.data, rgb=True)
67
+
68
+ # plot
69
+ fig = plt.figure(figsize=(10, 10))
70
+ plt.imshow(image.data, cmap="gray")
71
+ plt.imshow(mask, alpha=0.5)
72
+ plt.axis("off")
73
+
74
 
75
+ cols[col_id].pyplot(fig, use_container_width=True)