FlashAttention1:原理+伪码实现
各位做NLP或者大模型开发的朋友应该都有过这种体验:当模型参数量变大、输入文本变长时,注意力机制的计算速度会突然变慢,甚至直接爆显存。
传统的注意力计算就像让你在一本1000页的书里找所有相关内容,你得先把整本书都翻一遍,还得把每一页的关联笔记都写下来,不仅慢,还占满了你的草稿本。
而今天要讲的FlashAttention1,就像是给你配了一个智能检索助手——它能在不丢信息的前提下,用更少的草稿纸、更快的速度完成注意力计算,直接解决大模型长文本处理的痛点。
一、先搞懂:传统注意力为啥慢?
我们先简单回顾下标准注意力的计算逻辑,核心公式是:
Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q,K,V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})VAttention(Q,K,V)=softmax(dkQKT)V
这里面最耗时也最占显存的是**QKT矩阵相乘**:假设输入序列长度是L,Q/K的维度是d_k,那QKT会生成一个L×L的巨大矩阵。
当L是10000时,这个矩阵就有1亿个元素,光是存储就需要几十GB的显存,更别说后续的softmax和乘V的计算了。而且传统计算是把所有中间结果都存在显存里,完全是"暴力堆资源"。
二、FlashAttention1的核心原理:分块+重计算
FlashAttention1的核心思路其实很简单:不一次性处理所有数据,而是分块计算,同时利用CPU缓存/显存的分层存储特性,把数据"搬来搬去",只保留必要的中间结果,不用的就临时删掉,需要的时候再重新计算。
它的核心优化点有两个:
- 分块计算(Tiling):把Q、K、V分成一个个小的块(比如256×256),每次只处理一小块数据,这样中间生成的QK^T子矩阵就很小,不会占满显存。
- 重计算(Recomputation):计算softmax的时候,只保留归一化后的结果,把中间的注意力分数临时丢弃,后续需要的时候再重新计算对应的小块,以此节省显存空间。
简单来说,就是"拆小任务+按需返工",用一点点计算量的增加,换来了显存占用的大幅降低和计算速度的提升。
三、Python伪码实现(核心逻辑还原)
下面我们用Python代码还原FlashAttention1的核心分块计算逻辑,代码里会标注每一步的作用和注意事项:
import torch
import torch.nn.functional as F
def flash_attention(Q, K, V, block_size=256):
"""
FlashAttention1核心逻辑伪实现
参数说明:
Q: 查询矩阵,形状[batch_size, seq_len_q, d_k]
K: 键矩阵,形状[batch_size, seq_len_k, d_k]
V: 值矩阵,形状[batch_size, seq_len_k, d_v]
block_size: 分块大小,通常设为256/512,适配硬件缓存
返回:
output: 注意力计算结果,形状[batch_size, seq_len_q, d_v]
"""
batch_size, seq_len_q, d_k = Q.shape
seq_len_k = K.shape[1]
d_v = V.shape[2]
# 初始化输出矩阵和归一化因子
output = torch.zeros(batch_size, seq_len_q, d_v, device=Q.device)
l = torch.zeros(batch_size, seq_len_q, 1, device=Q.device) # softmax归一化的分母累计值
m = torch.full((batch_size, seq_len_q, 1), -float('inf'), device=Q.device) # 注意力分数的最大值
# 把K和V按序列维度分块,每次处理一个K-V块
for k_start in range(0, seq_len_k, block_size):
k_end = min(k_start + block_size, seq_len_k)
K_block = K[:, k_start:k_end, :] # 当前K块,形状[batch_size, block_size, d_k]
V_block = V[:, k_start:k_end, :] # 当前V块,形状[batch_size, block_size, d_v]
# 计算当前Q和K块的注意力分数
scores = torch.matmul(Q, K_block.transpose(1, 2)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
# 更新每个Q位置的最大注意力分数
m_new = torch.max(m, scores.max(dim=-1, keepdim=True)[0])
# 计算softmax的分子(缩放后的指数)
exp_scores = torch.exp(scores - m_new)
# 更新归一化因子
l_new = l * torch.exp(m - m_new) + exp_scores.sum(dim=-1, keepdim=True)
# 计算当前块对输出的贡献,并累加到最终输出
output = (output * torch.exp(m - m_new)) + torch.matmul(exp_scores, V_block)
# 更新m和l为当前块计算后的结果
m = m_new
l = l_new
# 最后做一次归一化,得到最终的注意力输出
output = output / l
return output
# ---------------------- 测试代码 ----------------------
if __name__ == "__main__":
# 模拟输入:batch_size=2,序列长度=1024,维度=64
batch_size = 2
seq_len = 1024
d_k = 64
d_v = 64
Q = torch.randn(batch_size, seq_len, d_k).cuda()
K = torch.randn(batch_size, seq_len, d_k).cuda()
V = torch.randn(batch_size, seq_len, d_v).cuda()
# 用FlashAttention计算
flash_output = flash_attention(Q, K, V, block_size=256)
# 用标准注意力计算做对比
standard_output = F.scaled_dot_product_attention(Q, K, V)
# 验证结果是否近似(因为浮点精度问题,不会完全相等)
print("结果误差:", torch.mean(torch.abs(flash_output - standard_output)))
代码关键部分说明:
- 分块循环:通过
for k_start in range(0, seq_len_k, block_size)把K和V切成小块,避免一次性处理大矩阵。 - 动态归一化:用
m记录每个Q位置的最大注意力分数,l记录归一化分母,这样每次处理新块时可以动态更新,不用存储完整的softmax中间结果。 - 结果累加:每次计算完一个K-V块的贡献后,直接累加到输出矩阵里,不用保存所有块的中间结果。
- 注意事项:
- block_size的选择要适配硬件缓存,比如GPU的L2缓存大小,通常256或512是比较通用的值。
- 伪实现没有完全还原FlashAttention的硬件优化(比如内存读写对齐),但核心逻辑和官方一致。
- 测试时可以看到,FlashAttention的结果和标准注意力几乎一致,误差来自浮点计算精度。
四、总结一下FlashAttention1的价值
FlashAttention1的出现直接改变了大模型的训练和推理效率:
- 显存占用降低3-5倍,让普通GPU也能处理更长的文本序列;
- 计算速度提升2-4倍,减少大模型训练的时间成本;
- 后续的FlashAttention2、FlashAttention3都是在这个基础上做的硬件级优化,但核心的分块+重计算思路没变。
如果你正在做长文本大模型开发,或者觉得注意力计算太卡显存,一定要试试FlashAttention系列的实现!
个人能力有限,有问题随时交流~
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐

所有评论(0)