sonoisa's picture
Restrict number of items to search
148e94d
raw
history blame
14.1 kB
from __future__ import unicode_literals
import os
import re
import unicodedata
import torch
from torch import nn
import streamlit as st
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import numpy as np
import scipy.spatial
import pyminizip
import transformers
from transformers import BertJapaneseTokenizer, BertModel
from huggingface_hub import hf_hub_download, snapshot_download
# from PIL import Image
def unicode_normalize(cls, s):
pt = re.compile("([{}]+)".format(cls))
def norm(c):
return unicodedata.normalize("NFKC", c) if pt.match(c) else c
s = "".join(norm(x) for x in re.split(pt, s))
s = re.sub("-", "-", s)
return s
def remove_extra_spaces(s):
s = re.sub("[  ]+", " ", s)
blocks = "".join(
(
"\u4E00-\u9FFF", # CJK UNIFIED IDEOGRAPHS
"\u3040-\u309F", # HIRAGANA
"\u30A0-\u30FF", # KATAKANA
"\u3000-\u303F", # CJK SYMBOLS AND PUNCTUATION
"\uFF00-\uFFEF", # HALFWIDTH AND FULLWIDTH FORMS
)
)
basic_latin = "\u0000-\u007F"
def remove_space_between(cls1, cls2, s):
p = re.compile("([{}]) ([{}])".format(cls1, cls2))
while p.search(s):
s = p.sub(r"\1\2", s)
return s
s = remove_space_between(blocks, blocks, s)
s = remove_space_between(blocks, basic_latin, s)
s = remove_space_between(basic_latin, blocks, s)
return s
def normalize_neologd(s):
s = s.strip()
s = unicode_normalize("0-9A-Za-z。-゚", s)
def maketrans(f, t):
return {ord(x): ord(y) for x, y in zip(f, t)}
s = re.sub("[˗֊‐‑‒–⁃⁻₋−]+", "-", s) # normalize hyphens
s = re.sub("[﹣-ー—―─━ー]+", "ー", s) # normalize choonpus
s = re.sub("[~∼∾〜〰~]+", "〜", s) # normalize tildes (modified by Isao Sonobe)
s = s.translate(
maketrans(
"!\"#$%&'()*+,-./:;<=>?@[¥]^_`{|}~。、・「」",
"!”#$%&’()*+,-./:;<=>?@[¥]^_`{|}〜。、・「」",
)
)
s = remove_extra_spaces(s)
s = unicode_normalize("!”#$%&’()*+,-./:;<>?@[¥]^_`{|}〜", s) # keep =,・,「,」
s = re.sub("[’]", "'", s)
s = re.sub("[”]", '"', s)
s = s.lower()
return s
def normalize_text(text):
return normalize_neologd(text)
class ClipTextModel(nn.Module):
def __init__(self, model_name_or_path, device=None):
super(ClipTextModel, self).__init__()
if os.path.exists(model_name_or_path):
# load from file system
output_linear_state_dict = torch.load(os.path.join(model_name_or_path, "output_linear.bin"))
else:
# download from the Hugging Face model hub
filename = hf_hub_download(repo_id=model_name_or_path, filename="output_linear.bin")
output_linear_state_dict = torch.load(filename)
self.model = BertModel.from_pretrained(model_name_or_path)
config = self.model.config
self.max_cls_depth = 6
sentence_vector_size = output_linear_state_dict["bias"].shape[0]
self.sentence_vector_size = sentence_vector_size
self.output_linear = nn.Linear(self.max_cls_depth * config.hidden_size, sentence_vector_size)
# self.output_linear = nn.Linear(3 * config.hidden_size, sentence_vector_size)
self.output_linear.load_state_dict(output_linear_state_dict)
self.tokenizer = BertJapaneseTokenizer.from_pretrained(model_name_or_path, do_lower_case=True)
self.eval()
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(device)
self.to(self.device)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
):
output_states = self.model(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=True,
return_dict=True,
)
token_embeddings = output_states[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
hidden_states = output_states["hidden_states"]
output_vectors = []
for i in range(1, self.max_cls_depth + 1):
cls_token = hidden_states[-1 * i][:, 0]
output_vectors.append(cls_token)
output_vector = torch.cat(output_vectors, dim=1)
logits = self.output_linear(output_vector)
output = (logits,) + output_states[2:]
return output
@torch.no_grad()
def encode_text(self, texts, batch_size=8, max_length=64):
self.eval()
all_embeddings = []
iterator = range(0, len(texts), batch_size)
for batch_idx in iterator:
batch = texts[batch_idx:batch_idx + batch_size]
encoded_input = self.tokenizer.batch_encode_plus(
batch, max_length=max_length, padding="longest",
truncation=True, return_tensors="pt").to(self.device)
model_output = self(**encoded_input)
text_embeddings = model_output[0].cpu()
all_embeddings.extend(text_embeddings)
# return torch.stack(all_embeddings).numpy()
return torch.stack(all_embeddings)
def save(self, output_dir):
self.model.save_pretrained(output_dir)
self.tokenizer.save_pretrained(output_dir)
torch.save(self.output_linear.state_dict(), os.path.join(output_dir, "output_linear.bin"))
# class ClipVisionModel(nn.Module):
# def __init__(self, model_name_or_path, device=None):
# super(ClipVisionModel, self).__init__()
# if os.path.exists(model_name_or_path):
# # load from file system
# visual_projection_state_dict = torch.load(os.path.join(model_name_or_path, "visual_projection.bin"))
# else:
# # download from the Hugging Face model hub
# filename = hf_hub_download(repo_id=model_name_or_path, filename="visual_projection.bin")
# visual_projection_state_dict = torch.load(filename)
# self.model = transformers.CLIPVisionModel.from_pretrained(model_name_or_path)
# config = self.model.config
# self.feature_extractor = transformers.CLIPFeatureExtractor.from_pretrained(model_name_or_path)
# vision_embed_dim = config.hidden_size
# projection_dim = 512
# self.visual_projection = nn.Linear(vision_embed_dim, projection_dim, bias=False)
# self.visual_projection.load_state_dict(visual_projection_state_dict)
# self.eval()
# if device is None:
# device = "cuda" if torch.cuda.is_available() else "cpu"
# self.device = torch.device(device)
# self.to(self.device)
# def forward(
# self,
# pixel_values=None,
# output_attentions=None,
# output_hidden_states=None,
# return_dict=None,
# ):
# output_states = self.model(
# pixel_values=pixel_values,
# output_attentions=output_attentions,
# output_hidden_states=output_hidden_states,
# return_dict=return_dict,
# )
# image_embeds = self.visual_projection(output_states[1])
# return image_embeds
# @torch.no_grad()
# def encode_image(self, images, batch_size=8):
# self.eval()
# all_embeddings = []
# iterator = range(0, len(images), batch_size)
# for batch_idx in iterator:
# batch = images[batch_idx:batch_idx + batch_size]
# encoded_input = self.feature_extractor(batch, return_tensors="pt").to(self.device)
# model_output = self(**encoded_input)
# image_embeddings = model_output.cpu()
# all_embeddings.extend(image_embeddings)
# # return torch.stack(all_embeddings).numpy()
# return torch.stack(all_embeddings)
# @staticmethod
# def remove_alpha_channel(image):
# image.convert("RGBA")
# alpha = image.convert('RGBA').split()[-1]
# background = Image.new("RGBA", image.size, (255, 255, 255))
# background.paste(image, mask=alpha)
# image = background.convert("RGB")
# return image
# def save(self, output_dir):
# self.model.save_pretrained(output_dir)
# self.feature_extractor.save_pretrained(output_dir)
# torch.save(self.visual_projection.state_dict(), os.path.join(output_dir, "visual_projection.bin"))
# class ClipModel(nn.Module):
# def __init__(self, model_name_or_path, device=None):
# super(ClipModel, self).__init__()
# if os.path.exists(model_name_or_path):
# # load from file system
# repo_dir = model_name_or_path
# else:
# # download from the Hugging Face model hub
# repo_dir = snapshot_download(model_name_or_path)
# self.text_model = ClipTextModel(repo_dir, device=device)
# self.vision_model = ClipVisionModel(os.path.join(repo_dir, "vision_model"), device=device)
# with torch.no_grad():
# logit_scale = nn.Parameter(torch.ones([]) * 2.6592)
# logit_scale.set_(torch.load(os.path.join(repo_dir, "logit_scale.bin")).clone().cpu())
# self.logit_scale = logit_scale
# self.eval()
# if device is None:
# device = "cuda" if torch.cuda.is_available() else "cpu"
# self.device = torch.device(device)
# self.to(self.device)
# def forward(self, pixel_values, input_ids, attention_mask, token_type_ids):
# image_features = self.vision_model(pixel_values=pixel_values)
# text_features = self.text_model(input_ids=input_ids,
# attention_mask=attention_mask,
# token_type_ids=token_type_ids)[0]
# image_features = image_features / image_features.norm(dim=-1, keepdim=True)
# text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# logit_scale = self.logit_scale.exp()
# logits_per_image = logit_scale * image_features @ text_features.t()
# logits_per_text = logits_per_image.t()
# return logits_per_image, logits_per_text
# def save(self, output_dir):
# torch.save(self.logit_scale, os.path.join(output_dir, "logit_scale.bin"))
# self.text_model.save(output_dir)
# self.vision_model.save(os.path.join(output_dir, "vision_model"))
class DummyClipModel:
def __init__(self, text_model):
self.text_model = text_model
def encode_text(text, model):
text = normalize_text(text)
text_embedding = model.text_model.encode_text([text]).numpy()
return text_embedding
# def encode_image(image_filename, model):
# image = Image.open(image_filename)
# image_embedding = model.vision_model.encode_image([image]).numpy()
# return image_embedding
st.title("いらすと検索(日本語CLIPゼロショット)")
description_text = st.empty()
if "model" not in st.session_state:
description_text.text("日本語CLIPモデル読み込み中... ")
device = "cuda" if torch.cuda.is_available() else "cpu"
text_model = ClipTextModel("sonoisa/clip-vit-b-32-japanese-v1", device=device)
# model = ClipModel("sonoisa/clip-vit-b-32-japanese-v1", device=device)
model = DummyClipModel(text_model)
st.session_state.model = model
print("extract dataset")
pyminizip.uncompress(
"clip_zeroshot_irasuto_image_items_20210224.pq.zip", st.secrets["ZIP_PASSWORD"], None, 1
)
print("loading dataset")
df = pq.read_table("clip_zeroshot_irasuto_image_items_20210224.parquet",
columns=["page", "description", "image_url", "image_vector"]).to_pandas()
st.session_state.df = df
# sentence_vectors = np.stack(df["sentence_vector"])
image_vectors = np.stack(df["image_vector"])
# st.session_state.sentence_vectors = sentence_vectors
st.session_state.image_vectors = image_vectors
print("finished loading model and dataset")
model = st.session_state.model
df = st.session_state.df
# sentence_vectors = st.session_state.sentence_vectors
image_vectors = st.session_state.image_vectors
description_text.text("日本語CLIPモデル(ゼロショット)を用いて、説明文の意味が近い「いらすとや」画像を検索します。\nキーワードを列挙するよりも、自然な文章を入力した方が精度よく検索できます。\n画像は必ずリンク先の「いらすとや」さんのページを開き、そこからダウンロードしてください。")
def clear_result():
result_text.text("")
prev_query = ""
query_input = st.text_input(label="説明文", value="", on_change=clear_result)
closest_n = st.number_input(label="検索数", min_value=1, value=10, max_value=100)
search_buttion = st.button("検索")
result_text = st.empty()
if search_buttion or prev_query != query_input:
prev_query = query_input
query_embedding = encode_text(query_input, model)
distances = scipy.spatial.distance.cdist(
query_embedding, image_vectors, metric="cosine"
)[0]
results = zip(range(len(distances)), distances)
results = sorted(results, key=lambda x: x[1])
md_content = ""
for i, (idx, distance) in enumerate(results[0:closest_n]):
page_url = df.iloc[idx]["page"]
desc = df.iloc[idx]["description"]
img_url = df.iloc[idx]["image_url"]
md_content += f"1. <div><a href='{page_url}' target='_blank' rel='noopener noreferrer'><img src='{img_url}' width='100'>{distance / 2:.4f}: {desc}</a><div>\n"
result_text.markdown(md_content, unsafe_allow_html=True)