import argparse |
import os |
import unicodedata |
import datasets |
import evaluate |
import numpy as np |
import torch |
import torch.nn as nn |
import torch.nn.functional as F |
from datasets import load_dataset |
from sklearn.metrics import classification_report, confusion_matrix |
from transformers import ( |
AutoModelForTokenClassification, |
AutoTokenizer, |
DataCollatorForTokenClassification, |
Trainer, |
TrainingArguments, |
) |
def compute_f05_score(precision, recall, beta=0.5): |
"""Calculate F0.5 score from precision and recall.""" |
if precision <= 0 or recall <= 0: |
return 0.0 |
return (1 + beta**2) * (precision * recall) / ((beta**2 * precision) + recall) |
def custom_classification_report(y_true, y_pred): |
"""Generate classification report with F0.5 score.""" |
from sklearn.metrics import accuracy_score, precision_recall_fscore_support |
precision, recall, f1, support = precision_recall_fscore_support(y_true, y_pred) |
f05_scores = [compute_f05_score(p, r) for p, r in zip(precision, recall)] |
accuracy = accuracy_score(y_true, y_pred) |
report = " precision recall f1-score f05-score support\n\n" |
for i in range(len(precision)): |
report += f" {i}" |
report += f" {precision[i]:.2f} {recall[i]:.2f}" |
report += f" {f1[i]:.2f} {f05_scores[i]:.2f} {support[i]}\n" |
report += "\n" |
n_samples = sum(support) |
macro_precision = np.mean(precision) |
macro_recall = np.mean(recall) |
macro_f1 = np.mean(f1) |
macro_f05 = np.mean(f05_scores) |
report += f" macro avg {macro_precision:.2f} {macro_recall:.2f}" |
report += f" {macro_f1:.2f} {macro_f05:.2f} {n_samples}\n" |
weighted_precision = np.average(precision, weights=support) |
weighted_recall = np.average(recall, weights=support) |
weighted_f1 = np.average(f1, weights=support) |
weighted_f05 = np.average(f05_scores, weights=support) |
report += f"weighted avg {weighted_precision:.2f} {weighted_recall:.2f}" |
report += f" {weighted_f1:.2f} {weighted_f05:.2f} {n_samples}\n" |
report += f" accuracy {accuracy:.2f} {n_samples}\n" |
return report |
def compute_metrics(eval_pred): |
precision_metric = evaluate.load("precision") |
recall_metric = evaluate.load("recall") |
f1_metric = evaluate.load("f1") |
predictions, labels = eval_pred |
predictions = np.argmax(predictions, axis=2) |
true_predictions = [] |
true_labels = [] |
for prediction, label in zip(predictions, labels): |
for p, l in zip(prediction, label): |
if l != -100: |
true_predictions.append(p) |
true_labels.append(l) |
true_predictions = np.array(true_predictions) |
true_labels = np.array(true_labels) |
precision = precision_metric.compute( |
predictions=true_predictions, |
references=true_labels, |
average="binary", |
)["precision"] |
recall = recall_metric.compute( |
predictions=true_predictions, |
references=true_labels, |
average="binary", |
)["recall"] |
f1 = f1_metric.compute( |
predictions=true_predictions, |
references=true_labels, |
average="binary", |
)["f1"] |
beta = 0.5 |
f05 = compute_f05_score(precision, recall, beta) |
report = custom_classification_report(true_labels, true_predictions) |
cm = confusion_matrix(true_labels, true_predictions) |
print("Validation Report:\n" + report) |
print("Confusion Matrix:\n" + str(cm)) |
return { |
"precision": precision, |
"recall": recall, |
"f1": f1, |
"f05": f05, |
} |
def unicode_normalize(text): |
return unicodedata.normalize("NFKC", text) |
def convert_spans_to_labels(text, spans, tokenizer): |
tokens = tokenizer(text, truncation=True, return_offsets_mapping=True) |
offset_mapping = tokens["offset_mapping"] |
labels = [0] * len(offset_mapping) |
labels = [ |
-100 if offset[0] == offset[1] == 0 else label |
for label, offset in zip(labels, offset_mapping) |
] |
for start, end in spans: |
for idx, (token_start, token_end) in enumerate(offset_mapping): |
if token_start == token_end == 0: |
continue |
if token_start < end and token_end > start: |
labels[idx] = 1 |
return { |
"labels": labels, |
"input_ids": tokens["input_ids"], |
"attention_mask": tokens["attention_mask"], |
} |
def main(args): |
dataset = load_dataset(args.dataset_name, split="train") |
dataset = dataset.select_columns(["text", "noise_spans"]) |
dataset = dataset.train_test_split(train_size=0.95, seed=42) |
wikipedia_dataset_count = args.add_wikipedia_dataset_count |
if wikipedia_dataset_count > 0: |
wikipedia_dataset = load_dataset( |
"hpprc/jawiki-paragraphs", "default", split="train" |
) |
wikipedia_dataset = wikipedia_dataset.map( |
lambda x: { |
"text": unicode_normalize(x["text"]), |
"noise_spans": [], |
}, |
num_proc=15, |
remove_columns=wikipedia_dataset.column_names, |
) |
target_indexes = np.random.choice( |
len(wikipedia_dataset), wikipedia_dataset_count, replace=False |
) |
print(wikipedia_dataset) |
wikipedia_dataset = wikipedia_dataset.select(target_indexes) |
new_features = datasets.Features( |
{ |
"text": datasets.Value("string"), |
"noise_spans": datasets.Sequence( |
datasets.Sequence(datasets.Value("int64")) |
), |
} |
) |
wikipedia_dataset = wikipedia_dataset.cast(new_features, num_proc=15) |
print(f"Adding {len(wikipedia_dataset)} examples from the Wikipedia dataset") |
print(f"original training examples: {len(dataset['train'])}") |
dataset["train"] = datasets.concatenate_datasets( |
[dataset["train"], wikipedia_dataset] |
) |
print(f"Total training examples: {len(dataset['train'])}") |
model = AutoModelForTokenClassification.from_pretrained( |
args.base_model_name, |
num_labels=2, |
classifier_dropout=0.1, |
) |
tokenizer = AutoTokenizer.from_pretrained( |
args.base_model_name, |
model_max_length=min(model.config.max_position_embeddings, 512), |
) |
if not tokenizer.pad_token: |
tokenizer.pad_token = tokenizer.eos_token |
def preprocess(examples): |
results = [] |
for text, spans in zip(examples["text"], examples["noise_spans"]): |
result = convert_spans_to_labels(text, spans, tokenizer) |
results.append(result) |
return { |
"input_ids": [r["input_ids"] for r in results], |
"attention_mask": [r["attention_mask"] for r in results], |
"labels": [r["labels"] for r in results], |
} |
tokenized_dataset = dataset.map( |
preprocess, |
batched=True, |
remove_columns=dataset["train"].column_names, |
num_proc=11, |
) |
data_collator = DataCollatorForTokenClassification( |
tokenizer=tokenizer, |
padding=True, |
return_tensors="pt", |
) |
training_args = TrainingArguments( |
output_dir=args.checkpoint_dir, |
evaluation_strategy="steps", |
save_strategy="steps", |
eval_steps=100, |
save_steps=100, |
logging_steps=10, |
learning_rate=5e-5, |
num_train_epochs=5, |
optim="adafactor", |
warmup_ratio=0.1, |
lr_scheduler_type="cosine", |
weight_decay=0.01, |
max_grad_norm=1.0, |
seed=42, |
per_device_train_batch_size=256, |
gradient_accumulation_steps=8, |
per_device_eval_batch_size=256, |
eval_on_start=True, |
eval_accumulation_steps=1, |
load_best_model_at_end=True, |
metric_for_best_model="f05", |
greater_is_better=True, |
bf16=True, |
) |
trainer = Trainer( |
model=model, |
args=training_args, |
train_dataset=tokenized_dataset["train"], |
eval_dataset=tokenized_dataset["test"], |
tokenizer=tokenizer, |
data_collator=data_collator, |
compute_metrics=compute_metrics, |
) |
trainer.train() |
trainer.save_model(os.path.join(args.checkpoint_dir, "final")) |
print("\nFinal Evaluation Results:") |
final_metrics = trainer.evaluate() |
print(final_metrics) |
if __name__ == "__main__": |
parser = argparse.ArgumentParser() |
parser.add_argument( |
"--base_model_name", |
type=str, |
default="hotchpotch/mMiniLMv2-L6-H384", |
) |
parser.add_argument( |
"--dataset_name", |
type=str, |
default="hotchpotch/fineweb-2-japanese-noise-spans", |
) |
parser.add_argument( |
"-w", |
"--add-wikipedia-dataset-count", |
type=int, |
default=0, |
help="Number of examples to add from the Wikipedia dataset", |
) |
parser.add_argument( |
"--checkpoint_dir", |
type=str, |
default="./models/text-cleaner/", |
) |
args = parser.parse_args() |
main(args) |