Memory usage breakdown during Training

1. Memory Composition

  • Model Parameters
  • Intermediate Activations (Forward pass)
    • will be used to calculate gradiants during backward
  • Gradients (Backward pass)
  • Optimizer States

2. Static Memory (Weights, Gradients, Optimizer)

Using AdamW optimizer and Mixed Precision training, the memory cost per parameter is calculated as follows:

  • Weights: 2 bytes (FP16) + 4 bytes (FP32 master copy)
  • Gradients: 2 bytes (FP16) + 4 bytes (FP32 copy)
  • Optimizer States: 4 bytes (Momentum) + 4 bytes (Variance)

Sum up: 20B per parameter. Total = .

3. Dynamic Memory (Activations)

Assumptions

  • Input Shape:
    • : Batch size
    • : Sequence length
    • : Hidden dimension
  • Data Type: FP16 (2 bytes per element)
  • Dropout Masks: 1 byte per element

3.1 Attention block

3.1.1. Input Layer

  • Operation: The entry point to the attention block.
  • Saved Tensor: The input .
  • Memory Usage:

3.1.2. Q & K Matrix matmul

  • Operation: Generating and vectors.
  • Saved Tensors: Both and must be saved.
  • Memory Usage:

3.1.3. Softmax

  • Operation: Computing attention scores
  • Saved Tensor: The input to the Softmax function .
  • Shape Note: The shape expands to include attention heads . Shape is
  • Memory Usage:
    • Note: This term is quadratic with sequence length , making it a memory bottleneck for long contexts.

3.1.4. Attention Scores (Softmax Output)

  • Operation: Applying softmax to get attention weights.
  • Saved Tensor: The softmax output (attention probabilities).
  • Memory Usage:

3.1.5. V Matrix Multiplication & Attention Output

  • Operation: Generating Value vectors and computing weighted sum.
  • Saved Tensors:
    1. :
    2. Optional: Mask matrix for causal attention ..
  • Total Step Memory:

3.1.6. Output Mapping & Dropout

  • Operation: The final linear projection and dropout
  • Saved Tensors:
    1. Input to the projection layer .
    2. Dropout mask matrix .
  • Total Step Memory:

Final Formula

Summing up all the terms:

  1. Linear terms ():
  2. Quadratic terms ():

Total Self-Attention Activation Memory:

3.2 MLP block

Formula:

  • First Linear Layer: Saves input, occupying .
  • Activation Function: Saves input, occupying (dimension is usually expanded by 4×).
  • Second Linear Layer: Saves input, occupying .
  • Dropout Mask: Occupies .

For the MLP block, the intermediate activations required to be saved are .

3.3 Layer Normalization

Each layer norm requires saving its input ().

Since self-attention and MLP blocks each have a layer norm, the total for two layer norms is

Total

  • For each Transformer layer the intermediate activation memory usage is .
  • For an -layer Transformer model, the total activation memory is approximately: .

4. Conclusion

To handle large batch sizes, Activation Recomputation (Checkpointing) is essential. This technique discards intermediate activations during the forward pass and re-calculates them during the backward pass, trading computation time for memory space.


Memory usage breakdown during Training
https://gdymind.github.io/2026/01/25/Memory-usage-breakdown-during-Training/
Author
gdymind
Posted on
January 25, 2026
Licensed under