Обучение моделей 24 просмотра

Gradient Checkpointing

Gradient Checkpointing / Activation Recomputation

Gradient Checkpointing — техника экономии VRAM при обучении, при которой промежуточные активации не хранятся, а пересчитываются при backward pass. Уменьшает потребление памяти в 3-5× ценой ~30% замедления.

Что такое Gradient Checkpointing

При forward pass каждый слой сохраняет промежуточные активации для backward pass. Для больших моделей это потребляет огромный объём VRAM. Gradient checkpointing (activation recomputation) вместо хранения пересчитывает активации во время backward pass.

Как работает

Без checkpointing (быстро, много памяти):
Forward:  слой 1 → [сохранить a1] → слой 2 → [сохранить a2] → ... → loss
Backward: [использовать a2] → grad → [использовать a1] → grad

С checkpointing (медленнее, мало памяти):
Forward:  слой 1 → слой 2 → ... → loss  (активации НЕ сохраняются)
Backward: [пересчитать a2] → grad → [пересчитать a1] → grad

Экономия памяти

Для модели с L слоями:
- Без checkpointing: O(L) памяти на активации
- С checkpointing: O(√L) при оптимальном размещении чекпоинтов

Практический пример (LLaMA 7B, batch_size=4, seq_len=2048):
- Без checkpointing: ~24 GB на активации
- С checkpointing: ~6 GB (4× экономия)

Виды

Full checkpointing

Пересчитываются все активации. Максимальная экономия, максимальное замедление (~33%).

Selective checkpointing

Сохраняются только дешёвые для хранения, но дорогие для вычисления активации (Megatron-LM).

Использование

# PyTorch
from torch.utils.checkpoint import checkpoint

class TransformerBlock(nn.Module):
    def forward(self, x):
        return checkpoint(self._forward, x, use_reentrant=False)

# Hugging Face Transformers
model.gradient_checkpointing_enable()

# DeepSpeed
"activation_checkpointing": {
    "partition_activations": true,
    "contiguous_memory_optimization": true
}

Компромисс

Параметр Без checkpointing С checkpointing
VRAM на активации O(L) O(√L)
Скорость ~0.7×
Эффективный throughput Ниже (малый batch) Выше (большой batch)

Парадокс: хотя checkpointing замедляет один шаг, он позволяет увеличить batch size, что часто увеличивает общий throughput.

Связанные термины

Улучшает
Измеряется
Использует

Попробуйте на практике

Арендуйте GPU и запустите ML-модели в Intelion Cloud

Начать работу