sonoisa commited on
Commit
c67f441
·
1 Parent(s): 20efab0

Add application files

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. app.py +187 -0
  3. irasuto_items_20210224.pq.zip +3 -0
  4. requirements.txt +6 -0
.gitattributes CHANGED
@@ -25,3 +25,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ irasuto_items_20210224.pq.zip filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import unicode_literals
2
+ import re
3
+ import unicodedata
4
+ import torch
5
+ import streamlit as st
6
+ import pandas as pd
7
+ import pyarrow as pa
8
+ import pyarrow.parquet as pq
9
+ import numpy as np
10
+ import scipy.spatial
11
+ from transformers import BertJapaneseTokenizer, BertModel
12
+ import pyminizip
13
+
14
+
15
+ def unicode_normalize(cls, s):
16
+ pt = re.compile("([{}]+)".format(cls))
17
+
18
+ def norm(c):
19
+ return unicodedata.normalize("NFKC", c) if pt.match(c) else c
20
+
21
+ s = "".join(norm(x) for x in re.split(pt, s))
22
+ s = re.sub("-", "-", s)
23
+ return s
24
+
25
+
26
+ def remove_extra_spaces(s):
27
+ s = re.sub("[  ]+", " ", s)
28
+ blocks = "".join(
29
+ (
30
+ "\u4E00-\u9FFF", # CJK UNIFIED IDEOGRAPHS
31
+ "\u3040-\u309F", # HIRAGANA
32
+ "\u30A0-\u30FF", # KATAKANA
33
+ "\u3000-\u303F", # CJK SYMBOLS AND PUNCTUATION
34
+ "\uFF00-\uFFEF", # HALFWIDTH AND FULLWIDTH FORMS
35
+ )
36
+ )
37
+ basic_latin = "\u0000-\u007F"
38
+
39
+ def remove_space_between(cls1, cls2, s):
40
+ p = re.compile("([{}]) ([{}])".format(cls1, cls2))
41
+ while p.search(s):
42
+ s = p.sub(r"\1\2", s)
43
+ return s
44
+
45
+ s = remove_space_between(blocks, blocks, s)
46
+ s = remove_space_between(blocks, basic_latin, s)
47
+ s = remove_space_between(basic_latin, blocks, s)
48
+ return s
49
+
50
+
51
+ def normalize_neologd(s):
52
+ s = s.strip()
53
+ s = unicode_normalize("0-9A-Za-z。-゚", s)
54
+
55
+ def maketrans(f, t):
56
+ return {ord(x): ord(y) for x, y in zip(f, t)}
57
+
58
+ s = re.sub("[˗֊‐‑‒–⁃⁻₋−]+", "-", s) # normalize hyphens
59
+ s = re.sub("[﹣-ー—―─━ー]+", "ー", s) # normalize choonpus
60
+ s = re.sub("[~∼∾〜〰~]+", "〜", s) # normalize tildes (modified by Isao Sonobe)
61
+ s = s.translate(
62
+ maketrans(
63
+ "!\"#$%&'()*+,-./:;<=>?@[¥]^_`{|}~。、・「」",
64
+ "!”#$%&’()*+,-./:;<=>?@[¥]^_`{|}〜。、・「」",
65
+ )
66
+ )
67
+
68
+ s = remove_extra_spaces(s)
69
+ s = unicode_normalize("!”#$%&’()*+,-./:;<>?@[¥]^_`{|}〜", s) # keep =,・,「,」
70
+ s = re.sub("[’]", "'", s)
71
+ s = re.sub("[”]", '"', s)
72
+ # s = s.upper()
73
+ return s
74
+
75
+
76
+ def normalize_text(text):
77
+ return normalize_neologd(text)
78
+
79
+
80
+ def normalize_title(title):
81
+ title = title.strip()
82
+
83
+ match = re.match(r"^「([^」]+)」$", title)
84
+ if match:
85
+ title = match.group(1)
86
+
87
+ match = re.match(r"^POP素材「([^」]+)」$", title)
88
+ if match:
89
+ title = match.group(1)
90
+
91
+ match = re.match(
92
+ r"^(.*?)(の?(?:イラスト|イラストの|イラストト|イ子のラスト|イラス|イラスト文字|「イラスト文字」|イラストPOP文字|ペンキ文字|タイトル文字|イラスト・メッセージ|イラスト文字・バナー|キャラクター(たち)?|マーク|アイコン|シルエット|シルエット素材|フレーム(枠)|フレーム|フレーム素材|テンプレート|パターン|パターン素材|ライン素材|コーナー素材|リボン型バナー|評価スタンプ|背景素材))+(\s*([0-90-9]*|その[0-90-9]+)\s*(((|\()[^))]+()|\))|「[^」]+」|・.+)*(です。)?)",
93
+ title,
94
+ )
95
+ if match:
96
+ title = match.group(1) + ("" if match.group(3) is None else match.group(3))
97
+ if title == "":
98
+ raise ValueError(title)
99
+
100
+ title = normalize_text(title)
101
+
102
+ return title
103
+
104
+
105
+ class SentenceBertJapanese:
106
+ def __init__(self, model_name_or_path, device=None):
107
+ self.tokenizer = BertJapaneseTokenizer.from_pretrained(model_name_or_path)
108
+ self.model = BertModel.from_pretrained(model_name_or_path)
109
+ self.model.eval()
110
+
111
+ if device is None:
112
+ device = "cuda" if torch.cuda.is_available() else "cpu"
113
+ self.device = torch.device(device)
114
+ self.model.to(device)
115
+
116
+ def _mean_pooling(self, model_output, attention_mask):
117
+ token_embeddings = model_output[
118
+ 0
119
+ ] # First element of model_output contains all token embeddings
120
+ input_mask_expanded = (
121
+ attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
122
+ )
123
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
124
+ input_mask_expanded.sum(1), min=1e-9
125
+ )
126
+
127
+ @torch.no_grad()
128
+ def encode(self, sentences, batch_size=8):
129
+ all_embeddings = []
130
+ iterator = range(0, len(sentences), batch_size)
131
+ for batch_idx in iterator:
132
+ batch = sentences[batch_idx : batch_idx + batch_size]
133
+
134
+ encoded_input = self.tokenizer.batch_encode_plus(
135
+ batch, padding="longest", truncation=True, return_tensors="pt"
136
+ ).to(self.device)
137
+ model_output = self.model(**encoded_input)
138
+ sentence_embeddings = self._mean_pooling(
139
+ model_output, encoded_input["attention_mask"]
140
+ ).to("cpu")
141
+
142
+ all_embeddings.extend(sentence_embeddings)
143
+
144
+ # return torch.stack(all_embeddings).numpy()
145
+ return torch.stack(all_embeddings)
146
+
147
+
148
+ st.title("いらすと検索")
149
+ description_text = st.empty()
150
+ description_text.text("...モデル読み込み中...")
151
+
152
+ model = SentenceBertJapanese("sonoisa/sentence-bert-base-ja-mean-tokens")
153
+
154
+ pyminizip.uncompress(
155
+ "irasuto_items_20210224.pq.zip", st.secrets["ZIP_PASSWORD"], None, 1
156
+ )
157
+
158
+ df = pq.read_table("irasuto_items_20210224.parquet").to_pandas()
159
+ sentence_vectors = np.array(df["sentence_vector"])
160
+
161
+ st.text("説明文の意味が近い「いらすとや」画像を検索します。")
162
+ query_input = st.text_input(label="説明文", value="")
163
+ search_buttion = st.button("検索")
164
+
165
+ closest_n = 5
166
+
167
+ if search_buttion:
168
+ query = str(query_input)
169
+ query_embedding = model.encode([query]).numpy()
170
+
171
+ distances = scipy.spatial.distance.cdist(
172
+ [query_embedding], sentence_vectors, metric="cosine"
173
+ )[0]
174
+
175
+ results = zip(range(len(distances)), distances)
176
+ results = sorted(results, key=lambda x: x[1])
177
+
178
+ print("\n\n======================\n\n")
179
+ print("Query:", query)
180
+ print("\nTop 5 most similar sentences in corpus:")
181
+
182
+ for idx, distance in results[0:closest_n]:
183
+ # print(sentences[idx].strip(), "(Score: %.4f)" % (distance / 2))
184
+ print(
185
+ f"{df.iloc[idx]['title']} {df.iloc[idx]['normalized_description']} (Score: %.4f)"
186
+ % (distance / 2)
187
+ )
irasuto_items_20210224.pq.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:338ffc5865419f827dd02a22f7962dbbf5e2cae4670861c518035d1fce7ead12
3
+ size 77950743
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ transformers==4.7.0
2
+ torch==1.7.0
3
+ sentencepiece
4
+ pyminizip
5
+ fugashi
6
+ ipadic