PEFT documentation

Memory Efficient Training

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v0.19.0).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Memory Efficient Training

🤗 PEFT makes fine-tuning parameter efficient, but not automatically memory efficient. This overview collects tips for cutting training memory and links to the detailed guides.

Always consider the basics of choosing a smaller base model, smaller batch size or shorter sequence length to lower your memory usage.

Training memory overview

Let’s dissect the distribution of training memory so we can reason about potential countermeasures. We will use a large language-model trained with the Adam optimizer as an example. When doing full fine-tuning, we will have the following positions taking up memory:

  1. base model parameters: the memory consumption highly depends on the chosen dtype. The less bits per parameter (16 for float16), the less memory this will take up. A 1B model in float16 (16 bit/2 byte per parameter) will roughly take 1e9 × 2 byte = 1.863 GiB of memory.
  2. all trainable base model parameters ×3 for gradients (1×) and Adam optimizer states (2×), therefore 5.59GB of memory
  3. memory for the intermediate activations between layers, these are hard to predict but mostly depend on the used compute dtype and sequence length / batch size.

A smaller base model or a using a smaller compute dtype will reduce all points while using shorter sequences or smaller batches mainly affects gradients and activation memory. Employing PEFT methods will reduce the number of trainable parameters and therefore significantly reduce both gradients and optimizer state, saving a lot of memory.

Choosing the right method

Not every PEFT method is built equally and some formulations are easier to build in a memory efficient manner. If you are on a memory budget it makes sense to check out the PEFT method comparison suite and filter for maximum accelerator memory usage. Average accelerator memory usage can be fairly equal across methods but not every method scales equally with activations and sequence length; some methods are more prone to memory spikes than others.

Consider using trainable tokens when targeting large layers like language modeling heads or embedding layers to fine-tune specific tokens.

Quantization

Quantization is one of the best ways to reduce memory consumption of the base model and will, depending on the employed quantization, also reduce activation memory. Since the PEFT methods will only take up a small portion of the total number of parameters, PEFT defaults to use a higher precision than the base model. This can also have the effect that adapters can mitigate some of the quality loss incurred by quantization methods. Read the PEFT quantization guide.

Compilation

The models we train are composed of operations like matrix multiplications, sums and assignments where each operation produces a new result and, subsequently, needs to take up memory. If those intermediate results are not needed we can fuse these operations and save up on memory. This is just one of many optimizations that torch.compile can do for you, so check out the PEFT torch.compile guide.

Gradient Checkpointing

You can trade memory with computation by only saving every nth gradient between layers and computing the rest on the fly. Check out the gradient checkpointing documentation of Transformers to learn more.

When not using Diffusers or Transformers you may need to implement your own gradient checkpointing logic, depending on the training framework that you are using.

Chunked NLL loss

Using NLLLoss is very common when training language models (or classification tasks). You allocate a matrix of size batch × sequence × vocabulary. With particularly long sequences or vocabularies this can get expensive fast.

When using TRL you can either use the Liger kernel integration or use Chunked NLLLoss. The latter will split the sequence in chunks of size 256 to keep the maximum memory consumption constant.

NLL vs. Chunked NLL comparison

In case the default chunk size is not optimal for your setting, look in the original TRL PR for more information on how to tune the chunk size.

Update on GitHub