LLM:flash-attention概述
标准注意力
FlashAttention 详细解析
FlashAttention 是一种优化 Transformer 模型注意力机制的技术,旨在提升计算效率和降低内存消耗,尤其在处理长序列任务时。以下是 FlashAttention 的详细解析,包括其主要特点、版本更新以及实际应用等方面。
核心技术
-
分块计算 (Tiling):
FlashAttention 通过将注意力计算的矩阵分块,将大矩阵拆分成适合存储在片上内存(SRAM)的较小块。这样减少了对全局内存(HBM)的依赖,降低了内存带宽的需求。这种方法有效地利用了现代 GPU 的多层次内存结构,优化了数据的访问速度【10†source】【13†source】。 -
重计算 (Recomputation):
在传统注意力机制中,中间结果的存储往往占用大量内存。FlashAttention 通过在需要时重新计算部分结果,而不是存储整个矩阵,从而减少了内存占用。这种策略特别在反向传播过程中显著减少了内存需求【10†source】。 -
IO感知 (IO-Awareness):
FlashAttention 优化了数据传输路径,减少了从全局内存到片上内存的数据移动。它通过有序地组织和处理数据,最大限度地利用了硬件资源,提高了整体计算效率【13†source】。
版本更新
-
FlashAttention-2:
- 并行性增强:除了传统的基于批大小和头数的并行处理外,FlashAttention-2 还引入了序列长度上的并行性。这对于长序列、批量较小的情况非常有利,能够显著提高计算速度。
- 支持更多头维度:从最多支持128个头维度扩展到256个,适配如GPT-J、CodeGen等大模型。这使得FlashAttention-2能够在更广泛的场景下应用,特别是需要高精度和长上下文的任务中【12†source】。
- 多查询注意力 (MQA) 和分组查询注意力 (GQA):这些变体在推理时减少了键值(KV)缓存的大小,从而提高了推理吞吐量【12†source】。
-
FlashAttention-3:
- 新硬件支持:利用最新的 NVIDIA Hopper GPU 架构,采用 WGMMA(Warpgroup Matrix Multiply-Accumulate)和 TMA(Tensor Memory Accelerator)等新特性。相比前一版本,FP16精度下性能提升至740 TFLOPS,而FP8精度下更是达到1.2 PFLOPS,且FP8的误差降低了2.6倍【11†source】【14†source】。
- 异步操作:通过异步执行 GEMM 和 softmax 操作,提高了整体吞吐量。例如,FP16的前向传递从570 TFLOPS 提升到620 TFLOPS,进一步达到640-660 TFLOPS【11†source】。
参考:
【Flash Attention 为什么那么快?原理讲解】 https://www.bilibili.com/video/BV1UT421k7rA/?share_source=copy_web&vd_source=29af710704ae24d166ca951b4c167d53
https://blog.csdn.net/v_JULY_v/article/details/133619540?ydreferer=aHR0cHM6Ly93d3cuZ29vZ2xlLmNvbS8%3D
https://zhuanlan.zhihu.com/p/626079753
https://blog.csdn.net/weixin_47196664/article/details/137000361?ydreferer=aHR0cHM6Ly93d3cuZ29vZ2xlLmNvbS8%3D
更多推荐
所有评论(0)