Spaces:
Running
Running
dev(narugo): support multiple sites
Browse files
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import json
|
2 |
import os
|
|
|
3 |
from functools import lru_cache
|
4 |
from typing import List, Dict
|
5 |
|
@@ -23,10 +24,14 @@ _ALL_MODEL_NAMES = [
|
|
23 |
for path in hf_fs.glob(f'{_REPO_ID}/*/knn.index')
|
24 |
]
|
25 |
|
|
|
|
|
|
|
26 |
|
27 |
-
|
|
|
28 |
with TemporaryDirectory() as td:
|
29 |
-
datapool =
|
30 |
datapool.batch_download_to_directory(
|
31 |
resource_ids=ids,
|
32 |
dst_dir=td,
|
@@ -42,13 +47,20 @@ def _get_from_ids(ids: List[int]) -> Dict[int, Image.Image]:
|
|
42 |
return retval
|
43 |
|
44 |
|
45 |
-
def
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
|
54 |
@lru_cache(maxsize=3)
|
@@ -85,11 +97,10 @@ def search(model_name: str, img_input, n_neighbours: int):
|
|
85 |
|
86 |
dists, indexes = knn_index.search(embeddings, k=n_neighbours)
|
87 |
neighbours_ids = images_ids[indexes][0]
|
88 |
-
neighbours_ids = [_x(x) for x in neighbours_ids]
|
89 |
|
90 |
captions = []
|
91 |
images = []
|
92 |
-
ids_to_images =
|
93 |
for image_id, dist in zip(neighbours_ids, dists[0]):
|
94 |
if image_id in ids_to_images:
|
95 |
images.append(ids_to_images[image_id])
|
|
|
1 |
import json
|
2 |
import os
|
3 |
+
from collections import defaultdict
|
4 |
from functools import lru_cache
|
5 |
from typing import List, Dict
|
6 |
|
|
|
24 |
for path in hf_fs.glob(f'{_REPO_ID}/*/knn.index')
|
25 |
]
|
26 |
|
27 |
+
_SITE_CLS = {
|
28 |
+
'danbooru': DanbooruWebpDataPool,
|
29 |
+
}
|
30 |
|
31 |
+
|
32 |
+
def _get_from_ids(site_name: str, ids: List[int]) -> Dict[int, Image.Image]:
|
33 |
with TemporaryDirectory() as td:
|
34 |
+
datapool = _SITE_CLS[site_name]()
|
35 |
datapool.batch_download_to_directory(
|
36 |
resource_ids=ids,
|
37 |
dst_dir=td,
|
|
|
47 |
return retval
|
48 |
|
49 |
|
50 |
+
def _get_from_raw_ids(ids: List[str]) -> Dict[str, Image.Image]:
|
51 |
+
_sites = defaultdict(list)
|
52 |
+
for id_ in ids:
|
53 |
+
site_name, num_id = id_.split('_', maxsplit=1)
|
54 |
+
num_id = int(num_id)
|
55 |
+
_sites[site_name].append(num_id)
|
56 |
+
|
57 |
+
_retval = {}
|
58 |
+
for site_name, site_ids in _sites.items():
|
59 |
+
_retval.update({
|
60 |
+
f'{site_name}_{id_}': image
|
61 |
+
for id_, image in _get_from_ids(site_name, site_ids)
|
62 |
+
})
|
63 |
+
return _retval
|
64 |
|
65 |
|
66 |
@lru_cache(maxsize=3)
|
|
|
97 |
|
98 |
dists, indexes = knn_index.search(embeddings, k=n_neighbours)
|
99 |
neighbours_ids = images_ids[indexes][0]
|
|
|
100 |
|
101 |
captions = []
|
102 |
images = []
|
103 |
+
ids_to_images = _get_from_raw_ids(neighbours_ids)
|
104 |
for image_id, dist in zip(neighbours_ids, dists[0]):
|
105 |
if image_id in ids_to_images:
|
106 |
images.append(ids_to_images[image_id])
|