FlashAttention的Triton实现:IO-aware注意力机制底层原理

当注意力变成内存噩梦

在Transformer统治NLP和CV领域的今天,注意力机制(Attention Mechanism)作为其核心组件,却始终面临一个鲜为人知却致命的性能瓶颈:O(N²)的内存复杂度

让我们做一个简单的数学计算。当序列长度 N 达到16,384(16K)时,单层注意力需要处理的注意力矩阵大小为:

注意力矩阵 S = Q × K^T
矩阵尺寸:16,384 × 16,384 = 268,435,456 元素
每个元素 float16 = 2 bytes
显存占用:536,870,912 bytes ≈ 512 MB(仅一个矩阵!)

而这仅仅是中间矩阵 S,还未计算 attention 矩阵 P = softmax(S) 和最终输出。完整计算需要三个这样的 N×N 矩阵,如果用标准实现,单层注意力就能轻松吞掉数GB显存。当层数加深、batch size 增大时,显存直接爆炸。

这还不是最致命的。真正的问题在于:我们正在用1949年的方法解决2024年的问题


GPU内存层次:被忽视的10倍性能鸿沟

理解FlashAttention为什么能革命性地加速注意力计算,必须先理解GPU的内存层次结构。

                    ┌─────────────────────────────────────┐
                    │           GPU Architecture           │
                    │                                      │
                    │   ┌────────────────────────────────┐ │
                    │   │         HBM (High Bandwidth)    │ │
                    │   │   ════════════════════════════  │ │
                    │   │   Capacity: 40-80 GB            │ │
                    │   │   Bandwidth: 1.5-2.0 TB/s       │ │
                    │   │   Latency: ~500 ns               │ │
                    │   │   (Main GPU Memory - VRAM)       │ │
                    │   └────────────────────────────────┘ │
                    │                ↑                       │
                    │                │ 10-100x slower          │
                    │                │                        │
                    │   ┌────────────────────────────────┐ │
                    │   │    On-Chip SRAM (L1/L2 Cache)   │ │
                    │   │   ════════════════════════════  │ │
                    │   │   Capacity: 192-256 KB (A100)   │ │
                    │   │   Bandwidth: ~19 TB/s           │ │
                    │   │   Latency: ~10 ns               │ │
                    │   │   (Shared Memory + Registers)    │ │
                    │   └────────────────────────────────┘ │
                    │                                      │
                    └─────────────────────────────────────┘

          Bandwidth Comparison:
          ┌──────────────────────────────────────────────────┐
          │ HBM:     ████████████████████████████ 1.5 TB/s   │
          │ SRAM:    ████████████████████████████████ 19 TB/s │
          │          (12.7x faster!)                          │
          └──────────────────────────────────────────────────┘

关键洞察:SRAM的带宽是HBM的10倍以上,但容量却只有其千分之一

在A100 GPU上:

  • HBM显存带宽:1.5 TB/s(理论峰值)
  • L1/L2缓存带宽:约5 TB/s
  • 共享内存(SRAM)带宽:约19 TB/s(理论峰值)

内存带宽差距高达12倍! 这意味着,如果你能让计算完全在SRAM中进行,理论加速比就能达到这个数量级。

标准注意力实现的问题在于:它会将整个 N×N 的注意力矩阵**物化(materialize)**到HBM中,然后再逐元素操作。这相当于每次计算都在"骑自行车穿越高速公路"——你明明有赛车,却选择用最慢的方式完成比赛。


IO-aware核心思想:让数据待在快的地方

FlashAttention的核心创新可以概括为一句话:永远不要把大矩阵物化到HBM中

传统的注意力计算遵循这个流程:

1. Load Q, K, V from HBM → compute S = Q × K^T → Store S to HBM
2. Load S from HBM → compute P = softmax(S) → Store P to HBM  
3. Load P, V from HBM → compute O = P × V → Store O to HBM
4. Load O from HBM → (后续操作)

每一步都涉及完整的N×N矩阵在HBM和计算单元之间的搬运。对于N=16K,单次搬运的数据量就是:

16K × 16K × 4(bytes for float32) = 1 GB per pass
至少需要3-4次这样的搬运 = 3-4 GB HBM traffic

FlashAttention的IO-aware策略则是:

1. 分块加载 Q_i, K, V
2. 在SRAM中计算局部 S_block = Q_i × K^T_block
3. 增量更新 softmax 结果(无需存储完整S矩阵)
4. 产出部分输出 O_i
5. 循环处理所有 Q 块,最终得到完整输出

这样,HBM中永远只需要存储原始的Q、K、V和最终输出O,中间计算全部在SRAM中流式完成。


分块策略:SRAM容量限制下的精打细算

SRAM虽然快,但容量极其有限。以A100为例,共享内存(shared memory)最大为164KB per block,而A100 L1 cache为192KB。这意味着我们不能一次性加载整个16K×128的Q矩阵。

FlashAttention采用Q分块(Block)策略

┌─────────────────────────────────────────────────────────────────┐
│                      Q Matrix (N × d)                           │
│  ┌─────────┐                                                     │
│  │ Q₁ (B×d)│  B = block size = 32/64 (head维度上的限制)           │
│  └─────────┘                                                     │
│  ┌─────────┐                                                     │
│  │ Q₂ (B×d)│                                                     │
│  └─────────┘                                                     │
│     ···                                                         │
│  ┌─────────┐                                                     │
│  │ Qₜ (B×d)│  t = N/B = 序列长度/块大小                          │
│  └─────────┘                                                     │
└─────────────────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────────────────┐
│                      K Matrix (N × d)                           │
│  ┌─────────┬─────────┬     ┌─────────┐                          │
│  │ K₁ (B×d)│ K₂ (B×d)│ ··· │ Kₜ (B×d)│ ← 整列加载到SRAM         │
│  └─────────┴─────────┴     ┌─────────┘                          │
└─────────────────────────────────────────────────────────────────┘

关键约束:我们需要同时在SRAM中放下:

  • Q块:B × d 元素
  • K块:B × d 元素
  • V块:B × d 元素
  • 局部S块:B × B 元素(用于计算 softmax 的部分)
  • 输出O块:B × d 元素
  • 一些辅助变量

对于 d=128, B=64, dtype=float16

内存需求 ≈ 5 × (64 × 128 × 2) + 64 × 64 × 2 
        ≈ 5 × 16KB + 8KB 
        ≈ 88 KB < 164 KB (A100 SMEM limit) ✓

在线Softmax:增量计算的艺术

传统softmax需要"先看全部,再算结果":

softmax(x_i) = exp(x_i) / Σ exp(x_j)

这要求我们先遍历所有元素计算分母的累加和。但对于流式分块计算,这是不可行的。我们需要在线算法——在看到新元素时,能够增量更新softmax结果。

Safe Softmax 的数学推导

设我们正在处理第 j 个元素,此时已有统计量 (m_j, ℓ_j)

  • m_j = max(x_1, ..., x_j) — 当前最大值
  • ℓ_j = Σ_{i=1}^{j} exp(x_i - m_j) — 归一化因子的累加和

当第 j+1 个元素到达时:

情况1:新元素更大 (x_{j+1} > m_j)

m_{j+1} = x_{j+1}
ℓ_{j+1} = exp(x_{j+1} - m_{j+1}) + Σ exp(x_i - m_{j+1})
        = 1 + Σ exp(x_i - x_{j+1})
        = 1 + Σ exp(x_i - m_j) × exp(m_j - x_{j+1})
        = 1 + ℓ_j × exp(m_j - x_{j+1})

情况2:新元素更小 (x_{j+1} ≤ m_j)

m_{j+1} = m_j
ℓ_{j+1} = Σ exp(x_i - m_j) + exp(x_{j+1} - m_j)
        = ℓ_j + exp(x_{j+1} - m_j)

两种情况可以统一为:

m_{j+1} = max(m_j, x_{j+1})
ℓ_{j+1} = ℓ_j × exp(m_j - m_{j+1}) + exp(x_{j+1} - m_{j+1})

最终归一化

当所有元素处理完毕后,最终的softmax值为:

softmax(x_i) = exp(x_i - m_N) / ℓ_N

这给了我们一个单遍历就能完成softmax的在线算法,只需要维护两个标量统计量 (m, ℓ)

分块在线Softmax

对于分块场景,设我们处理第 r 个Q块(对应第 r 行块),需要维护:

  • m_r — 当前行块的最大值(跨所有K块)
  • ℓ_r — 当前行块的归一化因子
  • d_r — 累积的缩放因子(用于正确合并不同阶段的计算结果)

具体算法:

# 伪代码展示分块在线softmax的核心逻辑
def block_softmax_update(m_prev, l_prev, d_prev, block_scores):
    """
    m_prev: 之前块的最大值
    l_prev: 之前块的归一化因子
    d_prev: 累积的缩放因子
    block_scores: 当前块的原始注意力分数 [B, B]
    """
    # 1. 计算当前块的行最大值
    m_block = block_scores.max(dim=-1)  # [B]
    
    # 2. 更新全局最大值
    m_new = torch.maximum(m_prev, m_block)
    
    # 3. 计算当前块对归一化因子的贡献
    correction_prev = torch.exp(m_prev - m_new)  # 用于修正之前的贡献
    block_contrib = torch.exp(block_scores - m_new.unsqueeze(-1)).sum(dim=-1)
    
    # 4. 合并归一化因子
    l_new = l_prev * correction_prev + block_contrib
    
    # 5. 更新缩放因子
    d_new = d_prev * correction_prev
    
    return m_new, l_new, d_new

Triton实现:从PyTorch到硬件的桥梁

Triton是一个专为深度学习设计的DSL(领域特定语言),它允许开发者用Python风格的代码编写高性能GPU kernel,同时保留接近CUDA C的性能。

FlashAttention的Triton Kernel结构

import triton
import triton.kernel as tk

@triton.autotune(
    configs=[
        triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64}, num_stages=3),
        triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128}, num_stages=4),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=4),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128}, num_stages=3),
    ],
    key=['seq_len', 'head_dim'],
)
@triton.jit
def flash_attention_kernel(
    Q_ptr, K_ptr, V_ptr, O_ptr,
    q_head_dim, kv_head_dim, seq_len, 
    stride_qm, stride_kn, stride_vn, stride_on,
    M_ptr, L_ptr,  # 用于存储中间统计量
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
    HEAD_DIM: tl.constexpr,
):
    # 块索引
    row_pid = tl.program_id(0)
    col_pid = tl.program_id(1)
    
    # 计算当前Q块在SRAM中的位置
    row_offset = row_pid * BLOCK_M + tl.arange(0, BLOCK_M)
    col_offset = col_pid * BLOCK_N + tl.arange(0, BLOCK_N)
    
    # 加载Q块到SRAM
    q_ptrs = Q_ptr + row_offset[:, None] * stride_qm + col_offset[None, :] * q_head_dim
    q_mask = (row_offset[:, None] < seq_len) & (col_offset[None, :] < HEAD_DIM)
    Q_block = tl.load(q_ptrs, mask=q_mask, other=0.0)
    
    # 初始化统计量
    m_i = tl.full([BLOCK_M], float('-inf'), dtype=tl.float32)
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
    acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
    
    # ============ FA1 vs FA2 的关键区别 ============
    # FlashAttention-1: 按列循环 (K/V块在外循环)
    # FlashAttention-2: 按行循环 (Q块在外循环) — 对decoder更友好
    # 这里展示FA2风格的实现
    
    # 遍历所有K/V块
    num_blocks = (seq_len + BLOCK_N - 1) // BLOCK_N
    
    for block_idx in range(num_blocks):
        # 加载K块
        k_ptrs = K_ptr + block_idx * BLOCK_N + tl.arange(0, BLOCK_N)
        k_offsets = tl.arange(0, BLOCK_N)
        k_mask = (k_offsets < seq_len)
        
        # 加载K块并计算 S = Q @ K^T
        K_block = tl.load(K_ptr + ..., mask=k_mask)
        
        # 计算局部注意力分数
        s_block = tl.dot(Q_block, K_block)  # [BLOCK_M, BLOCK_N]
        
        # ============ 在线Softmax更新 ============
        # 对列维度归约得到行最大值
        m_block = tl.max(s_block, axis=1)
        
        # Safe softmax更新
        m_new = tl.maximum(m_i, m_block)
        correction = tl.exp(m_i - m_new)
        
        # 归一化因子更新
        p_block = tl.exp(s_block - m_new[:, None])
        l_block = tl.sum(p_block, axis=1)
        l_new = l_i * correction + l_block
        
        # 缩放之前的累加结果
        acc_scale = correction / l_new[:, None]
        acc = acc * acc_scale[:, None]
        
        # 加载V块并计算新的累加
        V_block = tl.load(...)
        acc = tl.dot(p_block.to(V_block.dtype), V_block) + acc
        
        m_i = m_new
        l_i = l_new
    
    # 最终归一化
    O_block = acc / l_i[:, None]
    
    # 写回输出
    tl.store(O_ptr + ..., O_block, mask=...)

Autotuner:自动搜索最优配置

Triton的 @triton.autotune 装饰器会在运行时自动尝试不同的配置,选择最优的blocksize和stage组合:

@triton.autotune(
    configs=[
        # 不同blocksize对不同序列长度的性能影响显著
        triton.Config({'BLOCK_M': 32, 'BLOCK_N': 64}, num_stages=3, num_warps=2),
        triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=3, num_warps=4),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128}, num_stages=4, num_warps=8),
    ],
    key=['seq_len', 'head_dim'],
)

对于A100(108个SM),num_warps=4 意味着每个SM运行4个warp(共128线程),总共512线程并行处理一个block。


FA1 vs FA2:循环顺序的工程权衡

FlashAttention有两个主要版本,它们的计算循环顺序有本质区别:

FlashAttention-1 (Encoder优化):
┌─────────────────────────────────────────────────────┐
│ for block_j in blocks(K):                           │
│     Load K_j, V_j into SRAM                         │
│     for block_i in blocks(Q):                       │
│         Load Q_i                                    │
│         Compute S_ij, Update softmax(O_i)           │
│         Store partial O_i                           │
└─────────────────────────────────────────────────────┘
外层循环遍历K,内层循环遍历Q — 对encoder(bidirectional)友好

FlashAttention-2 (Decoder优化):
┌─────────────────────────────────────────────────────┐
│ for block_i in blocks
Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐