引言:显存——大模型训练的“阿喀琉斯之踵”

近年来,大模型的参数规模呈指数级增长,从几亿到几千亿甚至万亿。然而,GPU显存的扩展速度远远跟不上模型膨胀的步伐。训练一个千亿参数模型,如果采用最简单的FP32精度,仅模型权重就需要400GB以上显存,而目前最先进的H100也只有80GB。因此,如何精确计算显存占用、如何通过各种优化手段把模型塞进有限显存,成为大模型训练的必修课。

本文将带你系统梳理大模型训练中的显存计算与优化方法,从基础的模型权重、梯度、优化器状态,到复杂的混合精度训练、激活值缓存、ZeRO优化、模型并行等,并用具体的公式和示例让你彻底搞懂显存去哪儿了、怎么省下来。


第一部分:显存占用的基本构成

深度神经网络训练的显存消耗主要包含两大类:

  1. 模型状态:包括模型权重(参数)、梯度、优化器状态(如Adam的一阶动量和二阶方差)。

  2. 中间激活值:前向传播过程中各层输出的张量,用于反向传播时的梯度计算。

我们先聚焦于模型状态。

1.1 模型权重的显存占用

以参数量为 Φ 的模型为例,加载到GPU所需显存取决于数据类型:

量化程度 显存占用
FP32 4Φ 字节
FP16/BF16 2Φ 字节
INT8 1Φ 字节
INT4 ≤1Φ 字节

例如,一个70亿参数的模型(Φ=7e9),FP16权重占用约14GB。

1.2 训练时的完整模型状态(权重+梯度+优化器)

在训练过程中,除了权重,我们还需要存储梯度(用于参数更新)和优化器状态(如Adam的动量项)。以最常用的AdamW优化器为例,它维护两个动量(momentum和variance),每个都是FP32精度。

FP32全精度训练

  • 权重:4Φ

  • 梯度:4Φ

  • 优化器状态:动量4Φ + 方差4Φ = 8Φ

  • 总计:16Φ

这意味着,一个70亿参数的模型,仅模型状态就需要 16 × 7e9 = 112 GB 显存!这已经远远超过单卡容量。

1.3 混合精度训练真的省显存吗?

混合精度训练(如AMP)通常采用FP16/BF16计算,同时保留FP32权重副本。计算一下:

  • 权重(FP16):2Φ

  • 梯度(FP16):2Φ

  • 优化器状态:FP32权重副本 4Φ + 动量 4Φ + 方差 4Φ = 12Φ

  • 总计:2Φ + 2Φ + 12Φ = 16Φ

结果竟然也是16Φ!那混合精度训练的好处是什么?关键在于计算速度和激活值显存。半精度计算可以大幅加速矩阵运算,同时半精度激活值占用的显存比FP32少一半。但它并没有减少模型状态的显存占用,因为优化器必须依赖FP32精度来保证数值稳定性。

一点思考:有些读者可能见过“18Φ”或“20Φ”的说法,这取决于框架实现细节。例如,某些框架在梯度累积时会额外保留一份FP32梯度,导致显存升至20Φ。但核心原理不变——模型状态是训练显存的“大头”。


第二部分:激活值——被忽视的显存大户

除了模型状态,激活值(activations)在训练中同样消耗大量显存。它是在前向传播过程中各层产生的中间输出,必须保存下来供反向传播使用。

2.1 Transformer的激活值组成

以Llama-3这类Transformer模型为例,其基本单元包括多头注意力(MHA/GQA)和前馈网络(FFN)。我们定义以下符号:

  • bb:batch size

  • ss:序列长度

  • dhiddendhidden​:隐藏层维度

  • nheadnhead​:注意力头数

  • nkv-headnkv-head​:GQA中的KV头数(MHA时等于nheadnhead​,MQA时为1)

  • nlayernlayer​:层数

  • dFFNdFFN​:FFN中间维度

自注意力部分的激活缓存(每层)
  • 层归一化前的输入:2bsdhidden2bsdhidden​(半精度,下同)

  • 归一化后的输入:2bsdhidden2bsdhidden​

  • Q、K、V矩阵:2bs(dhidden+dhidden×nkv-headnhead)2bs(dhidden​+dhidden​×nhead​nkv-head​​)

  • 注意力logits:2bnheads22bnhead​s2

  • 注意力Dropout掩码(0/1):1bnheads21bnhead​s2

  • Dropout后的注意力得分:2bnheads22bnhead​s2

  • 注意力输出(与V相乘后):2bsdhidden2bsdhidden​

合并后,自注意力部分激活量约为:

8bsdhidden+4nkv-headnheadbsdhidden+5bnheads28bsdhidden​+4nhead​nkv-head​​bsdhidden​+5bnhead​s2

FFN部分的激活缓存(每层)
  • 层归一化前的输入:2bsdhidden2bsdhidden​

  • 归一化后的输入:2bsdhidden2bsdhidden​

  • gate矩阵激活前:2bsdFFN2bsdFFN​

  • up矩阵与gate矩阵(SwiGLU需要保存两个):2bsdFFN+2bsdFFN2bsdFFN​+2bsdFFN​

  • SwiGLU后、降维前:2bsdFFN2bsdFFN​

FFN部分合计:8bsdFFN+2bsdhidden8bsdFFN​+2bsdhidden​

所有层总激活显存(加上输出层)

Mact=nlayer[(10+4nkv-headnhead)bsdhidden+8bsdFFN+5bnheads2]+4bsdhiddenMact​=nlayer​[(10+4nhead​nkv-head​​)bsdhidden​+8bsdFFN​+5bnhead​s2]+4bsdhidden​

2.2 激活值的特点

  • 与batch size bb、序列长度 ss 强相关,尤其是 s2s2 项(注意力矩阵)会随着长序列急剧膨胀。

  • 与模型层数 nlayernlayer​ 成正比。

例如,一个70亿参数模型,b=1,s=2048,dhidden=4096,nhead=32,dFFN=11008,nlayer=32b=1,s=2048,dhidden​=4096,nhead​=32,dFFN​=11008,nlayer​=32,代入公式可得激活显存约为几十GB,甚至超过模型状态!

2.3 梯度检查点(Gradient Checkpointing)

为了降低激活显存,可以采用梯度检查点:在前向传播时只保存部分关键激活值(通常是每一层的输入),反向传播时临时重新计算缺失的激活值。这相当于用计算换显存,可将激活显存降低一个数量级,但会增加约30-40%的重计算开销。


第三部分:分布式训练与显存优化

当单卡无法装下整个模型时,必须采用分布式训练。下面介绍几种常用并行策略及其显存影响。

3.1 数据并行(DDP)

最朴素的数据并行:每张卡上都有一份完整的模型副本,处理不同的数据批次,通过All-Reduce同步梯度。

  • 每卡显存:依然是 16Φ(模型状态)

  • 总显存:num_devices×16Φnum_devices×16Φ

数据并行不节省单卡显存,但能通过增大全局batch size加速训练,并利用多卡算力。

梯度累积

梯度累积是一种“伪数据并行”:在单卡上连续做n次前向/反向,累积梯度后再更新参数。这样可以在不增加显存的情况下模拟更大的batch size,但训练速度会下降(更新次数减少)。

3.2 ZeRO(零冗余优化)—— 划时代的显存节省技术

ZeRO的核心思想:将模型状态(权重、梯度、优化器)进行分片,每张卡只存一部分,并通过通信原语在需要时获取完整状态。它分为三个级别,显存节省逐级增加,通信开销也逐级增大。

通信原语速成
  • All-Gather:从所有卡收集数据,每卡获得完整数据。

  • Reduce-Scatter:各卡数据先规约(如求和),然后按分片分发,每卡只得到一部分规约结果。

  • All-Reduce = Reduce-Scatter + All-Gather。

ZeRO-1(PosPos​):切分优化器状态
  • 每卡保留完整权重(FP16)和完整梯度(FP16),但优化器状态(FP32动量、方差)只存一部分。

  • 更新时:先Reduce-Scatter梯度,使每卡得到自己负责的那部分梯度的全局平均值;然后用本地优化器分片更新对应权重分片;最后All-Gather更新后的权重分片,使所有卡获得完整新权重。

  • 每卡显存:2Φ+2Φ+12Φnum_devices2Φ+2Φ+num_devices12Φ​

ZeRO-2(Pos+gPos+g​):切分梯度 + 优化器
  • 每卡保留完整权重,但梯度和优化器都分片。

  • 反向传播时逐层计算梯度,每算完一层就用Reduce-Scatter将这一层梯度分发给对应卡(聚合后存储),因此不需要保存完整梯度。

  • 每卡显存:2Φ+2Φ+12Φnum_devices2Φ+num_devices2Φ+12Φ​

ZeRO-3(Pos+g+pPos+g+p​):切分权重 + 梯度 + 优化器
  • 所有模型状态全部分片。

  • 前向传播时,需要从其他卡All-Gather权重分片,用完即弃;反向传播同样需要重新获取权重分片。

  • 每卡显存:2Φ+2Φ+12Φnum_devicesnum_devices2Φ+2Φ+12Φ​

峰值显存:由于通信过程中的临时缓冲区,实际峰值会比上述公式略高。例如ZeRO-3的峰值约为:

2Φndev+2Φndev+2Φnlayer+Mact+(2Φnlayerndev)ndev​2Φ​+ndev​2Φ​+nlayer​2Φ​+Mact​+(nlayer​ndev​2Φ​)

其中 nlayernlayer​ 是层数,逐层操作可降低临时内存。

ZeRO-Offload

若卡数不足,可将部分状态卸载到CPU内存甚至NVMe,进一步降低显存,但会引入CPU-GPU数据传输开销。


第四部分:模型并行——从结构上拆分模型

模型并行与数据并行不同,它将模型本身切分到多个设备上,每个设备只负责一部分计算。

4.1 张量并行(Tensor Parallelism)

在层内对矩阵运算进行切分,例如将注意力QKV矩阵按列切分,输出矩阵按行切分。每个设备只需存储一部分权重,但前向/反向时需要频繁通信中间激活值(All-Reduce或All-Gather)。

4.2 流水线并行(Pipeline Parallelism)

按层切分,不同层放在不同设备上。设备间通过点对点通信传递激活值和梯度。单卡显存减少为原来的 1/(DtpDpp)1/(Dtp​Dpp​),但存在流水线气泡。

4.3 模型并行下的显存

在张量并行度 DtpDtp​ 和流水线并行度 DppDpp​ 下,每卡保存的模型状态为:

Mper_gpu=16ΦDtpDppMper_gpu​=Dtp​Dpp​16Φ​

即总模型状态被切分到 DtpDppDtp​Dpp​ 张卡上。


第五部分:3D并行与综合优化

现代大模型训练往往采用数据并行 + 张量并行 + 流水线并行的三维并行(3D并行),有时还结合ZeRO。此时显存计算需考虑数据并行组数 Ddp=总卡数DtpDppDdp​=Dtp​Dpp​总卡数​。

在3D并行基础上叠加ZeRO,优化器状态可以在数据并行组内进一步分片:

Mper_gpu=2ΦDtpDpp+2ΦDtpDpp+12ΦDdpDtpDppMper_gpu​=Dtp​Dpp​2Φ​+Dtp​Dpp​2Φ​+Ddp​Dtp​Dpp​12Φ​

由于 DdpDtpDppDdp​Dtp​Dpp​ = 总卡数,当总卡数很大时,优化器状态几乎可以忽略不计!这正是万卡集群能够训练万亿参数模型的原因——靠卡数分摊状态。

例如,DeepSeek-V3训练时放弃张量并行,采用流水线并行+ZeRO-1,就是为了在通信效率和显存节省之间取得平衡。


第六部分:总结与个人思考

通过以上分析,我们可以总结出大模型显存优化的几个关键点:

  1. 模型状态是基础:16Φ(FP32/混合精度)是单卡训练的理论下限,必须通过分布式切分才能降低。

  2. 激活值与序列长度强相关:长序列时激活可能超过模型状态,梯度检查点是必备手段。

  3. ZeRO是数据并行的革命:通过分片状态,将显存占用与卡数成反比,是大规模训练的核心技术。

  4. 模型并行与ZeRO相辅相成:张量并行和流水线并行能进一步切分模型,与ZeRO组合实现极致显存节省。

  5. 通信开销是代价:任何显存优化都会增加通信量,需要在训练效率与显存之间权衡。

个人观点:未来随着硬件发展,显存容量会持续提升,但模型规模增长更快,因此这些优化技巧不会过时。更重要的是,我们要学会根据模型大小、集群规模、训练成本来选择合适的并行策略。例如,对于百亿模型,可能仅用ZeRO-3就够了;对于万亿模型,必须依赖3D并行。理解显存计算的本质,才能在大模型时代游刃有余。

希望本文能帮你拨开显存计算的迷雾,在实际训练中少走弯路。如果你对具体实现细节感兴趣,欢迎继续关注后续章节(如通信原语详解、流水线并行调度等)。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐