marutenbo / utils.py
isonuma's picture
Create utils.py
dcd398a verified
raw
history blame
No virus
5.09 kB
import numpy as np
import torch
from torch.nn import CrossEntropyLoss
import itertools
def is_hiragana_or_katakana(s):
for char in s:
if not ('\u3040' <= char <= '\u309F' or '\u30A0' <= char <= '\u30FF') or char == "ー":
return False
return True
def add_dakuten_handakuten(query, string_type):
def convert_to_hiragana(s):
"""与えられた文字列を平仮名に変換する"""
result = []
for char in s:
if 'ァ' <= char <= 'ヶ': # 片仮名を平仮名に変換
result.append(chr(ord(char) - 96))
else:
result.append(char)
return ''.join(result)
def convert_to_katakana(s):
"""与えられた文字列を片仮名に変換する"""
result = []
for char in s:
if 'ぁ' <= char <= 'ゖ': # 平仮名を片仮名に変換
result.append(chr(ord(char) + 96))
else:
result.append(char)
return ''.join(result)
if string_type == "hiragana":
s = convert_to_hiragana(query)
dakuon_map = {
'か': 'が', 'き': 'ぎ', 'く': 'ぐ', 'け': 'げ', 'こ': 'ご',
'さ': 'ざ', 'し': 'じ', 'す': 'ず', 'せ': 'ぜ', 'そ': 'ぞ',
'た': 'だ', 'ち': 'ぢ', 'つ': 'づ', 'て': 'で', 'と': 'ど',
'は': 'ば', 'ひ': 'び', 'ふ': 'ぶ', 'へ': 'べ', 'ほ': 'ぼ'
}
handakuon_map = {
'は': 'ぱ', 'ひ': 'ぴ', 'ふ': 'ぷ', 'へ': 'ぺ', 'ほ': 'ぽ'
}
elif string_type == "katakana":
s = convert_to_katakana(query)
dakuon_map = {
'カ': 'ガ', 'キ': 'ギ', 'ク': 'グ', 'ケ': 'ゲ', 'コ': 'ゴ',
'サ': 'ザ', 'シ': 'ジ', 'ス': 'ズ', 'セ': 'ゼ', 'ソ': 'ゾ',
'タ': 'ダ', 'チ': 'ヂ', 'ツ': 'ヅ', 'テ': 'デ', 'ト': 'ド',
'ハ': 'バ', 'ヒ': 'ビ', 'フ': 'ブ', 'ヘ': 'ベ', 'ホ': 'ボ',
'ウ': 'ヴ'
}
handakuon_map = {
'ハ': 'パ', 'ヒ': 'ピ', 'フ': 'プ', 'ヘ': 'ペ', 'ホ': 'ポ'
}
# 文字ごとに元の文字と濁音・半濁音をリストにする
options = []
for char in s:
temp = [char]
if char in dakuon_map:
temp.append(dakuon_map[char])
if char in handakuon_map:
temp.append(handakuon_map[char])
options.append(temp)
# 全ての組み合わせを生成
candidates = list(itertools.product(*options))
return candidates
def add_dashes(s):
if not s:
return ['']
# 再帰的に文字列の先頭以外の部分に「ー」を挿入するパターンを取得
substr_patterns = add_dashes(s[1:])
# 現在の文字を含めたパターンを生成
result = []
for pattern in substr_patterns:
result.append(s[0] + pattern) # そのまま連結
result.append(s[0] + 'ー' + pattern) # 「ー」を挿入して連結
return result
def compute_losses(candidates, model, tokenizer):
inputs = tokenizer(candidates, return_tensors="pt", padding=True)
inputs["labels"] = inputs["input_ids"].masked_fill(inputs["input_ids"] == tokenizer.pad_token_id, -100)
inputs = inputs.to(model.device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
labels = inputs["labels"]
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss(reduction="none")
losses_flat = loss_fct(shift_logits.view(-1, model.config.vocab_size), shift_labels.view(-1))
losses_seq = losses_flat.view(shift_labels.shape)
mask_labels = shift_labels != tokenizer.pad_token_id
losses = torch.sum(losses_seq * mask_labels, -1) / mask_labels.sum(-1)
return losses
def search_candidates(query, query_candidates, model, tokenizer, top_k=100):
old_query = query[:-1]
if old_query not in query_candidates:
old_candidates, _ = search_candidates(old_query, query_candidates, model=model, tokenizer=tokenizer, top_k=top_k)
else:
old_candidates, _ = query_candidates[old_query]
string = query[-1]
candidates = []
for string_type in ["hiragana", "katakana"]:
candidates_ = add_dakuten_handakuten(string, string_type=string_type)
for candidate_ in candidates_:
candidates += add_dashes(candidate_)
combinations = itertools.product(old_candidates, candidates)
new_candidates = [''.join(pair) for pair in combinations]
losses = compute_losses(new_candidates, model=model, tokenizer=tokenizer)
sorted_items = torch.sort(losses)
sorted_candidates = np.array(new_candidates)[sorted_items.indices.cpu().numpy()]
topk_candidates = sorted_candidates[:top_k].tolist()
topk_losses = sorted_items.values[:top_k].cpu().tolist()
query_candidates[query] = (topk_candidates, topk_losses)
return topk_candidates, topk_losses