Fangrui Liu
commited on
Commit
·
3f1124e
1
Parent(s):
e66a418
init repo
Browse files- README.md +1 -1
- TestSet.py +120 -0
- app.py +393 -0
- box_utils.py +133 -0
- card_model.py +94 -0
- classifier.py +121 -0
- query_model.py +108 -0
README.md
CHANGED
@@ -6,7 +6,7 @@ colorTo: purple
|
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.10.0
|
8 |
app_file: app.py
|
9 |
-
pinned:
|
10 |
license: lgpl-3.0
|
11 |
---
|
12 |
|
|
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.10.0
|
8 |
app_file: app.py
|
9 |
+
pinned: true
|
10 |
license: lgpl-3.0
|
11 |
---
|
12 |
|
TestSet.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import requests
|
3 |
+
from io import BytesIO
|
4 |
+
from os import path
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
class TestImageSetOnline(Dataset):
|
9 |
+
""" Test Image set with hugging face CLIP preprocess interface
|
10 |
+
|
11 |
+
Args:
|
12 |
+
Dataset (torch.utils.data.Dataset):
|
13 |
+
"""
|
14 |
+
def __init__(self, processor, image_list, timeout_base=0.5, timeout_mul=2):
|
15 |
+
"""
|
16 |
+
Args:
|
17 |
+
processor (CLIP preprocessor): process data to a CLIP digestable format
|
18 |
+
image_list (pandas.DataFrame): pandas.DataFrame that contains image metadata
|
19 |
+
timeout_base (float, optional): initial timeout parameter. Defaults to 0.5.
|
20 |
+
timeout_mul (int, optional): multiplier on timeout every time reqeust fails. Defaults to 2.
|
21 |
+
"""
|
22 |
+
self.image_list = image_list
|
23 |
+
self.processor = processor
|
24 |
+
self.timeout_base = timeout_base
|
25 |
+
self.timeout = self.timeout_base
|
26 |
+
self.timeout_mul = timeout_mul
|
27 |
+
|
28 |
+
def __getitem__(self, index):
|
29 |
+
row = self.image_list[index]
|
30 |
+
url = str(row['coco_url'])
|
31 |
+
_id = str(row['id'])
|
32 |
+
txt, img = None, None
|
33 |
+
flag = True
|
34 |
+
while flag:
|
35 |
+
try:
|
36 |
+
# Get images online
|
37 |
+
response = requests.get(url)
|
38 |
+
img = Image.open(BytesIO(response.content))
|
39 |
+
img_s = img.size
|
40 |
+
if img.mode in ['L', 'CMYK', 'RGBA']:
|
41 |
+
# L is grayscale, CMYK uses alternative color channels
|
42 |
+
img = img.convert('RGB')
|
43 |
+
# Preprocess image
|
44 |
+
ret = self.processor(text=txt, images=img, return_tensor='pt')
|
45 |
+
img = ret['pixel_values'][0]
|
46 |
+
# If success, then there will be no need to run this again
|
47 |
+
flag = False
|
48 |
+
# Relief the timeout param
|
49 |
+
if self.timeout > self.timeout_base:
|
50 |
+
self.timeout /= self.timeout_mul
|
51 |
+
except Exception as e:
|
52 |
+
print(f"{_id} {url}: {str(e)}")
|
53 |
+
if type(e) is KeyboardInterrupt:
|
54 |
+
raise e
|
55 |
+
time.sleep(self.timeout)
|
56 |
+
# Tension the timeout param and turn into a new request
|
57 |
+
self.timeout *= self.timeout_mul
|
58 |
+
return _id, url, img, img_s
|
59 |
+
|
60 |
+
def get(self, url):
|
61 |
+
_id = url
|
62 |
+
txt, img = None, None
|
63 |
+
flag = True
|
64 |
+
while flag:
|
65 |
+
try:
|
66 |
+
# Get images online
|
67 |
+
response = requests.get(url)
|
68 |
+
img = Image.open(BytesIO(response.content))
|
69 |
+
img_s = img.size
|
70 |
+
if img.mode in ['L', 'CMYK', 'RGBA']:
|
71 |
+
# L is grayscale, CMYK uses alternative color channels
|
72 |
+
img = img.convert('RGB')
|
73 |
+
# Preprocess image
|
74 |
+
ret = self.processor(text=txt, images=img, return_tensor='pt')
|
75 |
+
img = ret['pixel_values'][0]
|
76 |
+
# If success, then there will be no need to run this again
|
77 |
+
flag = False
|
78 |
+
# Relief the timeout param
|
79 |
+
if self.timeout > self.timeout_base:
|
80 |
+
self.timeout /= self.timeout_mul
|
81 |
+
except Exception as e:
|
82 |
+
print(f"{_id} {url}: {str(e)}")
|
83 |
+
if type(e) is KeyboardInterrupt:
|
84 |
+
raise e
|
85 |
+
time.sleep(self.timeout)
|
86 |
+
# Tension the timeout param and turn into a new request
|
87 |
+
self.timeout *= self.timeout_mul
|
88 |
+
return _id, url, img, img_s
|
89 |
+
|
90 |
+
|
91 |
+
def __len__(self,):
|
92 |
+
return len(self.image_list)
|
93 |
+
|
94 |
+
def __add__(self, other):
|
95 |
+
self.image_list += other.image_list
|
96 |
+
return self
|
97 |
+
|
98 |
+
class TestImageSet(TestImageSetOnline):
|
99 |
+
def __init__(self, droot, processor, image_list, timeout_base=0.5, timeout_mul=2):
|
100 |
+
super().__init__(processor, image_list, timeout_base, timeout_mul)
|
101 |
+
self.droot = droot
|
102 |
+
|
103 |
+
def __getitem__(self, index):
|
104 |
+
row = self.image_list[index]
|
105 |
+
url = str(row['coco_url'])
|
106 |
+
_id = '_'.join([url.split('/')[-2], str(row['id'])])
|
107 |
+
txt, img = None, None
|
108 |
+
# Get images online
|
109 |
+
img = Image.open(path.join(self.droot,
|
110 |
+
url.split('http://images.cocodataset.org/')[1]))
|
111 |
+
img_s = img.size
|
112 |
+
if img.mode in ['L', 'CMYK', 'RGBA']:
|
113 |
+
# L is grayscale, CMYK uses alternative color channels
|
114 |
+
img = img.convert('RGB')
|
115 |
+
# Preprocess image
|
116 |
+
ret = self.processor(text=txt, images=img, return_tensor='pt')
|
117 |
+
img = ret['pixel_values'][0]
|
118 |
+
# If success, then there will be no need to run this again
|
119 |
+
return _id, url, img, img_s
|
120 |
+
|
app.py
ADDED
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from time import time
|
2 |
+
import aiohttp
|
3 |
+
from io import BytesIO
|
4 |
+
import torch
|
5 |
+
import streamlit as st
|
6 |
+
import streamlit.components.v1 as components
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import logging
|
10 |
+
from os import environ
|
11 |
+
from transformers import OwlViTProcessor, OwlViTForObjectDetection
|
12 |
+
|
13 |
+
from myscaledb import Client
|
14 |
+
from classifier import Classifier, prompt2vec, tune, SplitLayer
|
15 |
+
from query_model import simple_query, topk_obj_query, rev_query
|
16 |
+
from card_model import card, obj_card, style
|
17 |
+
from box_utils import postprocess
|
18 |
+
|
19 |
+
environ['TOKENIZERS_PARALLELISM'] = 'true'
|
20 |
+
|
21 |
+
OBJ_DB_NAME = "mqdb_demo.coco_owl_vit_b_32_objects"
|
22 |
+
IMG_DB_NAME = "mqdb_demo.coco_owl_vit_b_32_images"
|
23 |
+
MODEL_ID = 'google/owlvit-base-patch32'
|
24 |
+
DIMS = 512
|
25 |
+
|
26 |
+
qtime = 0
|
27 |
+
|
28 |
+
|
29 |
+
def build_model(name="google/owlvit-base-patch32"):
|
30 |
+
"""Model builder function
|
31 |
+
|
32 |
+
Args:
|
33 |
+
name (str, optional): Name for HuggingFace OwlViT model. Defaults to "google/owlvit-base-patch32".
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
(model, processor): OwlViT model and its processor for both image and text
|
37 |
+
"""
|
38 |
+
device = 'cpu'
|
39 |
+
if torch.cuda.is_available():
|
40 |
+
device = 'cuda'
|
41 |
+
model = OwlViTForObjectDetection.from_pretrained(name).to(device)
|
42 |
+
processor = OwlViTProcessor.from_pretrained(name)
|
43 |
+
return model, processor
|
44 |
+
|
45 |
+
|
46 |
+
@st.experimental_singleton(show_spinner=False)
|
47 |
+
def init_owlvit():
|
48 |
+
""" Initialize OwlViT Model
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
model, processor
|
52 |
+
"""
|
53 |
+
model, processor = build_model(MODEL_ID)
|
54 |
+
return model, processor
|
55 |
+
|
56 |
+
|
57 |
+
@st.experimental_singleton(show_spinner=False)
|
58 |
+
def init_db():
|
59 |
+
""" Initialize the Database Connection
|
60 |
+
|
61 |
+
Returns:
|
62 |
+
meta_field: Meta field that records if an image is viewed or not
|
63 |
+
client: Database connection object
|
64 |
+
"""
|
65 |
+
meta = []
|
66 |
+
client = Client(
|
67 |
+
url=st.secrets["DB_URL"], user=st.secrets["USER"], password=st.secrets["PASSWD"])
|
68 |
+
# We can check if the connection is alive
|
69 |
+
assert client.is_alive()
|
70 |
+
return meta, client
|
71 |
+
|
72 |
+
|
73 |
+
def refresh_index():
|
74 |
+
""" Clean the session
|
75 |
+
"""
|
76 |
+
del st.session_state["meta"]
|
77 |
+
st.session_state.meta = []
|
78 |
+
st.session_state.query_num = 0
|
79 |
+
logging.info(f"Refresh for '{st.session_state.meta}'")
|
80 |
+
# Need to clear singleton function with streamlit API
|
81 |
+
init_db.clear()
|
82 |
+
# refresh session states
|
83 |
+
st.session_state.meta, st.session_state.index = init_db()
|
84 |
+
if 'clf' in st.session_state:
|
85 |
+
del st.session_state.clf
|
86 |
+
if 'xq' in st.session_state:
|
87 |
+
del st.session_state.xq
|
88 |
+
if 'topk_img_id' in st.session_state:
|
89 |
+
del st.session_state.topk_img_id
|
90 |
+
|
91 |
+
|
92 |
+
def query(xq, exclude_list=None):
|
93 |
+
""" Query matched w.r.t a given vector
|
94 |
+
|
95 |
+
In this part, we will retrieve A LOT OF data from the server,
|
96 |
+
including TopK boxes and their embeddings, the counterpart of non-TopK boxes in TopK images.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
xq (numpy.ndarray or list of floats): Query vector
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
matches: list of Records object. Keys referrring to selected columns group by images.
|
103 |
+
Exclude the user's viewlist.
|
104 |
+
img_matches: list of Records object. Containing other non-TopK but hit objects among TopK images.
|
105 |
+
side_matches: list of Records object. Containing REAL TopK objects disregard the user's view history
|
106 |
+
"""
|
107 |
+
attempt = 0
|
108 |
+
xq = xq
|
109 |
+
xq = xq / np.linalg.norm(xq, axis=-1, ord=2, keepdims=True)
|
110 |
+
status_bar = [st.empty(), st.empty()]
|
111 |
+
status_bar[0].write("Retrieving Another TopK Images...")
|
112 |
+
pbar = status_bar[1].progress(0)
|
113 |
+
while attempt < 3:
|
114 |
+
try:
|
115 |
+
matches = topk_obj_query(
|
116 |
+
st.session_state.index, xq, IMG_DB_NAME, OBJ_DB_NAME,
|
117 |
+
exclude_list=exclude_list, topk=5000)
|
118 |
+
img_ids = [r['img_id'] for r in matches]
|
119 |
+
if 'topk_img_id' not in st.session_state:
|
120 |
+
st.session_state.topk_img_id = img_ids
|
121 |
+
status_bar[0].write("Retrieving TopK Images...")
|
122 |
+
pbar.progress(25)
|
123 |
+
o_matches = rev_query(
|
124 |
+
st.session_state.index, xq, st.session_state.topk_img_id,
|
125 |
+
IMG_DB_NAME, OBJ_DB_NAME, thresh=0.1)
|
126 |
+
status_bar[0].write("Retrieving TopKs Objects...")
|
127 |
+
pbar.progress(50)
|
128 |
+
side_matches = simple_query(st.session_state.index, xq, IMG_DB_NAME, OBJ_DB_NAME,
|
129 |
+
thresh=-1, topk=5000)
|
130 |
+
status_bar[0].write(
|
131 |
+
"Retrieving Non-TopK in Another TopK Images...")
|
132 |
+
pbar.progress(75)
|
133 |
+
if len(img_ids) > 0:
|
134 |
+
img_matches = rev_query(
|
135 |
+
st.session_state.index, xq, img_ids, IMG_DB_NAME, OBJ_DB_NAME,
|
136 |
+
thresh=0.1)
|
137 |
+
else:
|
138 |
+
img_matches = []
|
139 |
+
status_bar[0].write("DONE!")
|
140 |
+
pbar.progress(100)
|
141 |
+
break
|
142 |
+
except Exception as e:
|
143 |
+
# force reload if we have trouble on connections or something else
|
144 |
+
logging.warning(str(e))
|
145 |
+
st.session_state.meta, st.session_state.index = init_db()
|
146 |
+
attempt += 1
|
147 |
+
matches = []
|
148 |
+
_ = [s.empty() for s in status_bar]
|
149 |
+
if len(matches) == 0:
|
150 |
+
logging.error(f"No matches found for '{OBJ_DB_NAME}'")
|
151 |
+
return matches, img_matches, side_matches, o_matches
|
152 |
+
|
153 |
+
|
154 |
+
@st.experimental_singleton(show_spinner=False)
|
155 |
+
def init_random_query():
|
156 |
+
"""Initialize a random query vector
|
157 |
+
|
158 |
+
Returns:
|
159 |
+
xq: a random vector
|
160 |
+
"""
|
161 |
+
xq = np.random.rand(1, DIMS)
|
162 |
+
xq /= np.linalg.norm(xq, keepdims=True, axis=-1)
|
163 |
+
return xq
|
164 |
+
|
165 |
+
|
166 |
+
def submit(meta):
|
167 |
+
""" Tune the model w.r.t given score from user.
|
168 |
+
"""
|
169 |
+
# Only updating the meta if the train button is pressed
|
170 |
+
st.session_state.meta.extend(meta)
|
171 |
+
st.session_state.step += 1
|
172 |
+
matches = st.session_state.matched_boxes
|
173 |
+
X, y = list(zip(*((v[-1],
|
174 |
+
st.session_state.text_prompts.index(
|
175 |
+
st.session_state[f"label-{i}"])) for i, v in matches.items())))
|
176 |
+
st.session_state.xq = tune(st.session_state.clf,
|
177 |
+
X, y, iters=int(st.session_state.iters))
|
178 |
+
st.session_state.matches, \
|
179 |
+
st.session_state.img_matches, \
|
180 |
+
st.session_state.side_matches, \
|
181 |
+
st.session_state.o_matches = query(
|
182 |
+
st.session_state.xq, st.session_state.meta)
|
183 |
+
|
184 |
+
|
185 |
+
# st.set_page_config(layout="wide")
|
186 |
+
# To hack the streamlit style we define our own style.
|
187 |
+
# Boxes are drawn in SVGs.
|
188 |
+
st.write(style(), unsafe_allow_html=True)
|
189 |
+
|
190 |
+
with st.spinner("Connecting DB..."):
|
191 |
+
st.session_state.meta, st.session_state.index = init_db()
|
192 |
+
|
193 |
+
with st.spinner("Loading Models..."):
|
194 |
+
# Initialize model
|
195 |
+
model, tokenizer = init_owlvit()
|
196 |
+
|
197 |
+
# If its a fresh start... (query not set)
|
198 |
+
if 'xq' not in st.session_state:
|
199 |
+
with st.container():
|
200 |
+
st.title('Object Detection Safari')
|
201 |
+
start = [st.empty() for _ in range(8)]
|
202 |
+
start[0].info("""
|
203 |
+
We extracted boxes from **287,104** images in COCO Dataset, including its train / val / test /
|
204 |
+
unlabeled images, collecting **165,371,904 boxes** which are then filtered with common prompts.
|
205 |
+
You can search with almost any words or phrases you can think of. Please enjoy your journey of
|
206 |
+
an adventure to COCO.
|
207 |
+
""")
|
208 |
+
prompt = start[1].text_input(
|
209 |
+
"Prompt:", value="", placeholder="Examples: football, billboard, stop sign, watermark ...",)
|
210 |
+
with start[2].container():
|
211 |
+
st.write(
|
212 |
+
'You can search with multiple keywords. Plese separate with commas but with no space.')
|
213 |
+
st.write('For example: `cat,dog,tree`')
|
214 |
+
st.markdown('''
|
215 |
+
<p style="color:gray;"> Don\'t know what to search? Try <b>Random</b>!</p>
|
216 |
+
''',
|
217 |
+
unsafe_allow_html=True)
|
218 |
+
|
219 |
+
upld_model = start[4].file_uploader(
|
220 |
+
"Or you can upload your previous run!", type='onnx')
|
221 |
+
upld_btn = start[5].button(
|
222 |
+
"Use Loaded Weights", disabled=upld_model is None, on_click=refresh_index)
|
223 |
+
|
224 |
+
with start[3]:
|
225 |
+
col = st.columns(8)
|
226 |
+
has_no_prompt = (len(prompt) == 0 and upld_model is None)
|
227 |
+
prompt_xq = col[6].button("Prompt", disabled=len(
|
228 |
+
prompt) == 0, on_click=refresh_index)
|
229 |
+
random_xq = col[7].button(
|
230 |
+
"Random", disabled=not has_no_prompt, on_click=refresh_index)
|
231 |
+
matches = []
|
232 |
+
img_matches = []
|
233 |
+
if random_xq:
|
234 |
+
xq = init_random_query()
|
235 |
+
st.session_state.xq = xq
|
236 |
+
prompt = 'unknown'
|
237 |
+
st.session_state.text_prompts = prompt.split(',') + ['none']
|
238 |
+
_ = [elem.empty() for elem in start]
|
239 |
+
t0 = time()
|
240 |
+
matches, img_matches, side_matches, o_matches = query(
|
241 |
+
st.session_state.xq, st.session_state.meta)
|
242 |
+
t1 = time()
|
243 |
+
qtime = (t1-t0) * 1000
|
244 |
+
elif prompt_xq or upld_btn:
|
245 |
+
if upld_model is not None:
|
246 |
+
import onnx
|
247 |
+
from onnx import numpy_helper
|
248 |
+
_model = onnx.load(upld_model)
|
249 |
+
st.session_state.text_prompts = [
|
250 |
+
node.name for node in _model.graph.output] + ['none']
|
251 |
+
weights = _model.graph.initializer
|
252 |
+
xq = numpy_helper.to_array(weights[0]).T
|
253 |
+
assert xq.shape[0] == len(
|
254 |
+
st.session_state.text_prompts)-1 and xq.shape[1] == DIMS
|
255 |
+
st.session_state.xq = xq
|
256 |
+
_ = [elem.empty() for elem in start]
|
257 |
+
else:
|
258 |
+
logging.info(f"Input prompt is {prompt}")
|
259 |
+
st.session_state.text_prompts = prompt.split(',') + ['none']
|
260 |
+
input_ids, xq = prompt2vec(
|
261 |
+
st.session_state.text_prompts[:-1], model, tokenizer)
|
262 |
+
st.session_state.xq = xq
|
263 |
+
_ = [elem.empty() for elem in start]
|
264 |
+
t0 = time()
|
265 |
+
st.session_state.matches, \
|
266 |
+
st.session_state.img_matches, \
|
267 |
+
st.session_state.side_matches, \
|
268 |
+
st.session_state.o_matches = query(
|
269 |
+
st.session_state.xq, st.session_state.meta)
|
270 |
+
t1 = time()
|
271 |
+
qtime = (t1-t0) * 1000
|
272 |
+
|
273 |
+
# If its not a fresh start (query is set)
|
274 |
+
if 'xq' in st.session_state:
|
275 |
+
o_matches = st.session_state.o_matches
|
276 |
+
side_matches = st.session_state.side_matches
|
277 |
+
img_matches = st.session_state.img_matches
|
278 |
+
matches = st.session_state.matches
|
279 |
+
# initialize classifier
|
280 |
+
if 'clf' not in st.session_state:
|
281 |
+
st.session_state.clf = Classifier(st.session_state.xq)
|
282 |
+
st.session_state.step = 0
|
283 |
+
if qtime > 0:
|
284 |
+
st.info("Query done in {0:.2f} ms and returned {1:d} images with {2:d} boxes".format(
|
285 |
+
qtime, len(matches), sum([len(m["box_id"]) + len(im["box_id"]) for m, im in zip(matches, img_matches)])))
|
286 |
+
|
287 |
+
# export the model into executable ONNX
|
288 |
+
st.session_state.dnld_model = BytesIO()
|
289 |
+
torch.onnx.export(torch.nn.Sequential(st.session_state.clf.model, SplitLayer()),
|
290 |
+
torch.zeros([1, len(st.session_state.xq[0])]),
|
291 |
+
st.session_state.dnld_model,
|
292 |
+
input_names=['input'],
|
293 |
+
output_names=st.session_state.text_prompts[:-1])
|
294 |
+
|
295 |
+
dnld_nam = st.text_input('Download Name:',
|
296 |
+
f'{("_".join([i.replace(" ", "-") for i in st.session_state.text_prompts[:-1]]) if "text_prompts" in st.session_state else "model")}.onnx',
|
297 |
+
max_chars=50)
|
298 |
+
dnld_btn = st.download_button('Download your classifier!',
|
299 |
+
st.session_state.dnld_model,
|
300 |
+
dnld_nam)
|
301 |
+
# build up a sidebar to display REAL TopK in DB
|
302 |
+
# this will change during user's finetune. But sometime it would lead to bad results
|
303 |
+
side_bar_len = min(240 // len(st.session_state.text_prompts), 120)
|
304 |
+
with st.sidebar:
|
305 |
+
with st.expander("Top-K Images"):
|
306 |
+
with st.container():
|
307 |
+
boxes_w_img, _ = postprocess(o_matches, st.session_state.text_prompts,
|
308 |
+
None)
|
309 |
+
boxes_w_img = sorted(
|
310 |
+
boxes_w_img, key=lambda x: x[4], reverse=True)
|
311 |
+
for img_id, img_url, img_w, img_h, img_score, boxes in boxes_w_img:
|
312 |
+
args = img_url, img_w, img_h, boxes
|
313 |
+
st.write(card(*args), unsafe_allow_html=True)
|
314 |
+
|
315 |
+
with st.expander("Top-K Objects", expanded=True):
|
316 |
+
side_cols = st.columns(
|
317 |
+
len(st.session_state.text_prompts[:-1]))
|
318 |
+
for _cols, m in zip(side_cols, side_matches):
|
319 |
+
with _cols.container():
|
320 |
+
for cx, cy, w, h, logit, img_url, img_w, img_h \
|
321 |
+
in zip(m['cx'], m['cy'], m['w'], m['h'], m['logit'],
|
322 |
+
m['img_url'], m['img_w'], m['img_h']):
|
323 |
+
st.write("{:s}: {:.4f}".format(
|
324 |
+
st.session_state.text_prompts[m['label']], logit))
|
325 |
+
_html = obj_card(
|
326 |
+
img_url, img_w, img_h, cx, cy, w, h, dst_len=side_bar_len)
|
327 |
+
components.html(
|
328 |
+
_html, side_bar_len, side_bar_len)
|
329 |
+
with st.container():
|
330 |
+
# Here let the user interact with batch labeling
|
331 |
+
with st.form("batch", clear_on_submit=False):
|
332 |
+
col = st.columns([1, 9])
|
333 |
+
|
334 |
+
# If there is nothing to show about
|
335 |
+
if len(matches) <= 0:
|
336 |
+
st.warning(
|
337 |
+
'Oops! We didn\'t find anything relevant to your query! Pleas try another one :/')
|
338 |
+
else:
|
339 |
+
st.session_state.iters = st.slider(
|
340 |
+
"Number of Iterations to Update", min_value=0, max_value=10, step=1, value=2)
|
341 |
+
# No matter what happened the user wants a way back
|
342 |
+
col[1].form_submit_button(
|
343 |
+
"Choose a new prompt", on_click=refresh_index)
|
344 |
+
|
345 |
+
# If there are things to show
|
346 |
+
if len(matches) > 0:
|
347 |
+
with st.container():
|
348 |
+
prompt_labels = st.session_state.text_prompts
|
349 |
+
|
350 |
+
# Post processing boxes regarding to their score, intersection
|
351 |
+
boxes_w_img, meta = postprocess(matches, st.session_state.text_prompts,
|
352 |
+
img_matches)
|
353 |
+
|
354 |
+
# Sort the result according to their relavancy
|
355 |
+
boxes_w_img = sorted(
|
356 |
+
boxes_w_img, key=lambda x: x[4], reverse=True)
|
357 |
+
|
358 |
+
st.session_state.matched_boxes = {}
|
359 |
+
# For each images in the retrieved images, DISPLAY
|
360 |
+
for img_id, img_url, img_w, img_h, img_score, boxes in boxes_w_img:
|
361 |
+
|
362 |
+
# prepare inputs for training
|
363 |
+
st.session_state.matched_boxes.update(
|
364 |
+
{b[0]: b for b in boxes})
|
365 |
+
args = img_url, img_w, img_h, boxes
|
366 |
+
|
367 |
+
# display boxes
|
368 |
+
with st.expander("{:s}: {:.4f}".format(img_id, img_score), expanded=True):
|
369 |
+
ind_b = 0
|
370 |
+
# 4 columns: (img, obj, obj, obj)
|
371 |
+
img_row = st.columns([4, 2, 2, 2])
|
372 |
+
img_row[0].write(
|
373 |
+
card(*args), unsafe_allow_html=True)
|
374 |
+
# crop objects out of the original image
|
375 |
+
for b in boxes:
|
376 |
+
_id, cx, cy, w, h, label, logit, is_selected, _ = b
|
377 |
+
with img_row[1 + ind_b % 3].container():
|
378 |
+
st.write(
|
379 |
+
"{:s}: {:.4f}".format(label, logit))
|
380 |
+
# quite hacky: with streamlit components API
|
381 |
+
_html = \
|
382 |
+
obj_card(img_url, img_w, img_h,
|
383 |
+
*b[1:5], dst_len=120)
|
384 |
+
components.html(_html, 120, 120)
|
385 |
+
# the user will choose the right label of the given object
|
386 |
+
st.selectbox(
|
387 |
+
"Class",
|
388 |
+
prompt_labels,
|
389 |
+
index=prompt_labels.index(label),
|
390 |
+
key=f"label-{_id}")
|
391 |
+
ind_b += 1
|
392 |
+
col[0].form_submit_button(
|
393 |
+
"Train!", on_click=lambda: submit(meta))
|
box_utils.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
def cxywh2xywh(cx, cy, w, h):
|
5 |
+
""" CxCyWH format to XYWH format conversion
|
6 |
+
"""
|
7 |
+
x = cx - w / 2
|
8 |
+
y = cy - h / 2
|
9 |
+
return x, y, w, h
|
10 |
+
|
11 |
+
|
12 |
+
def cxywh2ltrb(cx, cy, w, h):
|
13 |
+
"""CxCyWH format to LeftRightTopBottom format
|
14 |
+
"""
|
15 |
+
l = cx - w / 2
|
16 |
+
t = cy - h / 2
|
17 |
+
r = cx + w / 2
|
18 |
+
b = cy + h / 2
|
19 |
+
return l, t, r, b
|
20 |
+
|
21 |
+
|
22 |
+
def iou(ba, bb):
|
23 |
+
"""Calculate Intersection-Over-Union
|
24 |
+
|
25 |
+
Args:
|
26 |
+
ba (tuple): CxCyWH format with score
|
27 |
+
bb (tuple): CxCyWH format with score
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
IoU with size of length of given box
|
31 |
+
"""
|
32 |
+
a_l, a_t, a_r, a_b, sa = ba
|
33 |
+
b_l, b_t, b_r, b_b, sb = bb
|
34 |
+
|
35 |
+
x1 = np.maximum(a_l, b_l)
|
36 |
+
y1 = np.maximum(a_t, b_t)
|
37 |
+
x2 = np.minimum(a_r, b_r)
|
38 |
+
y2 = np.minimum(a_b, b_b)
|
39 |
+
w = np.maximum(0, x2 - x1)
|
40 |
+
h = np.maximum(0, y2 - y1)
|
41 |
+
intersec = w * h
|
42 |
+
iou = (intersec) / (sa + sb - intersec)
|
43 |
+
return iou.squeeze()
|
44 |
+
|
45 |
+
|
46 |
+
def nms(cx, cy, w, h, s, iou_thresh=0.3):
|
47 |
+
"""Bounding box Non-maximum Suppression
|
48 |
+
|
49 |
+
Args:
|
50 |
+
cx, cy, w, h, s: CxCyWH Format with score boxes
|
51 |
+
iou_thresh (float, optional): IoU threshold. Defaults to 0.3.
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
res: indexes of the selected boxes
|
55 |
+
"""
|
56 |
+
l, t, r, b = cxywh2ltrb(cx, cy, w, h)
|
57 |
+
areas = w * h
|
58 |
+
res = []
|
59 |
+
sort_ind = np.argsort(s, axis=-1)[::-1]
|
60 |
+
while sort_ind.shape[0] > 0:
|
61 |
+
i = sort_ind[0]
|
62 |
+
res.append(i)
|
63 |
+
|
64 |
+
_iou = iou((l[i], t[i], r[i], b[i], areas[i]),
|
65 |
+
(l[sort_ind[1:]], t[sort_ind[1:]],
|
66 |
+
r[sort_ind[1:]], b[sort_ind[1:]], areas[sort_ind[1:]]))
|
67 |
+
sel_ind = np.where(_iou <= iou_thresh)[0]
|
68 |
+
sort_ind = sort_ind[sel_ind + 1]
|
69 |
+
return res
|
70 |
+
|
71 |
+
|
72 |
+
def filter_nonpos(boxes, agnostic_ratio=0.5, class_ratio=0.7):
|
73 |
+
"""filter out insignificant boxes
|
74 |
+
|
75 |
+
Args:
|
76 |
+
boxes (list of records): returned query to be filtered
|
77 |
+
"""
|
78 |
+
ret = []
|
79 |
+
labelwise = {}
|
80 |
+
for _id, cx, cy, w, h, label, logit, is_selected, _ in boxes:
|
81 |
+
if label not in labelwise:
|
82 |
+
labelwise[label] = []
|
83 |
+
labelwise[label].append(logit)
|
84 |
+
labelwise = {l: max(s) for l, s in labelwise.items()}
|
85 |
+
agnostic = max([v for _, v in labelwise.items()])
|
86 |
+
for b in boxes:
|
87 |
+
_id, cx, cy, w, h, label, logit, is_selected, _ = b
|
88 |
+
if logit > class_ratio * labelwise[label] \
|
89 |
+
and logit > agnostic_ratio * agnostic:
|
90 |
+
ret.append(b)
|
91 |
+
return ret
|
92 |
+
|
93 |
+
|
94 |
+
def postprocess(matches, prompt_labels, img_matches=None):
|
95 |
+
meta = []
|
96 |
+
boxes_w_img = []
|
97 |
+
matches_ = {m['img_id']: m for m in matches}
|
98 |
+
if img_matches is not None:
|
99 |
+
img_matches_ = {m['img_id']: m for m in img_matches}
|
100 |
+
for k in matches_.keys():
|
101 |
+
m = matches_[k]
|
102 |
+
boxes = []
|
103 |
+
boxes += list(map(list, zip(m['box_id'], m['cx'], m['cy'], m['w'], m['h'],
|
104 |
+
[prompt_labels[int(l)]
|
105 |
+
for l in m['label']],
|
106 |
+
m['logit'], [1] *
|
107 |
+
len(m['box_id']),
|
108 |
+
list(np.array(m['cls_emb'])))))
|
109 |
+
if img_matches is not None:
|
110 |
+
img_m = img_matches_[k]
|
111 |
+
# and also those non-TopK hits and those non-topk are not anticipating training
|
112 |
+
boxes += [i for i in map(list, zip(img_m['box_id'], img_m['cx'], img_m['cy'], img_m['w'], img_m['h'],
|
113 |
+
[prompt_labels[int(
|
114 |
+
l)] for l in img_m['label']], img_m['logit'],
|
115 |
+
[0] * len(img_m['box_id']), list(np.array(img_m['cls_emb']))))
|
116 |
+
if i[0] not in [b[0] for b in boxes]]
|
117 |
+
# update record metadata after query
|
118 |
+
for b in boxes:
|
119 |
+
meta.append(b[0])
|
120 |
+
|
121 |
+
# remove some non-significant boxes
|
122 |
+
boxes = filter_nonpos(
|
123 |
+
boxes, agnostic_ratio=0.4, class_ratio=0.7)
|
124 |
+
|
125 |
+
# doing non-maximum suppression
|
126 |
+
cx, cy, w, h, s = list(map(lambda x: np.array(x),
|
127 |
+
list(zip(*[(*b[1:5], b[6]) for b in boxes]))))
|
128 |
+
ind = nms(cx, cy, w, h, s, 0.3)
|
129 |
+
boxes = [boxes[i] for i in ind]
|
130 |
+
img_score = img_m['img_score'] if img_matches is not None else m['img_score']
|
131 |
+
boxes_w_img.append(
|
132 |
+
(m["img_id"], m["img_url"], m["img_w"], m["img_h"], img_score, boxes))
|
133 |
+
return boxes_w_img, meta
|
card_model.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
from box_utils import cxywh2ltrb, cxywh2xywh
|
3 |
+
|
4 |
+
|
5 |
+
def style():
|
6 |
+
""" Style string for card models
|
7 |
+
"""
|
8 |
+
return """
|
9 |
+
<link
|
10 |
+
rel="stylesheet"
|
11 |
+
href="https://fonts.googleapis.com/css?family=Roboto:300,400,500,700&display=swap"
|
12 |
+
/>
|
13 |
+
<style>
|
14 |
+
.img-overlay-wrap {
|
15 |
+
position: relative;
|
16 |
+
display: inline-block;
|
17 |
+
}
|
18 |
+
.img-overlay-wrap {
|
19 |
+
position: relative;
|
20 |
+
display: inline-block;
|
21 |
+
/* <= shrinks container to image size */
|
22 |
+
transition: transform 150ms ease-in-out;
|
23 |
+
}
|
24 |
+
.img-overlay-wrap img {
|
25 |
+
/* <= optional, for responsiveness */
|
26 |
+
display: block;
|
27 |
+
max-width: 100%;
|
28 |
+
height: auto;
|
29 |
+
}
|
30 |
+
.img-overlay-wrap svg {
|
31 |
+
position: absolute;
|
32 |
+
top: 0;
|
33 |
+
left: 0;
|
34 |
+
}
|
35 |
+
</style>
|
36 |
+
"""
|
37 |
+
|
38 |
+
|
39 |
+
def card(img_url, img_w, img_h, boxes):
|
40 |
+
""" This is a hack to streamlit
|
41 |
+
Solution thanks to: https://discuss.streamlit.io/t/display-svg/172/5
|
42 |
+
Converting SVG to Base64 and display with <img> tag.
|
43 |
+
Also we used the
|
44 |
+
"""
|
45 |
+
_boxes = ""
|
46 |
+
for _id, cx, cy, w, h, label, logit, is_selected, _ in boxes:
|
47 |
+
x, y, w, h = cxywh2xywh(cx, cy, w, h)
|
48 |
+
x = round(img_w * x)
|
49 |
+
y = round(img_h * y)
|
50 |
+
w = round(img_w * w)
|
51 |
+
h = round(img_h * h)
|
52 |
+
logit = "%.3f" % logit
|
53 |
+
_boxes += f'''
|
54 |
+
<text fill="white" font-size="20" x="{x}" y="{y}" style="fill:white;opacity:0.7">{label}: {logit}</text>
|
55 |
+
<rect x="{x}" y="{y}" width="{w}" height="{h}" style="fill:none;stroke:{"red" if is_selected else "green"};
|
56 |
+
stroke-width:4;opacity:0.5" />
|
57 |
+
'''
|
58 |
+
_svg = f'''
|
59 |
+
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {img_w} {img_h}">
|
60 |
+
{_boxes}
|
61 |
+
</svg>
|
62 |
+
'''
|
63 |
+
_svg = r'<img style="position:absolute;top:0;left:0;" src="data:image/svg+xml;base64,%s"/>' % \
|
64 |
+
base64.b64encode(_svg.encode('utf-8')).decode('utf-8')
|
65 |
+
_img_d = f'''
|
66 |
+
<div class="img-overlay-wrap" width="{img_w}" height="{img_h}">
|
67 |
+
<img width="{img_w}" height="{img_h}" src="{img_url}">
|
68 |
+
{_svg}
|
69 |
+
</div>
|
70 |
+
'''
|
71 |
+
return _img_d
|
72 |
+
|
73 |
+
|
74 |
+
def obj_card(img_url, img_w, img_h, cx, cy, w, h, *args, dst_len=100):
|
75 |
+
"""object card for displaying cropped object
|
76 |
+
|
77 |
+
Args:
|
78 |
+
Retrieved image and object info
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
_obj_html: html string to display object
|
82 |
+
"""
|
83 |
+
w = img_w * w
|
84 |
+
h = img_h * h
|
85 |
+
s = max(w, h)
|
86 |
+
x = round(img_w * cx - s / 2)
|
87 |
+
y = round(img_h * cy - s / 2)
|
88 |
+
scale = dst_len / s
|
89 |
+
_obj_html = f'''
|
90 |
+
<div style="transform-origin:0 0;transform:scale({scale});">
|
91 |
+
<img src="{img_url}" style="margin:{-y}px 0px 0px {-x}px;">
|
92 |
+
</div>
|
93 |
+
'''
|
94 |
+
return _obj_html
|
classifier.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def extract_text_feature(prompt, model, processor, device='cpu'):
|
5 |
+
"""Extract text features
|
6 |
+
|
7 |
+
Args:
|
8 |
+
prompt: a single text query
|
9 |
+
model: OwlViT model
|
10 |
+
processor: OwlViT processor
|
11 |
+
device (str, optional): device to run. Defaults to 'cpu'.
|
12 |
+
"""
|
13 |
+
device = 'cpu'
|
14 |
+
if torch.cuda.is_available():
|
15 |
+
device = 'cuda'
|
16 |
+
with torch.no_grad():
|
17 |
+
input_ids = torch.as_tensor(processor(text=prompt)[
|
18 |
+
'input_ids']).to(device)
|
19 |
+
print(input_ids.device)
|
20 |
+
text_outputs = model.owlvit.text_model(
|
21 |
+
input_ids=input_ids,
|
22 |
+
attention_mask=None,
|
23 |
+
output_attentions=None,
|
24 |
+
output_hidden_states=None,
|
25 |
+
return_dict=None,
|
26 |
+
)
|
27 |
+
text_embeds = text_outputs[1]
|
28 |
+
text_embeds = model.owlvit.text_projection(text_embeds)
|
29 |
+
text_embeds /= text_embeds.norm(p=2, dim=-1, keepdim=True) + 1e-6
|
30 |
+
query_embeds = text_embeds
|
31 |
+
return input_ids, query_embeds
|
32 |
+
|
33 |
+
|
34 |
+
def prompt2vec(prompt: str, model, processor):
|
35 |
+
""" Convert prompt into a computational vector
|
36 |
+
|
37 |
+
Args:
|
38 |
+
prompt (str): Text to be tokenized
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
xq: vector from the tokenizer, representing the original prompt
|
42 |
+
"""
|
43 |
+
# inputs = tokenizer(prompt, return_tensors='pt')
|
44 |
+
# out = clip.get_text_features(**inputs)
|
45 |
+
input_ids, xq = extract_text_feature(prompt, model, processor)
|
46 |
+
input_ids = input_ids.detach().cpu().numpy()
|
47 |
+
xq = xq.detach().cpu().numpy()
|
48 |
+
return input_ids, xq
|
49 |
+
|
50 |
+
|
51 |
+
def tune(clf, X, y, iters=2):
|
52 |
+
""" Train the Zero-shot Classifier
|
53 |
+
|
54 |
+
Args:
|
55 |
+
X (numpy.ndarray): Input vectors (retreived vectors)
|
56 |
+
y (list of floats or numpy.ndarray): Scores given by user
|
57 |
+
iters (int, optional): iterations of updates to be run
|
58 |
+
"""
|
59 |
+
assert len(X) == len(y)
|
60 |
+
# train the classifier
|
61 |
+
clf.fit(X, y, iters=iters)
|
62 |
+
# extract new vector
|
63 |
+
return clf.get_weights()
|
64 |
+
|
65 |
+
|
66 |
+
class Classifier:
|
67 |
+
"""Multi-Class Zero-shot Classifier
|
68 |
+
This Classifier provides proxy regarding to the user's reaction to the probed images.
|
69 |
+
The proxy will replace the original query vector generated by prompted vector and finally
|
70 |
+
give the user a satisfying retrieval result.
|
71 |
+
|
72 |
+
This can be commonly seen in a recommendation system. The classifier will recommend more
|
73 |
+
precise result as it accumulating user's activity.
|
74 |
+
|
75 |
+
This is a multiclass classifier. For N queries it will set the all queries to the first-N classes
|
76 |
+
and the last one takes the negative one.
|
77 |
+
"""
|
78 |
+
|
79 |
+
def __init__(self, xq: list):
|
80 |
+
init_weight = torch.Tensor(xq)
|
81 |
+
self.num_class = xq.shape[0]
|
82 |
+
DIMS = xq.shape[1]
|
83 |
+
# note that the bias is ignored, as we only focus on the inner product result
|
84 |
+
self.model = torch.nn.Linear(DIMS, self.num_class, bias=False)
|
85 |
+
# convert initial query `xq` to tensor parameter to init weights
|
86 |
+
self.model.weight = torch.nn.Parameter(init_weight)
|
87 |
+
# init loss and optimizer
|
88 |
+
self.loss = torch.nn.BCEWithLogitsLoss()
|
89 |
+
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1)
|
90 |
+
|
91 |
+
def fit(self, X: list, y: list, iters: int = 5):
|
92 |
+
# convert X and y to tensor
|
93 |
+
X = torch.Tensor(X)
|
94 |
+
X /= torch.norm(X, p=2, dim=-1, keepdim=True)
|
95 |
+
y = torch.Tensor(y).long()
|
96 |
+
# Generate labels for binary classification and ignore outbound labels
|
97 |
+
non_ind = y > self.num_class
|
98 |
+
y = torch.nn.functional.one_hot(y % self.num_class, num_classes=self.num_class).float()
|
99 |
+
y[non_ind] = 0
|
100 |
+
for i in range(iters):
|
101 |
+
# zero gradients
|
102 |
+
self.optimizer.zero_grad()
|
103 |
+
# Normalize the weight before inference
|
104 |
+
# This will constrain the gradient or you will have an explosion on query vector
|
105 |
+
self.model.weight.data /= torch.norm(self.model.weight.data, p=2, dim=-1, keepdim=True)
|
106 |
+
# forward pass
|
107 |
+
out = self.model(X)
|
108 |
+
# compute loss
|
109 |
+
loss = self.loss(out, y)
|
110 |
+
# backward pass
|
111 |
+
loss.backward()
|
112 |
+
# update weights
|
113 |
+
self.optimizer.step()
|
114 |
+
|
115 |
+
def get_weights(self):
|
116 |
+
xq = self.model.weight.detach().numpy()
|
117 |
+
return xq
|
118 |
+
|
119 |
+
class SplitLayer(torch.nn.Module):
|
120 |
+
def forward(self, x):
|
121 |
+
return torch.split(x, 1, dim=-1)
|
query_model.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
|
4 |
+
def topk_obj_query(client, xq, IMG_DB_NAME, OBJ_DB_NAME,
|
5 |
+
exclude_list=[], topk=10):
|
6 |
+
xq_s = [
|
7 |
+
f"[{', '.join([str(float(fnum)) for fnum in _xq.tolist() + [1]])}]" for _xq in xq]
|
8 |
+
exclude_list_str = ','.join([f'\'{i}\'' for i in exclude_list])
|
9 |
+
_cond = (f"WHERE obj_id NOT IN ({exclude_list_str})" if len(
|
10 |
+
exclude_list) > 0 else "")
|
11 |
+
_subq_str = []
|
12 |
+
_img_score_subq = []
|
13 |
+
for _l, _xq in enumerate(xq_s):
|
14 |
+
_img_score_subq.append(
|
15 |
+
f"arrayReduce('maxIf', logit, arrayMap(x->x={_l}, label))")
|
16 |
+
_subq_str.append(f"""
|
17 |
+
SELECT img_id, img_url, img_w, img_h, 1/(1+exp(-arraySum(arrayMap((x,y)->x*y, prelogit, {_xq})))) AS pred_logit,
|
18 |
+
obj_id, box_cx, box_cy, box_w, box_h, class_embedding, {_l} AS l
|
19 |
+
FROM {OBJ_DB_NAME}
|
20 |
+
JOIN {IMG_DB_NAME}
|
21 |
+
ON {IMG_DB_NAME}.img_id = {OBJ_DB_NAME}.img_id
|
22 |
+
PREWHERE obj_id IN (
|
23 |
+
SELECT obj_id FROM (
|
24 |
+
SELECT obj_id, distance('topK={topk}', 'nprobe=32')(prelogit, {_xq}) AS dist FROM {OBJ_DB_NAME}
|
25 |
+
ORDER BY dist DESC
|
26 |
+
) {_cond} LIMIT 10
|
27 |
+
)
|
28 |
+
""")
|
29 |
+
_subq_str = ' UNION ALL '.join(_subq_str)
|
30 |
+
_img_score_q = ','.join(_img_score_subq)
|
31 |
+
_img_score_q = f"arraySum(arrayFilter(x->NOT isNaN(x), array({_img_score_q}))) AS img_score"
|
32 |
+
q_str = f"""
|
33 |
+
SELECT img_id, img_url, img_w, img_h, groupArray(obj_id) AS box_id,
|
34 |
+
groupArray(box_cx) AS cx, groupArray(box_cy) AS cy, groupArray(box_w) AS w, groupArray(box_h) AS h,
|
35 |
+
groupArray(pred_logit) AS logit, groupArray(l) as label, groupArray(class_embedding) AS cls_emb,
|
36 |
+
{_img_score_q}
|
37 |
+
FROM
|
38 |
+
({_subq_str})
|
39 |
+
GROUP BY img_id, img_url, img_w, img_h ORDER BY img_score DESC
|
40 |
+
"""
|
41 |
+
xc = client.fetch(q_str)
|
42 |
+
return xc
|
43 |
+
|
44 |
+
|
45 |
+
def rev_query(client, xq, img_ids, IMG_DB_NAME, OBJ_DB_NAME, thresh=0.08):
|
46 |
+
xq_s = [
|
47 |
+
f"[{', '.join([str(float(fnum)) for fnum in _xq.tolist() + [1]])}]" for _xq in xq]
|
48 |
+
image_list = ','.join([f'\'{i}\'' for i in img_ids])
|
49 |
+
_thresh = f"WHERE pred_logit > {thresh}" if thresh > 0 else ""
|
50 |
+
_subq_str = []
|
51 |
+
_img_score_subq = []
|
52 |
+
for _l, _xq in enumerate(xq_s):
|
53 |
+
_img_score_subq.append(
|
54 |
+
f"arrayReduce('maxIf', logit, arrayMap(x->x={_l}, label))")
|
55 |
+
_subq_str.append(f"""
|
56 |
+
SELECT {OBJ_DB_NAME}.img_id AS img_id, img_url, img_w, img_h,
|
57 |
+
(1 / (1 + exp(-(arraySum(arrayMap((x,y)->x*y, prelogit, {_xq})))))) AS pred_logit,
|
58 |
+
obj_id, box_cx, box_cy, box_w, box_h, class_embedding, {_l} AS l
|
59 |
+
FROM {OBJ_DB_NAME}
|
60 |
+
JOIN {IMG_DB_NAME}
|
61 |
+
ON {IMG_DB_NAME}.img_id = {OBJ_DB_NAME}.img_id
|
62 |
+
PREWHERE img_id IN ({image_list})
|
63 |
+
{_thresh}
|
64 |
+
""")
|
65 |
+
_subq_str = ' UNION ALL '.join(_subq_str)
|
66 |
+
_img_score_q = ','.join(_img_score_subq)
|
67 |
+
_img_score_q = f"arraySum(arrayFilter(x->NOT isNaN(x), array({_img_score_q}))) AS img_score"
|
68 |
+
q_str = f"""
|
69 |
+
SELECT img_id, groupArray(obj_id) AS box_id, img_url, img_w, img_h,
|
70 |
+
groupArray(box_cx) AS cx, groupArray(box_cy) AS cy, groupArray(box_w) AS w, groupArray(box_h) AS h,
|
71 |
+
groupArray(pred_logit) AS logit, groupArray(l) as label, groupArray(class_embedding) AS cls_emb,
|
72 |
+
{_img_score_q}
|
73 |
+
FROM
|
74 |
+
({_subq_str})
|
75 |
+
GROUP BY img_id, img_url, img_w, img_h ORDER BY img_score DESC
|
76 |
+
"""
|
77 |
+
xc = client.fetch(q_str)
|
78 |
+
return xc
|
79 |
+
|
80 |
+
|
81 |
+
def simple_query(client, xq, IMG_DB_NAME, OBJ_DB_NAME, thresh=0.08, topk=10):
|
82 |
+
xq_s = [
|
83 |
+
f"[{', '.join([str(float(fnum)) for fnum in _xq.tolist() + [1]])}]" for _xq in xq]
|
84 |
+
res = []
|
85 |
+
subq_str = []
|
86 |
+
_thresh = f"WHERE pred_logit > {thresh}" if thresh > 0 else ""
|
87 |
+
for _l, _xq in enumerate(xq_s):
|
88 |
+
subq_str.append(
|
89 |
+
f"""
|
90 |
+
SELECT {OBJ_DB_NAME}.img_id AS img_id, img_url, img_w, img_h, prelogit,
|
91 |
+
obj_id, box_cx, box_cy, box_w, box_h, {_l} AS l, distance('topK={topk}', 'nprobe=32')(prelogit, {_xq}) AS dist
|
92 |
+
FROM {OBJ_DB_NAME}
|
93 |
+
JOIN {IMG_DB_NAME}
|
94 |
+
ON {IMG_DB_NAME}.img_id = {OBJ_DB_NAME}.img_id
|
95 |
+
{_thresh} LIMIT 10
|
96 |
+
""")
|
97 |
+
subq_str = " UNION ALL ".join(subq_str)
|
98 |
+
q_str = f"""
|
99 |
+
SELECT groupArray(img_url) AS img_url, groupArray(img_w) AS img_w, groupArray(img_h) AS img_h,
|
100 |
+
groupArray(box_cx) AS cx, groupArray(box_cy) AS cy, groupArray(box_w) AS w, groupArray(box_h) AS h,
|
101 |
+
l AS label, groupArray(dist) as d,
|
102 |
+
groupArray(1 / (1 + exp(-dist))) AS logit FROM (
|
103 |
+
{subq_str}
|
104 |
+
)
|
105 |
+
GROUP BY l
|
106 |
+
"""
|
107 |
+
res = client.fetch(q_str)
|
108 |
+
return res
|