Gradient Checkpointing
Gradient Checkpointing / Activation Recomputation
Gradient Checkpointing — техника экономии VRAM при обучении, при которой промежуточные активации не хранятся, а пересчитываются при backward pass. Уменьшает потребление памяти в 3-5× ценой ~30% замедления.
Что такое Gradient Checkpointing
При forward pass каждый слой сохраняет промежуточные активации для backward pass. Для больших моделей это потребляет огромный объём VRAM. Gradient checkpointing (activation recomputation) вместо хранения пересчитывает активации во время backward pass.
Как работает
Без checkpointing (быстро, много памяти):
Forward: слой 1 → [сохранить a1] → слой 2 → [сохранить a2] → ... → loss
Backward: [использовать a2] → grad → [использовать a1] → grad
С checkpointing (медленнее, мало памяти):
Forward: слой 1 → слой 2 → ... → loss (активации НЕ сохраняются)
Backward: [пересчитать a2] → grad → [пересчитать a1] → grad
Экономия памяти
Для модели с L слоями:
- Без checkpointing: O(L) памяти на активации
- С checkpointing: O(√L) при оптимальном размещении чекпоинтов
Практический пример (LLaMA 7B, batch_size=4, seq_len=2048):
- Без checkpointing: ~24 GB на активации
- С checkpointing: ~6 GB (4× экономия)
Виды
Full checkpointing
Пересчитываются все активации. Максимальная экономия, максимальное замедление (~33%).
Selective checkpointing
Сохраняются только дешёвые для хранения, но дорогие для вычисления активации (Megatron-LM).
Использование
# PyTorch
from torch.utils.checkpoint import checkpoint
class TransformerBlock(nn.Module):
def forward(self, x):
return checkpoint(self._forward, x, use_reentrant=False)
# Hugging Face Transformers
model.gradient_checkpointing_enable()
# DeepSpeed
"activation_checkpointing": {
"partition_activations": true,
"contiguous_memory_optimization": true
}
Компромисс
| Параметр | Без checkpointing | С checkpointing |
|---|---|---|
| VRAM на активации | O(L) | O(√L) |
| Скорость | 1× | ~0.7× |
| Эффективный throughput | Ниже (малый batch) | Выше (большой batch) |
Парадокс: хотя checkpointing замедляет один шаг, он позволяет увеличить batch size, что часто увеличивает общий throughput.