|
import argparse |
|
import os |
|
import unicodedata |
|
|
|
import datasets |
|
import evaluate |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from datasets import load_dataset |
|
from sklearn.metrics import classification_report, confusion_matrix |
|
from transformers import ( |
|
AutoModelForTokenClassification, |
|
AutoTokenizer, |
|
DataCollatorForTokenClassification, |
|
Trainer, |
|
TrainingArguments, |
|
) |
|
|
|
|
|
def compute_f05_score(precision, recall, beta=0.5): |
|
"""Calculate F0.5 score from precision and recall.""" |
|
if precision <= 0 or recall <= 0: |
|
return 0.0 |
|
|
|
return (1 + beta**2) * (precision * recall) / ((beta**2 * precision) + recall) |
|
|
|
|
|
def custom_classification_report(y_true, y_pred): |
|
"""Generate classification report with F0.5 score.""" |
|
from sklearn.metrics import accuracy_score, precision_recall_fscore_support |
|
|
|
|
|
precision, recall, f1, support = precision_recall_fscore_support(y_true, y_pred) |
|
|
|
|
|
f05_scores = [compute_f05_score(p, r) for p, r in zip(precision, recall)] |
|
|
|
|
|
accuracy = accuracy_score(y_true, y_pred) |
|
|
|
|
|
report = " precision recall f1-score f05-score support\n\n" |
|
|
|
|
|
for i in range(len(precision)): |
|
report += f" {i}" |
|
report += f" {precision[i]:.2f} {recall[i]:.2f}" |
|
report += f" {f1[i]:.2f} {f05_scores[i]:.2f} {support[i]}\n" |
|
|
|
report += "\n" |
|
|
|
|
|
n_samples = sum(support) |
|
|
|
|
|
macro_precision = np.mean(precision) |
|
macro_recall = np.mean(recall) |
|
macro_f1 = np.mean(f1) |
|
macro_f05 = np.mean(f05_scores) |
|
report += f" macro avg {macro_precision:.2f} {macro_recall:.2f}" |
|
report += f" {macro_f1:.2f} {macro_f05:.2f} {n_samples}\n" |
|
|
|
|
|
weighted_precision = np.average(precision, weights=support) |
|
weighted_recall = np.average(recall, weights=support) |
|
weighted_f1 = np.average(f1, weights=support) |
|
weighted_f05 = np.average(f05_scores, weights=support) |
|
report += f"weighted avg {weighted_precision:.2f} {weighted_recall:.2f}" |
|
report += f" {weighted_f1:.2f} {weighted_f05:.2f} {n_samples}\n" |
|
|
|
|
|
report += f" accuracy {accuracy:.2f} {n_samples}\n" |
|
|
|
return report |
|
|
|
|
|
def compute_metrics(eval_pred): |
|
precision_metric = evaluate.load("precision") |
|
recall_metric = evaluate.load("recall") |
|
f1_metric = evaluate.load("f1") |
|
|
|
predictions, labels = eval_pred |
|
predictions = np.argmax(predictions, axis=2) |
|
|
|
|
|
true_predictions = [] |
|
true_labels = [] |
|
|
|
for prediction, label in zip(predictions, labels): |
|
for p, l in zip(prediction, label): |
|
if l != -100: |
|
true_predictions.append(p) |
|
true_labels.append(l) |
|
|
|
|
|
true_predictions = np.array(true_predictions) |
|
true_labels = np.array(true_labels) |
|
|
|
precision = precision_metric.compute( |
|
predictions=true_predictions, |
|
references=true_labels, |
|
average="binary", |
|
)["precision"] |
|
recall = recall_metric.compute( |
|
predictions=true_predictions, |
|
references=true_labels, |
|
average="binary", |
|
)["recall"] |
|
f1 = f1_metric.compute( |
|
predictions=true_predictions, |
|
references=true_labels, |
|
average="binary", |
|
)["f1"] |
|
|
|
|
|
beta = 0.5 |
|
f05 = compute_f05_score(precision, recall, beta) |
|
|
|
|
|
report = custom_classification_report(true_labels, true_predictions) |
|
cm = confusion_matrix(true_labels, true_predictions) |
|
print("Validation Report:\n" + report) |
|
print("Confusion Matrix:\n" + str(cm)) |
|
|
|
return { |
|
"precision": precision, |
|
"recall": recall, |
|
"f1": f1, |
|
"f05": f05, |
|
} |
|
|
|
|
|
def unicode_normalize(text): |
|
return unicodedata.normalize("NFKC", text) |
|
|
|
|
|
def convert_spans_to_labels(text, spans, tokenizer): |
|
|
|
tokens = tokenizer(text, truncation=True, return_offsets_mapping=True) |
|
offset_mapping = tokens["offset_mapping"] |
|
|
|
|
|
labels = [0] * len(offset_mapping) |
|
|
|
|
|
labels = [ |
|
-100 if offset[0] == offset[1] == 0 else label |
|
for label, offset in zip(labels, offset_mapping) |
|
] |
|
|
|
|
|
for start, end in spans: |
|
for idx, (token_start, token_end) in enumerate(offset_mapping): |
|
|
|
if token_start == token_end == 0: |
|
continue |
|
|
|
if token_start < end and token_end > start: |
|
labels[idx] = 1 |
|
|
|
return { |
|
"labels": labels, |
|
"input_ids": tokens["input_ids"], |
|
"attention_mask": tokens["attention_mask"], |
|
} |
|
|
|
|
|
def main(args): |
|
|
|
dataset = load_dataset(args.dataset_name, split="train") |
|
dataset = dataset.select_columns(["text", "noise_spans"]) |
|
|
|
dataset = dataset.train_test_split(train_size=0.95, seed=42) |
|
|
|
wikipedia_dataset_count = args.add_wikipedia_dataset_count |
|
if wikipedia_dataset_count > 0: |
|
wikipedia_dataset = load_dataset( |
|
"hpprc/jawiki-paragraphs", "default", split="train" |
|
) |
|
|
|
wikipedia_dataset = wikipedia_dataset.map( |
|
lambda x: { |
|
"text": unicode_normalize(x["text"]), |
|
"noise_spans": [], |
|
}, |
|
num_proc=15, |
|
remove_columns=wikipedia_dataset.column_names, |
|
) |
|
|
|
target_indexes = np.random.choice( |
|
len(wikipedia_dataset), wikipedia_dataset_count, replace=False |
|
) |
|
print(wikipedia_dataset) |
|
wikipedia_dataset = wikipedia_dataset.select(target_indexes) |
|
new_features = datasets.Features( |
|
{ |
|
"text": datasets.Value("string"), |
|
"noise_spans": datasets.Sequence( |
|
datasets.Sequence(datasets.Value("int64")) |
|
), |
|
} |
|
) |
|
|
|
|
|
wikipedia_dataset = wikipedia_dataset.cast(new_features, num_proc=15) |
|
print(f"Adding {len(wikipedia_dataset)} examples from the Wikipedia dataset") |
|
print(f"original training examples: {len(dataset['train'])}") |
|
dataset["train"] = datasets.concatenate_datasets( |
|
[dataset["train"], wikipedia_dataset] |
|
) |
|
print(f"Total training examples: {len(dataset['train'])}") |
|
|
|
|
|
model = AutoModelForTokenClassification.from_pretrained( |
|
args.base_model_name, |
|
num_labels=2, |
|
classifier_dropout=0.1, |
|
) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
args.base_model_name, |
|
model_max_length=min(model.config.max_position_embeddings, 512), |
|
) |
|
|
|
if not tokenizer.pad_token: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
def preprocess(examples): |
|
results = [] |
|
for text, spans in zip(examples["text"], examples["noise_spans"]): |
|
result = convert_spans_to_labels(text, spans, tokenizer) |
|
results.append(result) |
|
|
|
return { |
|
"input_ids": [r["input_ids"] for r in results], |
|
"attention_mask": [r["attention_mask"] for r in results], |
|
"labels": [r["labels"] for r in results], |
|
} |
|
|
|
tokenized_dataset = dataset.map( |
|
preprocess, |
|
batched=True, |
|
remove_columns=dataset["train"].column_names, |
|
num_proc=11, |
|
) |
|
|
|
|
|
data_collator = DataCollatorForTokenClassification( |
|
tokenizer=tokenizer, |
|
padding=True, |
|
return_tensors="pt", |
|
) |
|
|
|
|
|
training_args = TrainingArguments( |
|
output_dir=args.checkpoint_dir, |
|
evaluation_strategy="steps", |
|
save_strategy="steps", |
|
eval_steps=100, |
|
save_steps=100, |
|
logging_steps=10, |
|
learning_rate=5e-5, |
|
num_train_epochs=5, |
|
optim="adafactor", |
|
warmup_ratio=0.1, |
|
lr_scheduler_type="cosine", |
|
weight_decay=0.01, |
|
max_grad_norm=1.0, |
|
seed=42, |
|
per_device_train_batch_size=256, |
|
gradient_accumulation_steps=8, |
|
per_device_eval_batch_size=256, |
|
eval_on_start=True, |
|
eval_accumulation_steps=1, |
|
load_best_model_at_end=True, |
|
metric_for_best_model="f05", |
|
greater_is_better=True, |
|
bf16=True, |
|
) |
|
|
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=tokenized_dataset["train"], |
|
eval_dataset=tokenized_dataset["test"], |
|
tokenizer=tokenizer, |
|
data_collator=data_collator, |
|
compute_metrics=compute_metrics, |
|
) |
|
|
|
|
|
trainer.train() |
|
trainer.save_model(os.path.join(args.checkpoint_dir, "final")) |
|
|
|
|
|
print("\nFinal Evaluation Results:") |
|
final_metrics = trainer.evaluate() |
|
print(final_metrics) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--base_model_name", |
|
type=str, |
|
default="hotchpotch/mMiniLMv2-L6-H384", |
|
) |
|
parser.add_argument( |
|
"--dataset_name", |
|
type=str, |
|
default="hotchpotch/fineweb-2-japanese-noise-spans", |
|
) |
|
parser.add_argument( |
|
"-w", |
|
"--add-wikipedia-dataset-count", |
|
type=int, |
|
default=0, |
|
help="Number of examples to add from the Wikipedia dataset", |
|
) |
|
parser.add_argument( |
|
"--checkpoint_dir", |
|
type=str, |
|
default="./models/text-cleaner/", |
|
) |
|
|
|
args = parser.parse_args() |
|
main(args) |
|
|