|
from __future__ import unicode_literals |
|
import os |
|
import re |
|
import unicodedata |
|
import torch |
|
from torch import nn |
|
import streamlit as st |
|
import pandas as pd |
|
import pyarrow as pa |
|
import pyarrow.parquet as pq |
|
import numpy as np |
|
import scipy.spatial |
|
import pyminizip |
|
import transformers |
|
from transformers import BertJapaneseTokenizer, BertModel |
|
from huggingface_hub import hf_hub_download, snapshot_download |
|
|
|
|
|
|
|
def unicode_normalize(cls, s): |
|
pt = re.compile("([{}]+)".format(cls)) |
|
|
|
def norm(c): |
|
return unicodedata.normalize("NFKC", c) if pt.match(c) else c |
|
|
|
s = "".join(norm(x) for x in re.split(pt, s)) |
|
s = re.sub("-", "-", s) |
|
return s |
|
|
|
|
|
def remove_extra_spaces(s): |
|
s = re.sub("[ ]+", " ", s) |
|
blocks = "".join( |
|
( |
|
"\u4E00-\u9FFF", |
|
"\u3040-\u309F", |
|
"\u30A0-\u30FF", |
|
"\u3000-\u303F", |
|
"\uFF00-\uFFEF", |
|
) |
|
) |
|
basic_latin = "\u0000-\u007F" |
|
|
|
def remove_space_between(cls1, cls2, s): |
|
p = re.compile("([{}]) ([{}])".format(cls1, cls2)) |
|
while p.search(s): |
|
s = p.sub(r"\1\2", s) |
|
return s |
|
|
|
s = remove_space_between(blocks, blocks, s) |
|
s = remove_space_between(blocks, basic_latin, s) |
|
s = remove_space_between(basic_latin, blocks, s) |
|
return s |
|
|
|
|
|
def normalize_neologd(s): |
|
s = s.strip() |
|
s = unicode_normalize("0-9A-Za-z。-゚", s) |
|
|
|
def maketrans(f, t): |
|
return {ord(x): ord(y) for x, y in zip(f, t)} |
|
|
|
s = re.sub("[˗֊‐‑‒–⁃⁻₋−]+", "-", s) |
|
s = re.sub("[﹣-ー—―─━ー]+", "ー", s) |
|
s = re.sub("[~∼∾〜〰~]+", "〜", s) |
|
s = s.translate( |
|
maketrans( |
|
"!\"#$%&'()*+,-./:;<=>?@[¥]^_`{|}~。、・「」", |
|
"!”#$%&’()*+,-./:;<=>?@[¥]^_`{|}〜。、・「」", |
|
) |
|
) |
|
|
|
s = remove_extra_spaces(s) |
|
s = unicode_normalize("!”#$%&’()*+,-./:;<>?@[¥]^_`{|}〜", s) |
|
s = re.sub("[’]", "'", s) |
|
s = re.sub("[”]", '"', s) |
|
s = s.lower() |
|
return s |
|
|
|
|
|
def normalize_text(text): |
|
return normalize_neologd(text) |
|
|
|
|
|
class ClipTextModel(nn.Module): |
|
def __init__(self, model_name_or_path, device=None): |
|
super(ClipTextModel, self).__init__() |
|
|
|
if os.path.exists(model_name_or_path): |
|
|
|
output_linear_state_dict = torch.load(os.path.join(model_name_or_path, "output_linear.bin")) |
|
else: |
|
|
|
filename = hf_hub_download(repo_id=model_name_or_path, filename="output_linear.bin") |
|
output_linear_state_dict = torch.load(filename) |
|
|
|
self.model = BertModel.from_pretrained(model_name_or_path) |
|
config = self.model.config |
|
|
|
self.max_cls_depth = 6 |
|
|
|
sentence_vector_size = output_linear_state_dict["bias"].shape[0] |
|
self.sentence_vector_size = sentence_vector_size |
|
self.output_linear = nn.Linear(self.max_cls_depth * config.hidden_size, sentence_vector_size) |
|
|
|
self.output_linear.load_state_dict(output_linear_state_dict) |
|
|
|
self.tokenizer = BertJapaneseTokenizer.from_pretrained(model_name_or_path, do_lower_case=True) |
|
|
|
self.eval() |
|
|
|
if device is None: |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.device = torch.device(device) |
|
self.to(self.device) |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
attention_mask=None, |
|
token_type_ids=None, |
|
): |
|
output_states = self.model( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=None, |
|
head_mask=None, |
|
inputs_embeds=None, |
|
output_attentions=None, |
|
output_hidden_states=True, |
|
return_dict=True, |
|
) |
|
token_embeddings = output_states[0] |
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
hidden_states = output_states["hidden_states"] |
|
|
|
output_vectors = [] |
|
|
|
for i in range(1, self.max_cls_depth + 1): |
|
cls_token = hidden_states[-1 * i][:, 0] |
|
output_vectors.append(cls_token) |
|
|
|
output_vector = torch.cat(output_vectors, dim=1) |
|
logits = self.output_linear(output_vector) |
|
|
|
output = (logits,) + output_states[2:] |
|
return output |
|
|
|
@torch.no_grad() |
|
def encode_text(self, texts, batch_size=8, max_length=64): |
|
self.eval() |
|
all_embeddings = [] |
|
iterator = range(0, len(texts), batch_size) |
|
for batch_idx in iterator: |
|
batch = texts[batch_idx:batch_idx + batch_size] |
|
|
|
encoded_input = self.tokenizer.batch_encode_plus( |
|
batch, max_length=max_length, padding="longest", |
|
truncation=True, return_tensors="pt").to(self.device) |
|
model_output = self(**encoded_input) |
|
text_embeddings = model_output[0].cpu() |
|
|
|
all_embeddings.extend(text_embeddings) |
|
|
|
|
|
return torch.stack(all_embeddings) |
|
|
|
def save(self, output_dir): |
|
self.model.save_pretrained(output_dir) |
|
self.tokenizer.save_pretrained(output_dir) |
|
torch.save(self.output_linear.state_dict(), os.path.join(output_dir, "output_linear.bin")) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DummyClipModel: |
|
def __init__(self, text_model): |
|
self.text_model = text_model |
|
|
|
def encode_text(text, model): |
|
text = normalize_text(text) |
|
text_embedding = model.text_model.encode_text([text]).numpy() |
|
return text_embedding |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.title("いらすと検索(日本語CLIPゼロショット)") |
|
description_text = st.empty() |
|
|
|
if "model" not in st.session_state: |
|
description_text.text("日本語CLIPモデル読み込み中... ") |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
text_model = ClipTextModel("sonoisa/clip-vit-b-32-japanese-v1", device=device) |
|
|
|
model = DummyClipModel(text_model) |
|
st.session_state.model = model |
|
|
|
print("extract dataset") |
|
pyminizip.uncompress( |
|
"clip_zeroshot_irasuto_image_items_20210224.pq.zip", st.secrets["ZIP_PASSWORD"], None, 1 |
|
) |
|
|
|
print("loading dataset") |
|
df = pq.read_table("clip_zeroshot_irasuto_image_items_20210224.parquet", |
|
columns=["page", "description", "image_url", "image_vector"]).to_pandas() |
|
st.session_state.df = df |
|
|
|
|
|
image_vectors = np.stack(df["image_vector"]) |
|
|
|
st.session_state.image_vectors = image_vectors |
|
|
|
print("finished loading model and dataset") |
|
|
|
model = st.session_state.model |
|
df = st.session_state.df |
|
|
|
image_vectors = st.session_state.image_vectors |
|
|
|
description_text.text("日本語CLIPモデル(ゼロショット)を用いて、説明文の意味が近い「いらすとや」画像を検索します。\nキーワードを列挙するよりも、自然な文章を入力した方が精度よく検索できます。\n画像は必ずリンク先の「いらすとや」さんのページを開き、そこからダウンロードしてください。") |
|
|
|
def clear_result(): |
|
result_text.text("") |
|
|
|
prev_query = "" |
|
query_input = st.text_input(label="説明文", value="", on_change=clear_result) |
|
|
|
closest_n = st.number_input(label="検索数", min_value=1, value=10, max_value=100) |
|
|
|
search_buttion = st.button("検索") |
|
|
|
result_text = st.empty() |
|
|
|
if search_buttion or prev_query != query_input: |
|
prev_query = query_input |
|
query_embedding = encode_text(query_input, model) |
|
|
|
distances = scipy.spatial.distance.cdist( |
|
query_embedding, image_vectors, metric="cosine" |
|
)[0] |
|
|
|
results = zip(range(len(distances)), distances) |
|
results = sorted(results, key=lambda x: x[1]) |
|
|
|
md_content = "" |
|
for i, (idx, distance) in enumerate(results[0:closest_n]): |
|
page_url = df.iloc[idx]["page"] |
|
desc = df.iloc[idx]["description"] |
|
img_url = df.iloc[idx]["image_url"] |
|
md_content += f"1. <div><a href='{page_url}' target='_blank' rel='noopener noreferrer'><img src='{img_url}' width='100'>{distance / 2:.4f}: {desc}</a><div>\n" |
|
|
|
result_text.markdown(md_content, unsafe_allow_html=True) |
|
|
|
|