目录

先从一个问题开始

第一步:requires_grad——谁需要被追踪

第二步:前向传播——计算值的同时,悄悄建图

整张计算图长这样

第三步:.backward()——沿着图把梯度传回来

把反向传播的每一步拆开来看

第四步:叶节点与非叶节点的区别

第五步:梯度累积——一个你必须理解的坑

第六步:torch.no_grad()——关掉追踪的开关

第七步:计算图只能用一次

第八步:用 detach() 切断梯度流

第九步:参数更新之后,计算图就消失了$y$

第十步:当你用 nn.Module 和 optimizer 之后

第十一步:梯度消失和梯度爆炸是怎么回事

第十二步:自定义反向传播(选读)

小结:整条链路


在第一篇里我们说过,PyTorch 的动态图是边执行边构建的,.backward() 不是在"计算"梯度,而是在沿着已经构建好的图做一次反向遍历。

这一篇就来兑现这个承诺——把这条链路真正讲清楚。

不用复杂的模型,就用最简单的手写线性回归:输入一个数,输出一个数,一个权重,一个偏置。但我们会把这个过程拆解到足够细,让你看清楚每一步背后发生了什么。

搞懂这一篇之后,你再去看任何关于梯度消失、梯度裁剪、自定义反向传播的内容,都会理解其中的原理。


先从一个问题开始

假设你有一个极其简单的函数:

$L = (wx + b - y)^2$

其中 $x$ 是输入,y 是标签,$w$$b$ 是参数,L 是损失(预测值和真实值的误差的平方)。

你想用梯度下降更新 $w$$b$,就需要:

$\frac{\partial L}{\partial w}, \quad \frac{\partial L}{\partial b}$

手动推导不难:

$\frac{\partial L}{\partial w} = 2(wx + b - y) \cdot x$

$\frac{\partial L}{\partial b} = 2(wx + b - y)$

但问题是:你不想每次换个模型就重新推一遍导数。你希望框架能自动帮你算。

PyTorch 的自动微分(Autograd)就是干这个的。但它不是在符号层面推导公式,而是在计算图上做数值反向传播。这两件事的结果一样,但方式完全不同——理解这个区别是理解 Autograd 的起点。


第一步:requires_grad——谁需要被追踪

import torch

# 数据(不需要求导)
x = torch.tensor(2.0)
y = torch.tensor(5.0)   # 真实标签,预期输出 5

# 参数(需要求导,这是 PyTorch 需要追踪的)
w = torch.tensor(1.0, requires_grad=True)
b = torch.tensor(0.0, requires_grad=True)

print(x.requires_grad)   # False
print(w.requires_grad)   # True
print(b.requires_grad)   # True

requires_grad=True 告诉 PyTorch:这个 Tensor 是我们关心的参数,请在所有涉及它的计算过程中构建计算图,以便之后能对它求导

没有设置 requires_grad=True 的 Tensor(比如 xy),PyTorch 不会追踪与它们相关的操作,这样可以节省内存和计算。

有一个非常重要的传播规则:只要一个操作的任意输入有 requires_grad=True,那么这个操作的输出也会自动有 requires_grad=True

a = torch.tensor(3.0, requires_grad=True)
b = torch.tensor(2.0)   # 没有 requires_grad

c = a * b   # 输入里 a 有 requires_grad
print(c.requires_grad)   # True  ← 自动传播

这个传播规则是整个自动微分系统能工作的基础——从参数出发,所有经过它的计算结果都会自动被追踪。


第二步:前向传播——计算值的同时,悄悄建图

# 前向传播:计算预测值和损失
pred = w * x + b      # 预测值:1.0 * 2.0 + 0.0 = 2.0
loss = (pred - y) ** 2  # 损失:(2.0 - 5.0)^2 = 9.0

print(f"预测值: {pred.item()}")   # 2.0
print(f"损失:   {loss.item()}")   # 9.0

这两行代码在你眼里是在做数学运算,但在 PyTorch 背后,还同时发生了另一件事:计算图被悄悄建立起来了

你可以通过 grad_fn 属性窥探这张图:

print(w.grad_fn)      # None  ← 叶节点,没有 grad_fn
print(pred.grad_fn)   # <AddBackward0 object at ...>
print(loss.grad_fn)   # <PowBackward0 object at ...>

grad_fn 是什么?它是每个 Tensor 的"出生证明":记录了这个 Tensor 是由哪个操作产生的,以及这个操作需要什么信息才能做反向传播。

  • wb 是叶节点,直接由用户创建,没有 grad_fn
  • pred = w * x + b,是加法操作的结果,所以 grad_fnAddBackward0
  • loss = (pred - y) ** 2,是幂操作的结果,所以 grad_fnPowBackward0

这些 grad_fn 对象不只是标签,它们每个都存储了反向传播所需的中间数据,并且知道"收到上游梯度之后,该把什么梯度传给自己的输入"。

每个节点知道两件事:

  1. 自己的输出值是什么(前向传播已经算好了)
  2. 收到来自输出方向的梯度时,应该怎么把梯度传给自己的输入

反向传播就是从 loss 出发,沿着这张图的反向边,把梯度一路传回到 wb


第三步:.backward()——沿着图把梯度传回来

loss.backward()

就这一行。PyTorch 会:

  1. loss 开始,初始梯度为 1.0(损失对自身的导数是 1)
  2. 沿着计算图反向遍历,对每个节点调用它的 grad_fn,计算并传播梯度
  3. 直到到达叶节点(wb),把最终梯度累积到它们的 .grad 属性里

让我们看结果:

print(f"w 的梯度: {w.grad}")   # tensor(-12.)
print(f"b 的梯度: {b.grad}")   # tensor(-6.)

验证一下手动推导的结果:

$\frac{\partial L}{\partial w} = 2(wx + b - y) \cdot x = 2(2.0 - 5.0) \cdot 2.0 = 2 \cdot (-3.0) \cdot 2.0 = -12.0 $

$\frac{\partial L}{\partial b} = 2(wx + b - y) = 2(2.0 - 5.0) = 2 \cdot (-3.0) = -6.0 $

和手动计算完全一致。PyTorch 做的事情和你用链式法则手推是等价的,只是它是在计算图上自动完成的。

为了真正理解梯度是怎么在图上流动的,我们一步一步手动追踪:

第一步:loss → pred 的梯度

$L = (\text{pred} - y)^2$,令 $u = \text{pred} - y$,则 $L = u^2$

$\frac{\partial L}{\partial \text{pred}} = 2u = 2(\text{pred} - y) = 2(2.0 - 5.0) = -6.0$

第二步:pred → w 的梯度(通过链式法则)

$\text{pred} = w \cdot x + b$,所以 $\frac{\partial \text{pred}}{\partial w} = x = 2.0$

$\frac{\partial L}{\partial w} = \frac{\partial L}{\partial \text{pred}} \cdot \frac{\partial \text{pred}}{\partial w} = -6.0 \cdot 2.0 = -12.0 $

第三步:pred → b 的梯度

$\frac{\partial \text{pred}}{\partial b} = 1$

$\frac{\partial L}{\partial b} = \frac{\partial L}{\partial \text{pred}} \cdot \frac{\partial \text{pred}}{\partial b} = -6.0 \cdot 1 = -6.0 $

每个 grad_fn 的职责就是完成"第二步"和"第三步"这样的工作:接收上游传来的梯度($\frac{\partial L}{\partial \text{pred}} = -6.0$),乘以自己这层的局部导数,传给下游。

这就是链式法则在计算图上的具体体现


第四步:叶节点与非叶节点的区别

print(f"w 是叶节点吗: {w.is_leaf}")      # True
print(f"b 是叶节点吗: {b.is_leaf}")      # True
print(f"pred 是叶节点吗: {pred.is_leaf}") # False
print(f"loss 是叶节点吗: {loss.is_leaf}") # False

叶节点:由用户直接创建的 Tensor,不是任何运算的结果。模型的参数(权重、偏置)都是叶节点。它们的 grad_fnNone

非叶节点:由运算产生的 Tensor,中间计算结果。它们有 grad_fn,记录了自己从哪里来。

.backward() 完成后,只有叶节点的 .grad 会被保留,非叶节点的梯度默认会被释放:

print(f"w.grad:    {w.grad}")     # tensor(-12.)  ← 保留了
print(f"b.grad:    {b.grad}")     # tensor(-6.)   ← 保留了
print(f"pred.grad: {pred.grad}")  # None           ← 释放了!
print(f"loss.grad: {loss.grad}")  # None           ← 释放了!

这是 PyTorch 的内存优化策略:对于模型训练来说,你只需要叶节点(参数)的梯度来做参数更新,中间激活值的梯度是临时的,算完就可以扔掉。

如果你确实需要某个中间节点的梯度(比如做梯度可视化或者某些特殊的优化),可以在它上面调用 .retain_grad()

w = torch.tensor(1.0, requires_grad=True)
b = torch.tensor(0.0, requires_grad=True)
x = torch.tensor(2.0)
y = torch.tensor(5.0)

pred = w * x + b
pred.retain_grad()   # 告诉 PyTorch:请保留 pred 的梯度

loss = (pred - y) ** 2
loss.backward()

print(f"pred.grad: {pred.grad}")   # tensor(-6.)  ← 现在有了

第五步:梯度累积——一个你必须理解的坑

.backward() 做的事情是把梯度累加.grad 上,而不是替换。

w = torch.tensor(1.0, requires_grad=True)
b = torch.tensor(0.0, requires_grad=True)
x = torch.tensor(2.0)
y = torch.tensor(5.0)

# 第一次前向 + 反向
pred = w * x + b
loss = (pred - y) ** 2
loss.backward()
print(f"第一次 w.grad: {w.grad}")   # tensor(-12.)

# 第二次前向 + 反向(没有清零!)
pred = w * x + b
loss = (pred - y) ** 2
loss.backward()
print(f"第二次 w.grad: {w.grad}")   # tensor(-24.)  ← -12 + (-12) = -24!

梯度被累加了。这就是为什么训练循环里每次迭代必须先清零梯度:

# 正确的训练循环写法
for epoch in range(100):
    # ① 清零梯度(必须在前向传播之前,或反向传播之后)
    w.grad = None   # 手动清零方式
    b.grad = None

    # ② 前向传播
    pred = w * x + b
    loss = (pred - y) ** 2

    # ③ 反向传播
    loss.backward()

    # ④ 参数更新
    with torch.no_grad():
        w -= 0.01 * w.grad
        b -= 0.01 * b.grad

用 optimizer 的时候,optimizer.zero_grad() 就是在做清零这件事,等价于对所有参数的 .grad 置零(或者置 None)。

梯度累积并不总是坏事。当你因为显存限制无法使用大 batch size 时,可以故意不清零,跑几个小 batch、累积梯度,再做一次参数更新——效果等价于跑了一个大 batch。这是一种常用的训练技巧,叫做 Gradient Accumulation。


第六步:torch.no_grad()——关掉追踪的开关

做参数更新的时候,用 w -= 0.01 * w.grad 这一步计算本身也会触发 PyTorch 的追踪(因为 wrequires_grad=True)。但这次更新不应该被追踪——它只是在修改参数值,不是模型的前向计算。

这就是为什么参数更新要包在 torch.no_grad() 里:

with torch.no_grad():
    w -= 0.01 * w.grad
    b -= 0.01 * b.grad

torch.no_grad() 是一个上下文管理器,在它的代码块里:

  • 所有操作都不会被追踪
  • 创建的 Tensor 的 requires_grad 强制为 False
  • 不会构建计算图,节省内存和计算

同样的,验证模型时也应该放在 torch.no_grad() 里,原因在第一篇已经讲过:验证阶段不需要反向传播,不追踪计算图可以大幅节省显存和加快速度。

还有一个类似的工具是 torch.inference_mode(),比 no_grad() 更激进,连 grad_fn 都不记录,推理性能更好,但在它里面修改 Tensor 可能有副作用(不能回到训练模式),一般只在最终部署推理时使用。


第七步:计算图只能用一次

w = torch.tensor(1.0, requires_grad=True)
b = torch.tensor(0.0, requires_grad=True)
x = torch.tensor(2.0)
y = torch.tensor(5.0)

pred = w * x + b
loss = (pred - y) ** 2

loss.backward()   # 第一次反向传播,正常

# loss.backward()  ← 如果你再调用一次,会报错:
# RuntimeError: Trying to backward through the graph a second time...

为什么?因为 .backward() 在遍历计算图的过程中,会释放中间节点存储的那些"用于反向传播的中间数据"(比如前向传播的激活值),以节省内存。图走完之后,这些数据就消失了,没法再走第二遍。

如果你确实需要多次对同一个计算图做反向传播(某些研究场景,比如计算 Hessian 矩阵,或者 MAML 这类元学习算法),需要加 retain_graph=True

loss.backward(retain_graph=True)   # 不释放中间数据
loss.backward()                    # 可以再走一遍

注意:retain_graph=True 会让内存占用翻倍甚至更多,只在确实需要时使用。


第八步:用 detach() 切断梯度流

有时候你希望某段计算不参与反向传播——比如你在做一个 GAN,更新生成器的时候不希望梯度流回判别器;或者你在 RL 里有一个 target network,它的输出作为监督信号但本身不更新。

这时候需要 .detach()

w = torch.tensor(1.0, requires_grad=True)
b = torch.tensor(0.0, requires_grad=True)
x = torch.tensor(2.0)
y = torch.tensor(5.0)

pred = w * x + b      # pred 有 grad_fn,是图的一部分

# detach 返回一个新的 Tensor,和 pred 数值相同,但从计算图里"切断"了
pred_detached = pred.detach()
print(pred_detached.requires_grad)   # False
print(pred_detached.grad_fn)         # None  ← 梯度流在这里断了

loss = (pred_detached - y) ** 2
loss.backward()

# w 和 b 的梯度是 None,因为梯度流在 pred_detached 这里断了
print(w.grad)   # None
print(b.grad)   # None

.detach() 常见的实际用途:

# RL 里的 target network:target 不参与梯度计算
target_q = target_net(next_obs).max(dim=1).values.detach()
loss = F.mse_loss(current_q, target_q)   # target_q 断开梯度
loss.backward()   # 只更新 current network 的参数

# GAN 里更新判别器时不想让梯度流回生成器
fake_img = generator(noise).detach()   # 切断梯度
d_loss = discriminator(fake_img)
d_loss.backward()   # 只更新判别器

detach()torch.no_grad() 的区别:

  • torch.no_grad() 是上下文管理器,在其作用域内所有操作都不追踪,出了这个块一切恢复正常
  • .detach() 是对某个特定 Tensor 的永久切断,这个 Tensor 在任何地方都不再携带梯度信息

第九步:参数更新之后,计算图就消失了

把前面所有知识组装成一个完整的线性回归训练循环,同时展示计算图的生命周期:

import torch

# 数据
x = torch.tensor([1.0, 2.0, 3.0, 4.0])
y = torch.tensor([3.0, 5.0, 7.0, 9.0])   # y = 2x + 1

# 参数,随机初始化
w = torch.tensor(0.0, requires_grad=True)
b = torch.tensor(0.0, requires_grad=True)

lr = 0.01

for epoch in range(200):
    # ① 前向传播(同时建图)
    pred = w * x + b          # 向量操作,x 是 4 个样本
    loss = ((pred - y) ** 2).mean()   # MSE loss

    # ② 清零梯度(关键!在反向传播之前清零上一次残留的梯度)
    if w.grad is not None:
        w.grad.zero_()   # in-place 清零
    if b.grad is not None:
        b.grad.zero_()

    # ③ 反向传播(沿图传播梯度,图在这之后被释放)
    loss.backward()

    # ④ 参数更新(在 no_grad 里,避免这步操作被追踪进图)
    with torch.no_grad():
        w -= lr * w.grad
        b -= lr * b.grad

    if (epoch + 1) % 50 == 0:
        print(f"Epoch {epoch+1:3d} | loss: {loss.item():.6f} | w: {w.item():.4f} | b: {b.item():.4f}")

print(f"\n最终结果:w = {w.item():.4f},b = {b.item():.4f}")
print(f"期望结果:w = 2.0,b = 1.0")

输出大概是:

Epoch  50 | loss: 0.193418 | w: 1.6484 | b: 0.8096
Epoch 100 | loss: 0.018273 | w: 1.9003 | b: 0.9546
Epoch 150 | loss: 0.001726 | w: 1.9727 | b: 0.9854
Epoch 200 | loss: 0.000163 | w: 1.9924 | b: 0.9950

最终结果:w = 1.9924,b = 0.9950
期望结果:w = 2.0,b = 1.0

模型正在向 w=2, b=1 收敛,一切正常。

注意 zero_() 后面的下划线——这是 PyTorch 的 in-place 操作命名约定。所有以 _ 结尾的方法都是原地修改,不创建新 Tensor。zero_() 把 Tensor 里的所有数值置为 0。与之对应,zeros_like(t) 是创建一个新的全零 Tensor(不是 in-place)。


第十步:当你用 nn.Module 和 optimizer 之后

上面手写的训练循环,在实际使用中会用 nn.Moduleoptimizer 来替代,但背后的逻辑是一模一样的:

import torch
import torch.nn as nn
import torch.optim as optim

# 数据
x = torch.tensor([[1.0], [2.0], [3.0], [4.0]])
y = torch.tensor([[3.0], [5.0], [7.0], [9.0]])

# 模型:一个线性层,等价于 w*x + b
model = nn.Linear(1, 1)

# 优化器:SGD,lr=0.01
# optimizer 内部持有模型参数的引用,知道该更新哪些东西
optimizer = optim.SGD(model.parameters(), lr=0.01)

loss_fn = nn.MSELoss()

for epoch in range(200):
    # ① 前向传播
    pred = model(x)
    loss = loss_fn(pred, y)

    # ② 清零梯度(optimizer.zero_grad() 等价于前面手写的清零)
    optimizer.zero_grad()

    # ③ 反向传播
    loss.backward()

    # ④ 参数更新(optimizer.step() 等价于前面手写的 w -= lr * w.grad)
    # optimizer 内部自动用 torch.no_grad() 包裹参数更新
    optimizer.step()

    if (epoch + 1) % 50 == 0:
        print(f"Epoch {epoch+1:3d} | loss: {loss.item():.6f}")

optimizer.zero_grad()loss.backward()optimizer.step() 这三步是 PyTorch 训练循环的铁三角,顺序不能错,每步背后干的事情你现在都清楚了。

一个值得注意的细节:optimizer.zero_grad() 可以写在 optimizer.step() 之后(下一次迭代开始之前),也可以写在前向传播之前。两者效果相同,但有人偏好写在 step 之后,理由是这样的代码更接近"更新完参数,立刻清理,准备下一轮"的语义。在带 Gradient Accumulation 的场景里,清零时机需要更精细的控制,但那是后面的话题。


第十一步:梯度消失和梯度爆炸是怎么回事

理解了梯度是如何在图上流动的,现在可以自然地解释这两个深度学习里最烦恼的问题。

梯度在每一层都是通过链式法则相乘传播的:

$\frac{\partial L}{\partial w_1} = \frac{\partial L}{\partial a_n} \cdot \frac{\partial a_n}{\partial a_{n-1}} \cdot \frac{\partial a_{n-1}}{\partial a_{n-2}} \cdots \frac{\partial a_2}{\partial a_1} \cdot \frac{\partial a_1}{\partial w_1}$

这是 $n$ 项相乘。如果每一项都小于 1(比如 sigmoid 函数的导数最大只有 0.25),乘了 $n$ 次之后梯度就趋近于 0——梯度消失,靠近输入的层几乎没有梯度,根本学不动。

反过来,如果每一项都大于 1,乘了 $n$ 次之后梯度会爆炸式增长——梯度爆炸,参数更新步幅极大,训练直接发散。

# 用代码演示梯度消失
import torch
import torch.nn as nn

# 10 层网络,用 sigmoid 激活
layers = []
for i in range(10):
    layers.append(nn.Linear(10, 10))
    layers.append(nn.Sigmoid())   # sigmoid 导数最大 0.25
model = nn.Sequential(*layers)

x = torch.randn(1, 10)
y = torch.randn(1, 10)
loss = ((model(x) - y) ** 2).mean()
loss.backward()

# 看第一层和最后一层权重的梯度大小
first_layer_grad = model[0].weight.grad.abs().mean().item()
last_layer_grad  = model[-2].weight.grad.abs().mean().item()
print(f"最后一层梯度均值:  {last_layer_grad:.6f}")
print(f"第一层梯度均值:    {first_layer_grad:.6f}")
# 第一层的梯度会比最后一层小几个数量级

这就是为什么:

  • ReLU 取代 sigmoid 成为默认激活函数——ReLU 的导数是 1(正区间),不会持续缩小梯度
  • ResNet 引入残差连接——给梯度提供了一条"高速公路",不用穿过所有层就能传回来
  • 梯度裁剪(Gradient Clipping)用于对抗梯度爆炸——在更新参数之前,把梯度的范数限制在一个最大值内
# 梯度裁剪的标准写法
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
# clip_grad_norm_ 会计算所有参数梯度的全局范数,
# 如果超过 max_norm,就等比例缩小所有梯度,使总范数等于 max_norm

第十二步:自定义反向传播(选读)

大多数情况下,PyTorch 内置的 grad_fn 能自动处理一切。但偶尔你需要自己定义一个操作的反向传播——比如这个操作在数学上不可微但你知道一个合理的近似梯度,或者你在实现一个论文里的自定义 loss。

这时候用 torch.autograd.Function

class MySquare(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        # ctx 是一个上下文对象,用来在 forward 和 backward 之间传递信息
        ctx.save_for_backward(x)   # 把前向需要传给反向的值存起来
        return x ** 2

    @staticmethod
    def backward(ctx, grad_output):
        # grad_output:从输出方向传来的梯度(即 dL/d_output)
        x, = ctx.saved_tensors      # 取出前向存的值
        # 本地梯度:d(x^2)/dx = 2x
        # 链式法则:dL/dx = dL/d_output * d_output/dx = grad_output * 2x
        return grad_output * 2 * x

# 使用自定义操作
x = torch.tensor(3.0, requires_grad=True)
y = MySquare.apply(x)   # 注意:用 .apply() 调用,不是直接 ()
y.backward()
print(x.grad)   # tensor(6.)  ← 2 * 3 = 6 ✓

这个机制在以下场景里会用到:

  • 实现 Straight-Through Estimator(量化感知训练里绕过取整操作的不可微性)
  • 实现某些特殊的激活函数或者 loss 函数
  • 性能优化:用 CUDA kernel 实现自定义算子,同时告诉 PyTorch 它的反向传播怎么算

小结:整条链路

从头到尾,自动微分的完整工作流是这样的:

1. 给参数设置 requires_grad=True
         ↓
2. 前向传播:计算值 + 实时构建计算图
   每个操作产生一个 grad_fn 节点,记录操作类型和所需的中间数据
         ↓
3. loss.backward()
   从 loss 出发,沿图反向遍历
   每个节点的 grad_fn 接收上游梯度,用局部导数 * 上游梯度 = 本节点的梯度,传给下游
   中间节点的梯度随用随弃
   叶节点(参数)的梯度累加到 .grad 里
         ↓
4. 用 .grad 更新参数(在 no_grad 里做)
         ↓
5. 清零 .grad(zero_grad),准备下一次迭代

关键认知点汇总:

  • requires_grad 是追踪的开关,只有设置了它的 Tensor 及其下游才会被追踪;
  • grad_fn 是每个非叶节点的出生证明,存储了反向传播需要的中间信息;
  • 计算图在 backward() 之后默认释放,需要重复使用时加 retain_graph=True;
  • 梯度是累加的,不是覆盖的,训练循环每次迭代必须主动清零;
  • detach() 切断某个 Tensor 的梯度流,no_grad() 关掉整个代码块的追踪;
  • 叶节点的 .grad 会保留,非叶节点的默认释放,需要时用 retain_grad();
  • 梯度消失和梯度爆炸是链式法则连乘在深层网络里的必然结果,残差连接和梯度裁剪是工程上的对策。

下一篇,我们进入数据管道——DatasetDataLoader。你会看到,即便是加载数据这件看似平凡的事,在 PyTorch 里也有一套清晰的设计哲学值得理解透彻。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐