Callbacks
SetFit models can be influenced by callbacks, for example for logging or early stopping.
This guide will show you what they are and how they can be used.
Callbacks in SetFit
Callbacks are objects that customize the behaviour of the training loop in the SetFit Trainer that can inspect the training loop state (for progress reporting, logging, inspecting embeddings during training) and take decisions (e.g. early stopping).
In particular, the Trainer uses a TrainerControl
that can be influenced by callbacks to stop training, save models, evaluate, or log, and a TrainerState
which tracks some training loop metrics during training, such as the number of training steps so far.
SetFit relies on the Callbacks implemented in transformers
, as described in the transformers
documentation here.
Default Callbacks
SetFit uses the TrainingArguments.report_to
argument to specify which of the built-in callbacks should be enabled. This argument defaults to "all"
, meaning that all third-party callbacks from transformers
that are also installed will be enabled. For example the TensorBoardCallback
or the WandbCallback
.
Beyond that, the PrinterCallback
or ProgressCallback
is always enabled to show the training progress, and DefaultFlowCallback
is also always enabled to properly update the TrainerControl
.
Using Callbacks
As mentioned, you can use TrainingArguments.report_to
to specify exactly which callbacks you would like to enable. For example:
from setfit import TrainingArguments
args = TrainingArguments(
...,
report_to="wandb",
...,
)
# or
args = TrainingArguments(
...,
report_to=["wandb", "tensorboard"],
...,
)
You can also use Trainer.add_callback(), Trainer.pop_callback() and Trainer.remove_callback() to influence the trainer callbacks, and you can specify callbacks via the Trainer init, e.g.:
from setfit import Trainer
...
trainer = Trainer(
model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],
)
trainer.train()
Custom Callbacks
SetFit supports custom callbacks in the same way that transformers
does: by subclassing TrainerCallback
. This class implements a lot of on_...
methods that can be overridden. For example, the following script shows a custom callback that saves plots of the tSNE of the training and evaluation embeddings during training.
import os
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
class EmbeddingPlotCallback(TrainerCallback):
"""Simple embedding plotting callback that plots the tSNE of the training and evaluation datasets throughout training."""
def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
os.makedirs("logs", exist_ok=True)
def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model: SetFitModel, **kwargs):
train_embeddings = model.encode(train_dataset["text"])
eval_embeddings = model.encode(eval_dataset["text"])
fig, (train_ax, eval_ax) = plt.subplots(ncols=2)
train_X = TSNE(n_components=2).fit_transform(train_embeddings)
train_ax.scatter(*train_X.T, c=train_dataset["label"], label=train_dataset["label"])
train_ax.set_title("Training embeddings")
eval_X = TSNE(n_components=2).fit_transform(eval_embeddings)
eval_ax.scatter(*eval_X.T, c=eval_dataset["label"], label=eval_dataset["label"])
eval_ax.set_title("Evaluation embeddings")
fig.suptitle(f"tSNE of training and evaluation embeddings at step {state.global_step} of {state.max_steps}.")
fig.savefig(f"logs/step_{state.global_step}.png")
with
trainer = Trainer( model=model, args=args, train_dataset=train_dataset, eval_dataset=eval_dataset, callbacks=[EmbeddingPlotCallback()] ) trainer.train()
The on_evaluate
from EmbeddingPlotCallback
will be triggered on every single evaluation call. In the case of this example, it resulted in the following figures being plotted:
Step 20 | Step 40 |
---|---|
Step 60 | Step 80 |