Обучение моделей 17 просмотров

Training Instabilities

Training Instabilities / Loss Spikes

Training Instabilities — нестабильности при обучении LLM: всплески loss, NaN-значения, расходимость. Вызваны проблемами с данными, learning rate, инициализацией весов или ограничениями оптимизатора Adam.

Что такое Training Instabilities

При обучении больших моделей (10B+ параметров) loss может внезапно резко возрасти (spike) или стать NaN. Эти нестабильности — одна из главных проблем LLM-инженерии.

Типы нестабильностей

Loss spikes

Кратковременные всплески loss, после которых обучение может восстановиться.

Команда PaLM (Google) наблюдала десятки loss spikes на «крайне нерегулярных интервалах» при обучении. Их решение: откат к более раннему чекпоинту и пропуск потенциально проблемных батчей данных.

Divergence (расходимость)

Loss уходит в бесконечность или NaN. Обучение не восстанавливается.

Slow divergence

Постепенная деградация, которую сложно заметить — loss медленно растёт или метрики ухудшаются.

Причины

1. «Плохие» батчи данных

Определённые комбинации данных и состояния модели вызывают spikes. Решение: пропуск батча и продолжение с чекпоинта.

2. Adam epsilon

Исследование «A Theory on Adam Instability in Large-Scale Machine Learning» (2023) показало, что epsilon в Adam оптимизаторе может быть недостаточно мал при больших масштабах. Компоненты оценки градиента становятся сопоставимы с epsilon, что вызывает расходимость.

Практическая рекомендация: попробуйте eps=0 (с обработкой деления на ноль) или очень малое значение eps=1e-15.

3. Инициализация весов (STD Init)

Неправильная инициализация стандартного отклонения весов — одна из самых частых причин нестабильности. Требует математического расчёта для каждой архитектуры.

4. Learning rate

Слишком высокий LR при недостаточном warmup.

Диагностика

# Мониторинг градиентов
for name, param in model.named_parameters():
    if param.grad is not None:
        grad_norm = param.grad.norm()
        if torch.isnan(grad_norm) or torch.isinf(grad_norm):
            print(f"NaN/Inf gradient in {name}")

# Gradient clipping (обязательно для LLM)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

Защитные меры

  1. Gradient clipping — обрезка нормы градиента (max_norm=1.0)
  2. Warmup — постепенное увеличение LR (1000-2000 шагов)
  3. BF16 вместо FP16 — больший динамический диапазон
  4. Частые чекпоинты — возможность отката
  5. Мониторинг loss по батчам — обнаружение спайков
  6. Z-loss — регуляризация logits (предотвращает рост)

Связанные термины

Требует
Используется в

Попробуйте на практике

Арендуйте GPU и запустите ML-модели в Intelion Cloud

Начать работу