Обучение моделей 14 просмотров

FSDP

Fully Sharded Data Parallel

FSDP — встроенный в PyTorch механизм распределённого обучения, аналог DeepSpeed ZeRO-3. Шардирует веса, градиенты и состояние оптимизатора между GPU для обучения сверхбольших моделей.

Что такое FSDP

FSDP (Fully Sharded Data Parallel) — реализация шардированного параллелизма данных в PyTorch, аналогичная DeepSpeed ZeRO-3. Каждая GPU хранит только свою часть (шард) весов модели, собирая полные веса только на момент вычислений.

Как работает

  1. Веса модели разделяются между GPU
  2. Перед forward pass: all-gather для сбора полных весов слоя
  3. Вычисление forward → сразу освобождение чужих шардов
  4. Backward: снова all-gather → вычисление градиентов → reduce-scatter
  5. Каждая GPU обновляет только свой шард весов

Стратегии шардирования

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy

# Аналог ZeRO-3: полное шардирование
model = FSDP(model, sharding_strategy=ShardingStrategy.FULL_SHARD)

# Аналог ZeRO-2: шардирование только градиентов
model = FSDP(model, sharding_strategy=ShardingStrategy.SHARD_GRAD_OP)

# Без шардирования (DDP)
model = FSDP(model, sharding_strategy=ShardingStrategy.NO_SHARD)

Пример обучения

import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

# Auto-wrap каждый Transformer блок
policy = transformer_auto_wrap_policy(
    transformer_layer_cls={TransformerDecoderLayer}
)

model = FSDP(
    model,
    auto_wrap_policy=policy,
    mixed_precision=MixedPrecision(
        param_dtype=torch.bfloat16,
        reduce_dtype=torch.bfloat16,
    ),
)

FSDP vs DDP

DDP FSDP
Копия модели Полная на каждой GPU Шардированная
VRAM O(модель) на GPU O(модель/N) на GPU
Коммуникация All-Reduce All-Gather + Reduce-Scatter
Макс. модель Помещается в 1 GPU N × VRAM

Связанные термины

Альтернатива
Использует

Попробуйте на практике

Арендуйте GPU и запустите ML-модели в Intelion Cloud

Начать работу