Safetensors
Japanese
xlm-roberta
hotchpotch commited on
Commit
92c0372
·
verified ·
1 Parent(s): 5269525

Upload 2 files

Browse files
scripts/noise_detecter.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unicodedata
2
+ from typing import List, Tuple
3
+
4
+ import torch
5
+ from transformers import AutoModelForTokenClassification, AutoTokenizer
6
+
7
+
8
+ class NoiseDetector:
9
+ def __init__(self, model_path: str):
10
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ self.model = AutoModelForTokenClassification.from_pretrained(model_path).to(
12
+ self.device
13
+ )
14
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
15
+ self.model.eval()
16
+
17
+ def _normalize_text(self, text: str) -> str:
18
+ return unicodedata.normalize("NFKC", text)
19
+
20
+ def _convert_token_spans_to_char_spans(
21
+ self,
22
+ text: str,
23
+ noise_token_indices: List[int],
24
+ offset_mapping: List[Tuple[int, int]],
25
+ ) -> List[Tuple[int, int]]:
26
+ char_spans = []
27
+ current_span = None
28
+
29
+ for idx, (is_noise, (start, end)) in enumerate(
30
+ zip(noise_token_indices, offset_mapping)
31
+ ):
32
+ # Skip special tokens (CLS, SEP, etc.)
33
+ if start == end == 0:
34
+ continue
35
+
36
+ if is_noise:
37
+ if current_span is None:
38
+ current_span = [start, end]
39
+ else:
40
+ current_span[1] = end
41
+ elif current_span is not None:
42
+ char_spans.append(tuple(current_span))
43
+ current_span = None
44
+
45
+ # Don't forget to add the last span if it exists
46
+ if current_span is not None:
47
+ char_spans.append(tuple(current_span))
48
+
49
+ return char_spans
50
+
51
+ def detect(
52
+ self, texts: List[str], threshold: float = 0.5
53
+ ) -> List[List[Tuple[int, int]]]:
54
+ """
55
+ Detect noise spans in the given texts.
56
+
57
+ Args:
58
+ texts: List of input texts
59
+ threshold: Confidence threshold for noise detection (default: 0.5)
60
+
61
+ Returns:
62
+ List of lists containing (start, end) character positions of detected noise spans for each text
63
+ """
64
+ results = []
65
+
66
+ with torch.no_grad():
67
+ for text in texts:
68
+ # Normalize text
69
+ normalized_text = self._normalize_text(text)
70
+
71
+ # Tokenize
72
+ tokens = self.tokenizer(
73
+ normalized_text,
74
+ truncation=True,
75
+ return_offsets_mapping=True,
76
+ return_tensors="pt",
77
+ )
78
+
79
+ # Move to device
80
+ input_ids = tokens["input_ids"].to(self.device)
81
+ attention_mask = tokens["attention_mask"].to(self.device)
82
+
83
+ # Get predictions
84
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
85
+ logits = outputs.logits
86
+
87
+ # Convert logits to probabilities
88
+ probs = torch.softmax(logits, dim=-1)
89
+
90
+ # Get noise predictions (class 1)
91
+ noise_probs = probs[0, :, 1].cpu().numpy()
92
+ noise_predictions = (noise_probs > threshold).astype(int)
93
+
94
+ # Convert token-level predictions to character spans
95
+ char_spans = self._convert_token_spans_to_char_spans(
96
+ normalized_text,
97
+ noise_predictions,
98
+ tokens["offset_mapping"][0].tolist(),
99
+ )
100
+
101
+ results.append(char_spans)
102
+
103
+ return results
104
+
105
+ def detect_and_highlight(
106
+ self, texts: List[str], threshold: float = 0.5
107
+ ) -> List[str]:
108
+ """
109
+ Detect noise spans and return texts with noise sections highlighted.
110
+
111
+ Args:
112
+ texts: List of input texts
113
+ threshold: Confidence threshold for noise detection (default: 0.5)
114
+
115
+ Returns:
116
+ List of texts with noise sections wrapped in [NOISE]...[/NOISE] tags
117
+ """
118
+ noise_spans = self.detect(texts, threshold)
119
+ highlighted_texts = []
120
+
121
+ for text, spans in zip(texts, noise_spans):
122
+ if not spans:
123
+ highlighted_texts.append(text)
124
+ continue
125
+
126
+ # Sort spans by start position
127
+ spans = sorted(spans)
128
+
129
+ # Build highlighted text
130
+ result = []
131
+ last_end = 0
132
+
133
+ for start, end in spans:
134
+ # Add text before noise
135
+ result.append(text[last_end:start])
136
+ # Add highlighted noise
137
+ # もし長さがN以下なら、ハイライトしない
138
+ if end - start > 3:
139
+ result.append(f"[NOISE]{text[start:end]}[/NOISE]")
140
+ else:
141
+ result.append(text[start:end])
142
+ # result.append(f"[NOISE]{text[start:end]}[/NOISE]")
143
+ last_end = end
144
+
145
+ # Add remaining text
146
+ result.append(text[last_end:])
147
+
148
+ highlighted_texts.append("".join(result))
149
+
150
+ return highlighted_texts
151
+
152
+
153
+ def main():
154
+ model_path = "hotchpotch/fineweb-2-japanese-text-cleaner"
155
+ detector = NoiseDetector(model_path)
156
+
157
+ NOISE_TEXT = """
158
+ この文章は90日以上更新の無いサイトに表示されています。
159
+ ログイン ログアウト
160
+
161
+ 本当に必要な文章以外にも、さまざまなノイズが含まれていることがあります。例えば、この文章もその一例です。本来不要なテキストが入ってしまうことがこのようにあるでしょう。
162
+
163
+ 今なら50%オフ!クリックしてリンク先の商品を表示
164
+
165
+ とりわけ文章長が短い場合、文章のほとんどがノイズを含む可能性があります。それらを取り除くことで、より高品質の文章を抽出できないかと考えています。
166
+
167
+ 前のページ 次のページ
168
+ """.strip()
169
+
170
+ texts = [
171
+ NOISE_TEXT,
172
+ "これは正常なテキストです。しかし、ここに🤣絵文字があります。そして普通の文章が続きます。",
173
+ "普通の文章です。ASCII ART(^_^)があります。最後も普通です。",
174
+ "ログイン 文章の密ベクトルは、情報検索・文章判別・類似文章抽出など、さまざまな用途に使うことができます。しかしながら最先端のTransformerモデルは小さいモデルでも、とりわけCPU環境では処理速度が遅いため実用でないこともしばしばあります。この課題を解決する新しいアプローチとして、先日公開されたTransformerモデル「ではない」 StaticEmbeddingモデルは、例えば intfloat/multilingual-e5-small (以下mE5-small)とのベンチマーク比較では85%のスコアという最低十分な性能で、何よりCPUで動作時に126倍高速に文ベクトルを作成することができる、という驚きの速度です。 記事の一覧 >",
175
+ ]
176
+
177
+ highlighted_texts = detector.detect_and_highlight(texts, threshold=0.7)
178
+ for text in highlighted_texts:
179
+ print(f"\n{text}")
180
+
181
+
182
+ if __name__ == "__main__":
183
+ main()
scripts/trainer-fineweb-2-japanese-text-cleaner.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import unicodedata
4
+
5
+ import datasets
6
+ import evaluate
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from datasets import load_dataset
12
+ from sklearn.metrics import classification_report, confusion_matrix
13
+ from transformers import (
14
+ AutoModelForTokenClassification,
15
+ AutoTokenizer,
16
+ DataCollatorForTokenClassification,
17
+ Trainer,
18
+ TrainingArguments,
19
+ )
20
+
21
+
22
+ def compute_f05_score(precision, recall, beta=0.5):
23
+ """Calculate F0.5 score from precision and recall."""
24
+ if precision <= 0 or recall <= 0:
25
+ return 0.0
26
+
27
+ return (1 + beta**2) * (precision * recall) / ((beta**2 * precision) + recall)
28
+
29
+
30
+ def custom_classification_report(y_true, y_pred):
31
+ """Generate classification report with F0.5 score."""
32
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support
33
+
34
+ # Calculate precision, recall, f1, and support for each class
35
+ precision, recall, f1, support = precision_recall_fscore_support(y_true, y_pred)
36
+
37
+ # Calculate F0.5 scores
38
+ f05_scores = [compute_f05_score(p, r) for p, r in zip(precision, recall)]
39
+
40
+ # Calculate accuracy
41
+ accuracy = accuracy_score(y_true, y_pred)
42
+
43
+ # Generate report string
44
+ report = " precision recall f1-score f05-score support\n\n"
45
+
46
+ # Add metrics for each class
47
+ for i in range(len(precision)):
48
+ report += f" {i}"
49
+ report += f" {precision[i]:.2f} {recall[i]:.2f}"
50
+ report += f" {f1[i]:.2f} {f05_scores[i]:.2f} {support[i]}\n"
51
+
52
+ report += "\n"
53
+
54
+ # Calculate and add averages
55
+ n_samples = sum(support)
56
+
57
+ # Macro average
58
+ macro_precision = np.mean(precision)
59
+ macro_recall = np.mean(recall)
60
+ macro_f1 = np.mean(f1)
61
+ macro_f05 = np.mean(f05_scores)
62
+ report += f" macro avg {macro_precision:.2f} {macro_recall:.2f}"
63
+ report += f" {macro_f1:.2f} {macro_f05:.2f} {n_samples}\n"
64
+
65
+ # Weighted average
66
+ weighted_precision = np.average(precision, weights=support)
67
+ weighted_recall = np.average(recall, weights=support)
68
+ weighted_f1 = np.average(f1, weights=support)
69
+ weighted_f05 = np.average(f05_scores, weights=support)
70
+ report += f"weighted avg {weighted_precision:.2f} {weighted_recall:.2f}"
71
+ report += f" {weighted_f1:.2f} {weighted_f05:.2f} {n_samples}\n"
72
+
73
+ # Add accuracy
74
+ report += f" accuracy {accuracy:.2f} {n_samples}\n"
75
+
76
+ return report
77
+
78
+
79
+ def compute_metrics(eval_pred):
80
+ precision_metric = evaluate.load("precision")
81
+ recall_metric = evaluate.load("recall")
82
+ f1_metric = evaluate.load("f1")
83
+
84
+ predictions, labels = eval_pred
85
+ predictions = np.argmax(predictions, axis=2)
86
+
87
+ # Remove ignored index (special tokens)
88
+ true_predictions = []
89
+ true_labels = []
90
+
91
+ for prediction, label in zip(predictions, labels):
92
+ for p, l in zip(prediction, label):
93
+ if l != -100: # We have a valid label
94
+ true_predictions.append(p)
95
+ true_labels.append(l)
96
+
97
+ # Convert to numpy arrays
98
+ true_predictions = np.array(true_predictions)
99
+ true_labels = np.array(true_labels)
100
+
101
+ precision = precision_metric.compute(
102
+ predictions=true_predictions,
103
+ references=true_labels,
104
+ average="binary",
105
+ )["precision"]
106
+ recall = recall_metric.compute(
107
+ predictions=true_predictions,
108
+ references=true_labels,
109
+ average="binary",
110
+ )["recall"]
111
+ f1 = f1_metric.compute(
112
+ predictions=true_predictions,
113
+ references=true_labels,
114
+ average="binary",
115
+ )["f1"]
116
+
117
+ # Calculate F0.5 score
118
+ beta = 0.5
119
+ f05 = compute_f05_score(precision, recall, beta)
120
+
121
+ # Generate custom classification report
122
+ report = custom_classification_report(true_labels, true_predictions)
123
+ cm = confusion_matrix(true_labels, true_predictions)
124
+ print("Validation Report:\n" + report)
125
+ print("Confusion Matrix:\n" + str(cm))
126
+
127
+ return {
128
+ "precision": precision,
129
+ "recall": recall,
130
+ "f1": f1,
131
+ "f05": f05,
132
+ }
133
+
134
+
135
+ def unicode_normalize(text):
136
+ return unicodedata.normalize("NFKC", text)
137
+
138
+
139
+ def convert_spans_to_labels(text, spans, tokenizer):
140
+ # Tokenize text
141
+ tokens = tokenizer(text, truncation=True, return_offsets_mapping=True)
142
+ offset_mapping = tokens["offset_mapping"]
143
+
144
+ # Initialize labels (0 for non-noise, 1 for noise)
145
+ labels = [0] * len(offset_mapping)
146
+
147
+ # Mark special tokens with -100
148
+ labels = [
149
+ -100 if offset[0] == offset[1] == 0 else label
150
+ for label, offset in zip(labels, offset_mapping)
151
+ ]
152
+
153
+ # Convert character spans to token labels
154
+ for start, end in spans:
155
+ for idx, (token_start, token_end) in enumerate(offset_mapping):
156
+ # Skip special tokens
157
+ if token_start == token_end == 0:
158
+ continue
159
+ # If token overlaps with noise span, mark as noise
160
+ if token_start < end and token_end > start:
161
+ labels[idx] = 1
162
+
163
+ return {
164
+ "labels": labels,
165
+ "input_ids": tokens["input_ids"],
166
+ "attention_mask": tokens["attention_mask"],
167
+ }
168
+
169
+
170
+ def main(args):
171
+ # Load dataset
172
+ dataset = load_dataset(args.dataset_name, split="train")
173
+ dataset = dataset.select_columns(["text", "noise_spans"])
174
+ # Split dataset
175
+ dataset = dataset.train_test_split(train_size=0.95, seed=42)
176
+
177
+ wikipedia_dataset_count = args.add_wikipedia_dataset_count
178
+ if wikipedia_dataset_count > 0:
179
+ wikipedia_dataset = load_dataset(
180
+ "hpprc/jawiki-paragraphs", "default", split="train"
181
+ )
182
+ # select columns
183
+ wikipedia_dataset = wikipedia_dataset.map(
184
+ lambda x: {
185
+ "text": unicode_normalize(x["text"]),
186
+ "noise_spans": [],
187
+ },
188
+ num_proc=15,
189
+ remove_columns=wikipedia_dataset.column_names,
190
+ )
191
+ # random rampling
192
+ target_indexes = np.random.choice(
193
+ len(wikipedia_dataset), wikipedia_dataset_count, replace=False
194
+ )
195
+ print(wikipedia_dataset)
196
+ wikipedia_dataset = wikipedia_dataset.select(target_indexes)
197
+ new_features = datasets.Features(
198
+ {
199
+ "text": datasets.Value("string"),
200
+ "noise_spans": datasets.Sequence(
201
+ datasets.Sequence(datasets.Value("int64"))
202
+ ),
203
+ }
204
+ )
205
+
206
+ # データセットの特徴を変換
207
+ wikipedia_dataset = wikipedia_dataset.cast(new_features, num_proc=15)
208
+ print(f"Adding {len(wikipedia_dataset)} examples from the Wikipedia dataset")
209
+ print(f"original training examples: {len(dataset['train'])}")
210
+ dataset["train"] = datasets.concatenate_datasets(
211
+ [dataset["train"], wikipedia_dataset]
212
+ )
213
+ print(f"Total training examples: {len(dataset['train'])}")
214
+
215
+ # Initialize model and tokenizer
216
+ model = AutoModelForTokenClassification.from_pretrained(
217
+ args.base_model_name,
218
+ num_labels=2, # Binary classification: noise or not noise
219
+ classifier_dropout=0.1,
220
+ )
221
+
222
+ tokenizer = AutoTokenizer.from_pretrained(
223
+ args.base_model_name,
224
+ model_max_length=min(model.config.max_position_embeddings, 512),
225
+ )
226
+
227
+ if not tokenizer.pad_token:
228
+ tokenizer.pad_token = tokenizer.eos_token
229
+
230
+ # Preprocess dataset
231
+ def preprocess(examples):
232
+ results = []
233
+ for text, spans in zip(examples["text"], examples["noise_spans"]):
234
+ result = convert_spans_to_labels(text, spans, tokenizer)
235
+ results.append(result)
236
+
237
+ return {
238
+ "input_ids": [r["input_ids"] for r in results],
239
+ "attention_mask": [r["attention_mask"] for r in results],
240
+ "labels": [r["labels"] for r in results],
241
+ }
242
+
243
+ tokenized_dataset = dataset.map(
244
+ preprocess,
245
+ batched=True,
246
+ remove_columns=dataset["train"].column_names,
247
+ num_proc=11,
248
+ )
249
+
250
+ # Data collator
251
+ data_collator = DataCollatorForTokenClassification(
252
+ tokenizer=tokenizer,
253
+ padding=True,
254
+ return_tensors="pt",
255
+ )
256
+
257
+ # Training arguments
258
+ training_args = TrainingArguments(
259
+ output_dir=args.checkpoint_dir,
260
+ evaluation_strategy="steps",
261
+ save_strategy="steps",
262
+ eval_steps=100,
263
+ save_steps=100,
264
+ logging_steps=10,
265
+ learning_rate=5e-5,
266
+ num_train_epochs=5,
267
+ optim="adafactor",
268
+ warmup_ratio=0.1,
269
+ lr_scheduler_type="cosine",
270
+ weight_decay=0.01,
271
+ max_grad_norm=1.0,
272
+ seed=42,
273
+ per_device_train_batch_size=256,
274
+ gradient_accumulation_steps=8,
275
+ per_device_eval_batch_size=256,
276
+ eval_on_start=True,
277
+ eval_accumulation_steps=1,
278
+ load_best_model_at_end=True,
279
+ metric_for_best_model="f05",
280
+ greater_is_better=True,
281
+ bf16=True,
282
+ )
283
+
284
+ # Initialize trainer
285
+ trainer = Trainer(
286
+ model=model,
287
+ args=training_args,
288
+ train_dataset=tokenized_dataset["train"],
289
+ eval_dataset=tokenized_dataset["test"],
290
+ tokenizer=tokenizer,
291
+ data_collator=data_collator,
292
+ compute_metrics=compute_metrics,
293
+ )
294
+
295
+ # Train
296
+ trainer.train()
297
+ trainer.save_model(os.path.join(args.checkpoint_dir, "final"))
298
+
299
+ # Evaluate
300
+ print("\nFinal Evaluation Results:")
301
+ final_metrics = trainer.evaluate()
302
+ print(final_metrics)
303
+
304
+
305
+ if __name__ == "__main__":
306
+ parser = argparse.ArgumentParser()
307
+ parser.add_argument(
308
+ "--base_model_name",
309
+ type=str,
310
+ default="hotchpotch/mMiniLMv2-L6-H384",
311
+ )
312
+ parser.add_argument(
313
+ "--dataset_name",
314
+ type=str,
315
+ default="hotchpotch/fineweb-2-japanese-noise-spans",
316
+ )
317
+ parser.add_argument(
318
+ "-w",
319
+ "--add-wikipedia-dataset-count",
320
+ type=int,
321
+ default=0,
322
+ help="Number of examples to add from the Wikipedia dataset",
323
+ )
324
+ parser.add_argument(
325
+ "--checkpoint_dir",
326
+ type=str,
327
+ default="./models/text-cleaner/",
328
+ )
329
+
330
+ args = parser.parse_args()
331
+ main(args)