dev(narugo): better metrics
Browse files
app.py
CHANGED
@@ -6,8 +6,8 @@ from PIL import Image
|
|
6 |
from index import query_character
|
7 |
|
8 |
|
9 |
-
def _fn(image: Image.Image, count: int =
|
10 |
-
return query_character(image, count)
|
11 |
|
12 |
|
13 |
if __name__ == '__main__':
|
@@ -15,7 +15,8 @@ if __name__ == '__main__':
|
|
15 |
with gr.Row():
|
16 |
with gr.Column():
|
17 |
gr_input_image = gr.Image(type='pil', label='Original Image')
|
18 |
-
gr_max_count = gr.Slider(minimum=1, maximum=
|
|
|
19 |
gr_submit = gr.Button(value='Submit', variant='primary')
|
20 |
|
21 |
with gr.Column():
|
@@ -28,7 +29,7 @@ if __name__ == '__main__':
|
|
28 |
|
29 |
gr_submit.click(
|
30 |
_fn,
|
31 |
-
inputs=[gr_input_image, gr_max_count],
|
32 |
outputs=[gr_gallery, gr_table],
|
33 |
)
|
34 |
|
|
|
6 |
from index import query_character
|
7 |
|
8 |
|
9 |
+
def _fn(image: Image.Image, count: int = 10, threshold: float = 0.8):
|
10 |
+
return query_character(image, count, order_by='same_ratio', threshold=threshold)
|
11 |
|
12 |
|
13 |
if __name__ == '__main__':
|
|
|
15 |
with gr.Row():
|
16 |
with gr.Column():
|
17 |
gr_input_image = gr.Image(type='pil', label='Original Image')
|
18 |
+
gr_max_count = gr.Slider(minimum=1, maximum=30, step=1, value=10, label='Max Query Count')
|
19 |
+
gr_threshold = gr.Slider(minimum=0.0, maximum=0.99, step=0.01, value=0.8, label='Threshold')
|
20 |
gr_submit = gr.Button(value='Submit', variant='primary')
|
21 |
|
22 |
with gr.Column():
|
|
|
29 |
|
30 |
gr_submit.click(
|
31 |
_fn,
|
32 |
+
inputs=[gr_input_image, gr_max_count, gr_threshold],
|
33 |
outputs=[gr_gallery, gr_table],
|
34 |
)
|
35 |
|
index.py
CHANGED
@@ -8,7 +8,7 @@ from autofaiss import build_index
|
|
8 |
from hfutils.operate import get_hf_fs
|
9 |
from huggingface_hub import hf_hub_download
|
10 |
from imgutils.data import load_image
|
11 |
-
from imgutils.metrics import ccip_batch_extract_features
|
12 |
|
13 |
SRC_REPO = 'deepghs/character_index'
|
14 |
|
@@ -36,7 +36,7 @@ def gender_predict(p):
|
|
36 |
return 'not_sure'
|
37 |
|
38 |
|
39 |
-
def query_character(image: Image.Image, count: int = 5):
|
40 |
(index, index_infos), tag_infos = _make_index()
|
41 |
query = ccip_batch_extract_features([image])
|
42 |
assert query.shape == (1, 768)
|
@@ -44,7 +44,7 @@ def query_character(image: Image.Image, count: int = 5):
|
|
44 |
all_dists, all_indices = index.search(query, k=count)
|
45 |
dists, indices = all_dists[0], all_indices[0]
|
46 |
|
47 |
-
images, records =
|
48 |
for dist, idx in zip(dists, indices):
|
49 |
info = tag_infos[idx]
|
50 |
current_image = load_image(hf_hub_download(
|
@@ -52,14 +52,30 @@ def query_character(image: Image.Image, count: int = 5):
|
|
52 |
repo_type='dataset',
|
53 |
filename=f'{info["hprefix"]}/{info["short_tag"]}/1.webp'
|
54 |
))
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
records.append({
|
57 |
'id': info['id'],
|
58 |
'tag': info['tag'],
|
59 |
'gender': gender_predict(info['gender']),
|
60 |
'copyright': info['copyright'],
|
61 |
-
'
|
|
|
|
|
62 |
})
|
63 |
|
64 |
df_records = pd.DataFrame(records)
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from hfutils.operate import get_hf_fs
|
9 |
from huggingface_hub import hf_hub_download
|
10 |
from imgutils.data import load_image
|
11 |
+
from imgutils.metrics import ccip_batch_extract_features, ccip_batch_differences, ccip_default_threshold
|
12 |
|
13 |
SRC_REPO = 'deepghs/character_index'
|
14 |
|
|
|
36 |
return 'not_sure'
|
37 |
|
38 |
|
39 |
+
def query_character(image: Image.Image, count: int = 5, order_by: str = 'same_ratio', threshold: float = 0.7):
|
40 |
(index, index_infos), tag_infos = _make_index()
|
41 |
query = ccip_batch_extract_features([image])
|
42 |
assert query.shape == (1, 768)
|
|
|
44 |
all_dists, all_indices = index.search(query, k=count)
|
45 |
dists, indices = all_dists[0], all_indices[0]
|
46 |
|
47 |
+
images, records = {}, []
|
48 |
for dist, idx in zip(dists, indices):
|
49 |
info = tag_infos[idx]
|
50 |
current_image = load_image(hf_hub_download(
|
|
|
52 |
repo_type='dataset',
|
53 |
filename=f'{info["hprefix"]}/{info["short_tag"]}/1.webp'
|
54 |
))
|
55 |
+
feats = np.load(hf_hub_download(
|
56 |
+
repo_id=SRC_REPO,
|
57 |
+
repo_type='dataset',
|
58 |
+
filename=f'{info["hprefix"]}/{info["short_tag"]}/feat.npy'
|
59 |
+
))
|
60 |
+
diffs = ccip_batch_differences([query[0], *feats])[0, 1:]
|
61 |
+
images[info['tag']] = current_image
|
62 |
records.append({
|
63 |
'id': info['id'],
|
64 |
'tag': info['tag'],
|
65 |
'gender': gender_predict(info['gender']),
|
66 |
'copyright': info['copyright'],
|
67 |
+
'index_score': dist,
|
68 |
+
'mean_diff': diffs.mean(),
|
69 |
+
'same_ratio': (diffs < ccip_default_threshold()).mean(),
|
70 |
})
|
71 |
|
72 |
df_records = pd.DataFrame(records)
|
73 |
+
df_records = df_records.sort_values(
|
74 |
+
by=[order_by, 'index_score'] if order_by != 'index_score' else ['index_score'],
|
75 |
+
ascending=[False, False] if order_by != 'index_score' else [False],
|
76 |
+
)
|
77 |
+
df_records = df_records[df_records[order_by] >= threshold]
|
78 |
+
ret_images = []
|
79 |
+
for row_item in df_records.to_dict('records'):
|
80 |
+
ret_images.append((images[row_item['tag']], f'{row_item["tag"]} ({row_item[order_by]:.3f})'))
|
81 |
+
return ret_images, df_records
|