All-Reduce
All-Reduce
All-Reduce — коллективная операция, при которой каждый участник отправляет данные всем остальным и получает агрегированный результат (сумму). Основная операция синхронизации градиентов при распределённом обучении.
Что такое All-Reduce
All-Reduce — коллективная коммуникационная операция, сочетающая Reduce (агрегация на одном узле) и Broadcast (рассылка результата всем). После завершения все участники имеют одинаковый результат.
Зачем нужен All-Reduce
В Data Parallelism каждая GPU вычисляет градиенты на своей части данных. Для обновления весов нужно усреднить градиенты со всех GPU → All-Reduce.
GPU 0: grad_0 ─┐
GPU 1: grad_1 ─┤ All-Reduce(sum)
GPU 2: grad_2 ─┤ ──────────→ Каждая GPU получает
GPU 3: grad_3 ─┘ grad_0 + grad_1 + grad_2 + grad_3
Алгоритмы
Ring All-Reduce
Узлы образуют кольцо. Данные передаются по кольцу в 2 фазы:
1. Reduce-Scatter: каждый узел отправляет часть данных соседу, получает и суммирует
2. All-Gather: каждый узел рассылает свой агрегированный фрагмент
Стоимость: 2 × (N-1)/N × data_size / bandwidth
Не зависит от количества узлов!
Tree All-Reduce
Древовидная топология. Быстрее при малых объёмах данных (меньше задержка).
SHARP (Scalable Hierarchical Aggregation)
Агрегация данных прямо на InfiniBand-свитчах, минуя GPU. Уменьшает трафик вдвое.
Другие коллективные операции
| Операция | Описание | Применение в ML |
|---|---|---|
| All-Reduce | Reduce + Broadcast | Синхронизация градиентов (DDP) |
| All-Gather | Gather + Broadcast | Сбор шардов весов (FSDP, ZeRO-3) |
| Reduce-Scatter | Reduce + Scatter | Распределение градиентов (FSDP) |
| All-to-All | Каждый отправляет каждому свою часть | Expert Parallelism (MoE) |
| Broadcast | Один отправляет всем | Раздача весов |
Библиотека NCCL
NVIDIA NCCL (NVIDIA Collective Communications Library) — реализация коллективных операций, оптимизированная для GPU. Автоматически выбирает алгоритм (Ring, Tree) в зависимости от топологии и объёма данных.
# PyTorch использует NCCL автоматически
import torch.distributed as dist
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
Бенчмаркинг
# Простой бенчмарк (Stas Bekman's all_reduce_bench.py)
python -m torch.distributed.run --nproc_per_node=8 all_reduce_bench.py
# NCCL tests (более детальный)
./all_reduce_perf -b 1M -e 4G -f 2 -g 1