The world’s leading publication for data science, AI, and ML professionals.

Multi-GPU Fine-tuning for Llama 3.1 70B with FSDP and QLoRA

What you can do with only 2×24 GB GPUs and a lot of CPU RAM

Generated with DALL-E
Generated with DALL-E

Fine-tuning large language models (LLMs) with up to 35B parameters is relatively easy and cheap since it can be done with a single consumer GPU. Fine-tuning larger models with a single consumer GPU is, in theory, not impossible as we can offload parts of the model to the CPU memory. However, it would be extremely slow, even with high-end CPUs.

Using multiple GPUs is the only alternative to keep fine-tuning fast enough. A configuration with 2×24 GB GPUs opens a lot of possibilities. 48 GB of GPU memory is enough to fine-tune 70B models such as Llama 3 70B and Qwen2 72B.

In this article, I explain how to fine-tune 70B LLMs using only two GPUs thanks to FSDP and QLoRA.

I first explain what is FSDP and then we will see how to modify a standard QLoRA fine-tuning code to run it on multiple GPUs. For the experiments and demonstrations, I use Llama 3.1 70B but it would work similarly for other LLMs. For the hardware, I relied on 2 RTX 3090 GPUs provided by RunPod (referral link). Using 2 RTX 4090 GPUs would be faster but more expensive.

I also made a notebook implementing the code described in this article. It’s available here:

Get the notebook (#92)

Fully Sharded Data Parallel (FSDP): How Does It Work?

A common and efficient way to perform distributed training over multiple GPUs is to load a copy of the model on each GPU. However, this is not possible for very large models or if we only have small GPUs. For instance, Llama 3 70B occupies around 140 GB while the largest NVIDIA GPU that you can find in clouds, the H100, only has 80 GB of memory.

However, if we could split the model, we could actually load one part on one 80 GB GPU and the remainder on another 80 GB GPU. Yet, this can’t be that simple. The model itself only consumes a small percentage of the total memory consumed by fine-tuning. We also need to store the gradients, the activations, and the optimizer states. We need to split everything over all the GPUs and offload some of it to the CPU RAM if we don’t have enough GPU memory.

This is exactly what FSDP does. It has been introduced in 2022 in PyTorch:

Introducing PyTorch Fully Sharded Data Parallel (FSDP) API

In a nutshell, FSDP distributes optimizer states, gradients, and parameters across multiple devices (e.g., GPUs and CPUs). During the forward pass, each FSDP unit gathers the necessary weight shards from other devices to form the complete set of weights, performs the computation, and then discards the non-local shards. After computing the loss, during the backward pass, each FSDP unit again gathers the complete set of weights and performs computations to determine local gradients, which are then averaged. These averaged gradients are redistributed across the devices through a reduce-scatter operation. After this, each device updates its own shard of the parameters.

FSDP would be complex to implement by ourselves. Fortunately, Hugging Face’s Accelerate makes the use of FSDP almost transparent. We only have to:

  1. Generate a configuration file with Accelerate
  2. Add a call to Accelerator()
  3. Implement a strategy to save the final model after fine-tuning (only a few lines of code)
  4. Launch the fine-tuning script with Accelerate

Generate a configuration file for FSDP with Accelerate

We need Hugging Face’s Accelerate. Make sure it’s installed and up to date:

pip install --upgrade accelerate

Then, configure it by running:

accelerate config

It will ask you several questions. The goal here is to generate a configuration file that will be used for fine-tuning with FSDP. Some of the questions can be difficult to answer if you don’t know well how FSDP works. If this is the case, you can skip this step and use an existing configuration file, such as this one:

compute_environment: LOCAL_MACHINE                                                                                                                                           
debug: false                                                                                                                                                                 
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch: BACKWARD_PRE
  fsdp_cpu_ram_efficient_loading: true
  fsdp_forward_prefetch: false
  fsdp_offload_params: true
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_use_orig_params: false
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

You only have to modify "num_processes" which is the number of GPUs you have on your machine. Then, save it into a file, e.g., config_fsdp.yaml.

Add a call to Accelerator() for Fine-tuning with FSDP with QLoRA

We need to add the following lines of code before loading the model:

from accelerate import Accelerator
accelerator = Accelerator()

Code for Saving the Model Fine-tuned with FSDP

Since the model is split during fine-tuning, the checkpoints only contain pieces of the model.

To save the model, we need to gather all the pieces of the model on one device. This is achieved by the following code that we have to run after the training (this code handles the (Q)LoRA case):

fsdp_plugin = trainer.accelerator.state.fsdp_plugin
fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(trainer.model)

if trainer.is_fsdp_enabled:
    trainer.accelerator.state.fsdp_plugin.
    set_state_dict_type("FULL_STATE_DICT")

Launching the Fine-tuning with Accelerate

Then, instead of running the script with the "Python" command, we use "accelerate":

accelerate launch --config_file config_fsdp.yaml fine_tuning_FSDP_QLoRA.py

However, this won’t work with QLoRA. We need to make modifications in the QLoRA fine-tuning code.

Setting up QLoRA for FSDP

When we set up QLoRA for fine-tuning, we define a BitsAndBytesConfig which looks like this:

bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=compute_dtype,
            bnb_4bit_use_double_quant=True,
    )

When FSDP is activated, we need one more argument, "bnb_4bit_quant_storage":

bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=compute_dtype,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_storage=compute_dtype,
    )

This new argument is explained in the documentation.

compute_dtype can be torch.bfloat16 (preferred: more stable and faster) or torch.float16 for older GPUs.

In addition, we need to prepare the model for training. For single-GPU QLoRA fine-tuning, we would simply add this line:

model = prepare_model_for_kbit_training(model, gradient_checkpointing_kwargs={'use_reentrant':True})

It does the following:

  1. Cast the layernorm and the language modeling head in fp32
  2. Freeze the parameters of the models
  3. Make output embedding layer requires grads
  4. Activate gradient checkpointing

With FSDP, (1) doesn’t seem to be possible and triggers an error when the fine-tuning starts. To avoid this casting, I implemented what "prepare_model_for_kbit_training" does, minus this first step:

for name, param in model.named_parameters():
    # freeze base model's layers
    param.requires_grad = False

def make_inputs_require_grad(module, input, output):
    output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant':True})

Finally, for QLoRA fine-tuning, it is common to also use quantized optimizer states with AdamW 8bit or paged AdamW 8bit. It doesn’t seem to be well supported by FSDP. When using AdamW 8bit the trainer wasn’t happy that some parts of the parameters were on the CPU RAM rather than on the GPU. I didn’t try to solve it and rather switched to the standard AdamW (32-bit) which only increased the memory consumption by a few GBs.

Fine-tuning Llama 3.1 70B with QLoRA and FSDP

I recommend at least 2×24 GB GPUs and 200 GB of CPU RAM for fine-tuning 70B models with FSDP and QLoRA. It can be consumer GPUs such as the RTX 3090 or RTX 4090. Adding one more GPU would significantly decrease the consumption of CPU RAM and would speed up fine-tuning.

I use RunPod (referral link) for this experiment. However, RunPod doesn’t have a large choice of machines equipped with consumer GPUs and a lot of CPU RAM. The most cost-effective configuration that I found is to use the secure cloud of RunPod to set up a machine with 2 RTX 3090 (48GB of VRAM in total) and 251 GB of CPU RAM. This configuration costs $0.66/hour (+ $0.1/hour for the storage) (as of August 6th, 2024 ).

The complete fine-tuning code is as follows:

import torch, os, multiprocessing
from datasets import load_dataset
from peft import LoraConfig, prepare_model_for_kbit_training
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    set_seed
)
from trl import SFTTrainer, SFTConfig
from peft.utils.other import fsdp_auto_wrap_policy
from accelerate import Accelerator

accelerator = Accelerator()
set_seed(1234)
#use bf16 and FlashAttention if supported
if torch.cuda.is_bf16_supported():
  os.system('pip install flash_attn')
  compute_dtype = torch.bfloat16
  attn_implementation = 'flash_attention_2'
else:
  compute_dtype = torch.float16
  attn_implementation = 'sdpa'
model_name = "meta-llama/Meta-Llama-3.1-70B"
#Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
tokenizer.pad_token = "<|finetune_right_pad_id|>"
tokenizer.pad_token_id = 128004
tokenizer.padding_side = 'right'
ds = load_dataset("timdettmers/openassistant-guanaco")
#Add the EOS token
def process(row):
    row["text"] = row["text"]+"<|end_of_text|>"
    return row
ds = ds.map(
    process,
    num_proc= multiprocessing.cpu_count(),
    load_from_cache_file=False,
)
bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=compute_dtype,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_storage=compute_dtype,
)
model = AutoModelForCausalLM.from_pretrained(
          model_name, quantization_config=bnb_config, torch_dtype=torch.bfloat16, attn_implementation=attn_implementation
)
for name, param in model.named_parameters():
    # freeze base model's layers
    param.requires_grad = False
def make_inputs_require_grad(module, input, output):
    output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant':True})
peft_config = LoraConfig(
        lora_alpha=16,
        lora_dropout=0.05,
        r=16,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules= ['k_proj', 'q_proj', 'v_proj', 'o_proj', "gate_proj", "down_proj", "up_proj"]
)
output_dir = "./Llama3.1_70b_QLoRA/"

training_arguments = SFTConfig(
        output_dir=output_dir ,
        eval_strategy="steps",
        do_eval=True,
        optim="adamw_torch",
        per_device_train_batch_size=1,
        gradient_accumulation_steps=16,
        per_device_eval_batch_size=1,
        log_level="debug",
        logging_steps=10,
        learning_rate=1e-4,
        bf16 = True,
        eval_steps=10,
        max_steps=50,
        warmup_ratio=0.1,
        lr_scheduler_type="linear",
        dataset_text_field="text",
        max_seq_length=512,
)
trainer = SFTTrainer(
        model=model,
        train_dataset=ds['train'],
        eval_dataset=ds['test'],
        peft_config=peft_config,
        tokenizer=tokenizer,
        args=training_arguments,
)
fsdp_plugin = trainer.accelerator.state.fsdp_plugin
fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(trainer.model)
trainer.train()
if trainer.is_fsdp_enabled:
    trainer.accelerator.state.
    fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
trainer.save_model(output_dir)

Note: I recommend copying it in a Python file, rather than running it in a Jupyter notebook. Then, run this file with "accelerate" to start fine-tuning as explained in the previous section.

I ran only 50 training steps which took in total around 10 hours, including 5 validation runs, which took 28 minutes each. Running one epoch would take around 90 hours with this configuration. To give you a point of comparison, it’s 80 times slower than running one epoch with Llama 3.1 8B on a single RTX 3090.

Note that the training batch size is set for each GPU. In the training arguments, I set a batch size of 1 with 8 gradient accumulation steps. Since we use two GPUs, the total training batch size is 182 = 16. You may increase the batch size for faster training if you use 3 GPUs or more.

The learning curves for these 50 training steps:

Figure by the author
Figure by the author

The training loss and validation loss are both decreasing. The model is learning.

Conclusion

FSDP for QLoRA works very well once we understand what parts of the code need to be adapted. While fine-tuning a 70B LLM is relatively fast using only 2 GPUs, I would recommend investing in a third GPU to avoid using too much CPU RAM which slows down fine-tuning. Fine-tuning would become much faster and would be more cost-effective.


To support my work, consider subscribing to my newsletter for more articles/tutorials on recent advances in AI:

The Kaitchup – AI on a Budget | Benjamin Marie | Substack


Related Articles