The world’s leading publication for data science, AI, and ML professionals.

DPO Full Training vs. LoRA: How Good is LoRA for DPO Training?

One model, two adapters

Generated with Grok
Generated with Grok

There are various methods to align LLMs with human preferences. Beyond reinforcement learning with human feedback (RLHF), often seen as too resource-intensive for consistent application on newly fine-tuned models, Direct Preference Optimization (DPO) is one of the most popular alternatives for Llm alignment.

Although DPO is significantly more cost-effective than RLHF, it still requires a reference model in addition to the "policy" model (i.e., the model being actively trained). This means both models must be loaded into GPU memory simultaneously, which can be challenging for single-GPU configurations, especially with large models.

A more memory-efficient approach would be to use Lora for DPO training. Instead of training the entire model, we freeze its parameters and train a small adapter. This method becomes even more efficient if both the policy and reference models share the same base model; in that case, we load the base model once, then load a frozen adapter for the reference model and a trainable adapter for the policy model, significantly reducing memory requirements.

However, the effect of LoRA on DPO’s performance is still understudied in my opinion. While LoRA can closely approximate full training, its performance largely depends on the tasks.

In this article, I train an LLM, Qwen 2.5, with DPO using LoRA and compare its learning curves and costs to those of full training. For full training, neither the reference nor the policy models use adapters. I also provide a step-by-step guide on using adapters with both reference and policy models.

I made a notebook implementing the code explained in this article, for DPO training with LoRA, here:

Get the notebook (#122)

Full DPO Training for Qwen2.5

We need an instruct model that has already been fine-tuned on a conversational dataset. This is the supervised fine-tuning (SFT) step, where the model learns the specific task. This SFT model will serve as the initial point for DPO training and as the reference model in DPO.

For this article, I trained the SFT adapter using a fine-tuning almost identical to the one I wrote here:

Fine-Tuning LLMs with 32-bit, 8-bit, and Paged AdamW Optimizers

I used HuggingFaceH4/ultrachat_200k dataset (MIT license), a conversational dataset, for training. Only the "messages" column is used. TRL’s SFTTrainer automatically applies the chat template from the model’s tokenizer to convert JSON objects into token sequences.

My adapter is available here:

This is an adapter but for this section, I don’t want to deal with adapters. We should merge it into the base model:

from peft import  PeftModel
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
)
import torch
model_name = "Qwen/Qwen2.5-1.5B"
sft_adapter = "kaitchup/Qwen2.5-1.5B-SFT-UltraChat" #Your adapter to merge
compute_dtype = torch.float16 
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
      model_name, device_map={"": 0}, torch_dtype=compute_dtype)
model = PeftModel.from_pretrained(model, sft_adapter)
model = model.merge_and_unload()
model.save_pretrained("./SFT_LoRA_Merged/")
tokenizer.save_pretrained("./SFT_LoRA_Merged/")

The resulting model is saved in a directory named "SFT_LoRA_Merged".

Let’s now import what we will need for DPO:

import torch, multiprocessing
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    set_seed
)
from trl import DPOTrainer, DPOConfig
set_seed(1234)
model_name = "/workspace/SFT_LoRA_Merged/" #This is where your SFT model is.

compute_dtype = torch.bfloat16
#If you have troubles with FlashAttention, use 'sdpa' instead
attn_implementation = 'flash_attention_2'
bs = 4 #Batch size per device (training and validation)
gas = 8 #Gradient accumulation steps
mseqlen = 1024 #Maximum sequence length
lr = 1e-6 #Learning rate
output_dir = "/workspace/DPO_FFT/"

Decrease the batch size and increase the gradient accumulation steps if you don’t have enough memory. Decreasing the sequence length is also an option but be aware that your model won’t perform as well on longer sequences.

As for the learning rate, I arbitrarily chose 1e-6. A lower learning rate may work better for larger models. Next, we initialize and configure the tokenizer.

#Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = "<|image_pad|>"
tokenizer.pad_token_id = 151655
tokenizer.padding_side = 'right'

I use <|image_pad|> for padding since this token is not used in Qwen2.5. For the padding, you can choose right or left.

Padding Large Language Models

The training dataset I chose for DPO training is this one:

This dataset contains the three columns needed for DPO training:

  • the prompt
  • the chosen answer
  • the rejected answer

Remember that with DPO, the goal is to train the model to generate, given a prompt, the chosen answer while moving away from the rejected answer.

Both chosen and rejected answers are in a JSON format supported by the DPO trainer. The chat template is automatically applied to transform them into sequences of tokens. Nonetheless, I’m used to applying the chat template to the dataset by myself so I’m still doing this. It also shows you how to do in case:

  • you want to apply a custom chat template that is not the one in the tokenizer;
  • there is no prompt column in the dataset

The chosen and rejected columns both contain the prompt which is the first element of the list of messages. So we can take it, apply the chat template, and overwrite the prompt column with it. The remaining elements of the messages are what DPO will compare.

ds = load_dataset("mlabonne/orpo-dpo-mix-40k", split="train").train_test_split(test_size=0.01)
ds_train = ds['train']
ds_test = ds['test']
#Add the EOS token
def process(row):
    prompt_messages = tokenizer.apply_chat_template([row["chosen"][0]], tokenize=False)
    # Now we extract the final turn to define chosen/rejected responses
    chosen_messages = tokenizer.apply_chat_template(row["chosen"][1:], tokenize=False)+tokenizer.eos_token
    rejected_messages = tokenizer.apply_chat_template(row["rejected"][1:], tokenize=False)+tokenizer.eos_token
    row["prompt"] = prompt_messages
    row["chosen"] = chosen_messages
    row["rejected"] = rejected_messages
    return row
ds_train = ds_train.map(
    process,
    num_proc= multiprocessing.cpu_count(),
    load_from_cache_file=False,
)
ds_test = ds_test.map(
    process,
    num_proc= multiprocessing.cpu_count(),
    load_from_cache_file=False,
)

Then, we load the model and enable gradient checkpointing to reduce memory consumption:

model = AutoModelForCausalLM.from_pretrained(
      model_name, device_map={"": 0}, torch_dtype=compute_dtype, attn_implementation=attn_implementation)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant':True})

We load the model a second time. This will be our reference model. It won’t be trained and consequently doesn’t require gradient checkpointing.

ref_model = AutoModelForCausalLM.from_pretrained(
      model_name, device_map={"": 0}, torch_dtype=compute_dtype, attn_implementation=attn_implementation)

Next, we can set our training arguments:

training_arguments = DPOConfig(
        output_dir=output_dir,
        eval_strategy="steps",
        do_eval=True,
        optim="paged_adamw_8bit",
        per_device_train_batch_size=bs,
        gradient_accumulation_steps=gas,
        per_device_eval_batch_size=bs,
        log_level="debug",
        save_strategy="steps",
        save_steps=200,
        logging_steps=25,
        learning_rate=lr,
        bf16 = True,
        beta = 0.1,
        eval_steps=25,
        num_train_epochs=1,
        warmup_ratio=0.1,
        lr_scheduler_type="linear",
        max_length=mseqlen,
        max_prompt_length=mseqlen,
        dataset_num_proc=multiprocessing.cpu_count(),
)

I explain them in my guide on training hyperparameters and arguments. The "beta" is specific to DPO training. A low value such as 0.1 often works well. This is also the default value.

We can now create an instance of the DPOTrainer:

trainer = DPOTrainer(
    model,
    ref_model=ref_model,
    args=training_arguments,
    train_dataset=ds_train,
    eval_dataset=ds_test,
    processing_class=tokenizer,
)

Note that "processing_class" is a new argument that is replacing "tokenizer" which is now deprecated.

Start training:

trainer_ = trainer.train()

We will discuss the learning curves in the next sections.

DPO Training with LoRA

For DPO training with LoRA, we only have a few lines to change. We don’t need to merge the SFT adapter into the model.

We load the base model first:

model = AutoModelForCausalLM.from_pretrained(
      model_name, device_map={"": 0}, torch_dtype=compute_dtype, attn_implementation=attn_implementation)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant':True})

Then, we load the adapter fine-tuned with SFT on top of it, name this adapter "DPO" (or any other name of your choice), and make it trainable (is_trainable=True). For the reference, we load the adapter a second time, under a different name, for instance, "reference".

model = PeftModel.from_pretrained(model, sft_adapter, is_trainable=True, adapter_name="DPO")
model.load_adapter(sft_adapter, adapter_name="reference")

Note: This double loading with a trainable adapter generates a very long PyTorch warning about incompatible keys. You can safely ignore it.

The base model has now two adapters: one that initializes DPO training and that will be updated, and another one which is used for reference.

Next, we need to tell the DPO trainer what is the name of the adapters. This is done through the arguments model_adapter_name and ref_adapter_name of the DPOConfig:

training_arguments = DPOConfig(
        output_dir=output_dir,
        eval_strategy="steps",
        do_eval=True,
        optim="paged_adamw_8bit",
        per_device_train_batch_size=bs,
        gradient_accumulation_steps=gas,
        per_device_eval_batch_size=bs,
        log_level="debug",
        save_strategy="steps",
        save_steps=200,
        logging_steps=25,
        learning_rate=lr,
        bf16 = True,
        beta = 0.1,
        eval_steps=25,
        num_train_epochs=1,
        warmup_ratio=0.1,
        lr_scheduler_type="linear",
        max_length=mseqlen,
        max_prompt_length=mseqlen,
        model_adapter_name="DPO",
        ref_adapter_name="reference",
        dataset_num_proc=multiprocessing.cpu_count(),
)

For the DPOTrainer, we only need to remove the argument "ref_model":

trainer = DPOTrainer(
    model,
    args=training_arguments,
    train_dataset=ds_train,
    eval_dataset=ds_test,
    processing_class=tokenizer,
)
trainer_ = trainer.train()

DPO Training: Full Training vs. LoRA

Now, we can compare the learning curves of full training and LoRA. In the training logs, we have various metrics that we can use to draw learning curves:

Screenshot by the author
Screenshot by the author

The most important metrics:

  • rewards/chosen: the average difference between the log probabilities of the policy model and the reference model for selected responses, adjusted by a scaling factor, beta.
  • rewards/rejected: the average difference between the log probabilities of the policy model and the reference model for rejected responses, also scaled by beta.
  • rewards/accuracies: the average frequency with which the chosen rewards exceed their corresponding rejected rewards.
  • rewards/margins: the average difference between the rewards for chosen responses and their corresponding rejected responses.

The goal of DPO is to distinguish between accepted and rejected answers. Specifically, we aim to increase the difference (rewards/margins) between the rewards for chosen answers (rewards/chosen) and rejected answers (rewards/rejected). The rewards/accuracies metric also provides valuable insight into the learning process, indicating how accurately the model prefers chosen answers over rejected ones.

Using these metrics, we have these learning curves:

Typically, seeing these curves might lead us to conclude that LoRA performs better than full training. LoRA achieves higher accuracy and more decisively rejects the "rejected" answers, while still preserving rewards for the "chosen" answers.

However, the reality is more nuanced. When studies claim that LoRA outperforms full training or full fine-tuning, they may not have fully optimized hyperparameters for both methods. Specific values for learning rate and beta might be effective for LoRA but not for full training. For a fair comparison, we would need to explore a wide range of learning rates and beta values for both approaches. It’s likely that with sufficient experimentation, we would identify a configuration for full training that surpasses the best LoRA setups – though running these extensive experiments would be very resource-intensive.

What conclusions can we draw from these learning curves?

We can conclude that my default LoRA configuration, a standard setup, performs well overall.

However, examining the runtime reveals an interesting trade-off:

  • Full training: 6 hours and 13 minutes
  • LoRA: 8 hours and 36 minutes

Surprisingly, full training is faster than LoRA. This may be due to LoRA’s slight overhead in scoring samples (especially with the reference model) and because we’re using two adapters for the same model for policy training and reference. Switching between these adapters might not be fully efficient.

So, is full training more cost-effective than LoRA?

On the same hardware, yes – full training is faster. But in practice, we often face memory constraints. Full training demands significantly more memory, requiring larger GPUs. While DPO full training on a 7B model is challenging on a single 80 GB GPU, it’s manageable with LoRA.

Conclusion

LoRA performs well for DPO. It is a much more memory-efficient alternative to full DPO training, even though it doesn’t necessarily speed up the training process. We could further improve memory efficiency by quantizing the base model using methods like bitsandbytes or a GPTQ-compatible approach like AutoRound.

However, it’s important to note that quantization may slow down training considerably.

For most use cases, DPO training with LoRA is likely a more practical choice than full training, given its memory efficiency.


To support my work, consider subscribing to my newsletter for more articles/tutorials on recent advances in AI:

The Kaitchup – AI on a Budget | Benjamin Marie | Substack


Related Articles