Create utils.py
Browse files
utils.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from torch.nn import CrossEntropyLoss
|
4 |
+
import itertools
|
5 |
+
|
6 |
+
def is_hiragana_or_katakana(s):
|
7 |
+
for char in s:
|
8 |
+
if not ('\u3040' <= char <= '\u309F' or '\u30A0' <= char <= '\u30FF') or char == "ー":
|
9 |
+
return False
|
10 |
+
return True
|
11 |
+
|
12 |
+
def add_dakuten_handakuten(query, string_type):
|
13 |
+
def convert_to_hiragana(s):
|
14 |
+
"""与えられた文字列を平仮名に変換する"""
|
15 |
+
result = []
|
16 |
+
for char in s:
|
17 |
+
if 'ァ' <= char <= 'ヶ': # 片仮名を平仮名に変換
|
18 |
+
result.append(chr(ord(char) - 96))
|
19 |
+
else:
|
20 |
+
result.append(char)
|
21 |
+
return ''.join(result)
|
22 |
+
|
23 |
+
def convert_to_katakana(s):
|
24 |
+
"""与えられた文字列を片仮名に変換する"""
|
25 |
+
result = []
|
26 |
+
for char in s:
|
27 |
+
if 'ぁ' <= char <= 'ゖ': # 平仮名を片仮名に変換
|
28 |
+
result.append(chr(ord(char) + 96))
|
29 |
+
else:
|
30 |
+
result.append(char)
|
31 |
+
return ''.join(result)
|
32 |
+
|
33 |
+
if string_type == "hiragana":
|
34 |
+
s = convert_to_hiragana(query)
|
35 |
+
dakuon_map = {
|
36 |
+
'か': 'が', 'き': 'ぎ', 'く': 'ぐ', 'け': 'げ', 'こ': 'ご',
|
37 |
+
'さ': 'ざ', 'し': 'じ', 'す': 'ず', 'せ': 'ぜ', 'そ': 'ぞ',
|
38 |
+
'た': 'だ', 'ち': 'ぢ', 'つ': 'づ', 'て': 'で', 'と': 'ど',
|
39 |
+
'は': 'ば', 'ひ': 'び', 'ふ': 'ぶ', 'へ': 'べ', 'ほ': 'ぼ'
|
40 |
+
}
|
41 |
+
handakuon_map = {
|
42 |
+
'は': 'ぱ', 'ひ': 'ぴ', 'ふ': 'ぷ', 'へ': 'ぺ', 'ほ': 'ぽ'
|
43 |
+
}
|
44 |
+
elif string_type == "katakana":
|
45 |
+
s = convert_to_katakana(query)
|
46 |
+
dakuon_map = {
|
47 |
+
'カ': 'ガ', 'キ': 'ギ', 'ク': 'グ', 'ケ': 'ゲ', 'コ': 'ゴ',
|
48 |
+
'サ': 'ザ', 'シ': 'ジ', 'ス': 'ズ', 'セ': 'ゼ', 'ソ': 'ゾ',
|
49 |
+
'タ': 'ダ', 'チ': 'ヂ', 'ツ': 'ヅ', 'テ': 'デ', 'ト': 'ド',
|
50 |
+
'ハ': 'バ', 'ヒ': 'ビ', 'フ': 'ブ', 'ヘ': 'ベ', 'ホ': 'ボ',
|
51 |
+
'ウ': 'ヴ'
|
52 |
+
}
|
53 |
+
handakuon_map = {
|
54 |
+
'ハ': 'パ', 'ヒ': 'ピ', 'フ': 'プ', 'ヘ': 'ペ', 'ホ': 'ポ'
|
55 |
+
}
|
56 |
+
|
57 |
+
# 文字ごとに元の文字と濁音・半濁音をリストにする
|
58 |
+
options = []
|
59 |
+
for char in s:
|
60 |
+
temp = [char]
|
61 |
+
if char in dakuon_map:
|
62 |
+
temp.append(dakuon_map[char])
|
63 |
+
if char in handakuon_map:
|
64 |
+
temp.append(handakuon_map[char])
|
65 |
+
options.append(temp)
|
66 |
+
|
67 |
+
# 全ての組み合わせを生成
|
68 |
+
candidates = list(itertools.product(*options))
|
69 |
+
return candidates
|
70 |
+
|
71 |
+
def add_dashes(s):
|
72 |
+
if not s:
|
73 |
+
return ['']
|
74 |
+
|
75 |
+
# 再帰的に文字列の先頭以外の部分に「ー」を挿入するパターンを取得
|
76 |
+
substr_patterns = add_dashes(s[1:])
|
77 |
+
|
78 |
+
# 現在の文字を含めたパターンを生成
|
79 |
+
result = []
|
80 |
+
for pattern in substr_patterns:
|
81 |
+
result.append(s[0] + pattern) # そのまま連結
|
82 |
+
result.append(s[0] + 'ー' + pattern) # 「ー」を挿入して連結
|
83 |
+
|
84 |
+
return result
|
85 |
+
|
86 |
+
def compute_losses(candidates, model, tokenizer):
|
87 |
+
inputs = tokenizer(candidates, return_tensors="pt", padding=True)
|
88 |
+
inputs["labels"] = inputs["input_ids"].masked_fill(inputs["input_ids"] == tokenizer.pad_token_id, -100)
|
89 |
+
inputs = inputs.to(model.device)
|
90 |
+
|
91 |
+
with torch.no_grad():
|
92 |
+
outputs = model(**inputs)
|
93 |
+
|
94 |
+
logits = outputs.logits
|
95 |
+
labels = inputs["labels"]
|
96 |
+
|
97 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
98 |
+
shift_labels = labels[..., 1:].contiguous()
|
99 |
+
loss_fct = CrossEntropyLoss(reduction="none")
|
100 |
+
|
101 |
+
losses_flat = loss_fct(shift_logits.view(-1, model.config.vocab_size), shift_labels.view(-1))
|
102 |
+
losses_seq = losses_flat.view(shift_labels.shape)
|
103 |
+
mask_labels = shift_labels != tokenizer.pad_token_id
|
104 |
+
losses = torch.sum(losses_seq * mask_labels, -1) / mask_labels.sum(-1)
|
105 |
+
|
106 |
+
return losses
|
107 |
+
|
108 |
+
def search_candidates(query, query_candidates, model, tokenizer, top_k=100):
|
109 |
+
old_query = query[:-1]
|
110 |
+
if old_query not in query_candidates:
|
111 |
+
old_candidates, _ = search_candidates(old_query, query_candidates, model=model, tokenizer=tokenizer, top_k=top_k)
|
112 |
+
else:
|
113 |
+
old_candidates, _ = query_candidates[old_query]
|
114 |
+
|
115 |
+
string = query[-1]
|
116 |
+
candidates = []
|
117 |
+
for string_type in ["hiragana", "katakana"]:
|
118 |
+
candidates_ = add_dakuten_handakuten(string, string_type=string_type)
|
119 |
+
for candidate_ in candidates_:
|
120 |
+
candidates += add_dashes(candidate_)
|
121 |
+
|
122 |
+
combinations = itertools.product(old_candidates, candidates)
|
123 |
+
new_candidates = [''.join(pair) for pair in combinations]
|
124 |
+
|
125 |
+
losses = compute_losses(new_candidates, model=model, tokenizer=tokenizer)
|
126 |
+
sorted_items = torch.sort(losses)
|
127 |
+
sorted_candidates = np.array(new_candidates)[sorted_items.indices.cpu().numpy()]
|
128 |
+
topk_candidates = sorted_candidates[:top_k].tolist()
|
129 |
+
topk_losses = sorted_items.values[:top_k].cpu().tolist()
|
130 |
+
|
131 |
+
query_candidates[query] = (topk_candidates, topk_losses)
|
132 |
+
return topk_candidates, topk_losses
|