Flash Attention
Flash Attention
Flash Attention — оптимизированная реализация механизма внимания, которая работает в 2-4× быстрее стандартной и требует O(n) памяти вместо O(n²) за счёт тайлинга и вычислений в SRAM GPU.
Что такое Flash Attention
Flash Attention — алгоритмическая оптимизация стандартного Scaled Dot-Product Attention, разработанная Tri Dao (Stanford, 2022). Ключевая идея: перенести вычисления из медленной HBM в быструю SRAM GPU, обрабатывая данные блоками (tiling).
Проблема стандартного Attention
Наивная реализация:
1. Вычислить S = Q × K^T → матрица n×n в HBM
2. Применить softmax → ещё одна n×n матрица в HBM
3. Умножить на V
При n = 128K это 128K × 128K × 2 байта = 32 GB только для одной матрицы attention. Не помещается в VRAM!
Как работает Flash Attention
Flash Attention обрабатывает attention по блокам (tiles), которые помещаются в SRAM (20 MB на SM):
- Разбить Q, K, V на блоки
- Для каждого блока Q:
- Загрузить блок Q в SRAM
- Итерировать по блокам K, V
- Вычислить локальный attention и обновить running softmax
- Записать результат обратно в HBM
Результат: промежуточная матрица n×n никогда не материализуется в памяти.
Версии
| Версия | GPU | Улучшения |
|---|---|---|
| Flash Attention 1 | Ampere+ | IO-aware tiling |
| Flash Attention 2 | Ampere+ | Лучший параллелизм, 2× быстрее |
| Flash Attention 3 | Hopper (H100) | FP8, асинхронность, warp-level |
Использование
# PyTorch 2.0+ — встроено
import torch.nn.functional as F
# Автоматически использует Flash Attention при CUDA
output = F.scaled_dot_product_attention(query, key, value)
# Или явно через пакет flash-attn
from flash_attn import flash_attn_func
output = flash_attn_func(q, k, v, causal=True)
Влияние
Flash Attention сделал возможным:
- Контекстное окно 128K-1M токенов
- Обучение моделей на 40-60% быстрее
- Инференс с меньшим потреблением VRAM