isonuma commited on
Commit
dcd398a
1 Parent(s): fef3565

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +132 -0
utils.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch.nn import CrossEntropyLoss
4
+ import itertools
5
+
6
+ def is_hiragana_or_katakana(s):
7
+ for char in s:
8
+ if not ('\u3040' <= char <= '\u309F' or '\u30A0' <= char <= '\u30FF') or char == "ー":
9
+ return False
10
+ return True
11
+
12
+ def add_dakuten_handakuten(query, string_type):
13
+ def convert_to_hiragana(s):
14
+ """与えられた文字列を平仮名に変換する"""
15
+ result = []
16
+ for char in s:
17
+ if 'ァ' <= char <= 'ヶ': # 片仮名を平仮名に変換
18
+ result.append(chr(ord(char) - 96))
19
+ else:
20
+ result.append(char)
21
+ return ''.join(result)
22
+
23
+ def convert_to_katakana(s):
24
+ """与えられた文字列を片仮名に変換する"""
25
+ result = []
26
+ for char in s:
27
+ if 'ぁ' <= char <= 'ゖ': # 平仮名を片仮名に変換
28
+ result.append(chr(ord(char) + 96))
29
+ else:
30
+ result.append(char)
31
+ return ''.join(result)
32
+
33
+ if string_type == "hiragana":
34
+ s = convert_to_hiragana(query)
35
+ dakuon_map = {
36
+ 'か': 'が', 'き': 'ぎ', 'く': 'ぐ', 'け': 'げ', 'こ': 'ご',
37
+ 'さ': 'ざ', 'し': 'じ', 'す': 'ず', 'せ': 'ぜ', 'そ': 'ぞ',
38
+ 'た': 'だ', 'ち': 'ぢ', 'つ': 'づ', 'て': 'で', 'と': 'ど',
39
+ 'は': 'ば', 'ひ': 'び', 'ふ': 'ぶ', 'へ': 'べ', 'ほ': 'ぼ'
40
+ }
41
+ handakuon_map = {
42
+ 'は': 'ぱ', 'ひ': 'ぴ', 'ふ': 'ぷ', 'へ': 'ぺ', 'ほ': 'ぽ'
43
+ }
44
+ elif string_type == "katakana":
45
+ s = convert_to_katakana(query)
46
+ dakuon_map = {
47
+ 'カ': 'ガ', 'キ': 'ギ', 'ク': 'グ', 'ケ': 'ゲ', 'コ': 'ゴ',
48
+ 'サ': 'ザ', 'シ': 'ジ', 'ス': 'ズ', 'セ': 'ゼ', 'ソ': 'ゾ',
49
+ 'タ': 'ダ', 'チ': 'ヂ', 'ツ': 'ヅ', 'テ': 'デ', 'ト': 'ド',
50
+ 'ハ': 'バ', 'ヒ': 'ビ', 'フ': 'ブ', 'ヘ': 'ベ', 'ホ': 'ボ',
51
+ 'ウ': 'ヴ'
52
+ }
53
+ handakuon_map = {
54
+ 'ハ': 'パ', 'ヒ': 'ピ', 'フ': 'プ', 'ヘ': 'ペ', 'ホ': 'ポ'
55
+ }
56
+
57
+ # 文字ごとに元の文字と濁音・半濁音をリストにする
58
+ options = []
59
+ for char in s:
60
+ temp = [char]
61
+ if char in dakuon_map:
62
+ temp.append(dakuon_map[char])
63
+ if char in handakuon_map:
64
+ temp.append(handakuon_map[char])
65
+ options.append(temp)
66
+
67
+ # 全ての組み合わせを生成
68
+ candidates = list(itertools.product(*options))
69
+ return candidates
70
+
71
+ def add_dashes(s):
72
+ if not s:
73
+ return ['']
74
+
75
+ # 再帰的に文字列の先頭以外の部分に「ー」を挿入するパターンを取得
76
+ substr_patterns = add_dashes(s[1:])
77
+
78
+ # 現在の文字を含めたパターンを生成
79
+ result = []
80
+ for pattern in substr_patterns:
81
+ result.append(s[0] + pattern) # そのまま連結
82
+ result.append(s[0] + 'ー' + pattern) # 「ー」を挿入して連結
83
+
84
+ return result
85
+
86
+ def compute_losses(candidates, model, tokenizer):
87
+ inputs = tokenizer(candidates, return_tensors="pt", padding=True)
88
+ inputs["labels"] = inputs["input_ids"].masked_fill(inputs["input_ids"] == tokenizer.pad_token_id, -100)
89
+ inputs = inputs.to(model.device)
90
+
91
+ with torch.no_grad():
92
+ outputs = model(**inputs)
93
+
94
+ logits = outputs.logits
95
+ labels = inputs["labels"]
96
+
97
+ shift_logits = logits[..., :-1, :].contiguous()
98
+ shift_labels = labels[..., 1:].contiguous()
99
+ loss_fct = CrossEntropyLoss(reduction="none")
100
+
101
+ losses_flat = loss_fct(shift_logits.view(-1, model.config.vocab_size), shift_labels.view(-1))
102
+ losses_seq = losses_flat.view(shift_labels.shape)
103
+ mask_labels = shift_labels != tokenizer.pad_token_id
104
+ losses = torch.sum(losses_seq * mask_labels, -1) / mask_labels.sum(-1)
105
+
106
+ return losses
107
+
108
+ def search_candidates(query, query_candidates, model, tokenizer, top_k=100):
109
+ old_query = query[:-1]
110
+ if old_query not in query_candidates:
111
+ old_candidates, _ = search_candidates(old_query, query_candidates, model=model, tokenizer=tokenizer, top_k=top_k)
112
+ else:
113
+ old_candidates, _ = query_candidates[old_query]
114
+
115
+ string = query[-1]
116
+ candidates = []
117
+ for string_type in ["hiragana", "katakana"]:
118
+ candidates_ = add_dakuten_handakuten(string, string_type=string_type)
119
+ for candidate_ in candidates_:
120
+ candidates += add_dashes(candidate_)
121
+
122
+ combinations = itertools.product(old_candidates, candidates)
123
+ new_candidates = [''.join(pair) for pair in combinations]
124
+
125
+ losses = compute_losses(new_candidates, model=model, tokenizer=tokenizer)
126
+ sorted_items = torch.sort(losses)
127
+ sorted_candidates = np.array(new_candidates)[sorted_items.indices.cpu().numpy()]
128
+ topk_candidates = sorted_candidates[:top_k].tolist()
129
+ topk_losses = sorted_items.values[:top_k].cpu().tolist()
130
+
131
+ query_candidates[query] = (topk_candidates, topk_losses)
132
+ return topk_candidates, topk_losses