Safetensors
Japanese
xlm-roberta
fineweb-2-japanese-text-cleaner / scripts /trainer-fineweb-2-japanese-text-cleaner.py
hotchpotch's picture
Upload 2 files
92c0372 verified
import argparse
import os
import unicodedata
import datasets
import evaluate
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from sklearn.metrics import classification_report, confusion_matrix
from transformers import (
AutoModelForTokenClassification,
AutoTokenizer,
DataCollatorForTokenClassification,
Trainer,
TrainingArguments,
)
def compute_f05_score(precision, recall, beta=0.5):
"""Calculate F0.5 score from precision and recall."""
if precision <= 0 or recall <= 0:
return 0.0
return (1 + beta**2) * (precision * recall) / ((beta**2 * precision) + recall)
def custom_classification_report(y_true, y_pred):
"""Generate classification report with F0.5 score."""
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
# Calculate precision, recall, f1, and support for each class
precision, recall, f1, support = precision_recall_fscore_support(y_true, y_pred)
# Calculate F0.5 scores
f05_scores = [compute_f05_score(p, r) for p, r in zip(precision, recall)]
# Calculate accuracy
accuracy = accuracy_score(y_true, y_pred)
# Generate report string
report = " precision recall f1-score f05-score support\n\n"
# Add metrics for each class
for i in range(len(precision)):
report += f" {i}"
report += f" {precision[i]:.2f} {recall[i]:.2f}"
report += f" {f1[i]:.2f} {f05_scores[i]:.2f} {support[i]}\n"
report += "\n"
# Calculate and add averages
n_samples = sum(support)
# Macro average
macro_precision = np.mean(precision)
macro_recall = np.mean(recall)
macro_f1 = np.mean(f1)
macro_f05 = np.mean(f05_scores)
report += f" macro avg {macro_precision:.2f} {macro_recall:.2f}"
report += f" {macro_f1:.2f} {macro_f05:.2f} {n_samples}\n"
# Weighted average
weighted_precision = np.average(precision, weights=support)
weighted_recall = np.average(recall, weights=support)
weighted_f1 = np.average(f1, weights=support)
weighted_f05 = np.average(f05_scores, weights=support)
report += f"weighted avg {weighted_precision:.2f} {weighted_recall:.2f}"
report += f" {weighted_f1:.2f} {weighted_f05:.2f} {n_samples}\n"
# Add accuracy
report += f" accuracy {accuracy:.2f} {n_samples}\n"
return report
def compute_metrics(eval_pred):
precision_metric = evaluate.load("precision")
recall_metric = evaluate.load("recall")
f1_metric = evaluate.load("f1")
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=2)
# Remove ignored index (special tokens)
true_predictions = []
true_labels = []
for prediction, label in zip(predictions, labels):
for p, l in zip(prediction, label):
if l != -100: # We have a valid label
true_predictions.append(p)
true_labels.append(l)
# Convert to numpy arrays
true_predictions = np.array(true_predictions)
true_labels = np.array(true_labels)
precision = precision_metric.compute(
predictions=true_predictions,
references=true_labels,
average="binary",
)["precision"]
recall = recall_metric.compute(
predictions=true_predictions,
references=true_labels,
average="binary",
)["recall"]
f1 = f1_metric.compute(
predictions=true_predictions,
references=true_labels,
average="binary",
)["f1"]
# Calculate F0.5 score
beta = 0.5
f05 = compute_f05_score(precision, recall, beta)
# Generate custom classification report
report = custom_classification_report(true_labels, true_predictions)
cm = confusion_matrix(true_labels, true_predictions)
print("Validation Report:\n" + report)
print("Confusion Matrix:\n" + str(cm))
return {
"precision": precision,
"recall": recall,
"f1": f1,
"f05": f05,
}
def unicode_normalize(text):
return unicodedata.normalize("NFKC", text)
def convert_spans_to_labels(text, spans, tokenizer):
# Tokenize text
tokens = tokenizer(text, truncation=True, return_offsets_mapping=True)
offset_mapping = tokens["offset_mapping"]
# Initialize labels (0 for non-noise, 1 for noise)
labels = [0] * len(offset_mapping)
# Mark special tokens with -100
labels = [
-100 if offset[0] == offset[1] == 0 else label
for label, offset in zip(labels, offset_mapping)
]
# Convert character spans to token labels
for start, end in spans:
for idx, (token_start, token_end) in enumerate(offset_mapping):
# Skip special tokens
if token_start == token_end == 0:
continue
# If token overlaps with noise span, mark as noise
if token_start < end and token_end > start:
labels[idx] = 1
return {
"labels": labels,
"input_ids": tokens["input_ids"],
"attention_mask": tokens["attention_mask"],
}
def main(args):
# Load dataset
dataset = load_dataset(args.dataset_name, split="train")
dataset = dataset.select_columns(["text", "noise_spans"])
# Split dataset
dataset = dataset.train_test_split(train_size=0.95, seed=42)
wikipedia_dataset_count = args.add_wikipedia_dataset_count
if wikipedia_dataset_count > 0:
wikipedia_dataset = load_dataset(
"hpprc/jawiki-paragraphs", "default", split="train"
)
# select columns
wikipedia_dataset = wikipedia_dataset.map(
lambda x: {
"text": unicode_normalize(x["text"]),
"noise_spans": [],
},
num_proc=15,
remove_columns=wikipedia_dataset.column_names,
)
# random rampling
target_indexes = np.random.choice(
len(wikipedia_dataset), wikipedia_dataset_count, replace=False
)
print(wikipedia_dataset)
wikipedia_dataset = wikipedia_dataset.select(target_indexes)
new_features = datasets.Features(
{
"text": datasets.Value("string"),
"noise_spans": datasets.Sequence(
datasets.Sequence(datasets.Value("int64"))
),
}
)
# データセットの特徴を変換
wikipedia_dataset = wikipedia_dataset.cast(new_features, num_proc=15)
print(f"Adding {len(wikipedia_dataset)} examples from the Wikipedia dataset")
print(f"original training examples: {len(dataset['train'])}")
dataset["train"] = datasets.concatenate_datasets(
[dataset["train"], wikipedia_dataset]
)
print(f"Total training examples: {len(dataset['train'])}")
# Initialize model and tokenizer
model = AutoModelForTokenClassification.from_pretrained(
args.base_model_name,
num_labels=2, # Binary classification: noise or not noise
classifier_dropout=0.1,
)
tokenizer = AutoTokenizer.from_pretrained(
args.base_model_name,
model_max_length=min(model.config.max_position_embeddings, 512),
)
if not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.eos_token
# Preprocess dataset
def preprocess(examples):
results = []
for text, spans in zip(examples["text"], examples["noise_spans"]):
result = convert_spans_to_labels(text, spans, tokenizer)
results.append(result)
return {
"input_ids": [r["input_ids"] for r in results],
"attention_mask": [r["attention_mask"] for r in results],
"labels": [r["labels"] for r in results],
}
tokenized_dataset = dataset.map(
preprocess,
batched=True,
remove_columns=dataset["train"].column_names,
num_proc=11,
)
# Data collator
data_collator = DataCollatorForTokenClassification(
tokenizer=tokenizer,
padding=True,
return_tensors="pt",
)
# Training arguments
training_args = TrainingArguments(
output_dir=args.checkpoint_dir,
evaluation_strategy="steps",
save_strategy="steps",
eval_steps=100,
save_steps=100,
logging_steps=10,
learning_rate=5e-5,
num_train_epochs=5,
optim="adafactor",
warmup_ratio=0.1,
lr_scheduler_type="cosine",
weight_decay=0.01,
max_grad_norm=1.0,
seed=42,
per_device_train_batch_size=256,
gradient_accumulation_steps=8,
per_device_eval_batch_size=256,
eval_on_start=True,
eval_accumulation_steps=1,
load_best_model_at_end=True,
metric_for_best_model="f05",
greater_is_better=True,
bf16=True,
)
# Initialize trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["test"],
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
)
# Train
trainer.train()
trainer.save_model(os.path.join(args.checkpoint_dir, "final"))
# Evaluate
print("\nFinal Evaluation Results:")
final_metrics = trainer.evaluate()
print(final_metrics)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--base_model_name",
type=str,
default="hotchpotch/mMiniLMv2-L6-H384",
)
parser.add_argument(
"--dataset_name",
type=str,
default="hotchpotch/fineweb-2-japanese-noise-spans",
)
parser.add_argument(
"-w",
"--add-wikipedia-dataset-count",
type=int,
default=0,
help="Number of examples to add from the Wikipedia dataset",
)
parser.add_argument(
"--checkpoint_dir",
type=str,
default="./models/text-cleaner/",
)
args = parser.parse_args()
main(args)