test code
Browse files- classifier.py +147 -0
- merger.py +181 -0
- requirements.txt +4 -1
- ru_errant.py +117 -18
classifier.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from collections import defaultdict
|
4 |
+
from string import punctuation
|
5 |
+
|
6 |
+
import Levenshtein
|
7 |
+
from errant.edit import Edit
|
8 |
+
|
9 |
+
|
10 |
+
def edit_to_tuple(edit: Edit, idx: int = 0) -> tuple[int, int, str, str, int]:
|
11 |
+
cor_toks_str = " ".join([tok.text for tok in edit.c_toks])
|
12 |
+
return [edit.o_start, edit.o_end, edit.type, cor_toks_str, idx]
|
13 |
+
|
14 |
+
|
15 |
+
def classify(edit: Edit) -> list[Edit]:
|
16 |
+
"""Classifies an Edit via updating its `type` attribute."""
|
17 |
+
# Insertion and deletion
|
18 |
+
if ((not edit.o_toks and edit.c_toks) or (edit.o_toks and not edit.c_toks)):
|
19 |
+
error_cats = get_one_sided_type(edit.o_toks, edit.c_toks)
|
20 |
+
elif edit.o_toks != edit.c_toks:
|
21 |
+
error_cats = get_two_sided_type(edit.o_toks, edit.c_toks)
|
22 |
+
else:
|
23 |
+
error_cats = {"NA": edit.c_toks[0].text}
|
24 |
+
new_edit_list = []
|
25 |
+
if error_cats:
|
26 |
+
for error_cat, correct_str in error_cats.items():
|
27 |
+
edit.type = error_cat
|
28 |
+
edit_tuple = edit_to_tuple(edit)
|
29 |
+
edit_tuple[3] = correct_str
|
30 |
+
new_edit_list.append(edit_tuple)
|
31 |
+
return new_edit_list
|
32 |
+
|
33 |
+
|
34 |
+
def get_edit_info(toks):
|
35 |
+
pos = []
|
36 |
+
dep = []
|
37 |
+
morph = dict()
|
38 |
+
for tok in toks:
|
39 |
+
pos.append(tok.tag_)
|
40 |
+
dep.append(tok.dep_)
|
41 |
+
morphs = str(tok.morph).split('|')
|
42 |
+
for m in morphs:
|
43 |
+
if len(m.strip()):
|
44 |
+
k, v = m.strip().split('=')
|
45 |
+
morph[k] = v
|
46 |
+
return pos, dep, morph
|
47 |
+
|
48 |
+
|
49 |
+
def get_one_sided_type(o_toks, c_toks):
|
50 |
+
"""Classifies a zero-to-one or one-to-zero error based on a token list."""
|
51 |
+
pos_list, _, _ = get_edit_info(o_toks if o_toks else c_toks)
|
52 |
+
if "PUNCT" in pos_list or "SPACE" in pos_list:
|
53 |
+
return {"PUNCT": c_toks[0].text if c_toks else ""}
|
54 |
+
return {"SPELL": c_toks[0].text if c_toks else ""}
|
55 |
+
|
56 |
+
|
57 |
+
def get_two_sided_type(o_toks, c_toks) -> dict[str, str]:
|
58 |
+
"""Classifies a one-to-one or one-to-many or many-to-one error based on token lists."""
|
59 |
+
# one-to-one cases
|
60 |
+
if len(o_toks) == len(c_toks) == 1:
|
61 |
+
if (
|
62 |
+
all(char in punctuation + " " for char in o_toks[0].text) and
|
63 |
+
all(char in punctuation + " " for char in c_toks[0].text)
|
64 |
+
):
|
65 |
+
return {"PUNCT": c_toks[0].text}
|
66 |
+
source_w, correct_w = o_toks[0].text, c_toks[0].text
|
67 |
+
if source_w != correct_w:
|
68 |
+
# if both string are lowercase or both are uppercase,
|
69 |
+
# and there is no "ё" in both, then it may be only "SPELL" error type
|
70 |
+
if (((source_w.islower() and correct_w.islower()) or
|
71 |
+
(source_w.isupper() and correct_w.isupper())) and
|
72 |
+
"ё" not in source_w + correct_w):
|
73 |
+
return {"SPELL": correct_w}
|
74 |
+
# edits with multiple errors (e.g. SPELL + CASE)
|
75 |
+
# Step 1. Make char-level Levenstein table
|
76 |
+
char_edits = Levenshtein.editops(source_w, correct_w)
|
77 |
+
# Step 2. Classify operations (CASE, YO, SPELL)
|
78 |
+
edits_classified = classify_char_edits(char_edits, source_w, correct_w)
|
79 |
+
# Step 3. Combine the same-typed errors into minimal string pairs
|
80 |
+
separated_edits = get_edit_strings(source_w, correct_w, edits_classified)
|
81 |
+
return separated_edits
|
82 |
+
# one-to-many and many-to-one cases
|
83 |
+
if all(char in punctuation + " " for char in o_toks.text + c_toks.text):
|
84 |
+
return {"PUNCT": c_toks.text}
|
85 |
+
joint_corr_str = " ".join([tok.text for tok in c_toks])
|
86 |
+
joint_corr_str = joint_corr_str.replace("- ", "-").replace(" -", "-")
|
87 |
+
return {"SPELL": joint_corr_str}
|
88 |
+
|
89 |
+
|
90 |
+
def classify_char_edits(char_edits, source_w, correct_w):
|
91 |
+
"""Classifies char-level Levenstein operations into SPELL, YO and CASE."""
|
92 |
+
edits_classified = []
|
93 |
+
for edit in char_edits:
|
94 |
+
if edit[0] == "replace":
|
95 |
+
if "ё" in [source_w[edit[1]], correct_w[edit[2]]]:
|
96 |
+
edits_classified.append((*edit, "YO"))
|
97 |
+
elif source_w[edit[1]].lower() == correct_w[edit[2]].lower():
|
98 |
+
edits_classified.append((*edit, "CASE"))
|
99 |
+
else:
|
100 |
+
if (
|
101 |
+
(source_w[edit[1]].islower() and correct_w[edit[2]].isupper()) or
|
102 |
+
(source_w[edit[1]].isupper() and correct_w[edit[2]].islower())
|
103 |
+
):
|
104 |
+
edits_classified.append((*edit, "CASE"))
|
105 |
+
edits_classified.append((*edit, "SPELL"))
|
106 |
+
else:
|
107 |
+
edits_classified.append((*edit, "SPELL"))
|
108 |
+
return edits_classified
|
109 |
+
|
110 |
+
|
111 |
+
def get_edit_strings(source: str, correction: str,
|
112 |
+
edits_classified: list[tuple]) -> dict[str, str]:
|
113 |
+
"""
|
114 |
+
Applies classified (SPELL, YO and CASE) char operations to source word separately.
|
115 |
+
Returns a dict mapping error type to source string with corrections of this type only.
|
116 |
+
"""
|
117 |
+
separated_edits = defaultdict(lambda: source)
|
118 |
+
shift = 0 # char position shift to consider on deletions and insertions
|
119 |
+
for edit in edits_classified:
|
120 |
+
edit_type = edit[3]
|
121 |
+
curr_src = separated_edits[edit_type]
|
122 |
+
if edit_type == "CASE": # SOURCE letter spelled in CORRECTION case
|
123 |
+
if correction[edit[2]].isupper():
|
124 |
+
correction_char = source[edit[1]].upper()
|
125 |
+
else:
|
126 |
+
correction_char = source[edit[1]].lower()
|
127 |
+
else:
|
128 |
+
if edit[0] == "delete":
|
129 |
+
correction_char = ""
|
130 |
+
elif edit[0] == "insert":
|
131 |
+
correction_char = correction[edit[2]]
|
132 |
+
elif source[edit[1]].isupper():
|
133 |
+
correction_char = correction[edit[2]].upper()
|
134 |
+
else:
|
135 |
+
correction_char = correction[edit[2]].lower()
|
136 |
+
if edit[0] == "replace":
|
137 |
+
separated_edits[edit_type] = curr_src[:edit[1] + shift] + correction_char + \
|
138 |
+
curr_src[edit[1]+shift + 1:]
|
139 |
+
elif edit[0] == "delete":
|
140 |
+
separated_edits[edit_type] = curr_src[:edit[1] + shift] + \
|
141 |
+
curr_src[edit[1]+shift + 1:]
|
142 |
+
shift -= 1
|
143 |
+
elif edit[0] == "insert":
|
144 |
+
separated_edits[edit_type] = curr_src[:edit[1] + shift] + correction_char + \
|
145 |
+
curr_src[edit[1]+shift:]
|
146 |
+
shift += 1
|
147 |
+
return dict(separated_edits)
|
merger.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import itertools
|
4 |
+
import re
|
5 |
+
from string import punctuation
|
6 |
+
|
7 |
+
import Levenshtein
|
8 |
+
from errant.alignment import Alignment
|
9 |
+
from errant.edit import Edit
|
10 |
+
|
11 |
+
|
12 |
+
def get_rule_edits(alignment: Alignment) -> list[Edit]:
|
13 |
+
"""Groups word-level alignment according to merging rules."""
|
14 |
+
edits = []
|
15 |
+
# Split alignment into groups
|
16 |
+
alignment_groups = group_alignment(alignment, "new")
|
17 |
+
for op, group in alignment_groups:
|
18 |
+
group = list(group)
|
19 |
+
# Ignore M
|
20 |
+
if op == "M":
|
21 |
+
continue
|
22 |
+
# T is always split
|
23 |
+
if op == "T":
|
24 |
+
for seq in group:
|
25 |
+
edits.append(Edit(alignment.orig, alignment.cor, seq[1:]))
|
26 |
+
# Process D, I and S subsequence
|
27 |
+
else:
|
28 |
+
processed = process_seq(group, alignment)
|
29 |
+
# Turn the processed sequence into edits
|
30 |
+
for seq in processed:
|
31 |
+
edits.append(Edit(alignment.orig, alignment.cor, seq[1:]))
|
32 |
+
return edits
|
33 |
+
|
34 |
+
|
35 |
+
def group_alignment(alignment: Alignment, mode: str = "default") -> list[tuple[str, list[tuple]]]:
|
36 |
+
"""
|
37 |
+
Does initial alignment grouping:
|
38 |
+
1. Make groups of MDM, MIM od MSM.
|
39 |
+
2. In remaining operations, make groups of Ms, groups of Ts, and D/I/Ss.
|
40 |
+
Do not group what was on the sides of M[DIS]M: SSMDMS -> [SS, MDM, S], not [MDM, SSS].
|
41 |
+
3. Sort groups by the order in which they appear in the alignment.
|
42 |
+
"""
|
43 |
+
if mode == "new":
|
44 |
+
op_groups = []
|
45 |
+
# Format operation types sequence as string to use regex sequence search
|
46 |
+
all_ops_seq = "".join([op[0][0] for op in alignment.align_seq])
|
47 |
+
# Find M[DIS]M groups and merge (need them to detect hyphen vs. space spelling)
|
48 |
+
ungrouped_ids = list(range(len(alignment.align_seq)))
|
49 |
+
for match in re.finditer("M[DIS]M", all_ops_seq):
|
50 |
+
start, end = match.start(), match.end()
|
51 |
+
op_groups.append(("MSM", alignment.align_seq[start:end]))
|
52 |
+
for idx in range(start, end):
|
53 |
+
ungrouped_ids.remove(idx)
|
54 |
+
# Group remaining operations by default rules (groups of M, T and rest)
|
55 |
+
if ungrouped_ids:
|
56 |
+
def get_group_type(operation):
|
57 |
+
return operation if operation in {"M", "T"} else "DIS"
|
58 |
+
curr_group = [alignment.align_seq[ungrouped_ids[0]]]
|
59 |
+
last_oper_type = get_group_type(curr_group[0][0][0])
|
60 |
+
for i, idx in enumerate(ungrouped_ids[1:], start=1):
|
61 |
+
operation = alignment.align_seq[idx]
|
62 |
+
oper_type = get_group_type(operation[0][0])
|
63 |
+
if (oper_type == last_oper_type and
|
64 |
+
(idx - ungrouped_ids[i-1] == 1 or oper_type in {"M", "T"})):
|
65 |
+
curr_group.append(operation)
|
66 |
+
else:
|
67 |
+
op_groups.append((last_oper_type, curr_group))
|
68 |
+
curr_group = [operation]
|
69 |
+
last_oper_type = oper_type
|
70 |
+
if curr_group:
|
71 |
+
op_groups.append((last_oper_type, curr_group))
|
72 |
+
# Sort groups by the start id of the first group entry
|
73 |
+
op_groups = sorted(op_groups, key=lambda x: x[1][0][1])
|
74 |
+
else:
|
75 |
+
grouped = itertools.groupby(alignment.align_seq,
|
76 |
+
lambda x: x[0][0] if x[0][0] in {"M", "T"} else False)
|
77 |
+
op_groups = [(op, list(group)) for op, group in grouped]
|
78 |
+
return op_groups
|
79 |
+
|
80 |
+
|
81 |
+
def process_seq(seq: list[tuple], alignment: Alignment) -> list[tuple]:
|
82 |
+
"""Applies merging rules to previously formed alignment groups (`seq`)."""
|
83 |
+
# Return single alignments
|
84 |
+
if len(seq) <= 1:
|
85 |
+
return seq
|
86 |
+
# Get the ops for the whole sequence
|
87 |
+
ops = [op[0] for op in seq]
|
88 |
+
|
89 |
+
# Get indices of all start-end combinations in the seq: 012 = 01, 02, 12
|
90 |
+
combos = list(itertools.combinations(range(0, len(seq)), 2))
|
91 |
+
# Sort them starting with largest spans first
|
92 |
+
combos.sort(key=lambda x: x[1] - x[0], reverse=True)
|
93 |
+
# Loop through combos
|
94 |
+
for start, end in combos:
|
95 |
+
# Ignore ranges that do NOT contain a substitution, deletion or insertion.
|
96 |
+
if not any(type_ in ops[start:end + 1] for type_ in ["D", "I", "S"]):
|
97 |
+
continue
|
98 |
+
# Merge all D xor I ops. (95% of human multi-token edits contain S).
|
99 |
+
if set(ops[start:end + 1]) == {"D"} or set(ops[start:end + 1]) == {"I"}:
|
100 |
+
return (process_seq(seq[:start], alignment)
|
101 |
+
+ merge_edits(seq[start:end + 1])
|
102 |
+
+ process_seq(seq[end + 1:], alignment))
|
103 |
+
# Get the tokens in orig and cor.
|
104 |
+
o = alignment.orig[seq[start][1]:seq[end][2]]
|
105 |
+
c = alignment.cor[seq[start][3]:seq[end][4]]
|
106 |
+
if ops[start:end + 1] in [["M", "D", "M"], ["M", "I", "M"], ["M", "S", "M"]]:
|
107 |
+
# merge hyphens
|
108 |
+
if (o[start + 1].text == "-" or c[start + 1].text == "-") and len(o) != len(c):
|
109 |
+
return (process_seq(seq[:start], alignment)
|
110 |
+
+ merge_edits(seq[start:end + 1])
|
111 |
+
+ process_seq(seq[end + 1:], alignment))
|
112 |
+
# if it is not a hyphen-space edit, return only punct edit
|
113 |
+
return seq[start + 1: end]
|
114 |
+
# Merge possessive suffixes: [friends -> friend 's]
|
115 |
+
if o[-1].tag_ == "POS" or c[-1].tag_ == "POS":
|
116 |
+
return (process_seq(seq[:end - 1], alignment)
|
117 |
+
+ merge_edits(seq[end - 1:end + 1])
|
118 |
+
+ process_seq(seq[end + 1:], alignment))
|
119 |
+
# Case changes
|
120 |
+
if o[-1].lower == c[-1].lower:
|
121 |
+
# Merge first token I or D: [Cat -> The big cat]
|
122 |
+
if (start == 0 and
|
123 |
+
(len(o) == 1 and c[0].text[0].isupper()) or
|
124 |
+
(len(c) == 1 and o[0].text[0].isupper())):
|
125 |
+
return (merge_edits(seq[start:end + 1])
|
126 |
+
+ process_seq(seq[end + 1:], alignment))
|
127 |
+
# Merge with previous punctuation: [, we -> . We], [we -> . We]
|
128 |
+
if (len(o) > 1 and is_punct(o[-2])) or \
|
129 |
+
(len(c) > 1 and is_punct(c[-2])):
|
130 |
+
return (process_seq(seq[:end - 1], alignment)
|
131 |
+
+ merge_edits(seq[end - 1:end + 1])
|
132 |
+
+ process_seq(seq[end + 1:], alignment))
|
133 |
+
# Merge whitespace/hyphens: [acat -> a cat], [sub - way -> subway]
|
134 |
+
s_str = re.sub("['-]", "", "".join([tok.lower_ for tok in o]))
|
135 |
+
t_str = re.sub("['-]", "", "".join([tok.lower_ for tok in c]))
|
136 |
+
if s_str == t_str or s_str.replace(" ", "") == t_str.replace(" ", ""):
|
137 |
+
return (process_seq(seq[:start], alignment)
|
138 |
+
+ merge_edits(seq[start:end + 1])
|
139 |
+
+ process_seq(seq[end + 1:], alignment))
|
140 |
+
# Merge same POS or auxiliary/infinitive/phrasal verbs:
|
141 |
+
# [to eat -> eating], [watch -> look at]
|
142 |
+
pos_set = set([tok.pos for tok in o] + [tok.pos for tok in c])
|
143 |
+
if len(o) != len(c) and (len(pos_set) == 1 or pos_set.issubset({"AUX", "PART", "VERB"})):
|
144 |
+
return (process_seq(seq[:start], alignment)
|
145 |
+
+ merge_edits(seq[start:end + 1])
|
146 |
+
+ process_seq(seq[end + 1:], alignment))
|
147 |
+
# Split rules take effect when we get to smallest chunks
|
148 |
+
if end - start < 2:
|
149 |
+
# Split adjacent substitutions
|
150 |
+
if len(o) == len(c) == 2:
|
151 |
+
return (process_seq(seq[:start + 1], alignment)
|
152 |
+
+ process_seq(seq[start + 1:], alignment))
|
153 |
+
# Split similar substitutions at sequence boundaries
|
154 |
+
if ((ops[start] == "S" and char_cost(o[0].text, c[0].text) > 0.75) or
|
155 |
+
(ops[end] == "S" and char_cost(o[-1].text, c[-1].text) > 0.75)):
|
156 |
+
return (process_seq(seq[:start + 1], alignment)
|
157 |
+
+ process_seq(seq[start + 1:], alignment))
|
158 |
+
# Split final determiners
|
159 |
+
if (end == len(seq) - 1 and
|
160 |
+
((ops[-1] in {"D", "S"} and o[-1].pos == "DET") or
|
161 |
+
(ops[-1] in {"I", "S"} and c[-1].pos == "DET"))):
|
162 |
+
return process_seq(seq[:-1], alignment) + [seq[-1]]
|
163 |
+
return seq
|
164 |
+
|
165 |
+
|
166 |
+
def is_punct(token) -> bool:
|
167 |
+
return token.text in punctuation
|
168 |
+
|
169 |
+
|
170 |
+
def char_cost(a: str, b: str) -> float:
|
171 |
+
"""Calculate the cost of character alignment; i.e. char similarity."""
|
172 |
+
|
173 |
+
return Levenshtein.ratio(a, b)
|
174 |
+
|
175 |
+
|
176 |
+
def merge_edits(seq: list[tuple]) -> list[tuple]:
|
177 |
+
"""Merge the input alignment sequence to a single edit span."""
|
178 |
+
|
179 |
+
if seq:
|
180 |
+
return [("X", seq[0][1], seq[-1][2], seq[0][3], seq[-1][4])]
|
181 |
+
return seq
|
requirements.txt
CHANGED
@@ -1 +1,4 @@
|
|
1 |
-
git+https://github.com/huggingface/evaluate@main
|
|
|
|
|
|
|
|
1 |
+
git+https://github.com/huggingface/evaluate@main
|
2 |
+
git+https://github.com/Askinkaty/errant/@4183e57
|
3 |
+
Levenshtein
|
4 |
+
ru-core-news-lg @ https://huggingface.co/spacy/ru_core_news_lg/resolve/main/ru_core_news_lg-any-py3-none-any.whl
|
ru_errant.py
CHANGED
@@ -12,11 +12,26 @@
|
|
12 |
# See the License for the specific language governing permissions and
|
13 |
# limitations under the License.
|
14 |
"""TODO: Add a description here."""
|
|
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
import evaluate
|
17 |
import datasets
|
18 |
|
19 |
-
|
20 |
# TODO: Add BibTeX citation
|
21 |
_CITATION = """\
|
22 |
@InProceedings{huggingface:module,
|
@@ -31,7 +46,6 @@ _DESCRIPTION = """\
|
|
31 |
This new module is designed to solve this great ML task and is crafted with a lot of care.
|
32 |
"""
|
33 |
|
34 |
-
|
35 |
# TODO: Add description of the arguments of the module here
|
36 |
_KWARGS_DESCRIPTION = """
|
37 |
Calculates how good are predictions given some references, using certain scores
|
@@ -57,6 +71,40 @@ Examples:
|
|
57 |
BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt"
|
58 |
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
61 |
class RuErrant(evaluate.Metric):
|
62 |
"""TODO: Short description of my evaluation module."""
|
@@ -70,26 +118,77 @@ class RuErrant(evaluate.Metric):
|
|
70 |
citation=_CITATION,
|
71 |
inputs_description=_KWARGS_DESCRIPTION,
|
72 |
# This defines the format of each prediction and reference
|
73 |
-
features=datasets.Features(
|
74 |
-
|
75 |
-
|
76 |
-
|
|
|
|
|
|
|
77 |
# Homepage of the module for documentation
|
78 |
homepage="http://module.homepage",
|
79 |
# Additional links to the codebase or references
|
80 |
-
codebase_urls=["
|
81 |
reference_urls=["http://path.to.reference.url/new_module"]
|
82 |
)
|
83 |
|
84 |
def _download_and_prepare(self, dl_manager):
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
# See the License for the specific language governing permissions and
|
13 |
# limitations under the License.
|
14 |
"""TODO: Add a description here."""
|
15 |
+
from __future__ import annotations
|
16 |
|
17 |
+
import re
|
18 |
+
from collections import Counter, namedtuple
|
19 |
+
from typing import Iterable
|
20 |
+
from tqdm.auto import tqdm
|
21 |
+
|
22 |
+
from errant.annotator import Annotator
|
23 |
+
from errant.commands.compare_m2 import process_edits
|
24 |
+
from errant.commands.compare_m2 import evaluate_edits
|
25 |
+
from errant.commands.compare_m2 import merge_dict
|
26 |
+
from errant.edit import Edit
|
27 |
+
import spacy
|
28 |
+
from spacy.tokenizer import Tokenizer
|
29 |
+
from spacy.util import compile_prefix_regex, compile_infix_regex, compile_suffix_regex
|
30 |
+
import classifier
|
31 |
+
import merger
|
32 |
import evaluate
|
33 |
import datasets
|
34 |
|
|
|
35 |
# TODO: Add BibTeX citation
|
36 |
_CITATION = """\
|
37 |
@InProceedings{huggingface:module,
|
|
|
46 |
This new module is designed to solve this great ML task and is crafted with a lot of care.
|
47 |
"""
|
48 |
|
|
|
49 |
# TODO: Add description of the arguments of the module here
|
50 |
_KWARGS_DESCRIPTION = """
|
51 |
Calculates how good are predictions given some references, using certain scores
|
|
|
71 |
BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt"
|
72 |
|
73 |
|
74 |
+
def update_spacy_tokenizer(nlp):
|
75 |
+
"""
|
76 |
+
Changes Spacy tokenizer to parse additional patterns.
|
77 |
+
"""
|
78 |
+
infix_re = compile_infix_regex(nlp.Defaults.infixes[:-1] + ["\]\("])
|
79 |
+
simple_url_re = re.compile(r'''^https?://''')
|
80 |
+
nlp.tokenizer = Tokenizer(
|
81 |
+
nlp.vocab,
|
82 |
+
prefix_search=compile_prefix_regex(nlp.Defaults.prefixes + ['\\\\\"']).search,
|
83 |
+
suffix_search=compile_suffix_regex(nlp.Defaults.suffixes + ['\\\\']).search,
|
84 |
+
infix_finditer=infix_re.finditer,
|
85 |
+
token_match=None,
|
86 |
+
url_match=simple_url_re.match
|
87 |
+
)
|
88 |
+
return nlp
|
89 |
+
|
90 |
+
|
91 |
+
def annotate_errors(self, orig: str, cor: str, merging: str = "rules") -> list[Edit]:
|
92 |
+
"""
|
93 |
+
Overrides `Annotator.annotate()` function to allow multiple errors per token.
|
94 |
+
This is nesessary to parse combined errors, e.g.:
|
95 |
+
["werd", "Word"] >>> Errors: ["SPELL", "CASE"]
|
96 |
+
The `classify()` method called inside is implemented in ruerrant_classifier.py
|
97 |
+
(also overrides the original classifier).
|
98 |
+
"""
|
99 |
+
|
100 |
+
alignment = self.annotator.align(orig, cor, False)
|
101 |
+
edits = self.annotator.merge(alignment, merging)
|
102 |
+
classified_edits = []
|
103 |
+
for edit in edits:
|
104 |
+
classified_edits.extend(self.annotator.classify(edit))
|
105 |
+
return sorted(classified_edits, key=lambda x: (x[0], x[2]))
|
106 |
+
|
107 |
+
|
108 |
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
109 |
class RuErrant(evaluate.Metric):
|
110 |
"""TODO: Short description of my evaluation module."""
|
|
|
118 |
citation=_CITATION,
|
119 |
inputs_description=_KWARGS_DESCRIPTION,
|
120 |
# This defines the format of each prediction and reference
|
121 |
+
features=datasets.Features(
|
122 |
+
{
|
123 |
+
"sources": datasets.Value("string", id="sequence"),
|
124 |
+
"corrections": datasets.Value("string", id="sequence"),
|
125 |
+
"answers": datasets.Value("string", id="sequence"),
|
126 |
+
}
|
127 |
+
),
|
128 |
# Homepage of the module for documentation
|
129 |
homepage="http://module.homepage",
|
130 |
# Additional links to the codebase or references
|
131 |
+
codebase_urls=["https://github.com/ai-forever/sage"],
|
132 |
reference_urls=["http://path.to.reference.url/new_module"]
|
133 |
)
|
134 |
|
135 |
def _download_and_prepare(self, dl_manager):
|
136 |
+
self.annotator = Annotator("ru",
|
137 |
+
nlp=update_spacy_tokenizer(spacy.load("ru_core_news_lg")),
|
138 |
+
merger=merger,
|
139 |
+
classifier=classifier)
|
140 |
+
|
141 |
+
def _compute(self, sources, corrections, answers):
|
142 |
+
"""
|
143 |
+
Evaluates iterables of sources, hyp and ref corrections with ERRANT metric.
|
144 |
+
|
145 |
+
Args:
|
146 |
+
sources (Iterable[str]): an iterable of source texts;
|
147 |
+
corrections (Iterable[str]): an iterable of gold corrections for the source texts;
|
148 |
+
answers (Iterable[str]): an iterable of evaluated corrections for the source texts;
|
149 |
+
|
150 |
+
Returns:
|
151 |
+
dict[str, tuple[float, ...]]: a dict mapping error categories to the corresponding
|
152 |
+
P, R, F1 metric values.
|
153 |
+
"""
|
154 |
+
best_dict = Counter({"tp": 0, "fp": 0, "fn": 0})
|
155 |
+
best_cats = {}
|
156 |
+
sents = zip(sources, corrections, answers)
|
157 |
+
pb = tqdm(sents, desc="Calculating errant metric", total=len(sources))
|
158 |
+
for sent_id, sent in enumerate(pb):
|
159 |
+
src = self.annotator.parse(sent[0])
|
160 |
+
ref = self.annotator.parse(sent[1])
|
161 |
+
hyp = self.annotator.parse(sent[2])
|
162 |
+
# Align hyp and ref corrections and annotate errors
|
163 |
+
hyp_edits = self.annotate_errors(src, hyp)
|
164 |
+
ref_edits = self.annotate_errors(src, ref)
|
165 |
+
# Process the edits for detection/correction based on args
|
166 |
+
ProcessingArgs = namedtuple("ProcessingArgs",
|
167 |
+
["dt", "ds", "single", "multi", "filt", "cse"],
|
168 |
+
defaults=[False, False, False, False, [], True])
|
169 |
+
processing_args = ProcessingArgs()
|
170 |
+
hyp_dict = process_edits(hyp_edits, processing_args)
|
171 |
+
ref_dict = process_edits(ref_edits, processing_args)
|
172 |
+
# Evaluate edits and get best TP, FP, FN hyp+ref combo.
|
173 |
+
EvaluationArgs = namedtuple("EvaluationArgs",
|
174 |
+
["beta", "verbose"],
|
175 |
+
defaults=[1.0, False])
|
176 |
+
evaluation_args = EvaluationArgs()
|
177 |
+
count_dict, cat_dict = evaluate_edits(
|
178 |
+
hyp_dict, ref_dict, best_dict, sent_id, evaluation_args)
|
179 |
+
# Merge these dicts with best_dict and best_cats
|
180 |
+
best_dict += Counter(count_dict) # corpus-level TP, FP, FN
|
181 |
+
best_cats = merge_dict(best_cats, cat_dict) # corpus-level errortype-wise TP, FP, FN
|
182 |
+
cat_prf = {}
|
183 |
+
for cat, values in best_cats.items():
|
184 |
+
tp, fp, fn = values # fp - extra corrections, fn - missed corrections
|
185 |
+
p = float(tp) / (tp + fp) if tp + fp else 1.0
|
186 |
+
r = float(tp) / (tp + fn) if tp + fn else 1.0
|
187 |
+
f = (2 * p * r) / (p + r) if p + r else 0.0
|
188 |
+
cat_prf[cat] = (p, r, f)
|
189 |
+
|
190 |
+
for error_category in ["CASE", "PUNCT", "SPELL", "YO"]:
|
191 |
+
if error_category not in cat_prf:
|
192 |
+
cat_prf[error_category] = (1.0, 1.0, 1.0)
|
193 |
+
|
194 |
+
return cat_prf
|