Spaces:
Runtime error
Runtime error
JunchuanYu
commited on
Commit
·
8901ad6
1
Parent(s):
39d0cbe
Update app.py
Browse files
app.py
CHANGED
@@ -1,19 +1,17 @@
|
|
1 |
-
import streamlit as st
|
2 |
-
from PIL import Image
|
3 |
-
import numpy as np
|
4 |
-
import cv2
|
5 |
import os
|
6 |
-
|
7 |
import cv2
|
8 |
import matplotlib
|
9 |
import matplotlib.pyplot as plt
|
10 |
import numpy as np
|
11 |
import torch
|
|
|
12 |
|
13 |
from PIL import Image
|
14 |
|
15 |
from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
|
16 |
|
|
|
|
|
17 |
st.markdown(
|
18 |
"""
|
19 |
# Segment Anything Model (SAM)
|
@@ -26,17 +24,15 @@ st.markdown(
|
|
26 |
"""
|
27 |
)
|
28 |
|
29 |
-
#
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
sam = sam_model_registry["vit_b"](checkpoint="./meta-model.pth")
|
35 |
sam.to(device=device)
|
36 |
mask_generator = SamAutomaticMaskGenerator(sam)
|
37 |
predictor = SamPredictor(sam)
|
38 |
|
39 |
-
# copied from: https://github.com/facebookresearch/segment-anything
|
40 |
def show_anns(anns):
|
41 |
if len(anns) == 0:
|
42 |
return
|
@@ -53,44 +49,27 @@ def show_anns(anns):
|
|
53 |
img[:,:,i] = color_mask[i]
|
54 |
ax.imshow(np.dstack((img, m*0.35)))
|
55 |
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
# generate masks
|
60 |
-
masks = mask_generator.generate(input_image)
|
61 |
-
|
62 |
-
# add masks to image
|
63 |
plt.clf()
|
64 |
ppi = 100
|
65 |
-
height, width, _ =
|
66 |
-
plt.figure(figsize=(width / ppi, height / ppi))
|
67 |
-
plt.imshow(
|
68 |
show_anns(masks)
|
69 |
plt.axis('off')
|
|
|
|
|
|
|
70 |
|
71 |
-
|
72 |
-
|
73 |
-
output_image = cv2.imread('output_figure.png')
|
74 |
-
return Image.fromarray(output_image)
|
75 |
-
|
76 |
-
file = st.file_uploader("Upload File", type=["png", "jpg", "jpeg"])
|
77 |
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
st.image(input, caption='Test Image provided by SAM github')
|
82 |
-
# ex = Image.open(images[0])
|
83 |
-
# st.image(ex, width=200)
|
84 |
|
85 |
-
|
86 |
-
|
87 |
-
st.image(output)
|
88 |
|
89 |
-
|
90 |
-
input = file
|
91 |
-
if st.button("Generate Mask"):
|
92 |
-
output = segment_image(input)
|
93 |
-
if output is not None:
|
94 |
-
st.subheader("Segmentation Result")
|
95 |
-
st.write(output.shape)
|
96 |
-
st.image(output, width=850)
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
|
|
2 |
import cv2
|
3 |
import matplotlib
|
4 |
import matplotlib.pyplot as plt
|
5 |
import numpy as np
|
6 |
import torch
|
7 |
+
import gradio as gr
|
8 |
|
9 |
from PIL import Image
|
10 |
|
11 |
from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
|
12 |
|
13 |
+
matplotlib.pyplot.switch_backend('Agg') # for matplotlib to work in gradio
|
14 |
+
|
15 |
st.markdown(
|
16 |
"""
|
17 |
# Segment Anything Model (SAM)
|
|
|
24 |
"""
|
25 |
)
|
26 |
|
27 |
+
#setup model
|
28 |
+
sam_checkpoint = "meta-model.pth"
|
29 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # use GPU if available
|
30 |
+
model_type = "default"
|
31 |
+
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
|
|
32 |
sam.to(device=device)
|
33 |
mask_generator = SamAutomaticMaskGenerator(sam)
|
34 |
predictor = SamPredictor(sam)
|
35 |
|
|
|
36 |
def show_anns(anns):
|
37 |
if len(anns) == 0:
|
38 |
return
|
|
|
49 |
img[:,:,i] = color_mask[i]
|
50 |
ax.imshow(np.dstack((img, m*0.35)))
|
51 |
|
52 |
+
def segment_image(image):
|
53 |
+
masks = mask_generator.generate(image)
|
|
|
|
|
|
|
|
|
|
|
54 |
plt.clf()
|
55 |
ppi = 100
|
56 |
+
height, width, _ = image.shape
|
57 |
+
plt.figure(figsize=(width / ppi, height / ppi), dpi=ppi)
|
58 |
+
plt.imshow(image)
|
59 |
show_anns(masks)
|
60 |
plt.axis('off')
|
61 |
+
plt.savefig('output.png', bbox_inches='tight', pad_inches=0)
|
62 |
+
output = cv2.imread('output.png')
|
63 |
+
return Image.fromarray(output)
|
64 |
|
65 |
+
with gr.Blocks() as demo:
|
66 |
+
# gr.Markdown("## Segment-anything Demo")
|
|
|
|
|
|
|
|
|
67 |
|
68 |
+
with gr.Row():
|
69 |
+
image = gr.Image()
|
70 |
+
image_output = gr.Image()
|
|
|
|
|
|
|
71 |
|
72 |
+
segment_image_button = gr.Button("Segment")
|
73 |
+
segment_image_button.click(segment_image, inputs=[image], outputs=image_output)
|
|
|
74 |
|
75 |
+
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|