Upload 2 files
Browse files
scripts/noise_detecter.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unicodedata
|
2 |
+
from typing import List, Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from transformers import AutoModelForTokenClassification, AutoTokenizer
|
6 |
+
|
7 |
+
|
8 |
+
class NoiseDetector:
|
9 |
+
def __init__(self, model_path: str):
|
10 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
11 |
+
self.model = AutoModelForTokenClassification.from_pretrained(model_path).to(
|
12 |
+
self.device
|
13 |
+
)
|
14 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
15 |
+
self.model.eval()
|
16 |
+
|
17 |
+
def _normalize_text(self, text: str) -> str:
|
18 |
+
return unicodedata.normalize("NFKC", text)
|
19 |
+
|
20 |
+
def _convert_token_spans_to_char_spans(
|
21 |
+
self,
|
22 |
+
text: str,
|
23 |
+
noise_token_indices: List[int],
|
24 |
+
offset_mapping: List[Tuple[int, int]],
|
25 |
+
) -> List[Tuple[int, int]]:
|
26 |
+
char_spans = []
|
27 |
+
current_span = None
|
28 |
+
|
29 |
+
for idx, (is_noise, (start, end)) in enumerate(
|
30 |
+
zip(noise_token_indices, offset_mapping)
|
31 |
+
):
|
32 |
+
# Skip special tokens (CLS, SEP, etc.)
|
33 |
+
if start == end == 0:
|
34 |
+
continue
|
35 |
+
|
36 |
+
if is_noise:
|
37 |
+
if current_span is None:
|
38 |
+
current_span = [start, end]
|
39 |
+
else:
|
40 |
+
current_span[1] = end
|
41 |
+
elif current_span is not None:
|
42 |
+
char_spans.append(tuple(current_span))
|
43 |
+
current_span = None
|
44 |
+
|
45 |
+
# Don't forget to add the last span if it exists
|
46 |
+
if current_span is not None:
|
47 |
+
char_spans.append(tuple(current_span))
|
48 |
+
|
49 |
+
return char_spans
|
50 |
+
|
51 |
+
def detect(
|
52 |
+
self, texts: List[str], threshold: float = 0.5
|
53 |
+
) -> List[List[Tuple[int, int]]]:
|
54 |
+
"""
|
55 |
+
Detect noise spans in the given texts.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
texts: List of input texts
|
59 |
+
threshold: Confidence threshold for noise detection (default: 0.5)
|
60 |
+
|
61 |
+
Returns:
|
62 |
+
List of lists containing (start, end) character positions of detected noise spans for each text
|
63 |
+
"""
|
64 |
+
results = []
|
65 |
+
|
66 |
+
with torch.no_grad():
|
67 |
+
for text in texts:
|
68 |
+
# Normalize text
|
69 |
+
normalized_text = self._normalize_text(text)
|
70 |
+
|
71 |
+
# Tokenize
|
72 |
+
tokens = self.tokenizer(
|
73 |
+
normalized_text,
|
74 |
+
truncation=True,
|
75 |
+
return_offsets_mapping=True,
|
76 |
+
return_tensors="pt",
|
77 |
+
)
|
78 |
+
|
79 |
+
# Move to device
|
80 |
+
input_ids = tokens["input_ids"].to(self.device)
|
81 |
+
attention_mask = tokens["attention_mask"].to(self.device)
|
82 |
+
|
83 |
+
# Get predictions
|
84 |
+
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
85 |
+
logits = outputs.logits
|
86 |
+
|
87 |
+
# Convert logits to probabilities
|
88 |
+
probs = torch.softmax(logits, dim=-1)
|
89 |
+
|
90 |
+
# Get noise predictions (class 1)
|
91 |
+
noise_probs = probs[0, :, 1].cpu().numpy()
|
92 |
+
noise_predictions = (noise_probs > threshold).astype(int)
|
93 |
+
|
94 |
+
# Convert token-level predictions to character spans
|
95 |
+
char_spans = self._convert_token_spans_to_char_spans(
|
96 |
+
normalized_text,
|
97 |
+
noise_predictions,
|
98 |
+
tokens["offset_mapping"][0].tolist(),
|
99 |
+
)
|
100 |
+
|
101 |
+
results.append(char_spans)
|
102 |
+
|
103 |
+
return results
|
104 |
+
|
105 |
+
def detect_and_highlight(
|
106 |
+
self, texts: List[str], threshold: float = 0.5
|
107 |
+
) -> List[str]:
|
108 |
+
"""
|
109 |
+
Detect noise spans and return texts with noise sections highlighted.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
texts: List of input texts
|
113 |
+
threshold: Confidence threshold for noise detection (default: 0.5)
|
114 |
+
|
115 |
+
Returns:
|
116 |
+
List of texts with noise sections wrapped in [NOISE]...[/NOISE] tags
|
117 |
+
"""
|
118 |
+
noise_spans = self.detect(texts, threshold)
|
119 |
+
highlighted_texts = []
|
120 |
+
|
121 |
+
for text, spans in zip(texts, noise_spans):
|
122 |
+
if not spans:
|
123 |
+
highlighted_texts.append(text)
|
124 |
+
continue
|
125 |
+
|
126 |
+
# Sort spans by start position
|
127 |
+
spans = sorted(spans)
|
128 |
+
|
129 |
+
# Build highlighted text
|
130 |
+
result = []
|
131 |
+
last_end = 0
|
132 |
+
|
133 |
+
for start, end in spans:
|
134 |
+
# Add text before noise
|
135 |
+
result.append(text[last_end:start])
|
136 |
+
# Add highlighted noise
|
137 |
+
# もし長さがN以下なら、ハイライトしない
|
138 |
+
if end - start > 3:
|
139 |
+
result.append(f"[NOISE]{text[start:end]}[/NOISE]")
|
140 |
+
else:
|
141 |
+
result.append(text[start:end])
|
142 |
+
# result.append(f"[NOISE]{text[start:end]}[/NOISE]")
|
143 |
+
last_end = end
|
144 |
+
|
145 |
+
# Add remaining text
|
146 |
+
result.append(text[last_end:])
|
147 |
+
|
148 |
+
highlighted_texts.append("".join(result))
|
149 |
+
|
150 |
+
return highlighted_texts
|
151 |
+
|
152 |
+
|
153 |
+
def main():
|
154 |
+
model_path = "hotchpotch/fineweb-2-japanese-text-cleaner"
|
155 |
+
detector = NoiseDetector(model_path)
|
156 |
+
|
157 |
+
NOISE_TEXT = """
|
158 |
+
この文章は90日以上更新の無いサイトに表示されています。
|
159 |
+
ログイン ログアウト
|
160 |
+
|
161 |
+
本当に必要な文章以外にも、さまざまなノイズが含まれていることがあります。例えば、この文章もその一例です。本来不要なテキストが入ってしまうことがこのようにあるでしょう。
|
162 |
+
|
163 |
+
今なら50%オフ!クリックしてリンク先の商品を表示
|
164 |
+
|
165 |
+
とりわけ文章長が短い場合、文章のほとんどがノイズを含む可能性があります。それらを取り除くことで、より高品質の文章を抽出できないかと考えています。
|
166 |
+
|
167 |
+
前のページ 次のページ
|
168 |
+
""".strip()
|
169 |
+
|
170 |
+
texts = [
|
171 |
+
NOISE_TEXT,
|
172 |
+
"これは正常なテキストです。しかし、ここに🤣絵文字があります。そして普通の文章が続きます。",
|
173 |
+
"普通の文章です。ASCII ART(^_^)があります。最後も普通です。",
|
174 |
+
"ログイン 文章の密ベクトルは、情報検索・文章判別・類似文章抽出など、さまざまな用途に使うことができます。しかしながら最先端のTransformerモデルは小さいモデルでも、とりわけCPU環境では処理速度が遅いため実用でないこともしばしばあります。この課題を解決する新しいアプローチとして、先日公開されたTransformerモデル「ではない」 StaticEmbeddingモデルは、例えば intfloat/multilingual-e5-small (以下mE5-small)とのベンチマーク比較では85%のスコアという最低十分な性能で、何よりCPUで動作時に126倍高速に文ベクトルを作成することができる、という驚きの速度です。 記事の一覧 >",
|
175 |
+
]
|
176 |
+
|
177 |
+
highlighted_texts = detector.detect_and_highlight(texts, threshold=0.7)
|
178 |
+
for text in highlighted_texts:
|
179 |
+
print(f"\n{text}")
|
180 |
+
|
181 |
+
|
182 |
+
if __name__ == "__main__":
|
183 |
+
main()
|
scripts/trainer-fineweb-2-japanese-text-cleaner.py
ADDED
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import unicodedata
|
4 |
+
|
5 |
+
import datasets
|
6 |
+
import evaluate
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from datasets import load_dataset
|
12 |
+
from sklearn.metrics import classification_report, confusion_matrix
|
13 |
+
from transformers import (
|
14 |
+
AutoModelForTokenClassification,
|
15 |
+
AutoTokenizer,
|
16 |
+
DataCollatorForTokenClassification,
|
17 |
+
Trainer,
|
18 |
+
TrainingArguments,
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
def compute_f05_score(precision, recall, beta=0.5):
|
23 |
+
"""Calculate F0.5 score from precision and recall."""
|
24 |
+
if precision <= 0 or recall <= 0:
|
25 |
+
return 0.0
|
26 |
+
|
27 |
+
return (1 + beta**2) * (precision * recall) / ((beta**2 * precision) + recall)
|
28 |
+
|
29 |
+
|
30 |
+
def custom_classification_report(y_true, y_pred):
|
31 |
+
"""Generate classification report with F0.5 score."""
|
32 |
+
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
|
33 |
+
|
34 |
+
# Calculate precision, recall, f1, and support for each class
|
35 |
+
precision, recall, f1, support = precision_recall_fscore_support(y_true, y_pred)
|
36 |
+
|
37 |
+
# Calculate F0.5 scores
|
38 |
+
f05_scores = [compute_f05_score(p, r) for p, r in zip(precision, recall)]
|
39 |
+
|
40 |
+
# Calculate accuracy
|
41 |
+
accuracy = accuracy_score(y_true, y_pred)
|
42 |
+
|
43 |
+
# Generate report string
|
44 |
+
report = " precision recall f1-score f05-score support\n\n"
|
45 |
+
|
46 |
+
# Add metrics for each class
|
47 |
+
for i in range(len(precision)):
|
48 |
+
report += f" {i}"
|
49 |
+
report += f" {precision[i]:.2f} {recall[i]:.2f}"
|
50 |
+
report += f" {f1[i]:.2f} {f05_scores[i]:.2f} {support[i]}\n"
|
51 |
+
|
52 |
+
report += "\n"
|
53 |
+
|
54 |
+
# Calculate and add averages
|
55 |
+
n_samples = sum(support)
|
56 |
+
|
57 |
+
# Macro average
|
58 |
+
macro_precision = np.mean(precision)
|
59 |
+
macro_recall = np.mean(recall)
|
60 |
+
macro_f1 = np.mean(f1)
|
61 |
+
macro_f05 = np.mean(f05_scores)
|
62 |
+
report += f" macro avg {macro_precision:.2f} {macro_recall:.2f}"
|
63 |
+
report += f" {macro_f1:.2f} {macro_f05:.2f} {n_samples}\n"
|
64 |
+
|
65 |
+
# Weighted average
|
66 |
+
weighted_precision = np.average(precision, weights=support)
|
67 |
+
weighted_recall = np.average(recall, weights=support)
|
68 |
+
weighted_f1 = np.average(f1, weights=support)
|
69 |
+
weighted_f05 = np.average(f05_scores, weights=support)
|
70 |
+
report += f"weighted avg {weighted_precision:.2f} {weighted_recall:.2f}"
|
71 |
+
report += f" {weighted_f1:.2f} {weighted_f05:.2f} {n_samples}\n"
|
72 |
+
|
73 |
+
# Add accuracy
|
74 |
+
report += f" accuracy {accuracy:.2f} {n_samples}\n"
|
75 |
+
|
76 |
+
return report
|
77 |
+
|
78 |
+
|
79 |
+
def compute_metrics(eval_pred):
|
80 |
+
precision_metric = evaluate.load("precision")
|
81 |
+
recall_metric = evaluate.load("recall")
|
82 |
+
f1_metric = evaluate.load("f1")
|
83 |
+
|
84 |
+
predictions, labels = eval_pred
|
85 |
+
predictions = np.argmax(predictions, axis=2)
|
86 |
+
|
87 |
+
# Remove ignored index (special tokens)
|
88 |
+
true_predictions = []
|
89 |
+
true_labels = []
|
90 |
+
|
91 |
+
for prediction, label in zip(predictions, labels):
|
92 |
+
for p, l in zip(prediction, label):
|
93 |
+
if l != -100: # We have a valid label
|
94 |
+
true_predictions.append(p)
|
95 |
+
true_labels.append(l)
|
96 |
+
|
97 |
+
# Convert to numpy arrays
|
98 |
+
true_predictions = np.array(true_predictions)
|
99 |
+
true_labels = np.array(true_labels)
|
100 |
+
|
101 |
+
precision = precision_metric.compute(
|
102 |
+
predictions=true_predictions,
|
103 |
+
references=true_labels,
|
104 |
+
average="binary",
|
105 |
+
)["precision"]
|
106 |
+
recall = recall_metric.compute(
|
107 |
+
predictions=true_predictions,
|
108 |
+
references=true_labels,
|
109 |
+
average="binary",
|
110 |
+
)["recall"]
|
111 |
+
f1 = f1_metric.compute(
|
112 |
+
predictions=true_predictions,
|
113 |
+
references=true_labels,
|
114 |
+
average="binary",
|
115 |
+
)["f1"]
|
116 |
+
|
117 |
+
# Calculate F0.5 score
|
118 |
+
beta = 0.5
|
119 |
+
f05 = compute_f05_score(precision, recall, beta)
|
120 |
+
|
121 |
+
# Generate custom classification report
|
122 |
+
report = custom_classification_report(true_labels, true_predictions)
|
123 |
+
cm = confusion_matrix(true_labels, true_predictions)
|
124 |
+
print("Validation Report:\n" + report)
|
125 |
+
print("Confusion Matrix:\n" + str(cm))
|
126 |
+
|
127 |
+
return {
|
128 |
+
"precision": precision,
|
129 |
+
"recall": recall,
|
130 |
+
"f1": f1,
|
131 |
+
"f05": f05,
|
132 |
+
}
|
133 |
+
|
134 |
+
|
135 |
+
def unicode_normalize(text):
|
136 |
+
return unicodedata.normalize("NFKC", text)
|
137 |
+
|
138 |
+
|
139 |
+
def convert_spans_to_labels(text, spans, tokenizer):
|
140 |
+
# Tokenize text
|
141 |
+
tokens = tokenizer(text, truncation=True, return_offsets_mapping=True)
|
142 |
+
offset_mapping = tokens["offset_mapping"]
|
143 |
+
|
144 |
+
# Initialize labels (0 for non-noise, 1 for noise)
|
145 |
+
labels = [0] * len(offset_mapping)
|
146 |
+
|
147 |
+
# Mark special tokens with -100
|
148 |
+
labels = [
|
149 |
+
-100 if offset[0] == offset[1] == 0 else label
|
150 |
+
for label, offset in zip(labels, offset_mapping)
|
151 |
+
]
|
152 |
+
|
153 |
+
# Convert character spans to token labels
|
154 |
+
for start, end in spans:
|
155 |
+
for idx, (token_start, token_end) in enumerate(offset_mapping):
|
156 |
+
# Skip special tokens
|
157 |
+
if token_start == token_end == 0:
|
158 |
+
continue
|
159 |
+
# If token overlaps with noise span, mark as noise
|
160 |
+
if token_start < end and token_end > start:
|
161 |
+
labels[idx] = 1
|
162 |
+
|
163 |
+
return {
|
164 |
+
"labels": labels,
|
165 |
+
"input_ids": tokens["input_ids"],
|
166 |
+
"attention_mask": tokens["attention_mask"],
|
167 |
+
}
|
168 |
+
|
169 |
+
|
170 |
+
def main(args):
|
171 |
+
# Load dataset
|
172 |
+
dataset = load_dataset(args.dataset_name, split="train")
|
173 |
+
dataset = dataset.select_columns(["text", "noise_spans"])
|
174 |
+
# Split dataset
|
175 |
+
dataset = dataset.train_test_split(train_size=0.95, seed=42)
|
176 |
+
|
177 |
+
wikipedia_dataset_count = args.add_wikipedia_dataset_count
|
178 |
+
if wikipedia_dataset_count > 0:
|
179 |
+
wikipedia_dataset = load_dataset(
|
180 |
+
"hpprc/jawiki-paragraphs", "default", split="train"
|
181 |
+
)
|
182 |
+
# select columns
|
183 |
+
wikipedia_dataset = wikipedia_dataset.map(
|
184 |
+
lambda x: {
|
185 |
+
"text": unicode_normalize(x["text"]),
|
186 |
+
"noise_spans": [],
|
187 |
+
},
|
188 |
+
num_proc=15,
|
189 |
+
remove_columns=wikipedia_dataset.column_names,
|
190 |
+
)
|
191 |
+
# random rampling
|
192 |
+
target_indexes = np.random.choice(
|
193 |
+
len(wikipedia_dataset), wikipedia_dataset_count, replace=False
|
194 |
+
)
|
195 |
+
print(wikipedia_dataset)
|
196 |
+
wikipedia_dataset = wikipedia_dataset.select(target_indexes)
|
197 |
+
new_features = datasets.Features(
|
198 |
+
{
|
199 |
+
"text": datasets.Value("string"),
|
200 |
+
"noise_spans": datasets.Sequence(
|
201 |
+
datasets.Sequence(datasets.Value("int64"))
|
202 |
+
),
|
203 |
+
}
|
204 |
+
)
|
205 |
+
|
206 |
+
# データセットの特徴を変換
|
207 |
+
wikipedia_dataset = wikipedia_dataset.cast(new_features, num_proc=15)
|
208 |
+
print(f"Adding {len(wikipedia_dataset)} examples from the Wikipedia dataset")
|
209 |
+
print(f"original training examples: {len(dataset['train'])}")
|
210 |
+
dataset["train"] = datasets.concatenate_datasets(
|
211 |
+
[dataset["train"], wikipedia_dataset]
|
212 |
+
)
|
213 |
+
print(f"Total training examples: {len(dataset['train'])}")
|
214 |
+
|
215 |
+
# Initialize model and tokenizer
|
216 |
+
model = AutoModelForTokenClassification.from_pretrained(
|
217 |
+
args.base_model_name,
|
218 |
+
num_labels=2, # Binary classification: noise or not noise
|
219 |
+
classifier_dropout=0.1,
|
220 |
+
)
|
221 |
+
|
222 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
223 |
+
args.base_model_name,
|
224 |
+
model_max_length=min(model.config.max_position_embeddings, 512),
|
225 |
+
)
|
226 |
+
|
227 |
+
if not tokenizer.pad_token:
|
228 |
+
tokenizer.pad_token = tokenizer.eos_token
|
229 |
+
|
230 |
+
# Preprocess dataset
|
231 |
+
def preprocess(examples):
|
232 |
+
results = []
|
233 |
+
for text, spans in zip(examples["text"], examples["noise_spans"]):
|
234 |
+
result = convert_spans_to_labels(text, spans, tokenizer)
|
235 |
+
results.append(result)
|
236 |
+
|
237 |
+
return {
|
238 |
+
"input_ids": [r["input_ids"] for r in results],
|
239 |
+
"attention_mask": [r["attention_mask"] for r in results],
|
240 |
+
"labels": [r["labels"] for r in results],
|
241 |
+
}
|
242 |
+
|
243 |
+
tokenized_dataset = dataset.map(
|
244 |
+
preprocess,
|
245 |
+
batched=True,
|
246 |
+
remove_columns=dataset["train"].column_names,
|
247 |
+
num_proc=11,
|
248 |
+
)
|
249 |
+
|
250 |
+
# Data collator
|
251 |
+
data_collator = DataCollatorForTokenClassification(
|
252 |
+
tokenizer=tokenizer,
|
253 |
+
padding=True,
|
254 |
+
return_tensors="pt",
|
255 |
+
)
|
256 |
+
|
257 |
+
# Training arguments
|
258 |
+
training_args = TrainingArguments(
|
259 |
+
output_dir=args.checkpoint_dir,
|
260 |
+
evaluation_strategy="steps",
|
261 |
+
save_strategy="steps",
|
262 |
+
eval_steps=100,
|
263 |
+
save_steps=100,
|
264 |
+
logging_steps=10,
|
265 |
+
learning_rate=5e-5,
|
266 |
+
num_train_epochs=5,
|
267 |
+
optim="adafactor",
|
268 |
+
warmup_ratio=0.1,
|
269 |
+
lr_scheduler_type="cosine",
|
270 |
+
weight_decay=0.01,
|
271 |
+
max_grad_norm=1.0,
|
272 |
+
seed=42,
|
273 |
+
per_device_train_batch_size=256,
|
274 |
+
gradient_accumulation_steps=8,
|
275 |
+
per_device_eval_batch_size=256,
|
276 |
+
eval_on_start=True,
|
277 |
+
eval_accumulation_steps=1,
|
278 |
+
load_best_model_at_end=True,
|
279 |
+
metric_for_best_model="f05",
|
280 |
+
greater_is_better=True,
|
281 |
+
bf16=True,
|
282 |
+
)
|
283 |
+
|
284 |
+
# Initialize trainer
|
285 |
+
trainer = Trainer(
|
286 |
+
model=model,
|
287 |
+
args=training_args,
|
288 |
+
train_dataset=tokenized_dataset["train"],
|
289 |
+
eval_dataset=tokenized_dataset["test"],
|
290 |
+
tokenizer=tokenizer,
|
291 |
+
data_collator=data_collator,
|
292 |
+
compute_metrics=compute_metrics,
|
293 |
+
)
|
294 |
+
|
295 |
+
# Train
|
296 |
+
trainer.train()
|
297 |
+
trainer.save_model(os.path.join(args.checkpoint_dir, "final"))
|
298 |
+
|
299 |
+
# Evaluate
|
300 |
+
print("\nFinal Evaluation Results:")
|
301 |
+
final_metrics = trainer.evaluate()
|
302 |
+
print(final_metrics)
|
303 |
+
|
304 |
+
|
305 |
+
if __name__ == "__main__":
|
306 |
+
parser = argparse.ArgumentParser()
|
307 |
+
parser.add_argument(
|
308 |
+
"--base_model_name",
|
309 |
+
type=str,
|
310 |
+
default="hotchpotch/mMiniLMv2-L6-H384",
|
311 |
+
)
|
312 |
+
parser.add_argument(
|
313 |
+
"--dataset_name",
|
314 |
+
type=str,
|
315 |
+
default="hotchpotch/fineweb-2-japanese-noise-spans",
|
316 |
+
)
|
317 |
+
parser.add_argument(
|
318 |
+
"-w",
|
319 |
+
"--add-wikipedia-dataset-count",
|
320 |
+
type=int,
|
321 |
+
default=0,
|
322 |
+
help="Number of examples to add from the Wikipedia dataset",
|
323 |
+
)
|
324 |
+
parser.add_argument(
|
325 |
+
"--checkpoint_dir",
|
326 |
+
type=str,
|
327 |
+
default="./models/text-cleaner/",
|
328 |
+
)
|
329 |
+
|
330 |
+
args = parser.parse_args()
|
331 |
+
main(args)
|