Обучение моделей
14 просмотров
FSDP
Fully Sharded Data Parallel
FSDP — встроенный в PyTorch механизм распределённого обучения, аналог DeepSpeed ZeRO-3. Шардирует веса, градиенты и состояние оптимизатора между GPU для обучения сверхбольших моделей.
Что такое FSDP
FSDP (Fully Sharded Data Parallel) — реализация шардированного параллелизма данных в PyTorch, аналогичная DeepSpeed ZeRO-3. Каждая GPU хранит только свою часть (шард) весов модели, собирая полные веса только на момент вычислений.
Как работает
- Веса модели разделяются между GPU
- Перед forward pass: all-gather для сбора полных весов слоя
- Вычисление forward → сразу освобождение чужих шардов
- Backward: снова all-gather → вычисление градиентов → reduce-scatter
- Каждая GPU обновляет только свой шард весов
Стратегии шардирования
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy
# Аналог ZeRO-3: полное шардирование
model = FSDP(model, sharding_strategy=ShardingStrategy.FULL_SHARD)
# Аналог ZeRO-2: шардирование только градиентов
model = FSDP(model, sharding_strategy=ShardingStrategy.SHARD_GRAD_OP)
# Без шардирования (DDP)
model = FSDP(model, sharding_strategy=ShardingStrategy.NO_SHARD)
Пример обучения
import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
# Auto-wrap каждый Transformer блок
policy = transformer_auto_wrap_policy(
transformer_layer_cls={TransformerDecoderLayer}
)
model = FSDP(
model,
auto_wrap_policy=policy,
mixed_precision=MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
),
)
FSDP vs DDP
| DDP | FSDP | |
|---|---|---|
| Копия модели | Полная на каждой GPU | Шардированная |
| VRAM | O(модель) на GPU | O(модель/N) на GPU |
| Коммуникация | All-Reduce | All-Gather + Reduce-Scatter |
| Макс. модель | Помещается в 1 GPU | N × VRAM |
Связанные термины
Альтернатива
Использует