narugo's picture
Update app.py
ecfa7a3 verified
raw
history blame
4.66 kB
import json
import os
from collections import defaultdict
from functools import lru_cache
from typing import List, Dict
import faiss
import gradio as gr
import numpy as np
from PIL import Image
from cheesechaser.datapool import YandeWebpDataPool, ZerochanWebpDataPool, GelbooruWebpDataPool, \
KonachanWebpDataPool, AnimePicturesWebpDataPool, DanbooruNewestWebpDataPool, Rule34WebpDataPool
from hfutils.operate import get_hf_fs, get_hf_client
from hfutils.utils import TemporaryDirectory
from imgutils.tagging import wd14
_REPO_ID = 'deepghs/anime_sites_indices'
hf_fs = get_hf_fs()
hf_client = get_hf_client()
_DEFAULT_MODEL_NAME = 'SwinV2_v3_dgzyka_23325111_8GB'
_ALL_MODEL_NAMES = [
os.path.dirname(os.path.relpath(path, _REPO_ID))
for path in hf_fs.glob(f'{_REPO_ID}/*/knn.index')
]
_SITE_CLS = {
'danbooru': DanbooruNewestWebpDataPool,
'yandere': YandeWebpDataPool,
'zerochan': ZerochanWebpDataPool,
'gelbooru': GelbooruWebpDataPool,
'konachan': KonachanWebpDataPool,
'anime_pictures': AnimePicturesWebpDataPool,
'rule34': Rule34WebpDataPool,
}
def _get_from_ids(site_name: str, ids: List[int]) -> Dict[int, Image.Image]:
with TemporaryDirectory() as td:
datapool = _SITE_CLS[site_name]()
datapool.batch_download_to_directory(
resource_ids=ids,
dst_dir=td,
)
retval = {}
for file in os.listdir(td):
id_ = int(os.path.splitext(file)[0])
image = Image.open(os.path.join(td, file))
image.load()
retval[id_] = image
return retval
def _get_from_raw_ids(ids: List[str]) -> Dict[str, Image.Image]:
_sites = defaultdict(list)
for id_ in ids:
site_name, num_id = id_.rsplit('_', maxsplit=1)
num_id = int(num_id)
_sites[site_name].append(num_id)
_retval = {}
for site_name, site_ids in _sites.items():
_retval.update({
f'{site_name}_{id_}': image
for id_, image in _get_from_ids(site_name, site_ids).items()
})
return _retval
@lru_cache(maxsize=3)
def _get_index_info(repo_id: str, model_name: str):
image_ids = np.load(hf_client.hf_hub_download(
repo_id=repo_id,
repo_type='model',
filename=f'{model_name}/ids.npy',
))
knn_index = faiss.read_index(hf_client.hf_hub_download(
repo_id=repo_id,
repo_type='model',
filename=f'{model_name}/knn.index',
))
config = json.loads(open(hf_client.hf_hub_download(
repo_id=repo_id,
repo_type='model',
filename=f'{model_name}/infos.json',
)).read())["index_param"]
faiss.ParameterSpace().set_index_parameters(knn_index, config)
return image_ids, knn_index
def search(model_name: str, img_input, n_neighbours: int):
images_ids, knn_index = _get_index_info(_REPO_ID, model_name)
embeddings = wd14.get_wd14_tags(
img_input,
model_name="SwinV2_v3",
fmt="embedding",
)
embeddings = np.expand_dims(embeddings, 0)
faiss.normalize_L2(embeddings)
dists, indexes = knn_index.search(embeddings, k=n_neighbours)
neighbours_ids = images_ids[indexes][0]
captions = []
images = []
ids_to_images = _get_from_raw_ids(neighbours_ids)
for image_id, dist in zip(neighbours_ids, dists[0]):
if image_id in ids_to_images:
images.append(ids_to_images[image_id])
captions.append(f"{image_id}/{dist:.2f}")
return list(zip(images, captions))
if __name__ == "__main__":
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
img_input = gr.Image(type="pil", label="Input")
with gr.Column():
with gr.Row():
n_model = gr.Dropdown(
choices=_ALL_MODEL_NAMES,
value=_DEFAULT_MODEL_NAME,
label='Index to Use',
)
with gr.Row():
n_neighbours = gr.Slider(
minimum=1,
maximum=50,
value=20,
step=1,
label="# of images",
)
find_btn = gr.Button("Find similar images")
with gr.Row():
similar_images = gr.Gallery(label="Similar images", columns=[5])
find_btn.click(
fn=search,
inputs=[
n_model,
img_input,
n_neighbours,
],
outputs=[similar_images],
)
demo.queue().launch()