SetFit documentation

Callbacks

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Callbacks

SetFit models can be influenced by callbacks, for example for logging or early stopping.

This guide will show you what they are and how they can be used.

Callbacks in SetFit

Callbacks are objects that customize the behaviour of the training loop in the SetFit Trainer that can inspect the training loop state (for progress reporting, logging, inspecting embeddings during training) and take decisions (e.g. early stopping).

In particular, the Trainer uses a TrainerControl that can be influenced by callbacks to stop training, save models, evaluate, or log, and a TrainerState which tracks some training loop metrics during training, such as the number of training steps so far.

SetFit relies on the Callbacks implemented in transformers, as described in the transformers documentation here.

Default Callbacks

SetFit uses the TrainingArguments.report_to argument to specify which of the built-in callbacks should be enabled. This argument defaults to "all", meaning that all third-party callbacks from transformers that are also installed will be enabled. For example the TensorBoardCallback or the WandbCallback.

Beyond that, the PrinterCallback or ProgressCallback is always enabled to show the training progress, and DefaultFlowCallback is also always enabled to properly update the TrainerControl.

Using Callbacks

As mentioned, you can use TrainingArguments.report_to to specify exactly which callbacks you would like to enable. For example:

from setfit import TrainingArguments

args = TrainingArguments(
    ...,
    report_to="wandb",
    ...,
)
# or 
args = TrainingArguments(
    ...,
    report_to=["wandb", "tensorboard"],
    ...,
)

You can also use Trainer.add_callback(), Trainer.pop_callback() and Trainer.remove_callback() to influence the trainer callbacks, and you can specify callbacks via the Trainer init, e.g.:

from setfit import Trainer

...

trainer = Trainer(
    model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],
)
trainer.train()

Custom Callbacks

SetFit supports custom callbacks in the same way that transformers does: by subclassing TrainerCallback. This class implements a lot of on_... methods that can be overridden. For example, the following script shows a custom callback that saves plots of the tSNE of the training and evaluation embeddings during training.

import os
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

class EmbeddingPlotCallback(TrainerCallback):
    """Simple embedding plotting callback that plots the tSNE of the training and evaluation datasets throughout training."""
    def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        os.makedirs("logs", exist_ok=True)

    def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model: SetFitModel, **kwargs):
        train_embeddings = model.encode(train_dataset["text"])
        eval_embeddings = model.encode(eval_dataset["text"])

        fig, (train_ax, eval_ax) = plt.subplots(ncols=2)

        train_X = TSNE(n_components=2).fit_transform(train_embeddings)
        train_ax.scatter(*train_X.T, c=train_dataset["label"], label=train_dataset["label"])
        train_ax.set_title("Training embeddings")

        eval_X = TSNE(n_components=2).fit_transform(eval_embeddings)
        eval_ax.scatter(*eval_X.T, c=eval_dataset["label"], label=eval_dataset["label"])
        eval_ax.set_title("Evaluation embeddings")

        fig.suptitle(f"tSNE of training and evaluation embeddings at step {state.global_step} of {state.max_steps}.")
        fig.savefig(f"logs/step_{state.global_step}.png")

with

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    callbacks=[EmbeddingPlotCallback()]
)
trainer.train()

The on_evaluate from EmbeddingPlotCallback will be triggered on every single evaluation call. In the case of this example, it resulted in the following figures being plotted:

Step 20 Step 40
step_20 step_40
Step 60 Step 80
step_60 step_80
< > Update on GitHub