Архитектуры моделей 15 просмотров

Flash Attention

Flash Attention

Flash Attention — оптимизированная реализация механизма внимания, которая работает в 2-4× быстрее стандартной и требует O(n) памяти вместо O(n²) за счёт тайлинга и вычислений в SRAM GPU.

Что такое Flash Attention

Flash Attention — алгоритмическая оптимизация стандартного Scaled Dot-Product Attention, разработанная Tri Dao (Stanford, 2022). Ключевая идея: перенести вычисления из медленной HBM в быструю SRAM GPU, обрабатывая данные блоками (tiling).

Проблема стандартного Attention

Наивная реализация:
1. Вычислить S = Q × K^T → матрица n×n в HBM
2. Применить softmax → ещё одна n×n матрица в HBM
3. Умножить на V

При n = 128K это 128K × 128K × 2 байта = 32 GB только для одной матрицы attention. Не помещается в VRAM!

Как работает Flash Attention

Flash Attention обрабатывает attention по блокам (tiles), которые помещаются в SRAM (20 MB на SM):

  1. Разбить Q, K, V на блоки
  2. Для каждого блока Q:
  3. Загрузить блок Q в SRAM
  4. Итерировать по блокам K, V
  5. Вычислить локальный attention и обновить running softmax
  6. Записать результат обратно в HBM

Результат: промежуточная матрица n×n никогда не материализуется в памяти.

Версии

Версия GPU Улучшения
Flash Attention 1 Ampere+ IO-aware tiling
Flash Attention 2 Ampere+ Лучший параллелизм, 2× быстрее
Flash Attention 3 Hopper (H100) FP8, асинхронность, warp-level

Использование

# PyTorch 2.0+ — встроено
import torch.nn.functional as F

# Автоматически использует Flash Attention при CUDA
output = F.scaled_dot_product_attention(query, key, value)

# Или явно через пакет flash-attn
from flash_attn import flash_attn_func
output = flash_attn_func(q, k, v, causal=True)

Влияние

Flash Attention сделал возможным:
- Контекстное окно 128K-1M токенов
- Обучение моделей на 40-60% быстрее
- Инференс с меньшим потреблением VRAM

Связанные термины

Улучшает

Попробуйте на практике

Арендуйте GPU и запустите ML-модели в Intelion Cloud

Начать работу