from __future__ import unicode_literals |
import os |
import re |
import unicodedata |
import torch |
from torch import nn |
import streamlit as st |
import pandas as pd |
import pyarrow as pa |
import pyarrow.parquet as pq |
import numpy as np |
import scipy.spatial |
import pyminizip |
import transformers |
from transformers import BertJapaneseTokenizer, BertModel |
from huggingface_hub import hf_hub_download, snapshot_download |
from PIL import Image |
def unicode_normalize(cls, s): |
pt = re.compile("([{}]+)".format(cls)) |
def norm(c): |
return unicodedata.normalize("NFKC", c) if pt.match(c) else c |
s = "".join(norm(x) for x in re.split(pt, s)) |
s = re.sub("-", "-", s) |
return s |
def remove_extra_spaces(s): |
s = re.sub("[ ]+", " ", s) |
blocks = "".join( |
( |
"\u4E00-\u9FFF", |
"\u3040-\u309F", |
"\u30A0-\u30FF", |
"\u3000-\u303F", |
"\uFF00-\uFFEF", |
) |
) |
basic_latin = "\u0000-\u007F" |
def remove_space_between(cls1, cls2, s): |
p = re.compile("([{}]) ([{}])".format(cls1, cls2)) |
while p.search(s): |
s = p.sub(r"\1\2", s) |
return s |
s = remove_space_between(blocks, blocks, s) |
s = remove_space_between(blocks, basic_latin, s) |
s = remove_space_between(basic_latin, blocks, s) |
return s |
def normalize_neologd(s): |
s = s.strip() |
s = unicode_normalize("0-9A-Za-z。-゚", s) |
def maketrans(f, t): |
return {ord(x): ord(y) for x, y in zip(f, t)} |
s = re.sub("[˗֊‐‑‒–⁃⁻₋−]+", "-", s) |
s = re.sub("[﹣-ー—―─━ー]+", "ー", s) |
s = re.sub("[~∼∾〜〰~]+", "〜", s) |
s = s.translate( |
maketrans( |
"!\"#$%&'()*+,-./:;<=>?@[¥]^_`{|}~。、・「」", |
"!”#$%&’()*+,-./:;<=>?@[¥]^_`{|}〜。、・「」", |
) |
) |
s = remove_extra_spaces(s) |
s = unicode_normalize("!”#$%&’()*+,-./:;<>?@[¥]^_`{|}〜", s) |
s = re.sub("[’]", "'", s) |
s = re.sub("[”]", '"', s) |
s = s.lower() |
return s |
def normalize_text(text): |
return normalize_neologd(text) |
class ClipTextModel(nn.Module): |
def __init__(self, model_name_or_path, device=None): |
super(ClipTextModel, self).__init__() |
if os.path.exists(model_name_or_path): |
output_linear_state_dict = torch.load(os.path.join(model_name_or_path, "output_linear.bin")) |
else: |
filename = hf_hub_download(repo_id=model_name_or_path, filename="output_linear.bin") |
output_linear_state_dict = torch.load(filename) |
self.model = BertModel.from_pretrained(model_name_or_path) |
config = self.model.config |
self.max_cls_depth = 6 |
sentence_vector_size = output_linear_state_dict["bias"].shape[0] |
self.sentence_vector_size = sentence_vector_size |
self.output_linear = nn.Linear(self.max_cls_depth * config.hidden_size, sentence_vector_size) |
self.output_linear.load_state_dict(output_linear_state_dict) |
self.tokenizer = BertJapaneseTokenizer.from_pretrained(model_name_or_path, do_lower_case=True) |
self.eval() |
if device is None: |
device = "cuda" if torch.cuda.is_available() else "cpu" |
self.device = torch.device(device) |
self.to(self.device) |
def forward( |
self, |
input_ids=None, |
attention_mask=None, |
token_type_ids=None, |
): |
output_states = self.model( |
input_ids, |
attention_mask=attention_mask, |
token_type_ids=token_type_ids, |
position_ids=None, |
head_mask=None, |
inputs_embeds=None, |
output_attentions=None, |
output_hidden_states=True, |
return_dict=True, |
) |
token_embeddings = output_states[0] |
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
hidden_states = output_states["hidden_states"] |
output_vectors = [] |
for i in range(1, self.max_cls_depth + 1): |
cls_token = hidden_states[-1 * i][:, 0] |
output_vectors.append(cls_token) |
output_vector = torch.cat(output_vectors, dim=1) |
logits = self.output_linear(output_vector) |
output = (logits,) + output_states[2:] |
return output |
@torch.no_grad() |
def encode_text(self, texts, batch_size=8, max_length=64): |
self.eval() |
all_embeddings = [] |
iterator = range(0, len(texts), batch_size) |
for batch_idx in iterator: |
batch = texts[batch_idx:batch_idx + batch_size] |
encoded_input = self.tokenizer.batch_encode_plus( |
batch, max_length=max_length, padding="longest", |
truncation=True, return_tensors="pt").to(self.device) |
model_output = self(**encoded_input) |
text_embeddings = model_output[0].cpu() |
all_embeddings.extend(text_embeddings) |
return torch.stack(all_embeddings) |
def save(self, output_dir): |
self.model.save_pretrained(output_dir) |
self.tokenizer.save_pretrained(output_dir) |
torch.save(self.output_linear.state_dict(), os.path.join(output_dir, "output_linear.bin")) |
class ClipVisionModel(nn.Module): |
def __init__(self, model_name_or_path, device=None): |
super(ClipVisionModel, self).__init__() |
if os.path.exists(model_name_or_path): |
visual_projection_state_dict = torch.load(os.path.join(model_name_or_path, "visual_projection.bin")) |
else: |
filename = hf_hub_download(repo_id=model_name_or_path, filename="visual_projection.bin") |
visual_projection_state_dict = torch.load(filename) |
self.model = transformers.CLIPVisionModel.from_pretrained(model_name_or_path) |
config = self.model.config |
self.feature_extractor = transformers.CLIPFeatureExtractor.from_pretrained(model_name_or_path) |
vision_embed_dim = config.hidden_size |
projection_dim = 512 |
self.visual_projection = nn.Linear(vision_embed_dim, projection_dim, bias=False) |
self.visual_projection.load_state_dict(visual_projection_state_dict) |
self.eval() |
if device is None: |
device = "cuda" if torch.cuda.is_available() else "cpu" |
self.device = torch.device(device) |
self.to(self.device) |
def forward( |
self, |
pixel_values=None, |
output_attentions=None, |
output_hidden_states=None, |
return_dict=None, |
): |
output_states = self.model( |
pixel_values=pixel_values, |
output_attentions=output_attentions, |
output_hidden_states=output_hidden_states, |
return_dict=return_dict, |
) |
image_embeds = self.visual_projection(output_states[1]) |
return image_embeds |
@torch.no_grad() |
def encode_image(self, images, batch_size=8): |
self.eval() |
all_embeddings = [] |
iterator = range(0, len(images), batch_size) |
for batch_idx in iterator: |
batch = images[batch_idx:batch_idx + batch_size] |
encoded_input = self.feature_extractor(batch, return_tensors="pt").to(self.device) |
model_output = self(**encoded_input) |
image_embeddings = model_output.cpu() |
all_embeddings.extend(image_embeddings) |
return torch.stack(all_embeddings) |
@staticmethod |
def remove_alpha_channel(image): |
image.convert("RGBA") |
alpha = image.convert('RGBA').split()[-1] |
background = Image.new("RGBA", image.size, (255, 255, 255)) |
background.paste(image, mask=alpha) |
image = background.convert("RGB") |
return image |
def save(self, output_dir): |
self.model.save_pretrained(output_dir) |
self.feature_extractor.save_pretrained(output_dir) |
torch.save(self.visual_projection.state_dict(), os.path.join(output_dir, "visual_projection.bin")) |
class ClipModel(nn.Module): |
def __init__(self, model_name_or_path, device=None): |
super(ClipModel, self).__init__() |
if os.path.exists(model_name_or_path): |
repo_dir = model_name_or_path |
else: |
repo_dir = snapshot_download(model_name_or_path) |
self.text_model = ClipTextModel(repo_dir, device=device) |
self.vision_model = ClipVisionModel(os.path.join(repo_dir, "vision_model"), device=device) |
with torch.no_grad(): |
logit_scale = nn.Parameter(torch.ones([]) * 2.6592) |
logit_scale.set_(torch.load(os.path.join(repo_dir, "logit_scale.bin")).clone().cpu()) |
self.logit_scale = logit_scale |
self.eval() |
if device is None: |
device = "cuda" if torch.cuda.is_available() else "cpu" |
self.device = torch.device(device) |
self.to(self.device) |
def forward(self, pixel_values, input_ids, attention_mask, token_type_ids): |
image_features = self.vision_model(pixel_values=pixel_values) |
text_features = self.text_model(input_ids=input_ids, |
attention_mask=attention_mask, |
token_type_ids=token_type_ids)[0] |
image_features = image_features / image_features.norm(dim=-1, keepdim=True) |
text_features = text_features / text_features.norm(dim=-1, keepdim=True) |
logit_scale = self.logit_scale.exp() |
logits_per_image = logit_scale * image_features @ text_features.t() |
logits_per_text = logits_per_image.t() |
return logits_per_image, logits_per_text |
def save(self, output_dir): |
torch.save(self.logit_scale, os.path.join(output_dir, "logit_scale.bin")) |
self.text_model.save(output_dir) |
self.vision_model.save(os.path.join(output_dir, "vision_model")) |
def encode_text(text, model): |
text = normalize_text(text) |
text_embedding = model.text_model.encode_text([text]).numpy() |
return text_embedding |
def encode_image(image_filename, model): |
image = Image.open(image_filename) |
image_embedding = model.vision_model.encode_image([image]).numpy() |
return image_embedding |
st.title("いらすと検索(日本語CLIPゼロショット)") |
description_text = st.empty() |
if "model" not in st.session_state: |
description_text.text("日本語CLIPモデル読み込み中... ") |
device = "cuda" if torch.cuda.is_available() else "cpu" |
text_model = ClipTextModel("sonoisa/clip-vit-b-32-japanese-v1", device=device) |
model = ClipModel("sonoisa/clip-vit-b-32-japanese-v1", device=device) |
st.session_state.model = model |
print("extract dataset") |
pyminizip.uncompress( |
"clip_zeroshot_irasuto_items_20210224.pq.zip", st.secrets["ZIP_PASSWORD"], None, 1 |
) |
print("loading dataset") |
df = pq.read_table("clip_zeroshot_irasuto_items_20210224.parquet", |
columns=["page", "description", "image_url", "sentence_vector", "image_vector"]).to_pandas() |
sentence_vectors = np.stack(df["sentence_vector"]) |
image_vectors = np.stack(df["image_vector"]) |
st.session_state.sentence_vectors = sentence_vectors |
st.session_state.df = df |
st.session_state.image_vectors = image_vectors |
print("finished loading model and dataset") |
model = st.session_state.model |
df = st.session_state.df |
sentence_vectors = st.session_state.sentence_vectors |
image_vectors = st.session_state.image_vectors |
description_text.text("日本語CLIPモデル(ゼロショット)を用いて、説明文の意味が近い「いらすとや」画像を検索します。\nキーワードを列挙するよりも、自然な文章を入力した方が精度よく検索できます。\n画像は必ずリンク先の「いらすとや」さんのページを開き、そこからダウンロードしてください。") |
def clear_result(): |
result_text.text("") |
prev_query = "" |
query_input = st.text_input(label="説明文", value="", on_change=clear_result) |
closest_n = st.number_input(label="検索数", min_value=1, value=10, max_value=100) |
search_buttion = st.button("検索") |
result_text = st.empty() |
if search_buttion or prev_query != query_input: |
prev_query = query_input |
query_embedding = encode_text(query_input, model) |
distances = scipy.spatial.distance.cdist( |
query_embedding, image_vectors, metric="cosine" |
)[0] |
results = zip(range(len(distances)), distances) |
results = sorted(results, key=lambda x: x[1]) |
md_content = "" |
for i, (idx, distance) in enumerate(results[0:closest_n]): |
page_url = df.iloc[idx]["page"] |
desc = df.iloc[idx]["description"] |
img_url = df.iloc[idx]["image_url"] |
md_content += f"1. <div><a href='{page_url}' target='_blank' rel='noopener noreferrer'><img src='{img_url}' width='100'>{distance / 2:.4f}: {desc}</a><div>\n" |
result_text.markdown(md_content, unsafe_allow_html=True) |