JunchuanYu commited on
Commit
d7c2a2b
·
1 Parent(s): 91b8aa9

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +91 -0
  2. meta-model.pth +3 -0
  3. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import numpy as np
4
+ import cv2
5
+ import os
6
+
7
+ import cv2
8
+ import matplotlib
9
+ import matplotlib.pyplot as plt
10
+ import numpy as np
11
+ import torch
12
+
13
+ from PIL import Image
14
+
15
+ from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
16
+
17
+ # st.markdown(
18
+ # """
19
+ # The Segment Anything Model (SAM) produces high quality object masks from input prompts such as points or boxes, and it can be used to generate masks for all objects in an image.
20
+ # It has been trained on a dataset of 11 million images and 1.1 billion masks, and has strong zero-shot performance on a variety of segmentation tasks.
21
+ # """
22
+ # )
23
+
24
+ # suppress server-side GUI windows
25
+ matplotlib.pyplot.switch_backend('Agg')
26
+
27
+ # setup models
28
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
29
+ sam = sam_model_registry["vit_b"](checkpoint="./sam_vit_b_01ec64.pth")
30
+ sam.to(device=device)
31
+ mask_generator = SamAutomaticMaskGenerator(sam)
32
+ predictor = SamPredictor(sam)
33
+
34
+ # copied from: https://github.com/facebookresearch/segment-anything
35
+ def show_anns(anns):
36
+ if len(anns) == 0:
37
+ return
38
+ sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
39
+ ax = plt.gca()
40
+ ax.set_autoscale_on(False)
41
+ polygons = []
42
+ color = []
43
+ for ann in sorted_anns:
44
+ m = ann['segmentation']
45
+ img = np.ones((m.shape[0], m.shape[1], 3))
46
+ color_mask = np.random.random((1, 3)).tolist()[0]
47
+ for i in range(3):
48
+ img[:,:,i] = color_mask[i]
49
+ ax.imshow(np.dstack((img, m*0.35)))
50
+
51
+ # demo function
52
+ def segment_image(input_image):
53
+
54
+ # generate masks
55
+ masks = mask_generator.generate(input_image)
56
+
57
+ # add masks to image
58
+ plt.clf()
59
+ ppi = 100
60
+ height, width, _ = input_image.shape
61
+ plt.figure(figsize=(width / ppi, height / ppi)) # convert pixel to inches
62
+ plt.imshow(input_image)
63
+ show_anns(masks)
64
+ plt.axis('off')
65
+
66
+ # save and get figure
67
+ plt.savefig('output_figure.png', bbox_inches='tight')
68
+ output_image = cv2.imread('output_figure.png')
69
+ return Image.fromarray(output_image)
70
+
71
+ file = st.file_uploader("Upload File", type=["png", "jpg", "jpeg"])
72
+
73
+ col1, col2 = st.columns(2)
74
+ with col1:
75
+ input = Image.open('sample.jpg')
76
+ st.image(input, caption='Test Image provided by SAM github')
77
+ # ex = Image.open(images[0])
78
+ # st.image(ex, width=200)
79
+
80
+ with col2:
81
+ output = segment_image(input)
82
+ st.image(output)
83
+
84
+ if file is not None:
85
+ input = file
86
+ if st.button("Generate Mask"):
87
+ output = segment_image(input)
88
+ if output is not None:
89
+ st.subheader("Segmentation Result")
90
+ st.write(output.shape)
91
+ st.image(output, width=850)
meta-model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912
3
+ size 375042383
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio
2
+ opencv-python
3
+ matplotlib
4
+ numpy
5
+ torch
6
+ torchvision
7
+ git+https://github.com/facebookresearch/segment-anything.git