The world’s leading publication for data science, AI, and ML professionals.

Customizing Your Fine-tuning Code Using HuggingFace’s Transformers Library

Examples of custom callbacks and custom fine-tuning code from different libraries

Image generated by Gemini
Image generated by Gemini

The HuggingFace transformer library offers many basic building blocks and a variety of functionality to kickstart your AI code. Many products and libraries have been built on top of it and in this short blog, I will talk about some of the ways people have extended it to add custom training code on top of the HuggingFace transformer’s library:

  1. Reimplement the training code by iterating through the training data to recreate the fine-tuning loop and then adding in custom code, and
  2. Creating custom callbacks tacked onto the _Trainer_ class so that custom code be added to the callbacks.

Obviously, there may be other ways to customize the fine-tuning loop, but this blog is intended to focus on these two approaches.

Reimplementing the Training Code for a Custom Fine-tuning Loop

Typically when you train a model, a _Trainer_ object is created that allows you to specify the parameters for training a model. The Trainer object surfaces a train() method that you can call that initiates the training loop:

from Transformers import AutoModelForCausalLM
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer

dataset = load_dataset("stanfordnlp/imdb", split="train")
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")

# Example Trainer object. SFTTrainer is types of trainer for supervised fine-tuning
trainer = SFTTrainer(
    model,
    train_dataset=dataset,
    args=SFTConfig(output_dir="/tmp"),
)
# Initiate training
trainer.train()

Instead of calling a Trainer object, some libraries add custom code by 1) reimplementing the code in the train() function that passes the data to fine-tune the model and then 2) adding in custom code at different points of that reimplementation. A good example of this is AllenAI’s open-source library for Data Cartography.

In the library, training dynamics – the characteristics of the model and the training datapoints – are captured after each training epoch during the fine-tuning loop. Capturing the training dynamics requires custom code within the fine-tuning loop. In the code, an iterator is created that goes through each of the training epochs, and for each of the training epochs, the batches of the training data are passed to the model for training (the following is a condensed, commented version of the implementation):

model.zero_grad()

# ------------------------------------------------------------------------------
# 1. Creating an progress bar with trange (based on the tqdm library) 
#    to iterate through the specified epoch range
train_iterator = trange(epochs_trained, int(args.num_train_epochs), ...)
# ------------------------------------------------------------------------------

# --------------------- Add custom code here ------------------------------------
# ...
# ------------------------------------------------------------------------------

# Iterate through the epochs
for epoch, _ in enumerate(train_iterator):

  # Creating another iterator that goes through each of the batches of training data 
  # defined by the train_dataloader of type DataLoader object
  epoch_iterator = tqdm(train_dataloader, desc="Iteration", ...)

  # ------------------------------------------------------------------------------
  # 2. Iterate through the batches of training data
  for step, batch in enumerate(epoch_iterator):
    # ------------------------------------------------------------------------------

    # --------------------- Add custom code here ------------------------------------
    # Such as checking if it is resuming a training loop and if so skipping past steps that have already been trained on
    # ------------------------------------------------------------------------------

    # Train the model
    model.train()
    # Prep the data according to the format expected by the model (this is a BERT model) 
    batch = tuple(t.to(args.device) for t in batch)
    inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
    outputs = model(**inputs)
    loss = outputs[0]

    # --------------------- Add custom code here ------------------------------------
    # Such as capturing the training dynamics, 
    # i.e. model and training data properties at that specified epoch
    if train_logits is None:  # Keep track of training dynamics.
        train_ids = batch[4].detach().cpu().numpy()
        train_logits = outputs[1].detach().cpu().numpy()
        train_golds = inputs["labels"].detach().cpu().numpy()
        train_losses = loss.detach().cpu().numpy()
    else:
        train_ids = np.append(train_ids, batch[4].detach().cpu().numpy())
        train_logits = np.append(train_logits, outputs[1].detach().cpu().numpy(), axis=0)
        train_golds = np.append(train_golds, inputs["labels"].detach().cpu().numpy())
        train_losses = np.append(train_losses, loss.detach().cpu().numpy())
    # ------------------------------------------------------------------------------

By reimplementing what is done within the fine-tuning loop, the basic parts of what the Trainer object also does are reimplemented, including performing a step of training on batches of the training data and computing the model’s loss on a batch of data. While this approach of customizing your fine-tuning loop gives developers fine-grained control over the implementation, this approach also requires a lot of work to ensure the code works. The second approach for adding custom code does not require reimplementing parts of the Trainer object since it uses custom callbacks.

Creating Custom Callbacks to Customize the Trainer Class

A callback is a function passed as an argument to another function. The second function can call the passed function at a later time. Callbacks enable you to add custom code within the function that you pass to the second function. There is a _TrainerCallback_ class that contains empty callback functions that you can override with your custom code. These callback functions are essentially called at different parts of the training loop within the Trainer class (or its inherited version of the class such as SFTTrainer, if you are using one).

# Taken straight from the TrainerCallback source code:
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/trainer_callback.py#L260
class TrainerCallback:
    # A bunch of empty functions you can override
    def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        """
        Event called at the beginning of training.
        """
        pass

    def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        """
        Event called at the end of training.
        """
        pass

    def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        """
        Event called at the beginning of an epoch.
        """
        pass

    def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        """
        Event called at the end of an epoch.
        """
        pass
    # More empty functions not included here...

Each of the empty functions passes several arguments (1) args of type TrainingArguments, (2) state of type TrainerState, (3) control of type TrainerControl, and (4) additional arbitrary arguments that are combined into the kwargs** argument. These arguments contain objects current to the Trainer class that you can access and then add custom code to. More details about these arguments can be found here.

For instance, if you want to access the current state of the model after each epoch, you would

  1. Override the _on_epochend function and
  2. Access the model within the _on_epochend function
from transformers import TrainerCallback, TrainerState, TrainerControl

class ExampleTrainerCallback(TrainerCallback):
    """Custom ExampleTrainerCallback that accesses the model after each epoch
    """
    def __init__(self, some_tokenized_dataset):
        """Initializes the ExampleTrainerCallback instance."""
        super().__init__()
        # --------------------- Add custom code here ------------------------------------
        self.some_tokenized_dataset = some_tokenized_dataset
        # ------------------------------------------------------------------------------

    # Overriding the on_epoch_end() function
    def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        """
        Event called at the end of an epoch.
        """
        # --------------------- Add custom code here ------------------------------------
        print('Hello an epoch has ended!')

        # Access the current state of the model after the epoch ends: 
        model = kwargs["model"]

        # Add some custom code here...
        model.eval()

        # Perform inference on some dataset
        with torch.no_grad():
            for item in self.some_tokenized_dataset: 
                input_ids = item["input_ids"].unsqueeze(0)  # Add batch dimension
                attention_mask = item["attention_mask"].unsqueeze(0)  # Add batch dimension

                # Forward pass, assuming model is a BertForSequenceClassification type
                # i.e. model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)  
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                logits = outputs.logits
                probabilities = torch.nn.functional.softmax(logits, dim=-1)
                prediction = torch.argmax(probabilities, dim=-1).item()
                # Do something with prediction
        # ------------------------------------------------------------------------------

In the above code, we access the current state of the model after every epoch ends using the kwargs["model"] line, and then we add some custom code (in this case we performance inference on some dataset that we tokenized in the init function of the ExampleTrainerCallback class). The nice part about using TrainerCallback is we do not have to go into the nitty gritty of reimplementing some of the core things done in the Trainer class, including computing loss and passing the batches of training data into the model for learning. We preserve all of those things using existing code (that other people have tested!) and build on top of the existing code with our custom code.

One major thing with custom callbacks is to ensure you also pass it into the Trainer object you will use:

from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset
from trl import SFTTrainer

# Create a Trainer Object
trainer = SFTTrainer(
    model=model,
    train_dataset=train_data,
    args=TrainingArguments(num_train_epochs=5, evaluation_strategy='epoch', ...),
    # Additional arguments here...
)

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# Prep some data, where some_tokenized_dataset is of type DatasetDict
def tokenize_function(example):
    return tokenizer(example['text'], padding="max_length", truncation=True)
some_tokenized_dataset = load_dataset('json', data_files='path_to_your_data', split='test')
some_tokenized_dataset = some_tokenized_dataset.map(tokenize_function, batched=True)

# --------------- DO NOT FORGET TO ADD YOUR CALLBACKS TO YOUR TRAINER! -----------------------------------------------
# Create the callback with you custom code
example_callback = ExampleTrainerCallback(
  some_tokenized_dataset = some_tokenized_dataset
)

# Add the callback to the Trainer
trainer.add_callback(example_callback)

# ------------------------------------------------------------------------------

# Train the model
trainer.train()

To illustrate a concrete example, the Weights and Biases library has an example TrainerCallback, called the WandbCallback, that adds custom code during the fine-tuning loop. The WandbCallback logs several things such as metrics and model checkpoints throughout the training loop that then get sent to Weights and Biases for you to later on do things like visualizing your experiments using their tools (the following is a condensed, commented version of the custom callback):

from transformers import TrainerCallback

class WandbCallback(TrainerCallback):
    """
    A [`TrainerCallback`] that logs metrics, media, model checkpoints to [Weight and Biases](https://www.wandb.com/).
    """

    def __init__(self):
        # ---------------------------------------------------------------------------
        # Custom code that runs when the WandbCallback class is initialized
        # ---------------------------------------------------------------------------

    def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
        # ---------------------------------------------------------------------------
        # Custom code 
        # add the model architecture to a separate text file
        save_model_architecture_to_file(model, temp_dir)
        # ---------------------------------------------------------------------------

    def on_log(self, args, state, control, model=None, logs=None, **kwargs):
        # ---------------------------------------------------------------------------
        # Custom code that logs the items listed in single_value_scalars
        single_value_scalars = [
            "train_runtime",
            "train_samples_per_second",
            "train_steps_per_second",
            "train_loss",
            "total_flos",
        ]

        # More code here that accesses these values in the logs argument that then gets saved to wandb
        if state.is_world_process_zero:
            for k, v in logs.items():
                if k in single_value_scalars:
                    self._wandb.run.summary[k] = v
            non_scalar_logs = {k: v for k, v in logs.items() if k not in single_value_scalars}
            non_scalar_logs = rewrite_logs(non_scalar_logs)
            self._wandb.log({**non_scalar_logs, "train/global_step": state.global_step})
        # ---------------------------------------------------------------------------

Side note: Weights and Biases sometimes come enabled by default when fine-tuning using the TrainingArgument class (which then gets passed to the Trainer class). They should be disabled if you are using proprietary data (also depending on your organization’s policy):

from transformers.training_args import TrainingArguments
from trl import SFTTrainer

# Disable any reporting of internal, proprietary data when using HuggingFace classes:
args = TrainingArguments(report_to=None, ...)

trainer = SFTTrainer(
    model=model,
    train_dataset=train_data,
    args=args
)

Conclusion

This short blog introduces some approaches to customizing your fine-tuning loop in the HuggingFace Transformers library. I draw on two different examples from existing libraries and code and talk about the benefits and drawbacks of each approach. Many other creative approaches should exist, but they are beyond the scope of this mini-blog.


Related Articles