tomaarsen's picture
tomaarsen HF staff
Create train_script.py
1aa7c41 verified
raw
history blame
7.49 kB
import logging
import traceback
from collections import defaultdict
from datasets import load_dataset
from datasets.load import load_from_disk
from torch import nn
import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation.CENanoBEIREvaluator import CENanoBEIREvaluator
from sentence_transformers.cross_encoder.evaluation.CERerankingEvaluator import CERerankingEvaluator
from sentence_transformers.cross_encoder.losses.BinaryCrossEntropyLoss import BinaryCrossEntropyLoss
from sentence_transformers.cross_encoder.losses.CachedMultipleNegativesRankingLoss import (
CachedMultipleNegativesRankingLoss,
)
from sentence_transformers.cross_encoder.trainer import CrossEncoderTrainer
from sentence_transformers.cross_encoder.training_args import CrossEncoderTrainingArguments
from sentence_transformers.evaluation.SequentialEvaluator import SequentialEvaluator
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.util import mine_hard_negatives
def main():
model_name = "answerdotai/ModernBERT-base"
# Set the log level to INFO to get more information
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
train_batch_size = 64
num_epochs = 1
# 1. Define our CrossEncoder model
model = CrossEncoder(model_name)
print("Model max length:", model.max_length)
print("Model num labels:", model.num_labels)
# 2. Load the MS MARCO dataset:
logging.info("Read train dataset")
embedding_model = SentenceTransformer("sentence-transformers/static-retrieval-mrl-en-v1")
full_dataset = load_dataset("sentence-transformers/natural-questions", split=f"train")
dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12)
train_dataset = dataset_dict["train"]
eval_dataset = dataset_dict["test"]
# '''
hard_eval_dataset = mine_hard_negatives(
eval_dataset,
embedding_model,
corpus=full_dataset["answer"],
num_negatives=30,
batch_size=256,
positive_among_negatives=True,
as_triplets=False,
# faiss_batch_size=4096,
use_faiss=True,
)
print(hard_eval_dataset)
# # breakpoint()
# indices = []
# for sample in eval_dataset:
# try:
# idx = list(sample.values())[2:].index(sample["answer"])
# except ValueError:
# idx = len(eval_dataset.column_names) - 2
# indices.append(idx)
# print(sum(indices) / len(indices))
# breakpoint()
hard_train_dataset = mine_hard_negatives(
train_dataset,
embedding_model,
num_negatives=5, # 5 negatives per question-answer pair
margin=0, # Similarity between query and negative samples should be at least 0.1 less than query-positive similarity
range_min=0, # Skip the 10 most similar samples
range_max=100, # Consider only the 100 most similar samples
sampling_strategy="top", # Randomly sample negatives from the range
batch_size=256,
as_triplets=False, # We want 7 columns: query, positive, negative1, negative2, negative3, negative4, negative5
use_faiss=True,
)
# breakpoint()
# hard_train_dataset.save_to_disk("nq-train-hard-negatives")
# hard_eval_dataset.save_to_disk("nq-eval-hard-negatives")
# '''
# hard_train_dataset = load_from_disk("nq-train-hard-negatives")
# hard_eval_dataset = load_from_disk("nq-eval-hard-negatives")
def mapper(batch):
batch_size = len(batch["query"])
num_negatives = len(batch) - 2
num_candidates = len(batch) - 1
return {
"query": batch["query"] * num_candidates,
"response": sum(list(batch.values())[1:], []),
"label": [1] * batch_size + [0] * num_negatives * batch_size,
}
hard_train_dataset = hard_train_dataset.map(mapper, batched=True, remove_columns=hard_train_dataset.column_names)
eval_dataset = eval_dataset.map(mapper, batched=True, remove_columns=eval_dataset.column_names)
# 3. Define our training loss
loss = BinaryCrossEntropyLoss(model=model, pos_weight=torch.tensor(5))
# 4. Define the evaluator. We use the CENanoBEIREvaluator, which is a light-weight evaluator for English reranking
reranking_evaluator = CERerankingEvaluator(
samples=[
{
"query": sample["query"],
"positive": [sample["answer"]],
"negative": [sample[column_name] for column_name in hard_eval_dataset.column_names[2:]],
}
for sample in hard_eval_dataset
],
batch_size=train_batch_size,
negatives_are_ranked=True,
name="nq-dev",
)
nano_beir_evaluator = CENanoBEIREvaluator(
dataset_names=["msmarco", "nfcorpus", "nq"],
batch_size=train_batch_size,
)
evaluator = SequentialEvaluator([reranking_evaluator, nano_beir_evaluator])
evaluator(model)
# 5. Define the training arguments
short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1]
run_name = f"reranker-{short_model_name}-nq-bce-static-retriever-hardest"
args = CrossEncoderTrainingArguments(
# Required parameter:
output_dir=f"models/{run_name}",
# Optional training parameters:
num_train_epochs=num_epochs,
per_device_train_batch_size=train_batch_size,
per_device_eval_batch_size=train_batch_size,
learning_rate=2e-5,
warmup_ratio=0.1,
fp16=False, # Set to False if you get an error that your GPU can't run on FP16
bf16=True, # Set to True if you have a GPU that supports BF16
dataloader_num_workers=4,
# (Cached)MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
load_best_model_at_end=True,
metric_for_best_model="eval_nq-dev_ndcg@10",
# Optional tracking/debugging parameters:
eval_strategy="steps",
eval_steps=1000,
save_strategy="steps",
save_steps=1000,
save_total_limit=2,
logging_steps=200,
logging_first_step=True,
run_name=run_name, # Will be used in W&B if `wandb` is installed
seed=12,
)
# 6. Create the trainer & start training
trainer = CrossEncoderTrainer(
model=model,
args=args,
train_dataset=hard_train_dataset,
eval_dataset=eval_dataset,
loss=loss,
evaluator=evaluator,
)
trainer.train()
# 7. Evaluate the final model, useful to include these in the model card
evaluator(model)
# 8. Save the final model
final_output_dir = f"models/{run_name}/final"
model.save_pretrained(final_output_dir)
# 9. (Optional) save the model to the Hugging Face Hub!
# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first
try:
model.push_to_hub(run_name)
except Exception:
logging.error(
f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run "
f"`huggingface-cli login`, followed by loading the model using `model = CrossEncoder({final_output_dir!r})` "
f"and saving it using `model.push_to_hub('{run_name}')`."
)
if __name__ == "__main__":
main()