from __future__ import unicode_literals import os import re import unicodedata import torch from torch import nn import streamlit as st import pandas as pd import pyarrow as pa import pyarrow.parquet as pq import numpy as np import scipy.spatial import pyminizip import transformers from transformers import BertJapaneseTokenizer, BertModel from huggingface_hub import hf_hub_download, snapshot_download from PIL import Image def unicode_normalize(cls, s): pt = re.compile("([{}]+)".format(cls)) def norm(c): return unicodedata.normalize("NFKC", c) if pt.match(c) else c s = "".join(norm(x) for x in re.split(pt, s)) s = re.sub("-", "-", s) return s def remove_extra_spaces(s): s = re.sub("[  ]+", " ", s) blocks = "".join( ( "\u4E00-\u9FFF", # CJK UNIFIED IDEOGRAPHS "\u3040-\u309F", # HIRAGANA "\u30A0-\u30FF", # KATAKANA "\u3000-\u303F", # CJK SYMBOLS AND PUNCTUATION "\uFF00-\uFFEF", # HALFWIDTH AND FULLWIDTH FORMS ) ) basic_latin = "\u0000-\u007F" def remove_space_between(cls1, cls2, s): p = re.compile("([{}]) ([{}])".format(cls1, cls2)) while p.search(s): s = p.sub(r"\1\2", s) return s s = remove_space_between(blocks, blocks, s) s = remove_space_between(blocks, basic_latin, s) s = remove_space_between(basic_latin, blocks, s) return s def normalize_neologd(s): s = s.strip() s = unicode_normalize("0-9A-Za-z。-゚", s) def maketrans(f, t): return {ord(x): ord(y) for x, y in zip(f, t)} s = re.sub("[˗֊‐‑‒–⁃⁻₋−]+", "-", s) # normalize hyphens s = re.sub("[﹣-ー—―─━ー]+", "ー", s) # normalize choonpus s = re.sub("[~∼∾〜〰~]+", "〜", s) # normalize tildes (modified by Isao Sonobe) s = s.translate( maketrans( "!\"#$%&'()*+,-./:;<=>?@[¥]^_`{|}~。、・「」", "!”#$%&’()*+,-./:;<=>?@[¥]^_`{|}〜。、・「」", ) ) s = remove_extra_spaces(s) s = unicode_normalize("!”#$%&’()*+,-./:;<>?@[¥]^_`{|}〜", s) # keep =,・,「,」 s = re.sub("[’]", "'", s) s = re.sub("[”]", '"', s) s = s.lower() return s def normalize_text(text): return normalize_neologd(text) class ClipTextModel(nn.Module): def __init__(self, model_name_or_path, device=None): super(ClipTextModel, self).__init__() if os.path.exists(model_name_or_path): # load from file system output_linear_state_dict = torch.load(os.path.join(model_name_or_path, "output_linear.bin")) else: # download from the Hugging Face model hub filename = hf_hub_download(repo_id=model_name_or_path, filename="output_linear.bin") output_linear_state_dict = torch.load(filename) self.model = BertModel.from_pretrained(model_name_or_path) config = self.model.config self.max_cls_depth = 6 sentence_vector_size = output_linear_state_dict["bias"].shape[0] self.sentence_vector_size = sentence_vector_size self.output_linear = nn.Linear(self.max_cls_depth * config.hidden_size, sentence_vector_size) # self.output_linear = nn.Linear(3 * config.hidden_size, sentence_vector_size) self.output_linear.load_state_dict(output_linear_state_dict) self.tokenizer = BertJapaneseTokenizer.from_pretrained(model_name_or_path, do_lower_case=True) self.eval() if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" self.device = torch.device(device) self.to(self.device) def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, ): output_states = self.model( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=None, head_mask=None, inputs_embeds=None, output_attentions=None, output_hidden_states=True, return_dict=True, ) token_embeddings = output_states[0] input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() hidden_states = output_states["hidden_states"] output_vectors = [] for i in range(1, self.max_cls_depth + 1): cls_token = hidden_states[-1 * i][:, 0] output_vectors.append(cls_token) output_vector = torch.cat(output_vectors, dim=1) logits = self.output_linear(output_vector) output = (logits,) + output_states[2:] return output @torch.no_grad() def encode_text(self, texts, batch_size=8, max_length=64): self.eval() all_embeddings = [] iterator = range(0, len(texts), batch_size) for batch_idx in iterator: batch = texts[batch_idx:batch_idx + batch_size] encoded_input = self.tokenizer.batch_encode_plus( batch, max_length=max_length, padding="longest", truncation=True, return_tensors="pt").to(self.device) model_output = self(**encoded_input) text_embeddings = model_output[0].cpu() all_embeddings.extend(text_embeddings) # return torch.stack(all_embeddings).numpy() return torch.stack(all_embeddings) def save(self, output_dir): self.model.save_pretrained(output_dir) self.tokenizer.save_pretrained(output_dir) torch.save(self.output_linear.state_dict(), os.path.join(output_dir, "output_linear.bin")) class ClipVisionModel(nn.Module): def __init__(self, model_name_or_path, device=None): super(ClipVisionModel, self).__init__() if os.path.exists(model_name_or_path): # load from file system visual_projection_state_dict = torch.load(os.path.join(model_name_or_path, "visual_projection.bin")) else: # download from the Hugging Face model hub filename = hf_hub_download(repo_id=model_name_or_path, filename="visual_projection.bin") visual_projection_state_dict = torch.load(filename) self.model = transformers.CLIPVisionModel.from_pretrained(model_name_or_path) config = self.model.config self.feature_extractor = transformers.CLIPFeatureExtractor.from_pretrained(model_name_or_path) vision_embed_dim = config.hidden_size projection_dim = 512 self.visual_projection = nn.Linear(vision_embed_dim, projection_dim, bias=False) self.visual_projection.load_state_dict(visual_projection_state_dict) self.eval() if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" self.device = torch.device(device) self.to(self.device) def forward( self, pixel_values=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): output_states = self.model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) image_embeds = self.visual_projection(output_states[1]) return image_embeds @torch.no_grad() def encode_image(self, images, batch_size=8): self.eval() all_embeddings = [] iterator = range(0, len(images), batch_size) for batch_idx in iterator: batch = images[batch_idx:batch_idx + batch_size] encoded_input = self.feature_extractor(batch, return_tensors="pt").to(self.device) model_output = self(**encoded_input) image_embeddings = model_output.cpu() all_embeddings.extend(image_embeddings) # return torch.stack(all_embeddings).numpy() return torch.stack(all_embeddings) @staticmethod def remove_alpha_channel(image): image.convert("RGBA") alpha = image.convert('RGBA').split()[-1] background = Image.new("RGBA", image.size, (255, 255, 255)) background.paste(image, mask=alpha) image = background.convert("RGB") return image def save(self, output_dir): self.model.save_pretrained(output_dir) self.feature_extractor.save_pretrained(output_dir) torch.save(self.visual_projection.state_dict(), os.path.join(output_dir, "visual_projection.bin")) class ClipModel(nn.Module): def __init__(self, model_name_or_path, device=None): super(ClipModel, self).__init__() if os.path.exists(model_name_or_path): # load from file system repo_dir = model_name_or_path else: # download from the Hugging Face model hub repo_dir = snapshot_download(model_name_or_path) self.text_model = ClipTextModel(repo_dir, device=device) self.vision_model = ClipVisionModel(os.path.join(repo_dir, "vision_model"), device=device) with torch.no_grad(): logit_scale = nn.Parameter(torch.ones([]) * 2.6592) logit_scale.set_(torch.load(os.path.join(repo_dir, "logit_scale.bin")).clone().cpu()) self.logit_scale = logit_scale self.eval() if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" self.device = torch.device(device) self.to(self.device) def forward(self, pixel_values, input_ids, attention_mask, token_type_ids): image_features = self.vision_model(pixel_values=pixel_values) text_features = self.text_model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)[0] image_features = image_features / image_features.norm(dim=-1, keepdim=True) text_features = text_features / text_features.norm(dim=-1, keepdim=True) logit_scale = self.logit_scale.exp() logits_per_image = logit_scale * image_features @ text_features.t() logits_per_text = logits_per_image.t() return logits_per_image, logits_per_text def save(self, output_dir): torch.save(self.logit_scale, os.path.join(output_dir, "logit_scale.bin")) self.text_model.save(output_dir) self.vision_model.save(os.path.join(output_dir, "vision_model")) def encode_text(text, model): text = normalize_text(text) text_embedding = model.text_model.encode_text([text]).numpy() return text_embedding def encode_image(image_filename, model): image = Image.open(image_filename) image = ClipVisionModel.remove_alpha_channel(image) image_embedding = model.vision_model.encode_image([image]).numpy() return image_embedding st.title("いらすと検索(日本語CLIPゼロショット)") description_text = st.empty() if "model" not in st.session_state: description_text.markdown("日本語CLIPモデル読み込み中... ") device = "cuda" if torch.cuda.is_available() else "cpu" model = ClipModel("sonoisa/clip-vit-b-32-japanese-v1", device=device) st.session_state.model = model pyminizip.uncompress( "clip_zeroshot_irasuto_items_20210224.pq.zip", st.secrets["ZIP_PASSWORD"], None, 1 ) df = pq.read_table("clip_zeroshot_irasuto_items_20210224.parquet", columns=["page", "description", "image_url", "sentence_vector", "image_vector"]).to_pandas() sentence_vectors = np.stack(df["sentence_vector"]) image_vectors = np.stack(df["image_vector"]) st.session_state.df = df st.session_state.sentence_vectors = sentence_vectors st.session_state.image_vectors = image_vectors model = st.session_state.model df = st.session_state.df sentence_vectors = st.session_state.sentence_vectors image_vectors = st.session_state.image_vectors description_text.markdown("日本語CLIPモデル(ゼロショット)を用いて、説明文の意味が近い「いらすとや」画像を検索します。 \n" + \ "使い方: \n\n" + \ "1. 「クエリ種別」でテキスト(説明文)と画像のどちらをクエリに用いるか選ぶ。\n" + \ "2. 「クエリ種別」に合わせて「説明文」に検索クエリとなるテキストを入力するか、「画像」に検索クエリとなる画像ファイルを指定する。\n" + \ "3. 「検索数」で検索結果の表示数を、「検索対象ベクトル」で画像ベクトルと文ベクトルのどちらとの類似性をもって検索するかを指定することができる。\n\n" + \ "説明文にはキーワードを列挙するよりも、自然な文章を入力した方が精度よく検索できます。 \n" + \ "画像は必ずリンク先の「いらすとや」さんのページを開き、そこからダウンロードしてください。") def clear_result(): result_text.text("") query_type = st.radio(label="クエリ種別", options=("説明文", "画像")) prev_query = "" query_input = st.text_input(label="説明文", value="", on_change=clear_result) query_image = st.file_uploader(label="画像", type=["png", "jpg", "jpeg"], on_change=clear_result) closest_n = st.number_input(label="検索数", min_value=1, value=10, max_value=100) model_type = st.radio(label="検索対象ベクトル", options=("画像", "文")) search_buttion = st.button("検索") result_text = st.empty() if search_buttion or prev_query != query_input: if query_type == "説明文" or query_image is None: prev_query = query_input query_embedding = encode_text(query_input, model) else: query_embedding = encode_image(query_image, model) if model_type == "画像": target_vectors = image_vectors else: target_vectors = sentence_vectors distances = scipy.spatial.distance.cdist( query_embedding, target_vectors, metric="cosine" )[0] results = zip(range(len(distances)), distances) results = sorted(results, key=lambda x: x[1]) md_content = "" for i, (idx, distance) in enumerate(results[0:closest_n]): page_url = df.iloc[idx]["page"] desc = df.iloc[idx]["description"] img_url = df.iloc[idx]["image_url"] md_content += f"1.
{distance / 2:.4f}: {desc}
\n" result_text.markdown(md_content, unsafe_allow_html=True)