narugo commited on
Commit
681a350
·
1 Parent(s): 3c7d8d9

dev(narugo): support multiple sites

Browse files
Files changed (1) hide show
  1. app.py +22 -11
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import json
2
  import os
 
3
  from functools import lru_cache
4
  from typing import List, Dict
5
 
@@ -23,10 +24,14 @@ _ALL_MODEL_NAMES = [
23
  for path in hf_fs.glob(f'{_REPO_ID}/*/knn.index')
24
  ]
25
 
 
 
 
26
 
27
- def _get_from_ids(ids: List[int]) -> Dict[int, Image.Image]:
 
28
  with TemporaryDirectory() as td:
29
- datapool = DanbooruWebpDataPool()
30
  datapool.batch_download_to_directory(
31
  resource_ids=ids,
32
  dst_dir=td,
@@ -42,13 +47,20 @@ def _get_from_ids(ids: List[int]) -> Dict[int, Image.Image]:
42
  return retval
43
 
44
 
45
- def _x(x):
46
- if isinstance(x, (int, np.integer)):
47
- return int(x)
48
- elif isinstance(x, (str, np.str_)):
49
- return int(str(x).split('_')[-1])
50
- else:
51
- raise ValueError(f'Invalid ID: {x!r}, type: {type(x)!r}')
 
 
 
 
 
 
 
52
 
53
 
54
  @lru_cache(maxsize=3)
@@ -85,11 +97,10 @@ def search(model_name: str, img_input, n_neighbours: int):
85
 
86
  dists, indexes = knn_index.search(embeddings, k=n_neighbours)
87
  neighbours_ids = images_ids[indexes][0]
88
- neighbours_ids = [_x(x) for x in neighbours_ids]
89
 
90
  captions = []
91
  images = []
92
- ids_to_images = _get_from_ids(neighbours_ids)
93
  for image_id, dist in zip(neighbours_ids, dists[0]):
94
  if image_id in ids_to_images:
95
  images.append(ids_to_images[image_id])
 
1
  import json
2
  import os
3
+ from collections import defaultdict
4
  from functools import lru_cache
5
  from typing import List, Dict
6
 
 
24
  for path in hf_fs.glob(f'{_REPO_ID}/*/knn.index')
25
  ]
26
 
27
+ _SITE_CLS = {
28
+ 'danbooru': DanbooruWebpDataPool,
29
+ }
30
 
31
+
32
+ def _get_from_ids(site_name: str, ids: List[int]) -> Dict[int, Image.Image]:
33
  with TemporaryDirectory() as td:
34
+ datapool = _SITE_CLS[site_name]()
35
  datapool.batch_download_to_directory(
36
  resource_ids=ids,
37
  dst_dir=td,
 
47
  return retval
48
 
49
 
50
+ def _get_from_raw_ids(ids: List[str]) -> Dict[str, Image.Image]:
51
+ _sites = defaultdict(list)
52
+ for id_ in ids:
53
+ site_name, num_id = id_.split('_', maxsplit=1)
54
+ num_id = int(num_id)
55
+ _sites[site_name].append(num_id)
56
+
57
+ _retval = {}
58
+ for site_name, site_ids in _sites.items():
59
+ _retval.update({
60
+ f'{site_name}_{id_}': image
61
+ for id_, image in _get_from_ids(site_name, site_ids)
62
+ })
63
+ return _retval
64
 
65
 
66
  @lru_cache(maxsize=3)
 
97
 
98
  dists, indexes = knn_index.search(embeddings, k=n_neighbours)
99
  neighbours_ids = images_ids[indexes][0]
 
100
 
101
  captions = []
102
  images = []
103
+ ids_to_images = _get_from_raw_ids(neighbours_ids)
104
  for image_id, dist in zip(neighbours_ids, dists[0]):
105
  if image_id in ids_to_images:
106
  images.append(ids_to_images[image_id])