Spaces:
Runtime error
Runtime error
import torch | |
import json | |
import gradio as gr | |
TITLE = "Danboru Tag Similarity" | |
DESCRIPTION = """ | |
与えられたダンボールタグの類似度を計算します。\n | |
対応するタグのリストはFilesからそれぞれのテキストファイルを参照してください。(Dartと同じです)。\n | |
Dartを参考に、isek-ai/danbooru-tags-2023データセットでタグをシャッフルして2エポック学習しました。\n | |
学習後のトークン埋め込みを元に計算しています。 | |
""" | |
with open("id_to_token.json", "r") as f: | |
id_to_token = json.load(f) | |
token_to_id = {v:int(k) for k,v in id_to_token.items()} | |
with open("popular.txt", "r") as f: | |
populars = f.read().splitlines() | |
with open("character.txt", "r") as f: | |
characters = f.read().splitlines() | |
with open("copyright.txt", "r") as f: | |
copyrights = f.read().splitlines() | |
with open("general.txt", "r") as f: | |
generals = f.read().splitlines() | |
tags = characters + copyrights + generals | |
token_embeddings = torch.load("token_embeddings.pt") | |
def predict(target_tag, sort_by, category, popular): | |
if sort_by == "descending": | |
multiplier = 1 | |
else: | |
multiplier = -1 | |
target_embedding = token_embeddings[token_to_id[target_tag]].unsqueeze(0) | |
sims = torch.cosine_similarity(target_embedding, token_embeddings, dim=1) * multiplier | |
if category == "general": | |
tag_list = generals | |
elif category == "character": | |
tag_list = characters | |
elif category == "copyright": | |
tag_list = copyrights | |
else: | |
tag_list = tags | |
if popular=="only_popular": | |
tag_list = list(set(tag_list) & set(populars)) | |
return {k:sims[token_to_id[k]].item() for k in tag_list} | |
demo = gr.Interface( | |
fn=predict, | |
inputs=[ | |
gr.Textbox(label="Target tag", value="otoko no ko"), | |
gr.Radio(choices=["descending", "ascending"], label="Sort by", value="descending"), | |
gr.Dropdown(choices=["all", "general", "character", "copyright"], value="all", label="category"), | |
gr.Radio(choices=["all", "only_popular"], label="Only popular tag (count>=1000)", value="all"), | |
], | |
outputs=gr.Label(num_top_classes=50), | |
title=TITLE, | |
description=DESCRIPTION | |
) | |
demo.launch() | |