Spaces:
Runtime error
Runtime error
JunchuanYu
commited on
Commit
·
d7c2a2b
1
Parent(s):
91b8aa9
Upload 3 files
Browse files- app.py +91 -0
- meta-model.pth +3 -0
- requirements.txt +7 -0
app.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
import cv2
|
5 |
+
import os
|
6 |
+
|
7 |
+
import cv2
|
8 |
+
import matplotlib
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from PIL import Image
|
14 |
+
|
15 |
+
from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
|
16 |
+
|
17 |
+
# st.markdown(
|
18 |
+
# """
|
19 |
+
# The Segment Anything Model (SAM) produces high quality object masks from input prompts such as points or boxes, and it can be used to generate masks for all objects in an image.
|
20 |
+
# It has been trained on a dataset of 11 million images and 1.1 billion masks, and has strong zero-shot performance on a variety of segmentation tasks.
|
21 |
+
# """
|
22 |
+
# )
|
23 |
+
|
24 |
+
# suppress server-side GUI windows
|
25 |
+
matplotlib.pyplot.switch_backend('Agg')
|
26 |
+
|
27 |
+
# setup models
|
28 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
29 |
+
sam = sam_model_registry["vit_b"](checkpoint="./sam_vit_b_01ec64.pth")
|
30 |
+
sam.to(device=device)
|
31 |
+
mask_generator = SamAutomaticMaskGenerator(sam)
|
32 |
+
predictor = SamPredictor(sam)
|
33 |
+
|
34 |
+
# copied from: https://github.com/facebookresearch/segment-anything
|
35 |
+
def show_anns(anns):
|
36 |
+
if len(anns) == 0:
|
37 |
+
return
|
38 |
+
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
|
39 |
+
ax = plt.gca()
|
40 |
+
ax.set_autoscale_on(False)
|
41 |
+
polygons = []
|
42 |
+
color = []
|
43 |
+
for ann in sorted_anns:
|
44 |
+
m = ann['segmentation']
|
45 |
+
img = np.ones((m.shape[0], m.shape[1], 3))
|
46 |
+
color_mask = np.random.random((1, 3)).tolist()[0]
|
47 |
+
for i in range(3):
|
48 |
+
img[:,:,i] = color_mask[i]
|
49 |
+
ax.imshow(np.dstack((img, m*0.35)))
|
50 |
+
|
51 |
+
# demo function
|
52 |
+
def segment_image(input_image):
|
53 |
+
|
54 |
+
# generate masks
|
55 |
+
masks = mask_generator.generate(input_image)
|
56 |
+
|
57 |
+
# add masks to image
|
58 |
+
plt.clf()
|
59 |
+
ppi = 100
|
60 |
+
height, width, _ = input_image.shape
|
61 |
+
plt.figure(figsize=(width / ppi, height / ppi)) # convert pixel to inches
|
62 |
+
plt.imshow(input_image)
|
63 |
+
show_anns(masks)
|
64 |
+
plt.axis('off')
|
65 |
+
|
66 |
+
# save and get figure
|
67 |
+
plt.savefig('output_figure.png', bbox_inches='tight')
|
68 |
+
output_image = cv2.imread('output_figure.png')
|
69 |
+
return Image.fromarray(output_image)
|
70 |
+
|
71 |
+
file = st.file_uploader("Upload File", type=["png", "jpg", "jpeg"])
|
72 |
+
|
73 |
+
col1, col2 = st.columns(2)
|
74 |
+
with col1:
|
75 |
+
input = Image.open('sample.jpg')
|
76 |
+
st.image(input, caption='Test Image provided by SAM github')
|
77 |
+
# ex = Image.open(images[0])
|
78 |
+
# st.image(ex, width=200)
|
79 |
+
|
80 |
+
with col2:
|
81 |
+
output = segment_image(input)
|
82 |
+
st.image(output)
|
83 |
+
|
84 |
+
if file is not None:
|
85 |
+
input = file
|
86 |
+
if st.button("Generate Mask"):
|
87 |
+
output = segment_image(input)
|
88 |
+
if output is not None:
|
89 |
+
st.subheader("Segmentation Result")
|
90 |
+
st.write(output.shape)
|
91 |
+
st.image(output, width=850)
|
meta-model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912
|
3 |
+
size 375042383
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
opencv-python
|
3 |
+
matplotlib
|
4 |
+
numpy
|
5 |
+
torch
|
6 |
+
torchvision
|
7 |
+
git+https://github.com/facebookresearch/segment-anything.git
|