Training Instabilities
Training Instabilities / Loss Spikes
Training Instabilities — нестабильности при обучении LLM: всплески loss, NaN-значения, расходимость. Вызваны проблемами с данными, learning rate, инициализацией весов или ограничениями оптимизатора Adam.
Что такое Training Instabilities
При обучении больших моделей (10B+ параметров) loss может внезапно резко возрасти (spike) или стать NaN. Эти нестабильности — одна из главных проблем LLM-инженерии.
Типы нестабильностей
Loss spikes
Кратковременные всплески loss, после которых обучение может восстановиться.
Команда PaLM (Google) наблюдала десятки loss spikes на «крайне нерегулярных интервалах» при обучении. Их решение: откат к более раннему чекпоинту и пропуск потенциально проблемных батчей данных.
Divergence (расходимость)
Loss уходит в бесконечность или NaN. Обучение не восстанавливается.
Slow divergence
Постепенная деградация, которую сложно заметить — loss медленно растёт или метрики ухудшаются.
Причины
1. «Плохие» батчи данных
Определённые комбинации данных и состояния модели вызывают spikes. Решение: пропуск батча и продолжение с чекпоинта.
2. Adam epsilon
Исследование «A Theory on Adam Instability in Large-Scale Machine Learning» (2023) показало, что epsilon в Adam оптимизаторе может быть недостаточно мал при больших масштабах. Компоненты оценки градиента становятся сопоставимы с epsilon, что вызывает расходимость.
Практическая рекомендация: попробуйте eps=0 (с обработкой деления на ноль) или очень малое значение eps=1e-15.
3. Инициализация весов (STD Init)
Неправильная инициализация стандартного отклонения весов — одна из самых частых причин нестабильности. Требует математического расчёта для каждой архитектуры.
4. Learning rate
Слишком высокий LR при недостаточном warmup.
Диагностика
# Мониторинг градиентов
for name, param in model.named_parameters():
if param.grad is not None:
grad_norm = param.grad.norm()
if torch.isnan(grad_norm) or torch.isinf(grad_norm):
print(f"NaN/Inf gradient in {name}")
# Gradient clipping (обязательно для LLM)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
Защитные меры
- Gradient clipping — обрезка нормы градиента (max_norm=1.0)
- Warmup — постепенное увеличение LR (1000-2000 шагов)
- BF16 вместо FP16 — больший динамический диапазон
- Частые чекпоинты — возможность отката
- Мониторинг loss по батчам — обнаружение спайков
- Z-loss — регуляризация logits (предотвращает рост)