作者丨宫酱手艺人@知乎(已授权)

来源丨https://zhuanlan.zhihu.com/p/686240618

编辑丨极市平台

导读

 

flash attention是LLM训练的标配。它是一个加速attention的cuda算子;ring attention则是利用分布式计算扩展attention长度的一个工作。然而它们背后的核心则都是softmax局部和全局关系的一个巧妙公式。 

TL; DR

观察到局部softmax和全局softmax的关系,充分利用容量小但速度快的cache计算局部的attention,再推导出全局的attention,最终达到加速attention计算,或扩展attention长度的目的。

概要

局部的softmax和全局的softmax可以推出一个公式关系。利用这一点,flash-attention使用SRAM来计算局部的attention,再规约到全局的attention,并将attention包装为一个CUDA kernel,大大加速attention计算速度,并减小现存占用。而ring attention则反着利用这个公式关系,让一个GPU计算attention的一个局部,整个GPU多卡集群就可以计算出全局的attention,这样就大大扩展了Transformer序列长度。

方法

局部softmax和全局softmax之间的关系

Attention机制里面使用Softmax函数将Attention权重归一化。考虑向量 ,那么向量 作为Softmax的输出,有

注意到softmax函数有平移不变的性质

一般的softmax实现都是利用这个性质,给分子分母同时减去最大值,这样取指数的时候就不容易越界了:

flash-attention则利用这个性质,来大幅提高softmax算子的局部性。具体上,将 向量分成 个块,每个块 个元素,对每个分块先进行计算,这样SRAM里面需要处理的元素就不多了,处理完以后,再将每个分块里面计算的结果进行组合,计算出来最终的softmax结果。

这个事情的关键就是要把局部块的值和全局的值得关系找出来。为了符号的简单,考虑两个块之间的关系。 。令 和 为局部的最大值, 为全局最大值:

对于第一个块里面的输出值 ,有

第二个块也类似。那么我们就可以看出来,全局softmax值通过局部块的softmax值和这些 这些局部最大值因子可以算出来。所以呢,通过计算局部块的指数 ,累积和 ,以及最值 并保留下来,就可以算出全局的softmax了。

这样的好处是什么呢?

attention里面的softmax,一般值还挺多的,特别是对于长序列,所以整体计算全局softmax,可能cache里面就放不下,就得放到容量大但是比较慢的地方去计算这些数了。而局部的值,数量少,就可以放在cache里面算,算完以后再根据上面的公式,把总体的softmax算出来。

flash attention:利用SRAM作为cache

flash attention是一个attention的算子,主要目的是加速attention的计算。

GPU里面的存储有个层次结构。HBM (high bandwidth memory,可以认为就是cuda编程里面的global memory)就是显卡上边的memory,容量大,但是速度慢; SRAM (Static Random-Access Memory,可以认为就是cuda编程里面的shared memory),容量小,但是速度快。

flash-attention的核心思想就是,把attention的计算分成一小块一小块的,放在SRAM里面算,算完以后再通过前面介绍的关系,把全局的attention值算出来。大大提升了attention的计算速度。

flash-attention还把整个attention的计算做成一个算子,这样就可以把中间的结果给它省掉,大大减小了显存占用。

cf71f697585d2609b8391421a5e4de0f.jpeg
CPU/GPU计算时候的存储层次结构 from flash-attention

ring attention:利用单GPU卡作为cache

ring attention的主要目的是扩展Transformer的序列长度。计算Transformer序列长度的一个核心困难是算attention的时候,序列太长会OOM。

ring attention的核心想法是,每一个GPU只计算一个局部的attention,然后全局的attention再利用前面的公式给计算出来。这样,因为每个GPU的算的attention长度就没那么长了,就可以计算了,但整体的attention长度就可以大大扩展了。这个attention长度的扩展还是根据GPU数量线性增加的,有多少GPU就能扩多长,所以ring attention的论文题目里说"Near-Infinite Context"。

小结与想法

flash attention已经是LLM训练的标配了。它是一个加速attention的cuda算子;ring attention则是利用分布式计算扩展attention长度的一个工作。然而它们背后的核心则都是softmax局部和全局关系的一个巧妙公式。真的是非常漂亮。

推荐阅读

欢迎大家加入DLer-计算机视觉技术交流群!

大家好,群里会第一时间发布计算机视觉方向的前沿论文解读和交流分享,主要方向有:图像分类、Transformer、目标检测、目标跟踪、点云与语义分割、GAN、超分辨率、人脸检测与识别、动作行为与时空运动、模型压缩和量化剪枝、迁移学习、人体姿态估计等内容。

进群请备注:研究方向+学校/公司+昵称(如图像分类+上交+小明)

773a60557e2143ec9077325747fa5aa5.jpeg

👆 长按识别,邀请您进群!

55b1acba3a074e2066a64b4e036d9dac.gif

GitHub 加速计划 / fl / flash-attention
6
1
下载
Fast and memory-efficient exact attention
最近提交(Master分支:3 个月前 )
3669b252 4 个月前
5d5bfbb6 - 4 个月前
Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐