Memory usage breakdown during Training
1. Memory Composition
- Model Parameters
- Intermediate Activations (Forward pass)
- will be used to calculate gradiants during backward
- Gradients (Backward pass)
- Optimizer States
2. Static Memory (Weights, Gradients, Optimizer)
Using AdamW optimizer and Mixed Precision training, the memory cost per parameter
- Weights: 2 bytes (FP16) + 4 bytes (FP32 master copy)
- Gradients: 2 bytes (FP16) + 4 bytes (FP32 copy)
- Optimizer States: 4 bytes (Momentum) + 4 bytes (Variance)
Sum up: 20B per parameter. Total =
3. Dynamic Memory (Activations)
Assumptions
- Input Shape:
: Batch size : Sequence length : Hidden dimension
- Data Type: FP16 (2 bytes per element)
- Dropout Masks: 1 byte per element
3.1 Attention block
3.1.1. Input Layer
- Operation: The entry point to the attention block.
- Saved Tensor: The input
. - Memory Usage:
3.1.2. Q & K Matrix matmul
- Operation: Generating
and vectors. - Saved Tensors: Both
and must be saved. - Memory Usage:
3.1.3. Softmax
- Operation: Computing attention scores
- Saved Tensor: The input to the Softmax function
. - Shape Note: The shape expands to include attention heads
. Shape is - Memory Usage:
- Note: This term is quadratic with sequence length
, making it a memory bottleneck for long contexts.
3.1.4. Attention Scores (Softmax Output)
- Operation: Applying softmax to get attention weights.
- Saved Tensor: The softmax output (attention probabilities).
- Memory Usage:
3.1.5. V Matrix Multiplication & Attention Output
- Operation: Generating Value
vectors and computing weighted sum. - Saved Tensors:
: - Optional: Mask matrix for causal attention
..
- Total Step Memory:
3.1.6. Output Mapping & Dropout
- Operation: The final linear projection and dropout
- Saved Tensors:
- Input to the projection layer
. - Dropout mask matrix
.
- Input to the projection layer
- Total Step Memory:
Final Formula
Summing up all the terms:
- Linear terms (
): - Quadratic terms (
):
Total Self-Attention Activation Memory:
3.2 MLP block
Formula:
- First Linear Layer: Saves input, occupying
. - Activation Function: Saves input, occupying
(dimension is usually expanded by 4×). - Second Linear Layer: Saves input, occupying
. - Dropout Mask: Occupies
.
For the MLP block, the intermediate activations required to be saved are
3.3 Layer Normalization
Each layer norm requires saving its input (
Since self-attention and MLP blocks each have a layer norm, the total for two layer norms is
Total
- For each Transformer layer the intermediate activation memory usage is
. - For an
-layer Transformer model, the total activation memory is approximately: .
4. Conclusion
To handle large batch sizes, Activation Recomputation (Checkpointing) is essential. This technique discards intermediate activations during the forward pass and re-calculates them during the backward pass, trading computation time for memory space.
Memory usage breakdown during Training
https://gdymind.github.io/2026/01/25/Memory-usage-breakdown-during-Training/