斯坦福 CS336 从零构建大模型 (2025 春) - 第八讲:多机并行训练 2(Parallelism 2)

文章目录

斯坦福 CS336 第八讲的主题是**“多机并行训练 2(代码实战与通信底层)”。如果说第七讲侧重于并行策略的理论推导和显存账本,那么第八讲则是将这些理论落地,通过 PyTorch 代码从零实现这些并行算法**。讲师通过多层感知机(MLP)作为极简案例,详细剥析了数据搬运的底层逻辑。

以下是本讲不遗漏任何知识点的详细解析:

一、核心矛盾:计算与数据传输的层级 (Memory & Communication Hierarchy)

深度学习底层永远在围绕一个核心矛盾做优化:保持计算单元(ALU/SM)高强度运转,而避免被慢速的数据传输卡脖子。在多机多卡训练中,数据传输的物理层级决定了并行的速度,从快到慢依次为:

  • L1 Cache / Shared Memory:在 GPU 的 SM 内部,极快但极小。
  • HBM (高带宽显存):单张 GPU 的全局显存。
  • NVLink:同一台机器(Node)内部不同 GPU 之间的专用通信桥梁,带宽极高(如 H100 的 NVLink 带宽可达 900 GB/s),完全绕过了慢速的 CPU 和 PCIe 总线。
  • NVSwitch / 以太网:跨机器(跨 Node)的 GPU 通信。传统家用机器使用 PCIe 到以太网(慢且有内核开销),而现代 AI 集群使用 NVSwitch 直接相连。

二、集合通信原语与软件栈 (Collective Operations)

并行算法底层依赖于一系列经典的集合通信原语(源自 80 年代)。

  • 操作回顾
    • Broadcast(广播,一对多)与 Scatter(散播不同的切片,一对多)。
    • Gather(收集,多对一)与 All-Gather(全收集,多对多)。
    • Reduce(规约求和/求最大值,多对一)与 Reduce-Scatter(规约后打散分配,多对多)。
    • All-Reduce(全规约,相当于 Reduce-Scatter + All-Gather)。
  • 软件栈 (NCCL 与 PyTorch)
    • NVIDIA 提供了 NCCL 库,它能自动探测硬件拓扑结构,将这些抽象的集合操作翻译成最优的底层数据包进行传输。
    • torch.distributed 则为 Python 提供了高层接口,并且支持不同的后端(有 GPU 时用 NCCL,纯 CPU 调试时可用 Gloo)。
  • 基准测试的坑与带宽计算
    • 在测试通信时间时,必须进行“预热(Warm-up)”,并使用 barriersynchronize 确保异步进程同步。
    • 带宽计算公式差异:All-Reduce 需要先收集求和再分发,数据来回跑了一次,因此传输字节数是张量大小的 2倍;而 Reduce-Scatter 只有 1 倍。讲师实测中,All-Reduce 的实际带宽约为 277 GB/s。

三、分布式训练代码实战 (Coding Parallelism from Scratch)

讲师以一个纯纯的深度 MLP 模型为例,展示了如何用最底层的原语写出三种并行:

1. 数据并行 (Data Parallelism, DDP)

  • 切分维度:沿着 Batch 维度切分数据。
  • 实现逻辑:初始化时,根据不同的 rank(设备ID)让各个 GPU 获取不同的数据切片(Local batch size)。
  • 魔法注入点:前向传播和算 Loss 就像写普通的单卡 SGD 完全一样。唯一的改变是在反向传播算完梯度后、优化器更新参数前,强行插入一行代码:对所有层的 param.grad 调用一次 all_reduce 取平均。
  • 结果:虽然每张卡的前向 Loss 不一样(因为数据不同),但梯度被强行平均了,所以各个卡的参数永远保持完全一致。

2. 张量并行 (Tensor Parallelism, TP)

  • 切分维度:沿着网络宽度(Hidden Dimension)切分模型权重。
  • 实现逻辑:每张卡收到的是完全相同的一整批数据,但初始化的权重矩阵被切小了(比如 1024 维切成 4 份,每卡只有 256 维)。
  • 通信逻辑:每张卡先算出局部的激活值(此时只有局部的特征维度)。然后,必须调用 all_gather,把所有卡算出的局部激活值全部拉取过来,拼接(Concatenate)成完整的激活值矩阵。
  • 代价:每算一层都要调用 all_gather,通信极其频繁,这就是为什么 TP 极度吃带宽、通常只能在单机 8 卡内使用的底层代码原因。

3. 流水线并行 (Pipeline Parallelism, PP)

  • 切分维度:沿着网络深度(Layers)切分。
  • 实现逻辑:Rank 0 只初始化并负责前 2 层,Rank 1 负责后 2 层。
  • 通信逻辑:不再使用集合通信,而是使用点对点通信(Point-to-point)的 sendrecv 原语。为了减少流水线气泡,必须将一个大 Batch 切分成多个微批次(Micro-batches)。Rank 0 算完一个微批次后 send 给 Rank 1,Rank 1 recv 收到后接着算。
  • 挑战:讲师演示的基础代码是同步阻塞的。如果想要真正高效,必须把通信异步化,并处理极度复杂的通信与计算的重叠(Overlapping),且同时处理前向和反向的交错调度。

四、其他生态系统与硬件的未来 (Ecosystems & Future)

  • JAX 的声明式并行:讲师最后对比了 PyTorch 生态(如 FSDP,需要开发者做大量繁琐的模型状态 bookkeeping)与 Google JAX 生态。在 JAX(如 Levanter 框架)中,你只需要从高维度“声明”你想在哪一个张量维度上做切分(比如按 batch 切,或按 head 切),底层的 XLA 编译器会自动推导出所需的所有通信原语。
  • 专用芯片(Cerebras / Groq):有学生提问 GPU 能否被替代。讲师指出,因为 GPU 身上背负着传统 CPU 时代的“控制流(分支跳转)”历史包袱,而大模型的计算图是极度静态的“数据流(Data flow)”。像 Cerebras 这样的专用 AI 芯片,通过把巨大的静态内存直接做进单块超大芯片里(相当于把 SRAM/L1 做到极大),直接消除了数据反复搬移的通信瓶颈,代表着一种在硬件底层解决通信开销的思路。

五、核心概念问答 (Q&A)

Q1:在传统的 PCIe 总线连接中,GPU 间的数据传输是不是必须先经过 CPU?那 NVLink 呢?

回答:是的,传统的 PCIe 总线通信,数据传输必须先经过 CPU(主机内存)。而现代的 NVLink 允许 GPU 之间直接进行超高速的通信,绕过了 CPU 内存的瓶颈。不过即便有了 NVLink,GPU 依然需要与 CPU 保持连接以接收指令调度。

Q2:在调用 reduce_scatter 操作时,系统是如何记录和知道哪个索引的数据该发给哪个具体的 GPU 的?

回答:这主要依靠张量维度的约定(Convention)。通常,输入张量的其中一个维度的大小必须严格等于 world_size(总设备数)。底层通信库(如 NCCL)据此就知道,它需要将对应的切片自动分发给对应的 Rank(设备)。

Q3:在计算带宽时,为什么 all_reduce 的传输字节数要乘以 2,而 reduce_scatter 不用?计算 reduce_scatter 时不用考虑输入数据的读取耗时吗?

回答:我们假设输入数据已经存在于当前设备的显存中,所以不把最初的显存读取算作跨节点通信时间。all_reduce 之所以要乘以 2,是因为它本质上是两步:先进行归约(等价于 reduce_scatter),然后再将结果广播回所有人(等价于 all_gather),有发和收的双倍流量。而 reduce_scatter 只有单纯的归约到对应设备这一步操作。

Q4:在写纯数据并行(DDP)的代码时,由于各个 Rank 是异步并行运行的,当我们调用 all_reduce 混合梯度时,怎么保证它们刚好都运行到了同一个 Step?

回答:这是一个好问题。all_reduce 这样的集合通信原语在底层本身就是一个**“同步屏障(Synchronization point / Barrier)”**。先执行到这一步的进程会被挂起等待,直到所有的 Rank 都调用了同一个 all_reduce 才会放行。当然这也意味着,如果某个 Rank 因为 bug 死机或漏掉了这个调用,整个训练程序就会死锁卡住(Hang)。

Q5:在数据并行(DDP)中,如果有像 BatchNorm 这样严重依赖全局数据分布的层,由于每张卡只看到自己的局部数据,这会导致问题吗?

回答:是的,BatchNorm 跨 GPU 同步一直非常让人头疼(需要专门的 SyncBatchNorm)。但幸运的是,在大语言模型(LLM)的世界里这个问题几乎不存在,因为 LLM 普遍使用的是 LayerNorm 或 RMSNorm。这类归一化只依赖于当前处理的单个 Token / 序列的局部特征,不同卡之间完全独立,只要初始权重和随机种子一致,就不会有任何问题。

Q6:在流水线并行(PP)的代码中,Rank 之间互相等待激活值的传递。这是一种“事件驱动(Event-driven)”的编程模式吗?

回答:并不是。事件驱动编程通常是指写一堆回调函数,当某个网络包到达时触发执行。而现代深度学习的流水线并行代码通常写得非常“同步且死板(Synchronous paradigm)”,程序是按照严格的预设顺序,同步地调用 recv 阻塞等待,收到数据后计算,再调用 send。

Q7:如果我们在流水线并行里调用了同步的 send 和 recv,那计算和通信就没法重叠(Overlap)了。在代码里应该怎么改?

回答:你需要将同步的阻塞指令替换为异步原语(Asynchronous primitives,例如 isend 和 irecv)。异步调用会立刻返回一个句柄(Handle),这样 CPU 就可以立刻让 GPU 去计算下一个微批次(Microbatch),而底层的 DMA 引擎会在后台同时处理网络传输。最后你只需要在合适的地方调用 wait() 确保传输完成即可。

Q8:如果在使用 send 发送张量时,接收方根本没调用 recv,程序会怎样?如果是连续调用多个 send 和 recv,系统怎么区分哪个张量对应哪个?

回答:如果只发不收,发送进程大概率会一直阻塞死等下去。至于区分张量,底层的点对点通信并不关心张量在 Python 里叫什么名字变量,它只看源 Rank、目标 Rank 以及发送的先后顺序(Stream order)。所以顺序绝对不能写错。

Q9:在流水线并行中,排在最后面的那个 Rank(拿到了最终输出)接下来会做什么?

回答:最后一个 Rank 拿到了完整的正向传播结果,它会用这个结果计算 Loss(损失函数)。接着,它会立刻开启反向传播算梯度,并将其对应层的激活值梯度通过网络逆向 send 回给上一个 Rank。

Q10:比起需要手动写这么多复杂的底层通信代码,PyTorch 有没有像 Jax 那样优雅的高级 API 帮我们自动处理分片(Sharding)?

回答:PyTorch 提供了 FSDP(Fully Sharded Data Parallel)这个高级封装,你只需要包裹住模型它就能自动运行。但客观地说,在声明式自动切分上,Jax + TPU 生态(如 Levanter)目前确实更优雅。 不过,生态不同玩法也不同。像 Google 在高大上的 Jax 里通过编译器处理一切;而像 DeepSeek 这样的公司,面对的是网络互联极差的“平民 GPU 集群”,他们为了压榨出极限性能,会直接绕过高级 API,深入到 NCCL 底层去大搞 Hack 优化。所以了解底层逻辑依然极其重要。

Q11:GPU 会不会有一天被专门为 Transformer 设计的芯片彻底取代?我看像 Groq 这种芯片好像没这些复杂的概念。

回答:在**推理(Inference)**领域,这种趋势已经非常明显了,像 Groq 和 Cerebras 确实在做这种事。 从硬件发展史来看,GPU 最初是为了渲染图形设计的,带有大量的控制逻辑和分支预测包袱。而深度学习的特点是“数据流图(Data flow)”一旦确定就基本静态不变。专有芯片(如 Cerebras)的做法干脆是把海量的 SRAM内存直接集成到计算核心旁边(Wafer-Scale),彻底干掉“把数据在显存和计算核心间搬来搬去”的瓶颈。这也是未来芯片极具潜力的方向。

Q12:既然有物理瓶颈,那我们为什么不把单个 GPU 节点造得无限大(容纳无限多的显存和算力)?

回答:主要是极大的物理限制和散热(Power and thermal issues)。你无法将芯片无限做大且无限密集,因为供电和把巨大的热量排出去本身就是目前硬件工程里的终极难题之一。

Q13:我们在讲这些分布式通信原语(如 NCCL 库),它们究竟是 CPU 指令,还是在 GPU 上运行的程序?

回答:CPU 依然是整个系统的“主控(Master)”。当你调用一个集合操作时,是 CPU 上的代码调用了 NCCL 库,然后 NCCL 会负责向 GPU 发射(Launch)专门的通信 Kernel,由 GPU 上的 DMA 控制器去执行实际的数据搬运。

Q14:你刚才展示了用这套系统从头训练一个模型。那如果有了新的增量数据,我们可以用同一套并行架构去做继续训练(Continual Pre-training)吗,而不用从头算起?

回答:当然可以绝对没问题。因为不管你是从头开始,还是从某个 Checkpoint 加载了一半的权重,对于并行架构和优化器来说,本质上都是在做“前向、后向、更新梯度”的循环。架构对此是完全透明的。


六、第八讲复习题

📝 第八讲复习题 (Lecture 8: Parallelism 2)

一、 硬件连接与通信基础设施 (Hardware & Collective Comms)

  1. 网络拓扑的演进: 传统的消费级 GPU 通常通过 PCIe 总线和以太网(Ethernet)连接。为什么现代用于深度学习的大规模 GPU 集群(如 H100 集群)要抛弃这种架构,转而使用 NVLink 和 NVSwitch?
  2. 底层软件栈: 在 PyTorch 中,当我们调用 torch.distributed 进行集合通信时,底层负责实际执行 GPU 间通信优化的 Nvidia 高性能核心库叫什么?如果我们在没有 GPU 的笔记本上用 CPU 进行分布式代码调试,PyTorch 推荐使用哪个后端(Backend)?
  3. 带宽计算的数学细节 (Bandwidth Math): 在对集合通信进行基准测试(Benchmarking)并计算通信带宽(GB/s)时,为什么讲师强调在计算 All-reduce 操作的传输字节数时需要乘以 2,而计算 Reduce-scatter 时不需要乘以 2?

二、 数据并行代码实现 (Data Parallelism - DDP)

  1. DDP 的切分逻辑: 在手写纯数据并行的代码时,输入的 Batch 数据和模型的参数(Parameters)在各个 Rank(GPU 设备)上分别是怎样分配?
  2. DDP 的梯度同步: 在 DDP 的训练循环中,前向传播和反向传播与单机代码无异。但在调用优化器更新(optimizer.step())之前,必须在代码中显式地插入哪一个集合通信操作(Collective Operation),并且是作用于哪个变量上?

三、 张量并行代码实现 (Tensor Parallelism - TP)

  1. TP 的切分逻辑: 如果我们要用代码对一个多层感知机(MLP)实现张量并行,假设将其分配给 W 个 Rank,模型每层权重矩阵的形状(Shape)会发生怎样的改变?输入数据又该如何分配?
  2. TP 的前向通信: 在张量并行的前向传播代码中,当每个 Rank 用自己局部切分的权重矩阵算出了局部的激活值(Activations)后,必须调用哪一个集合通信操作才能拼凑出完整的激活值,以便送入下一层?

四、 流水线并行代码实现 (Pipeline Parallelism - PP)

  1. PP 的切分逻辑: 在手写流水线并行时,模型结构是如何被切分并分配给不同 Rank 的?
  2. PP 的点对点通信: 与 DDP 和 TP 依赖全局的集合通信(如 All-reduce, All-gather)不同,流水线并行的代码在层与层之间传递激活值时,主要依赖哪两个底层通信原语(Primitives)?
  3. 重叠通信的进阶 (Overlapping): 在极简的流水线并行示例代码中,通信会导致后续进程死等。为了实现通信与计算的重叠(Overlap),在代码层面上应该把阻塞的同步通信指令替换成什么类型的指令?

七、参考答案与知识点解析

  1. 网络拓扑的演进?
    答案: 传统架构中,GPU 间的数据传输必须经过 CPU 内存并使用较慢的以太网,这会引入极大的延迟和开销。而 NVLink 和 NVSwitch 绕过了 CPU,允许 GPU 之间直接进行超高速的通信(例如 H100 的 NVLink 带宽高达 900 GB/s),这是专门为深度学习中高频、大规模的数据同步而设计的。

  2. 底层软件栈?
    答案: 负责 GPU 间高性能通信的底层库是 NCCL (Nvidia Collective Communication Library)。如果要在纯 CPU 环境下调试分布式代码,PyTorch 提供了 Gloo 后端作为替代。

  3. 带宽计算的数学细节?
    答案: 因为 All-reduce 本质上包含两个阶段:每个 Rank 先将自己的数据发送出去参与规约(等价于 Reduce-scatter),然后再接收规约后的全局结果(等价于 All-gather)。这意味着每个节点既发送了一份完整数据大小的流量,又接收了一份,因此双向传输的数据量是单向的 2 倍。而 Reduce-scatter 仅涉及发送并规约到特定节点,没有随后的全局广播步骤,所以没有这额外的 2 倍系数。

  4. DDP 的切分逻辑?
    答案: 在 DDP 中,模型参数是完整复制的,每个 Rank 都拥有一个完全相同的模型副本;而输入数据是沿着 Batch 维度被切分的(例如全局 Batch Size 为 128,在 4 个 Rank 上,每个 Rank 的 local_batch_size 只有 32)。

  5. DDP 的梯度同步?
    答案: 必须使用 All-reduce 操作,并且该操作是作用于每一层权重的梯度(param.grad)上。这确保了各个 Rank 在局部数据上算出的独立梯度被求和(或平均),使得所有 GPU 在更新前拥有完全一致的全局梯度。

  6. TP 的切分逻辑?
    答案: 权重矩阵会沿着隐藏维度(Hidden Dimension)切分;每个 Rank 都会获得完整且相同的一份 Batch 数据副本来与自己的局部权重相乘。

  7. TP 的前向通信?
    答案: 必须调用 All-gather。因为每个 Rank 的局部矩阵乘法只算出了激活值特征的一部分(Partial dimension),只有通过 All-gather 将所有 Rank 的输出收集并拼接(Concatenate)起来,才能恢复出具有完整隐藏维度的激活值矩阵,以便供下一层使用。

  8. PP 的切分逻辑?
    答案: 模型沿着**深度(层/Layers)**被切分。每个 Rank 只分配到模型的一段连续层(例如 4 层网络,Rank 0 拿前 2 层,Rank 1 拿后 2 层)。

  9. PP 的点对点通信?
    答案: 流水线并行主要使用点对点通信原语(Point-to-point primitives),即 sendrecv。因为激活值和梯度只需要沿着流水线传递给紧邻的下一个(或上一个)Rank,不需要全局广播。

  10. 重叠通信的进阶?
    答案: 应该将同步的 send/recv 替换为异步通信指令(Asynchronous primitives,如 isend)。这会返回一个句柄(Handle),让 CPU 可以在后台派发网络传输任务,同时 GPU 立刻开始处理下一个微批次(Microbatch)的计算,从而实现通信时间与计算时间的完美重叠。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐