|
import numpy as np |
|
import torch |
|
from torch.nn import CrossEntropyLoss |
|
import itertools |
|
|
|
def is_hiragana_or_katakana(s): |
|
for char in s: |
|
if not ('\u3040' <= char <= '\u309F' or '\u30A0' <= char <= '\u30FF') or char == "ー": |
|
return False |
|
return True |
|
|
|
def add_dakuten_handakuten(query, string_type): |
|
def convert_to_hiragana(s): |
|
"""与えられた文字列を平仮名に変換する""" |
|
result = [] |
|
for char in s: |
|
if 'ァ' <= char <= 'ヶ': |
|
result.append(chr(ord(char) - 96)) |
|
else: |
|
result.append(char) |
|
return ''.join(result) |
|
|
|
def convert_to_katakana(s): |
|
"""与えられた文字列を片仮名に変換する""" |
|
result = [] |
|
for char in s: |
|
if 'ぁ' <= char <= 'ゖ': |
|
result.append(chr(ord(char) + 96)) |
|
else: |
|
result.append(char) |
|
return ''.join(result) |
|
|
|
if string_type == "hiragana": |
|
s = convert_to_hiragana(query) |
|
dakuon_map = { |
|
'か': 'が', 'き': 'ぎ', 'く': 'ぐ', 'け': 'げ', 'こ': 'ご', |
|
'さ': 'ざ', 'し': 'じ', 'す': 'ず', 'せ': 'ぜ', 'そ': 'ぞ', |
|
'た': 'だ', 'ち': 'ぢ', 'つ': 'づ', 'て': 'で', 'と': 'ど', |
|
'は': 'ば', 'ひ': 'び', 'ふ': 'ぶ', 'へ': 'べ', 'ほ': 'ぼ' |
|
} |
|
handakuon_map = { |
|
'は': 'ぱ', 'ひ': 'ぴ', 'ふ': 'ぷ', 'へ': 'ぺ', 'ほ': 'ぽ' |
|
} |
|
elif string_type == "katakana": |
|
s = convert_to_katakana(query) |
|
dakuon_map = { |
|
'カ': 'ガ', 'キ': 'ギ', 'ク': 'グ', 'ケ': 'ゲ', 'コ': 'ゴ', |
|
'サ': 'ザ', 'シ': 'ジ', 'ス': 'ズ', 'セ': 'ゼ', 'ソ': 'ゾ', |
|
'タ': 'ダ', 'チ': 'ヂ', 'ツ': 'ヅ', 'テ': 'デ', 'ト': 'ド', |
|
'ハ': 'バ', 'ヒ': 'ビ', 'フ': 'ブ', 'ヘ': 'ベ', 'ホ': 'ボ', |
|
'ウ': 'ヴ' |
|
} |
|
handakuon_map = { |
|
'ハ': 'パ', 'ヒ': 'ピ', 'フ': 'プ', 'ヘ': 'ペ', 'ホ': 'ポ' |
|
} |
|
|
|
|
|
options = [] |
|
for char in s: |
|
temp = [char] |
|
if char in dakuon_map: |
|
temp.append(dakuon_map[char]) |
|
if char in handakuon_map: |
|
temp.append(handakuon_map[char]) |
|
options.append(temp) |
|
|
|
|
|
candidates = list(itertools.product(*options)) |
|
return candidates |
|
|
|
def add_dashes(s): |
|
if not s: |
|
return [''] |
|
|
|
|
|
substr_patterns = add_dashes(s[1:]) |
|
|
|
|
|
result = [] |
|
for pattern in substr_patterns: |
|
result.append(s[0] + pattern) |
|
result.append(s[0] + 'ー' + pattern) |
|
|
|
return result |
|
|
|
def compute_losses(candidates, model, tokenizer): |
|
inputs = tokenizer(candidates, return_tensors="pt", padding=True) |
|
inputs["labels"] = inputs["input_ids"].masked_fill(inputs["input_ids"] == tokenizer.pad_token_id, -100) |
|
inputs = inputs.to(model.device) |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
logits = outputs.logits |
|
labels = inputs["labels"] |
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
shift_labels = labels[..., 1:].contiguous() |
|
loss_fct = CrossEntropyLoss(reduction="none") |
|
|
|
losses_flat = loss_fct(shift_logits.view(-1, model.config.vocab_size), shift_labels.view(-1)) |
|
losses_seq = losses_flat.view(shift_labels.shape) |
|
mask_labels = shift_labels != tokenizer.pad_token_id |
|
losses = torch.sum(losses_seq * mask_labels, -1) / mask_labels.sum(-1) |
|
|
|
return losses |
|
|
|
def search_candidates(query, query_candidates, model, tokenizer, top_k=100): |
|
old_query = query[:-1] |
|
if old_query not in query_candidates: |
|
old_candidates, _ = search_candidates(old_query, query_candidates, model=model, tokenizer=tokenizer, top_k=top_k) |
|
else: |
|
old_candidates, _ = query_candidates[old_query] |
|
|
|
string = query[-1] |
|
candidates = [] |
|
for string_type in ["hiragana", "katakana"]: |
|
candidates_ = add_dakuten_handakuten(string, string_type=string_type) |
|
for candidate_ in candidates_: |
|
candidates += add_dashes(candidate_) |
|
|
|
combinations = itertools.product(old_candidates, candidates) |
|
new_candidates = [''.join(pair) for pair in combinations] |
|
|
|
losses = compute_losses(new_candidates, model=model, tokenizer=tokenizer) |
|
sorted_items = torch.sort(losses) |
|
sorted_candidates = np.array(new_candidates)[sorted_items.indices.cpu().numpy()] |
|
topk_candidates = sorted_candidates[:top_k].tolist() |
|
topk_losses = sorted_items.values[:top_k].cpu().tolist() |
|
|
|
query_candidates[query] = (topk_candidates, topk_losses) |
|
return topk_candidates, topk_losses |