JunchuanYu commited on
Commit
8901ad6
·
1 Parent(s): 39d0cbe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -45
app.py CHANGED
@@ -1,19 +1,17 @@
1
- import streamlit as st
2
- from PIL import Image
3
- import numpy as np
4
- import cv2
5
  import os
6
-
7
  import cv2
8
  import matplotlib
9
  import matplotlib.pyplot as plt
10
  import numpy as np
11
  import torch
 
12
 
13
  from PIL import Image
14
 
15
  from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
16
 
 
 
17
  st.markdown(
18
  """
19
  # Segment Anything Model (SAM)
@@ -26,17 +24,15 @@ st.markdown(
26
  """
27
  )
28
 
29
- # suppress server-side GUI windows
30
- matplotlib.pyplot.switch_backend('Agg')
31
-
32
- # setup models
33
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
34
- sam = sam_model_registry["vit_b"](checkpoint="./meta-model.pth")
35
  sam.to(device=device)
36
  mask_generator = SamAutomaticMaskGenerator(sam)
37
  predictor = SamPredictor(sam)
38
 
39
- # copied from: https://github.com/facebookresearch/segment-anything
40
  def show_anns(anns):
41
  if len(anns) == 0:
42
  return
@@ -53,44 +49,27 @@ def show_anns(anns):
53
  img[:,:,i] = color_mask[i]
54
  ax.imshow(np.dstack((img, m*0.35)))
55
 
56
- # demo function
57
- def segment_image(input_image):
58
-
59
- # generate masks
60
- masks = mask_generator.generate(input_image)
61
-
62
- # add masks to image
63
  plt.clf()
64
  ppi = 100
65
- height, width, _ = input_image.shape
66
- plt.figure(figsize=(width / ppi, height / ppi)) # convert pixel to inches
67
- plt.imshow(input_image)
68
  show_anns(masks)
69
  plt.axis('off')
 
 
 
70
 
71
- # save and get figure
72
- plt.savefig('output_figure.png', bbox_inches='tight')
73
- output_image = cv2.imread('output_figure.png')
74
- return Image.fromarray(output_image)
75
-
76
- file = st.file_uploader("Upload File", type=["png", "jpg", "jpeg"])
77
 
78
- col1, col2 = st.columns(2)
79
- with col1:
80
- input = Image.open('sample.jpg')
81
- st.image(input, caption='Test Image provided by SAM github')
82
- # ex = Image.open(images[0])
83
- # st.image(ex, width=200)
84
 
85
- with col2:
86
- output = segment_image(input)
87
- st.image(output)
88
 
89
- if file is not None:
90
- input = file
91
- if st.button("Generate Mask"):
92
- output = segment_image(input)
93
- if output is not None:
94
- st.subheader("Segmentation Result")
95
- st.write(output.shape)
96
- st.image(output, width=850)
 
 
 
 
 
1
  import os
 
2
  import cv2
3
  import matplotlib
4
  import matplotlib.pyplot as plt
5
  import numpy as np
6
  import torch
7
+ import gradio as gr
8
 
9
  from PIL import Image
10
 
11
  from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
12
 
13
+ matplotlib.pyplot.switch_backend('Agg') # for matplotlib to work in gradio
14
+
15
  st.markdown(
16
  """
17
  # Segment Anything Model (SAM)
 
24
  """
25
  )
26
 
27
+ #setup model
28
+ sam_checkpoint = "meta-model.pth"
29
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # use GPU if available
30
+ model_type = "default"
31
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
 
32
  sam.to(device=device)
33
  mask_generator = SamAutomaticMaskGenerator(sam)
34
  predictor = SamPredictor(sam)
35
 
 
36
  def show_anns(anns):
37
  if len(anns) == 0:
38
  return
 
49
  img[:,:,i] = color_mask[i]
50
  ax.imshow(np.dstack((img, m*0.35)))
51
 
52
+ def segment_image(image):
53
+ masks = mask_generator.generate(image)
 
 
 
 
 
54
  plt.clf()
55
  ppi = 100
56
+ height, width, _ = image.shape
57
+ plt.figure(figsize=(width / ppi, height / ppi), dpi=ppi)
58
+ plt.imshow(image)
59
  show_anns(masks)
60
  plt.axis('off')
61
+ plt.savefig('output.png', bbox_inches='tight', pad_inches=0)
62
+ output = cv2.imread('output.png')
63
+ return Image.fromarray(output)
64
 
65
+ with gr.Blocks() as demo:
66
+ # gr.Markdown("## Segment-anything Demo")
 
 
 
 
67
 
68
+ with gr.Row():
69
+ image = gr.Image()
70
+ image_output = gr.Image()
 
 
 
71
 
72
+ segment_image_button = gr.Button("Segment")
73
+ segment_image_button.click(segment_image, inputs=[image], outputs=image_output)
 
74
 
75
+ demo.launch()