narugo commited on
Commit
dfcc607
1 Parent(s): f474bbc

dev(narugo): better metrics

Browse files
Files changed (2) hide show
  1. app.py +5 -4
  2. index.py +22 -6
app.py CHANGED
@@ -6,8 +6,8 @@ from PIL import Image
6
  from index import query_character
7
 
8
 
9
- def _fn(image: Image.Image, count: int = 5):
10
- return query_character(image, count)
11
 
12
 
13
  if __name__ == '__main__':
@@ -15,7 +15,8 @@ if __name__ == '__main__':
15
  with gr.Row():
16
  with gr.Column():
17
  gr_input_image = gr.Image(type='pil', label='Original Image')
18
- gr_max_count = gr.Slider(minimum=1, maximum=20, step=1, value=5, label='Max Query Count')
 
19
  gr_submit = gr.Button(value='Submit', variant='primary')
20
 
21
  with gr.Column():
@@ -28,7 +29,7 @@ if __name__ == '__main__':
28
 
29
  gr_submit.click(
30
  _fn,
31
- inputs=[gr_input_image, gr_max_count],
32
  outputs=[gr_gallery, gr_table],
33
  )
34
 
 
6
  from index import query_character
7
 
8
 
9
+ def _fn(image: Image.Image, count: int = 10, threshold: float = 0.8):
10
+ return query_character(image, count, order_by='same_ratio', threshold=threshold)
11
 
12
 
13
  if __name__ == '__main__':
 
15
  with gr.Row():
16
  with gr.Column():
17
  gr_input_image = gr.Image(type='pil', label='Original Image')
18
+ gr_max_count = gr.Slider(minimum=1, maximum=30, step=1, value=10, label='Max Query Count')
19
+ gr_threshold = gr.Slider(minimum=0.0, maximum=0.99, step=0.01, value=0.8, label='Threshold')
20
  gr_submit = gr.Button(value='Submit', variant='primary')
21
 
22
  with gr.Column():
 
29
 
30
  gr_submit.click(
31
  _fn,
32
+ inputs=[gr_input_image, gr_max_count, gr_threshold],
33
  outputs=[gr_gallery, gr_table],
34
  )
35
 
index.py CHANGED
@@ -8,7 +8,7 @@ from autofaiss import build_index
8
  from hfutils.operate import get_hf_fs
9
  from huggingface_hub import hf_hub_download
10
  from imgutils.data import load_image
11
- from imgutils.metrics import ccip_batch_extract_features
12
 
13
  SRC_REPO = 'deepghs/character_index'
14
 
@@ -36,7 +36,7 @@ def gender_predict(p):
36
  return 'not_sure'
37
 
38
 
39
- def query_character(image: Image.Image, count: int = 5):
40
  (index, index_infos), tag_infos = _make_index()
41
  query = ccip_batch_extract_features([image])
42
  assert query.shape == (1, 768)
@@ -44,7 +44,7 @@ def query_character(image: Image.Image, count: int = 5):
44
  all_dists, all_indices = index.search(query, k=count)
45
  dists, indices = all_dists[0], all_indices[0]
46
 
47
- images, records = [], []
48
  for dist, idx in zip(dists, indices):
49
  info = tag_infos[idx]
50
  current_image = load_image(hf_hub_download(
@@ -52,14 +52,30 @@ def query_character(image: Image.Image, count: int = 5):
52
  repo_type='dataset',
53
  filename=f'{info["hprefix"]}/{info["short_tag"]}/1.webp'
54
  ))
55
- images.append((current_image, f'{info["tag"]} ({dist:.3f})'))
 
 
 
 
 
 
56
  records.append({
57
  'id': info['id'],
58
  'tag': info['tag'],
59
  'gender': gender_predict(info['gender']),
60
  'copyright': info['copyright'],
61
- 'score': dist,
 
 
62
  })
63
 
64
  df_records = pd.DataFrame(records)
65
- return images, df_records
 
 
 
 
 
 
 
 
 
8
  from hfutils.operate import get_hf_fs
9
  from huggingface_hub import hf_hub_download
10
  from imgutils.data import load_image
11
+ from imgutils.metrics import ccip_batch_extract_features, ccip_batch_differences, ccip_default_threshold
12
 
13
  SRC_REPO = 'deepghs/character_index'
14
 
 
36
  return 'not_sure'
37
 
38
 
39
+ def query_character(image: Image.Image, count: int = 5, order_by: str = 'same_ratio', threshold: float = 0.7):
40
  (index, index_infos), tag_infos = _make_index()
41
  query = ccip_batch_extract_features([image])
42
  assert query.shape == (1, 768)
 
44
  all_dists, all_indices = index.search(query, k=count)
45
  dists, indices = all_dists[0], all_indices[0]
46
 
47
+ images, records = {}, []
48
  for dist, idx in zip(dists, indices):
49
  info = tag_infos[idx]
50
  current_image = load_image(hf_hub_download(
 
52
  repo_type='dataset',
53
  filename=f'{info["hprefix"]}/{info["short_tag"]}/1.webp'
54
  ))
55
+ feats = np.load(hf_hub_download(
56
+ repo_id=SRC_REPO,
57
+ repo_type='dataset',
58
+ filename=f'{info["hprefix"]}/{info["short_tag"]}/feat.npy'
59
+ ))
60
+ diffs = ccip_batch_differences([query[0], *feats])[0, 1:]
61
+ images[info['tag']] = current_image
62
  records.append({
63
  'id': info['id'],
64
  'tag': info['tag'],
65
  'gender': gender_predict(info['gender']),
66
  'copyright': info['copyright'],
67
+ 'index_score': dist,
68
+ 'mean_diff': diffs.mean(),
69
+ 'same_ratio': (diffs < ccip_default_threshold()).mean(),
70
  })
71
 
72
  df_records = pd.DataFrame(records)
73
+ df_records = df_records.sort_values(
74
+ by=[order_by, 'index_score'] if order_by != 'index_score' else ['index_score'],
75
+ ascending=[False, False] if order_by != 'index_score' else [False],
76
+ )
77
+ df_records = df_records[df_records[order_by] >= threshold]
78
+ ret_images = []
79
+ for row_item in df_records.to_dict('records'):
80
+ ret_images.append((images[row_item['tag']], f'{row_item["tag"]} ({row_item[order_by]:.3f})'))
81
+ return ret_images, df_records