Spaces:
Sleeping
Sleeping
ravi.naik
commited on
Commit
•
10f4748
1
Parent(s):
3b41a3f
Added FastSAM code and supporting UI
Browse files- FastSAM-x.pt +3 -0
- app.py +168 -47
- experiments/clip.ipynb +0 -0
- experiments/sam.ipynb +121 -0
- fastsam/__init__.py +9 -0
- fastsam/decoder.py +131 -0
- fastsam/model.py +106 -0
- fastsam/predict.py +56 -0
- fastsam/prompt.py +455 -0
- fastsam/utils.py +86 -0
- requirements.txt +20 -3
FastSAM-x.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c0be4e7ddbe4c15333d15a859c676d053c486d0a746a3be6a7a9790d52a9b6d7
|
3 |
+
size 144943063
|
app.py
CHANGED
@@ -2,8 +2,15 @@ import gradio as gr
|
|
2 |
from PIL import Image
|
3 |
from transformers import CLIPProcessor, CLIPModel
|
4 |
|
5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
prediction_image = None
|
|
|
7 |
|
8 |
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
9 |
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
@@ -48,59 +55,173 @@ def predict(text):
|
|
48 |
return {output: gr.update(value=results)}
|
49 |
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
)
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
)
|
70 |
-
|
71 |
-
|
72 |
-
file_types=["image"],
|
73 |
-
file_count="multiple",
|
74 |
)
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
value
|
80 |
-
label="
|
81 |
-
|
82 |
-
elem_id="gallery_sample",
|
83 |
-
columns=3,
|
84 |
-
rows=2,
|
85 |
-
height="auto",
|
86 |
-
object_fit="contain",
|
87 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
|
103 |
-
|
104 |
|
105 |
|
106 |
-
app.launch()
|
|
|
2 |
from PIL import Image
|
3 |
from transformers import CLIPProcessor, CLIPModel
|
4 |
|
5 |
+
from fastsam import FastSAM, FastSAMPrompt
|
6 |
+
|
7 |
+
project_path = "/home/ravi.naik/learning/era/s19"
|
8 |
+
sam_model = FastSAM(f"{project_path}/FastSAM-x.pt")
|
9 |
+
|
10 |
+
DEVICE = "cpu"
|
11 |
+
sample_images = [f"{project_path}/sample_images/{i}.jpg" for i in range(5)]
|
12 |
prediction_image = None
|
13 |
+
sam_prediction_image = None
|
14 |
|
15 |
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
16 |
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
|
|
55 |
return {output: gr.update(value=results)}
|
56 |
|
57 |
|
58 |
+
def show_hide_sam_text(status):
|
59 |
+
if status == "Text Based":
|
60 |
+
return {sam_input_text: gr.update(visible=True)}
|
61 |
+
return {sam_input_text: gr.update(visible=False)}
|
62 |
+
|
63 |
+
|
64 |
+
def set_prediction_image_sam(evt: gr.SelectData, gallery):
|
65 |
+
global sam_prediction_image
|
66 |
+
if isinstance(gallery[evt.index], dict):
|
67 |
+
sam_prediction_image = gallery[evt.index]["name"]
|
68 |
+
else:
|
69 |
+
sam_prediction_image = gallery[evt.index][0]["name"]
|
70 |
+
|
71 |
+
|
72 |
+
def sam_predict(radio, text):
|
73 |
+
output_path = f"{project_path}/output/sam_results.jpg"
|
74 |
+
everything_results = sam_model(
|
75 |
+
sam_prediction_image,
|
76 |
+
device=DEVICE,
|
77 |
+
retina_masks=True,
|
78 |
+
imgsz=1024,
|
79 |
+
conf=0.4,
|
80 |
+
iou=0.9,
|
81 |
+
)
|
82 |
+
prompt_process = FastSAMPrompt(
|
83 |
+
sam_prediction_image, everything_results, device=DEVICE
|
84 |
+
)
|
85 |
+
ann = prompt_process.everything_prompt()
|
86 |
+
if radio == "Text Based":
|
87 |
+
ann = prompt_process.text_prompt(text=text)
|
88 |
+
|
89 |
+
prompt_process.plot(
|
90 |
+
annotations=ann,
|
91 |
+
output_path=output_path,
|
92 |
)
|
93 |
+
|
94 |
+
return {sam_output: gr.update(value=output_path)}
|
95 |
+
|
96 |
+
|
97 |
+
with gr.Blocks() as app:
|
98 |
+
gr.Markdown("## ERA Session19 - FastSAM & CLIP Inference with Gradio")
|
99 |
+
with gr.Tab("FastSAM"):
|
100 |
+
gr.Markdown("### ERA Session19 - Image Segmentation with FastSAM")
|
101 |
+
gr.Markdown(
|
102 |
+
"""Please an image or select one of the sample images.
|
103 |
+
Select either segment everything or text based segmentation.
|
104 |
+
Enter the text if you opt for segment based on text and hit Submit.
|
105 |
+
"""
|
106 |
+
)
|
107 |
+
with gr.Row():
|
108 |
+
with gr.Column():
|
109 |
+
with gr.Box():
|
110 |
+
with gr.Group():
|
111 |
+
upload_gallery = gr.Gallery(
|
112 |
+
value=None,
|
113 |
+
label="Uploaded images",
|
114 |
+
show_label=False,
|
115 |
+
elem_id="gallery_upload",
|
116 |
+
columns=5,
|
117 |
+
rows=2,
|
118 |
+
height="auto",
|
119 |
+
object_fit="contain",
|
120 |
+
)
|
121 |
+
upload_button = gr.UploadButton(
|
122 |
+
"Click to Upload images",
|
123 |
+
file_types=["image"],
|
124 |
+
file_count="multiple",
|
125 |
+
)
|
126 |
+
upload_button.upload(upload_file, upload_button, upload_gallery)
|
127 |
+
|
128 |
+
with gr.Group():
|
129 |
+
sample_gallery = gr.Gallery(
|
130 |
+
value=sample_images,
|
131 |
+
label="Sample images",
|
132 |
+
show_label=False,
|
133 |
+
elem_id="gallery_sample",
|
134 |
+
columns=3,
|
135 |
+
rows=2,
|
136 |
+
height="auto",
|
137 |
+
object_fit="contain",
|
138 |
+
)
|
139 |
+
|
140 |
+
upload_gallery.select(
|
141 |
+
set_prediction_image_sam, inputs=[upload_gallery]
|
142 |
)
|
143 |
+
sample_gallery.select(
|
144 |
+
set_prediction_image_sam, inputs=[sample_gallery]
|
|
|
|
|
145 |
)
|
146 |
+
with gr.Box():
|
147 |
+
radio = gr.Radio(
|
148 |
+
choices=["Segment Everything", "Text Based"],
|
149 |
+
value="Segment Everything",
|
150 |
+
type="value",
|
151 |
+
label="Select a Segmentation approach",
|
152 |
+
interactive=True,
|
|
|
|
|
|
|
|
|
|
|
153 |
)
|
154 |
+
sam_input_text = gr.TextArea(
|
155 |
+
label="Segementation Input",
|
156 |
+
placeholder="Please enter some text",
|
157 |
+
interactive=True,
|
158 |
+
visible=False,
|
159 |
+
)
|
160 |
+
radio.change(
|
161 |
+
show_hide_sam_text, inputs=[radio], outputs=[sam_input_text]
|
162 |
+
)
|
163 |
|
164 |
+
sam_submit_btn = gr.Button(value="Submit")
|
165 |
+
with gr.Column():
|
166 |
+
with gr.Box():
|
167 |
+
sam_output = gr.Image(value=None, label="Segmentation Results")
|
168 |
+
|
169 |
+
sam_submit_btn.click(
|
170 |
+
sam_predict, inputs=[radio, sam_input_text], outputs=[sam_output]
|
171 |
)
|
172 |
+
with gr.Tab("CLIP"):
|
173 |
+
gr.Markdown("### ERA Session19 - Zero Shot Classification with CLIP")
|
174 |
+
gr.Markdown(
|
175 |
+
"Please an image or select one of the sample images. Type some classification labels separated by comma. For ex: dog, cat"
|
176 |
+
)
|
177 |
+
with gr.Row():
|
178 |
+
with gr.Column():
|
179 |
+
with gr.Box():
|
180 |
+
with gr.Group():
|
181 |
+
upload_gallery = gr.Gallery(
|
182 |
+
value=None,
|
183 |
+
label="Uploaded images",
|
184 |
+
show_label=False,
|
185 |
+
elem_id="gallery_upload",
|
186 |
+
columns=5,
|
187 |
+
rows=2,
|
188 |
+
height="auto",
|
189 |
+
object_fit="contain",
|
190 |
+
)
|
191 |
+
upload_button = gr.UploadButton(
|
192 |
+
"Click to Upload images",
|
193 |
+
file_types=["image"],
|
194 |
+
file_count="multiple",
|
195 |
+
)
|
196 |
+
upload_button.upload(upload_file, upload_button, upload_gallery)
|
197 |
+
|
198 |
+
with gr.Group():
|
199 |
+
sample_gallery = gr.Gallery(
|
200 |
+
value=sample_images,
|
201 |
+
label="Sample images",
|
202 |
+
show_label=False,
|
203 |
+
elem_id="gallery_sample",
|
204 |
+
columns=3,
|
205 |
+
rows=2,
|
206 |
+
height="auto",
|
207 |
+
object_fit="contain",
|
208 |
+
)
|
209 |
+
|
210 |
+
upload_gallery.select(set_prediction_image, inputs=[upload_gallery])
|
211 |
+
sample_gallery.select(set_prediction_image, inputs=[sample_gallery])
|
212 |
+
with gr.Box():
|
213 |
+
input_text = gr.TextArea(
|
214 |
+
label="Classification Text",
|
215 |
+
placeholder="Please enter comma separated text",
|
216 |
+
interactive=True,
|
217 |
+
)
|
218 |
|
219 |
+
submit_btn = gr.Button(value="Submit")
|
220 |
+
with gr.Column():
|
221 |
+
with gr.Box():
|
222 |
+
output = gr.Label(value=None, label="Classification Results")
|
223 |
|
224 |
+
submit_btn.click(predict, inputs=[input_text], outputs=[output])
|
225 |
|
226 |
|
227 |
+
app.launch(debug=True, show_error=True)
|
experiments/clip.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
experiments/sam.ipynb
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"from fastsam import FastSAM, FastSAMPrompt\n"
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": 3,
|
15 |
+
"metadata": {},
|
16 |
+
"outputs": [],
|
17 |
+
"source": [
|
18 |
+
"model = FastSAM(\"FastSAM-x.pt\")\n",
|
19 |
+
"IMAGE_PATH = \"./sample_images/3.jpg\"\n"
|
20 |
+
]
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"cell_type": "code",
|
24 |
+
"execution_count": 5,
|
25 |
+
"metadata": {},
|
26 |
+
"outputs": [],
|
27 |
+
"source": [
|
28 |
+
"DEVICE = \"cpu\""
|
29 |
+
]
|
30 |
+
},
|
31 |
+
{
|
32 |
+
"cell_type": "code",
|
33 |
+
"execution_count": 6,
|
34 |
+
"metadata": {},
|
35 |
+
"outputs": [
|
36 |
+
{
|
37 |
+
"name": "stderr",
|
38 |
+
"output_type": "stream",
|
39 |
+
"text": [
|
40 |
+
"\n",
|
41 |
+
"image 1/1 /home/ravi.naik/learning/era/s19/sample_images/3.jpg: 704x1024 5 objects, 5524.6ms\n",
|
42 |
+
"Speed: 77.9ms preprocess, 5524.6ms inference, 75.1ms postprocess per image at shape (1, 3, 1024, 1024)\n"
|
43 |
+
]
|
44 |
+
}
|
45 |
+
],
|
46 |
+
"source": [
|
47 |
+
"everything_results = model(\n",
|
48 |
+
" IMAGE_PATH,\n",
|
49 |
+
" device=DEVICE,\n",
|
50 |
+
" retina_masks=True,\n",
|
51 |
+
" imgsz=1024,\n",
|
52 |
+
" conf=0.4,\n",
|
53 |
+
" iou=0.9,\n",
|
54 |
+
")\n",
|
55 |
+
"prompt_process = FastSAMPrompt(IMAGE_PATH, everything_results, device=DEVICE)"
|
56 |
+
]
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"cell_type": "code",
|
60 |
+
"execution_count": 8,
|
61 |
+
"metadata": {},
|
62 |
+
"outputs": [
|
63 |
+
{
|
64 |
+
"name": "stderr",
|
65 |
+
"output_type": "stream",
|
66 |
+
"text": [
|
67 |
+
"100%|███████████████████████████████████████| 338M/338M [00:32<00:00, 10.9MiB/s]\n"
|
68 |
+
]
|
69 |
+
}
|
70 |
+
],
|
71 |
+
"source": [
|
72 |
+
"# everything prompt\n",
|
73 |
+
"ann = prompt_process.everything_prompt()\n",
|
74 |
+
"\n",
|
75 |
+
"# bbox default shape [0,0,0,0] -> [x1,y1,x2,y2]\n",
|
76 |
+
"# ann = prompt_process.box_prompt(bbox=[[200, 200, 300, 300]])\n",
|
77 |
+
"\n",
|
78 |
+
"# text prompt\n",
|
79 |
+
"ann = prompt_process.text_prompt(text=\"a photo of a dog\")\n",
|
80 |
+
"\n",
|
81 |
+
"# point prompt\n",
|
82 |
+
"# points default [[0,0]] [[x1,y1],[x2,y2]]\n",
|
83 |
+
"# point_label default [0] [1,0] 0:background, 1:foreground\n",
|
84 |
+
"# ann = prompt_process.point_prompt(points=[[620, 360]], pointlabel=[1])\n",
|
85 |
+
"\n",
|
86 |
+
"prompt_process.plot(\n",
|
87 |
+
" annotations=ann,\n",
|
88 |
+
" output_path=\"./output/dog.jpg\",\n",
|
89 |
+
")"
|
90 |
+
]
|
91 |
+
},
|
92 |
+
{
|
93 |
+
"cell_type": "code",
|
94 |
+
"execution_count": null,
|
95 |
+
"metadata": {},
|
96 |
+
"outputs": [],
|
97 |
+
"source": []
|
98 |
+
}
|
99 |
+
],
|
100 |
+
"metadata": {
|
101 |
+
"kernelspec": {
|
102 |
+
"display_name": "Python 3",
|
103 |
+
"language": "python",
|
104 |
+
"name": "python3"
|
105 |
+
},
|
106 |
+
"language_info": {
|
107 |
+
"codemirror_mode": {
|
108 |
+
"name": "ipython",
|
109 |
+
"version": 3
|
110 |
+
},
|
111 |
+
"file_extension": ".py",
|
112 |
+
"mimetype": "text/x-python",
|
113 |
+
"name": "python",
|
114 |
+
"nbconvert_exporter": "python",
|
115 |
+
"pygments_lexer": "ipython3",
|
116 |
+
"version": "3.10.12"
|
117 |
+
}
|
118 |
+
},
|
119 |
+
"nbformat": 4,
|
120 |
+
"nbformat_minor": 2
|
121 |
+
}
|
fastsam/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
|
3 |
+
from .model import FastSAM
|
4 |
+
from .predict import FastSAMPredictor
|
5 |
+
from .prompt import FastSAMPrompt
|
6 |
+
# from .val import FastSAMValidator
|
7 |
+
from .decoder import FastSAMDecoder
|
8 |
+
|
9 |
+
__all__ = 'FastSAMPredictor', 'FastSAM', 'FastSAMPrompt', 'FastSAMDecoder'
|
fastsam/decoder.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .model import FastSAM
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
import clip
|
5 |
+
from typing import Optional, List, Tuple, Union
|
6 |
+
|
7 |
+
|
8 |
+
class FastSAMDecoder:
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
model: FastSAM,
|
12 |
+
device: str='cpu',
|
13 |
+
conf: float=0.4,
|
14 |
+
iou: float=0.9,
|
15 |
+
imgsz: int=1024,
|
16 |
+
retina_masks: bool=True,
|
17 |
+
):
|
18 |
+
self.model = model
|
19 |
+
self.device = device
|
20 |
+
self.retina_masks = retina_masks
|
21 |
+
self.imgsz = imgsz
|
22 |
+
self.conf = conf
|
23 |
+
self.iou = iou
|
24 |
+
self.image = None
|
25 |
+
self.image_embedding = None
|
26 |
+
|
27 |
+
def run_encoder(self, image):
|
28 |
+
if isinstance(image,str):
|
29 |
+
image = np.array(Image.open(image))
|
30 |
+
self.image = image
|
31 |
+
image_embedding = self.model(
|
32 |
+
self.image,
|
33 |
+
device=self.device,
|
34 |
+
retina_masks=self.retina_masks,
|
35 |
+
imgsz=self.imgsz,
|
36 |
+
conf=self.conf,
|
37 |
+
iou=self.iou
|
38 |
+
)
|
39 |
+
return image_embedding[0].numpy()
|
40 |
+
|
41 |
+
def run_decoder(
|
42 |
+
self,
|
43 |
+
image_embedding,
|
44 |
+
point_prompt: Optional[np.ndarray]=None,
|
45 |
+
point_label: Optional[np.ndarray]=None,
|
46 |
+
box_prompt: Optional[np.ndarray]=None,
|
47 |
+
text_prompt: Optional[str]=None,
|
48 |
+
)->np.ndarray:
|
49 |
+
self.image_embedding = image_embedding
|
50 |
+
if point_prompt is not None:
|
51 |
+
ann = self.point_prompt(points=point_prompt, pointlabel=point_label)
|
52 |
+
return ann
|
53 |
+
elif box_prompt is not None:
|
54 |
+
ann = self.box_prompt(bbox=box_prompt)
|
55 |
+
return ann
|
56 |
+
elif text_prompt is not None:
|
57 |
+
ann = self.text_prompt(text=text_prompt)
|
58 |
+
return ann
|
59 |
+
else:
|
60 |
+
return None
|
61 |
+
|
62 |
+
def box_prompt(self, bbox):
|
63 |
+
assert (bbox[2] != 0 and bbox[3] != 0)
|
64 |
+
masks = self.image_embedding.masks.data
|
65 |
+
target_height = self.image.shape[0]
|
66 |
+
target_width = self.image.shape[1]
|
67 |
+
h = masks.shape[1]
|
68 |
+
w = masks.shape[2]
|
69 |
+
if h != target_height or w != target_width:
|
70 |
+
bbox = [
|
71 |
+
int(bbox[0] * w / target_width),
|
72 |
+
int(bbox[1] * h / target_height),
|
73 |
+
int(bbox[2] * w / target_width),
|
74 |
+
int(bbox[3] * h / target_height), ]
|
75 |
+
bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0
|
76 |
+
bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0
|
77 |
+
bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w
|
78 |
+
bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h
|
79 |
+
|
80 |
+
# IoUs = torch.zeros(len(masks), dtype=torch.float32)
|
81 |
+
bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
|
82 |
+
|
83 |
+
masks_area = np.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], axis=(1, 2))
|
84 |
+
orig_masks_area = np.sum(masks, axis=(1, 2))
|
85 |
+
|
86 |
+
union = bbox_area + orig_masks_area - masks_area
|
87 |
+
IoUs = masks_area / union
|
88 |
+
max_iou_index = np.argmax(IoUs)
|
89 |
+
|
90 |
+
return np.array([masks[max_iou_index].cpu().numpy()])
|
91 |
+
|
92 |
+
def point_prompt(self, points, pointlabel): # numpy
|
93 |
+
|
94 |
+
masks = self._format_results(self.image_embedding[0], 0)
|
95 |
+
target_height = self.image.shape[0]
|
96 |
+
target_width = self.image.shape[1]
|
97 |
+
h = masks[0]['segmentation'].shape[0]
|
98 |
+
w = masks[0]['segmentation'].shape[1]
|
99 |
+
if h != target_height or w != target_width:
|
100 |
+
points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
|
101 |
+
onemask = np.zeros((h, w))
|
102 |
+
masks = sorted(masks, key=lambda x: x['area'], reverse=True)
|
103 |
+
for i, annotation in enumerate(masks):
|
104 |
+
if type(annotation) == dict:
|
105 |
+
mask = annotation['segmentation']
|
106 |
+
else:
|
107 |
+
mask = annotation
|
108 |
+
for i, point in enumerate(points):
|
109 |
+
if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
|
110 |
+
onemask[mask] = 1
|
111 |
+
if mask[point[1], point[0]] == 1 and pointlabel[i] == 0:
|
112 |
+
onemask[mask] = 0
|
113 |
+
onemask = onemask >= 1
|
114 |
+
return np.array([onemask])
|
115 |
+
|
116 |
+
def _format_results(self, result, filter=0):
|
117 |
+
annotations = []
|
118 |
+
n = len(result.masks.data)
|
119 |
+
for i in range(n):
|
120 |
+
annotation = {}
|
121 |
+
mask = result.masks.data[i] == 1.0
|
122 |
+
|
123 |
+
if np.sum(mask) < filter:
|
124 |
+
continue
|
125 |
+
annotation['id'] = i
|
126 |
+
annotation['segmentation'] = mask
|
127 |
+
annotation['bbox'] = result.boxes.data[i]
|
128 |
+
annotation['score'] = result.boxes.conf[i]
|
129 |
+
annotation['area'] = annotation['segmentation'].sum()
|
130 |
+
annotations.append(annotation)
|
131 |
+
return annotations
|
fastsam/model.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
"""
|
3 |
+
FastSAM model interface.
|
4 |
+
|
5 |
+
Usage - Predict:
|
6 |
+
from ultralytics import FastSAM
|
7 |
+
|
8 |
+
model = FastSAM('last.pt')
|
9 |
+
results = model.predict('ultralytics/assets/bus.jpg')
|
10 |
+
"""
|
11 |
+
|
12 |
+
from ultralytics.yolo.cfg import get_cfg
|
13 |
+
from ultralytics.yolo.engine.exporter import Exporter
|
14 |
+
from ultralytics.yolo.engine.model import YOLO
|
15 |
+
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, ROOT, is_git_dir
|
16 |
+
from ultralytics.yolo.utils.checks import check_imgsz
|
17 |
+
|
18 |
+
from ultralytics.yolo.utils.torch_utils import model_info, smart_inference_mode
|
19 |
+
from .predict import FastSAMPredictor
|
20 |
+
|
21 |
+
|
22 |
+
class FastSAM(YOLO):
|
23 |
+
|
24 |
+
@smart_inference_mode()
|
25 |
+
def predict(self, source=None, stream=False, **kwargs):
|
26 |
+
"""
|
27 |
+
Perform prediction using the YOLO model.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
|
31 |
+
Accepts all source types accepted by the YOLO model.
|
32 |
+
stream (bool): Whether to stream the predictions or not. Defaults to False.
|
33 |
+
**kwargs : Additional keyword arguments passed to the predictor.
|
34 |
+
Check the 'configuration' section in the documentation for all available options.
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
(List[ultralytics.yolo.engine.results.Results]): The prediction results.
|
38 |
+
"""
|
39 |
+
if source is None:
|
40 |
+
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
|
41 |
+
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
|
42 |
+
overrides = self.overrides.copy()
|
43 |
+
overrides['conf'] = 0.25
|
44 |
+
overrides.update(kwargs) # prefer kwargs
|
45 |
+
overrides['mode'] = kwargs.get('mode', 'predict')
|
46 |
+
assert overrides['mode'] in ['track', 'predict']
|
47 |
+
overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python
|
48 |
+
self.predictor = FastSAMPredictor(overrides=overrides)
|
49 |
+
self.predictor.setup_model(model=self.model, verbose=False)
|
50 |
+
try:
|
51 |
+
return self.predictor(source, stream=stream)
|
52 |
+
except Exception as e:
|
53 |
+
return None
|
54 |
+
|
55 |
+
def train(self, **kwargs):
|
56 |
+
"""Function trains models but raises an error as FastSAM models do not support training."""
|
57 |
+
raise NotImplementedError("Currently, the training codes are on the way.")
|
58 |
+
|
59 |
+
def val(self, **kwargs):
|
60 |
+
"""Run validation given dataset."""
|
61 |
+
overrides = dict(task='segment', mode='val')
|
62 |
+
overrides.update(kwargs) # prefer kwargs
|
63 |
+
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
64 |
+
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
|
65 |
+
validator = FastSAM(args=args)
|
66 |
+
validator(model=self.model)
|
67 |
+
self.metrics = validator.metrics
|
68 |
+
return validator.metrics
|
69 |
+
|
70 |
+
@smart_inference_mode()
|
71 |
+
def export(self, **kwargs):
|
72 |
+
"""
|
73 |
+
Export model.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
|
77 |
+
"""
|
78 |
+
overrides = dict(task='detect')
|
79 |
+
overrides.update(kwargs)
|
80 |
+
overrides['mode'] = 'export'
|
81 |
+
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
82 |
+
args.task = self.task
|
83 |
+
if args.imgsz == DEFAULT_CFG.imgsz:
|
84 |
+
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
|
85 |
+
if args.batch == DEFAULT_CFG.batch:
|
86 |
+
args.batch = 1 # default to 1 if not modified
|
87 |
+
return Exporter(overrides=args)(model=self.model)
|
88 |
+
|
89 |
+
def info(self, detailed=False, verbose=True):
|
90 |
+
"""
|
91 |
+
Logs model info.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
detailed (bool): Show detailed information about model.
|
95 |
+
verbose (bool): Controls verbosity.
|
96 |
+
"""
|
97 |
+
return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
|
98 |
+
|
99 |
+
def __call__(self, source=None, stream=False, **kwargs):
|
100 |
+
"""Calls the 'predict' function with given arguments to perform object detection."""
|
101 |
+
return self.predict(source, stream, **kwargs)
|
102 |
+
|
103 |
+
def __getattr__(self, attr):
|
104 |
+
"""Raises error if object has no requested attribute."""
|
105 |
+
name = self.__class__.__name__
|
106 |
+
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
fastsam/predict.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from ultralytics.yolo.engine.results import Results
|
4 |
+
from ultralytics.yolo.utils import DEFAULT_CFG, ops
|
5 |
+
from ultralytics.yolo.v8.detect.predict import DetectionPredictor
|
6 |
+
from .utils import bbox_iou
|
7 |
+
|
8 |
+
class FastSAMPredictor(DetectionPredictor):
|
9 |
+
|
10 |
+
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
11 |
+
super().__init__(cfg, overrides, _callbacks)
|
12 |
+
self.args.task = 'segment'
|
13 |
+
|
14 |
+
def postprocess(self, preds, img, orig_imgs):
|
15 |
+
"""TODO: filter by classes."""
|
16 |
+
p = ops.non_max_suppression(preds[0],
|
17 |
+
self.args.conf,
|
18 |
+
self.args.iou,
|
19 |
+
agnostic=self.args.agnostic_nms,
|
20 |
+
max_det=self.args.max_det,
|
21 |
+
nc=len(self.model.names),
|
22 |
+
classes=self.args.classes)
|
23 |
+
|
24 |
+
results = []
|
25 |
+
if len(p) == 0 or len(p[0]) == 0:
|
26 |
+
print("No object detected.")
|
27 |
+
return results
|
28 |
+
|
29 |
+
full_box = torch.zeros_like(p[0][0])
|
30 |
+
full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0
|
31 |
+
full_box = full_box.view(1, -1)
|
32 |
+
critical_iou_index = bbox_iou(full_box[0][:4], p[0][:, :4], iou_thres=0.9, image_shape=img.shape[2:])
|
33 |
+
if critical_iou_index.numel() != 0:
|
34 |
+
full_box[0][4] = p[0][critical_iou_index][:,4]
|
35 |
+
full_box[0][6:] = p[0][critical_iou_index][:,6:]
|
36 |
+
p[0][critical_iou_index] = full_box
|
37 |
+
|
38 |
+
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
|
39 |
+
for i, pred in enumerate(p):
|
40 |
+
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
|
41 |
+
path = self.batch[0]
|
42 |
+
img_path = path[i] if isinstance(path, list) else path
|
43 |
+
if not len(pred): # save empty boxes
|
44 |
+
results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6]))
|
45 |
+
continue
|
46 |
+
if self.args.retina_masks:
|
47 |
+
if not isinstance(orig_imgs, torch.Tensor):
|
48 |
+
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
49 |
+
masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
|
50 |
+
else:
|
51 |
+
masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
|
52 |
+
if not isinstance(orig_imgs, torch.Tensor):
|
53 |
+
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
54 |
+
results.append(
|
55 |
+
Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
|
56 |
+
return results
|
fastsam/prompt.py
ADDED
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import cv2
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from .utils import image_to_np_ndarray
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
try:
|
11 |
+
import clip # for linear_assignment
|
12 |
+
|
13 |
+
except (ImportError, AssertionError, AttributeError):
|
14 |
+
from ultralytics.yolo.utils.checks import check_requirements
|
15 |
+
|
16 |
+
check_requirements('git+https://github.com/openai/CLIP.git') # required before installing lap from source
|
17 |
+
import clip
|
18 |
+
|
19 |
+
|
20 |
+
class FastSAMPrompt:
|
21 |
+
|
22 |
+
def __init__(self, image, results, device='cuda'):
|
23 |
+
if isinstance(image, str) or isinstance(image, Image.Image):
|
24 |
+
image = image_to_np_ndarray(image)
|
25 |
+
self.device = device
|
26 |
+
self.results = results
|
27 |
+
self.img = image
|
28 |
+
|
29 |
+
def _segment_image(self, image, bbox):
|
30 |
+
if isinstance(image, Image.Image):
|
31 |
+
image_array = np.array(image)
|
32 |
+
else:
|
33 |
+
image_array = image
|
34 |
+
segmented_image_array = np.zeros_like(image_array)
|
35 |
+
x1, y1, x2, y2 = bbox
|
36 |
+
segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
|
37 |
+
segmented_image = Image.fromarray(segmented_image_array)
|
38 |
+
black_image = Image.new('RGB', image.size, (255, 255, 255))
|
39 |
+
# transparency_mask = np.zeros_like((), dtype=np.uint8)
|
40 |
+
transparency_mask = np.zeros((image_array.shape[0], image_array.shape[1]), dtype=np.uint8)
|
41 |
+
transparency_mask[y1:y2, x1:x2] = 255
|
42 |
+
transparency_mask_image = Image.fromarray(transparency_mask, mode='L')
|
43 |
+
black_image.paste(segmented_image, mask=transparency_mask_image)
|
44 |
+
return black_image
|
45 |
+
|
46 |
+
def _format_results(self, result, filter=0):
|
47 |
+
annotations = []
|
48 |
+
n = len(result.masks.data)
|
49 |
+
for i in range(n):
|
50 |
+
annotation = {}
|
51 |
+
mask = result.masks.data[i] == 1.0
|
52 |
+
|
53 |
+
if torch.sum(mask) < filter:
|
54 |
+
continue
|
55 |
+
annotation['id'] = i
|
56 |
+
annotation['segmentation'] = mask.cpu().numpy()
|
57 |
+
annotation['bbox'] = result.boxes.data[i]
|
58 |
+
annotation['score'] = result.boxes.conf[i]
|
59 |
+
annotation['area'] = annotation['segmentation'].sum()
|
60 |
+
annotations.append(annotation)
|
61 |
+
return annotations
|
62 |
+
|
63 |
+
def filter_masks(annotations): # filte the overlap mask
|
64 |
+
annotations.sort(key=lambda x: x['area'], reverse=True)
|
65 |
+
to_remove = set()
|
66 |
+
for i in range(0, len(annotations)):
|
67 |
+
a = annotations[i]
|
68 |
+
for j in range(i + 1, len(annotations)):
|
69 |
+
b = annotations[j]
|
70 |
+
if i != j and j not in to_remove:
|
71 |
+
# check if
|
72 |
+
if b['area'] < a['area']:
|
73 |
+
if (a['segmentation'] & b['segmentation']).sum() / b['segmentation'].sum() > 0.8:
|
74 |
+
to_remove.add(j)
|
75 |
+
|
76 |
+
return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove
|
77 |
+
|
78 |
+
def _get_bbox_from_mask(self, mask):
|
79 |
+
mask = mask.astype(np.uint8)
|
80 |
+
contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
81 |
+
x1, y1, w, h = cv2.boundingRect(contours[0])
|
82 |
+
x2, y2 = x1 + w, y1 + h
|
83 |
+
if len(contours) > 1:
|
84 |
+
for b in contours:
|
85 |
+
x_t, y_t, w_t, h_t = cv2.boundingRect(b)
|
86 |
+
# Merge multiple bounding boxes into one.
|
87 |
+
x1 = min(x1, x_t)
|
88 |
+
y1 = min(y1, y_t)
|
89 |
+
x2 = max(x2, x_t + w_t)
|
90 |
+
y2 = max(y2, y_t + h_t)
|
91 |
+
h = y2 - y1
|
92 |
+
w = x2 - x1
|
93 |
+
return [x1, y1, x2, y2]
|
94 |
+
|
95 |
+
def plot_to_result(self,
|
96 |
+
annotations,
|
97 |
+
bboxes=None,
|
98 |
+
points=None,
|
99 |
+
point_label=None,
|
100 |
+
mask_random_color=True,
|
101 |
+
better_quality=True,
|
102 |
+
retina=False,
|
103 |
+
withContours=True) -> np.ndarray:
|
104 |
+
if isinstance(annotations[0], dict):
|
105 |
+
annotations = [annotation['segmentation'] for annotation in annotations]
|
106 |
+
image = self.img
|
107 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
108 |
+
original_h = image.shape[0]
|
109 |
+
original_w = image.shape[1]
|
110 |
+
if sys.platform == "darwin":
|
111 |
+
plt.switch_backend("TkAgg")
|
112 |
+
plt.figure(figsize=(original_w / 100, original_h / 100))
|
113 |
+
# Add subplot with no margin.
|
114 |
+
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
|
115 |
+
plt.margins(0, 0)
|
116 |
+
plt.gca().xaxis.set_major_locator(plt.NullLocator())
|
117 |
+
plt.gca().yaxis.set_major_locator(plt.NullLocator())
|
118 |
+
|
119 |
+
plt.imshow(image)
|
120 |
+
if better_quality:
|
121 |
+
if isinstance(annotations[0], torch.Tensor):
|
122 |
+
annotations = np.array(annotations.cpu())
|
123 |
+
for i, mask in enumerate(annotations):
|
124 |
+
mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
|
125 |
+
annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
|
126 |
+
if self.device == 'cpu':
|
127 |
+
annotations = np.array(annotations)
|
128 |
+
self.fast_show_mask(
|
129 |
+
annotations,
|
130 |
+
plt.gca(),
|
131 |
+
random_color=mask_random_color,
|
132 |
+
bboxes=bboxes,
|
133 |
+
points=points,
|
134 |
+
pointlabel=point_label,
|
135 |
+
retinamask=retina,
|
136 |
+
target_height=original_h,
|
137 |
+
target_width=original_w,
|
138 |
+
)
|
139 |
+
else:
|
140 |
+
if isinstance(annotations[0], np.ndarray):
|
141 |
+
annotations = torch.from_numpy(annotations)
|
142 |
+
self.fast_show_mask_gpu(
|
143 |
+
annotations,
|
144 |
+
plt.gca(),
|
145 |
+
random_color=mask_random_color,
|
146 |
+
bboxes=bboxes,
|
147 |
+
points=points,
|
148 |
+
pointlabel=point_label,
|
149 |
+
retinamask=retina,
|
150 |
+
target_height=original_h,
|
151 |
+
target_width=original_w,
|
152 |
+
)
|
153 |
+
if isinstance(annotations, torch.Tensor):
|
154 |
+
annotations = annotations.cpu().numpy()
|
155 |
+
if withContours:
|
156 |
+
contour_all = []
|
157 |
+
temp = np.zeros((original_h, original_w, 1))
|
158 |
+
for i, mask in enumerate(annotations):
|
159 |
+
if type(mask) == dict:
|
160 |
+
mask = mask['segmentation']
|
161 |
+
annotation = mask.astype(np.uint8)
|
162 |
+
if not retina:
|
163 |
+
annotation = cv2.resize(
|
164 |
+
annotation,
|
165 |
+
(original_w, original_h),
|
166 |
+
interpolation=cv2.INTER_NEAREST,
|
167 |
+
)
|
168 |
+
contours, hierarchy = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
169 |
+
for contour in contours:
|
170 |
+
contour_all.append(contour)
|
171 |
+
cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
|
172 |
+
color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8])
|
173 |
+
contour_mask = temp / 255 * color.reshape(1, 1, -1)
|
174 |
+
plt.imshow(contour_mask)
|
175 |
+
|
176 |
+
plt.axis('off')
|
177 |
+
fig = plt.gcf()
|
178 |
+
plt.draw()
|
179 |
+
|
180 |
+
try:
|
181 |
+
buf = fig.canvas.tostring_rgb()
|
182 |
+
except AttributeError:
|
183 |
+
fig.canvas.draw()
|
184 |
+
buf = fig.canvas.tostring_rgb()
|
185 |
+
cols, rows = fig.canvas.get_width_height()
|
186 |
+
img_array = np.frombuffer(buf, dtype=np.uint8).reshape(rows, cols, 3)
|
187 |
+
result = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
|
188 |
+
plt.close()
|
189 |
+
return result
|
190 |
+
|
191 |
+
# Remark for refactoring: IMO a function should do one thing only, storing the image and plotting should be seperated and do not necessarily need to be class functions but standalone utility functions that the user can chain in his scripts to have more fine-grained control.
|
192 |
+
def plot(self,
|
193 |
+
annotations,
|
194 |
+
output_path,
|
195 |
+
bboxes=None,
|
196 |
+
points=None,
|
197 |
+
point_label=None,
|
198 |
+
mask_random_color=True,
|
199 |
+
better_quality=True,
|
200 |
+
retina=False,
|
201 |
+
withContours=True):
|
202 |
+
if len(annotations) == 0:
|
203 |
+
return None
|
204 |
+
result = self.plot_to_result(
|
205 |
+
annotations,
|
206 |
+
bboxes,
|
207 |
+
points,
|
208 |
+
point_label,
|
209 |
+
mask_random_color,
|
210 |
+
better_quality,
|
211 |
+
retina,
|
212 |
+
withContours,
|
213 |
+
)
|
214 |
+
|
215 |
+
path = os.path.dirname(os.path.abspath(output_path))
|
216 |
+
if not os.path.exists(path):
|
217 |
+
os.makedirs(path)
|
218 |
+
result = result[:, :, ::-1]
|
219 |
+
cv2.imwrite(output_path, result)
|
220 |
+
|
221 |
+
# CPU post process
|
222 |
+
def fast_show_mask(
|
223 |
+
self,
|
224 |
+
annotation,
|
225 |
+
ax,
|
226 |
+
random_color=False,
|
227 |
+
bboxes=None,
|
228 |
+
points=None,
|
229 |
+
pointlabel=None,
|
230 |
+
retinamask=True,
|
231 |
+
target_height=960,
|
232 |
+
target_width=960,
|
233 |
+
):
|
234 |
+
msak_sum = annotation.shape[0]
|
235 |
+
height = annotation.shape[1]
|
236 |
+
weight = annotation.shape[2]
|
237 |
+
#Sort annotations based on area.
|
238 |
+
areas = np.sum(annotation, axis=(1, 2))
|
239 |
+
sorted_indices = np.argsort(areas)
|
240 |
+
annotation = annotation[sorted_indices]
|
241 |
+
|
242 |
+
index = (annotation != 0).argmax(axis=0)
|
243 |
+
if random_color:
|
244 |
+
color = np.random.random((msak_sum, 1, 1, 3))
|
245 |
+
else:
|
246 |
+
color = np.ones((msak_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255])
|
247 |
+
transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6
|
248 |
+
visual = np.concatenate([color, transparency], axis=-1)
|
249 |
+
mask_image = np.expand_dims(annotation, -1) * visual
|
250 |
+
|
251 |
+
show = np.zeros((height, weight, 4))
|
252 |
+
h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij')
|
253 |
+
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
|
254 |
+
# Use vectorized indexing to update the values of 'show'.
|
255 |
+
show[h_indices, w_indices, :] = mask_image[indices]
|
256 |
+
if bboxes is not None:
|
257 |
+
for bbox in bboxes:
|
258 |
+
x1, y1, x2, y2 = bbox
|
259 |
+
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
|
260 |
+
# draw point
|
261 |
+
if points is not None:
|
262 |
+
plt.scatter(
|
263 |
+
[point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
|
264 |
+
[point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
|
265 |
+
s=20,
|
266 |
+
c='y',
|
267 |
+
)
|
268 |
+
plt.scatter(
|
269 |
+
[point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
|
270 |
+
[point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
|
271 |
+
s=20,
|
272 |
+
c='m',
|
273 |
+
)
|
274 |
+
|
275 |
+
if not retinamask:
|
276 |
+
show = cv2.resize(show, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
|
277 |
+
ax.imshow(show)
|
278 |
+
|
279 |
+
def fast_show_mask_gpu(
|
280 |
+
self,
|
281 |
+
annotation,
|
282 |
+
ax,
|
283 |
+
random_color=False,
|
284 |
+
bboxes=None,
|
285 |
+
points=None,
|
286 |
+
pointlabel=None,
|
287 |
+
retinamask=True,
|
288 |
+
target_height=960,
|
289 |
+
target_width=960,
|
290 |
+
):
|
291 |
+
msak_sum = annotation.shape[0]
|
292 |
+
height = annotation.shape[1]
|
293 |
+
weight = annotation.shape[2]
|
294 |
+
areas = torch.sum(annotation, dim=(1, 2))
|
295 |
+
sorted_indices = torch.argsort(areas, descending=False)
|
296 |
+
annotation = annotation[sorted_indices]
|
297 |
+
# Find the index of the first non-zero value at each position.
|
298 |
+
index = (annotation != 0).to(torch.long).argmax(dim=0)
|
299 |
+
if random_color:
|
300 |
+
color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device)
|
301 |
+
else:
|
302 |
+
color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor([
|
303 |
+
30 / 255, 144 / 255, 255 / 255]).to(annotation.device)
|
304 |
+
transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6
|
305 |
+
visual = torch.cat([color, transparency], dim=-1)
|
306 |
+
mask_image = torch.unsqueeze(annotation, -1) * visual
|
307 |
+
# Select data according to the index. The index indicates which batch's data to choose at each position, converting the mask_image into a single batch form.
|
308 |
+
show = torch.zeros((height, weight, 4)).to(annotation.device)
|
309 |
+
try:
|
310 |
+
h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight), indexing='ij')
|
311 |
+
except:
|
312 |
+
h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight))
|
313 |
+
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
|
314 |
+
# Use vectorized indexing to update the values of 'show'.
|
315 |
+
show[h_indices, w_indices, :] = mask_image[indices]
|
316 |
+
show_cpu = show.cpu().numpy()
|
317 |
+
if bboxes is not None:
|
318 |
+
for bbox in bboxes:
|
319 |
+
x1, y1, x2, y2 = bbox
|
320 |
+
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
|
321 |
+
# draw point
|
322 |
+
if points is not None:
|
323 |
+
plt.scatter(
|
324 |
+
[point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
|
325 |
+
[point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
|
326 |
+
s=20,
|
327 |
+
c='y',
|
328 |
+
)
|
329 |
+
plt.scatter(
|
330 |
+
[point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
|
331 |
+
[point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
|
332 |
+
s=20,
|
333 |
+
c='m',
|
334 |
+
)
|
335 |
+
if not retinamask:
|
336 |
+
show_cpu = cv2.resize(show_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
|
337 |
+
ax.imshow(show_cpu)
|
338 |
+
|
339 |
+
# clip
|
340 |
+
@torch.no_grad()
|
341 |
+
def retrieve(self, model, preprocess, elements, search_text: str, device) -> int:
|
342 |
+
preprocessed_images = [preprocess(image).to(device) for image in elements]
|
343 |
+
tokenized_text = clip.tokenize([search_text]).to(device)
|
344 |
+
stacked_images = torch.stack(preprocessed_images)
|
345 |
+
image_features = model.encode_image(stacked_images)
|
346 |
+
text_features = model.encode_text(tokenized_text)
|
347 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
348 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
349 |
+
probs = 100.0 * image_features @ text_features.T
|
350 |
+
return probs[:, 0].softmax(dim=0)
|
351 |
+
|
352 |
+
def _crop_image(self, format_results):
|
353 |
+
|
354 |
+
image = Image.fromarray(cv2.cvtColor(self.img, cv2.COLOR_BGR2RGB))
|
355 |
+
ori_w, ori_h = image.size
|
356 |
+
annotations = format_results
|
357 |
+
mask_h, mask_w = annotations[0]['segmentation'].shape
|
358 |
+
if ori_w != mask_w or ori_h != mask_h:
|
359 |
+
image = image.resize((mask_w, mask_h))
|
360 |
+
cropped_boxes = []
|
361 |
+
cropped_images = []
|
362 |
+
not_crop = []
|
363 |
+
filter_id = []
|
364 |
+
# annotations, _ = filter_masks(annotations)
|
365 |
+
# filter_id = list(_)
|
366 |
+
for _, mask in enumerate(annotations):
|
367 |
+
if np.sum(mask['segmentation']) <= 100:
|
368 |
+
filter_id.append(_)
|
369 |
+
continue
|
370 |
+
bbox = self._get_bbox_from_mask(mask['segmentation']) # mask 的 bbox
|
371 |
+
cropped_boxes.append(self._segment_image(image, bbox))
|
372 |
+
# cropped_boxes.append(segment_image(image,mask["segmentation"]))
|
373 |
+
cropped_images.append(bbox) # Save the bounding box of the cropped image.
|
374 |
+
|
375 |
+
return cropped_boxes, cropped_images, not_crop, filter_id, annotations
|
376 |
+
|
377 |
+
def box_prompt(self, bbox=None, bboxes=None):
|
378 |
+
if self.results == None:
|
379 |
+
return []
|
380 |
+
assert bbox or bboxes
|
381 |
+
if bboxes is None:
|
382 |
+
bboxes = [bbox]
|
383 |
+
max_iou_index = []
|
384 |
+
for bbox in bboxes:
|
385 |
+
assert (bbox[2] != 0 and bbox[3] != 0)
|
386 |
+
masks = self.results[0].masks.data
|
387 |
+
target_height = self.img.shape[0]
|
388 |
+
target_width = self.img.shape[1]
|
389 |
+
h = masks.shape[1]
|
390 |
+
w = masks.shape[2]
|
391 |
+
if h != target_height or w != target_width:
|
392 |
+
bbox = [
|
393 |
+
int(bbox[0] * w / target_width),
|
394 |
+
int(bbox[1] * h / target_height),
|
395 |
+
int(bbox[2] * w / target_width),
|
396 |
+
int(bbox[3] * h / target_height), ]
|
397 |
+
bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0
|
398 |
+
bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0
|
399 |
+
bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w
|
400 |
+
bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h
|
401 |
+
|
402 |
+
# IoUs = torch.zeros(len(masks), dtype=torch.float32)
|
403 |
+
bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
|
404 |
+
|
405 |
+
masks_area = torch.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], dim=(1, 2))
|
406 |
+
orig_masks_area = torch.sum(masks, dim=(1, 2))
|
407 |
+
|
408 |
+
union = bbox_area + orig_masks_area - masks_area
|
409 |
+
IoUs = masks_area / union
|
410 |
+
max_iou_index.append(int(torch.argmax(IoUs)))
|
411 |
+
max_iou_index = list(set(max_iou_index))
|
412 |
+
return np.array(masks[max_iou_index].cpu().numpy())
|
413 |
+
|
414 |
+
def point_prompt(self, points, pointlabel): # numpy
|
415 |
+
if self.results == None:
|
416 |
+
return []
|
417 |
+
masks = self._format_results(self.results[0], 0)
|
418 |
+
target_height = self.img.shape[0]
|
419 |
+
target_width = self.img.shape[1]
|
420 |
+
h = masks[0]['segmentation'].shape[0]
|
421 |
+
w = masks[0]['segmentation'].shape[1]
|
422 |
+
if h != target_height or w != target_width:
|
423 |
+
points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
|
424 |
+
onemask = np.zeros((h, w))
|
425 |
+
masks = sorted(masks, key=lambda x: x['area'], reverse=True)
|
426 |
+
for i, annotation in enumerate(masks):
|
427 |
+
if type(annotation) == dict:
|
428 |
+
mask = annotation['segmentation']
|
429 |
+
else:
|
430 |
+
mask = annotation
|
431 |
+
for i, point in enumerate(points):
|
432 |
+
if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
|
433 |
+
onemask[mask] = 1
|
434 |
+
if mask[point[1], point[0]] == 1 and pointlabel[i] == 0:
|
435 |
+
onemask[mask] = 0
|
436 |
+
onemask = onemask >= 1
|
437 |
+
return np.array([onemask])
|
438 |
+
|
439 |
+
def text_prompt(self, text):
|
440 |
+
if self.results == None:
|
441 |
+
return []
|
442 |
+
format_results = self._format_results(self.results[0], 0)
|
443 |
+
cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results)
|
444 |
+
clip_model, preprocess = clip.load('ViT-B/32', device=self.device)
|
445 |
+
scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device)
|
446 |
+
max_idx = scores.argsort()
|
447 |
+
max_idx = max_idx[-1]
|
448 |
+
max_idx += sum(np.array(filter_id) <= int(max_idx))
|
449 |
+
return np.array([annotations[max_idx]['segmentation']])
|
450 |
+
|
451 |
+
def everything_prompt(self):
|
452 |
+
if self.results == None:
|
453 |
+
return []
|
454 |
+
return self.results[0].masks.data
|
455 |
+
|
fastsam/utils.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
|
6 |
+
def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
|
7 |
+
'''Adjust bounding boxes to stick to image border if they are within a certain threshold.
|
8 |
+
Args:
|
9 |
+
boxes: (n, 4)
|
10 |
+
image_shape: (height, width)
|
11 |
+
threshold: pixel threshold
|
12 |
+
Returns:
|
13 |
+
adjusted_boxes: adjusted bounding boxes
|
14 |
+
'''
|
15 |
+
|
16 |
+
# Image dimensions
|
17 |
+
h, w = image_shape
|
18 |
+
|
19 |
+
# Adjust boxes
|
20 |
+
boxes[:, 0] = torch.where(boxes[:, 0] < threshold, torch.tensor(
|
21 |
+
0, dtype=torch.float, device=boxes.device), boxes[:, 0]) # x1
|
22 |
+
boxes[:, 1] = torch.where(boxes[:, 1] < threshold, torch.tensor(
|
23 |
+
0, dtype=torch.float, device=boxes.device), boxes[:, 1]) # y1
|
24 |
+
boxes[:, 2] = torch.where(boxes[:, 2] > w - threshold, torch.tensor(
|
25 |
+
w, dtype=torch.float, device=boxes.device), boxes[:, 2]) # x2
|
26 |
+
boxes[:, 3] = torch.where(boxes[:, 3] > h - threshold, torch.tensor(
|
27 |
+
h, dtype=torch.float, device=boxes.device), boxes[:, 3]) # y2
|
28 |
+
|
29 |
+
return boxes
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
def convert_box_xywh_to_xyxy(box):
|
34 |
+
x1 = box[0]
|
35 |
+
y1 = box[1]
|
36 |
+
x2 = box[0] + box[2]
|
37 |
+
y2 = box[1] + box[3]
|
38 |
+
return [x1, y1, x2, y2]
|
39 |
+
|
40 |
+
|
41 |
+
def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=False):
|
42 |
+
'''Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes.
|
43 |
+
Args:
|
44 |
+
box1: (4, )
|
45 |
+
boxes: (n, 4)
|
46 |
+
Returns:
|
47 |
+
high_iou_indices: Indices of boxes with IoU > thres
|
48 |
+
'''
|
49 |
+
boxes = adjust_bboxes_to_image_border(boxes, image_shape)
|
50 |
+
# obtain coordinates for intersections
|
51 |
+
x1 = torch.max(box1[0], boxes[:, 0])
|
52 |
+
y1 = torch.max(box1[1], boxes[:, 1])
|
53 |
+
x2 = torch.min(box1[2], boxes[:, 2])
|
54 |
+
y2 = torch.min(box1[3], boxes[:, 3])
|
55 |
+
|
56 |
+
# compute the area of intersection
|
57 |
+
intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
|
58 |
+
|
59 |
+
# compute the area of both individual boxes
|
60 |
+
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
61 |
+
box2_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
62 |
+
|
63 |
+
# compute the area of union
|
64 |
+
union = box1_area + box2_area - intersection
|
65 |
+
|
66 |
+
# compute the IoU
|
67 |
+
iou = intersection / union # Should be shape (n, )
|
68 |
+
if raw_output:
|
69 |
+
if iou.numel() == 0:
|
70 |
+
return 0
|
71 |
+
return iou
|
72 |
+
|
73 |
+
# get indices of boxes with IoU > thres
|
74 |
+
high_iou_indices = torch.nonzero(iou > iou_thres).flatten()
|
75 |
+
|
76 |
+
return high_iou_indices
|
77 |
+
|
78 |
+
|
79 |
+
def image_to_np_ndarray(image):
|
80 |
+
if type(image) is str:
|
81 |
+
return np.array(Image.open(image))
|
82 |
+
elif issubclass(type(image), Image.Image):
|
83 |
+
return np.array(image)
|
84 |
+
elif type(image) is np.ndarray:
|
85 |
+
return image
|
86 |
+
return None
|
requirements.txt
CHANGED
@@ -1,4 +1,21 @@
|
|
1 |
transformers
|
2 |
-
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
transformers
|
2 |
+
# Base-----------------------------------
|
3 |
+
matplotlib>=3.2.2
|
4 |
+
opencv-python>=4.6.0
|
5 |
+
Pillow>=7.1.2
|
6 |
+
PyYAML>=5.3.1
|
7 |
+
requests>=2.23.0
|
8 |
+
scipy>=1.4.1
|
9 |
+
torch>=1.7.0
|
10 |
+
torchvision>=0.8.1
|
11 |
+
tqdm>=4.64.0
|
12 |
+
|
13 |
+
pandas>=1.1.4
|
14 |
+
seaborn>=0.11.0
|
15 |
+
|
16 |
+
gradio
|
17 |
+
|
18 |
+
# Ultralytics-----------------------------------
|
19 |
+
ultralytics == 8.0.120
|
20 |
+
|
21 |
+
clip @ git+https://github.com/openai/CLIP.git
|