链式法则是微积分中用于求复合函数导数的核心法则。它在机器学习中尤其重要,因为神经网络的反向传播算法本质上就是链式法则的反复应用。本文将全面讲解链式法则,从单变量到多变量,从理论到应用。


7.1 什么是链式法则?

1) 单变量链式法则

假设有两个函数:

  • y = f ( u ) y = f(u) y=f(u)
  • u = g ( x ) u = g(x) u=g(x)

那么 y y y 作为 $ x $ 的复合函数 y = f ( g ( x ) ) y = f(g(x)) y=f(g(x)) 的导数为:
d y d x = d y d u ⋅ d u d x \frac{dy}{dx} = \frac{dy}{du} \cdot \frac{du}{dx} dxdy=dudydxdu

换句话说,复合函数的导数等于外层函数对内层变量的导数乘以内层函数对自变量的导数

例子:求 y = ( 3 x 2 + 1 ) 5 y = (3x^2 + 1)^5 y=(3x2+1)5 的导数。

  • u = 3 x 2 + 1 u = 3x^2 + 1 u=3x2+1,则 y = u 5 y = u^5 y=u5
  • d y d u = 5 u 4 \frac{dy}{du} = 5u^4 dudy=5u4 d u d x = 6 x \frac{du}{dx} = 6x dxdu=6x
  • 因此 d y d x = 5 u 4 ⋅ 6 x = 30 x ( 3 x 2 + 1 ) 4 \frac{dy}{dx} = 5u^4 \cdot 6x = 30x(3x^2+1)^4 dxdy=5u46x=30x(3x2+1)4
2) 几何直观

可以把导数理解为变化率。 d u / d x du/dx du/dx 表示 x x x 变化时 u u u 的变化速度, d y / d u dy/du dy/du 表示 u u u 变化时 y y y 的变化速度。那么 $ x $ 变化引起的 y y y 的变化速度,自然就是这两个速度的乘积。


7.2 多变量链式法则

当内层函数有多个变量时,链式法则需要扩展到偏导数。

1) 一个自变量,多个中间变量

z = f ( x , y ) z = f(x, y) z=f(x,y),而 x = g ( t ) x = g(t) x=g(t) y = h ( t ) y = h(t) y=h(t),则 z z z 关于 t t t 的导数为:
d z d t = ∂ z ∂ x ⋅ d x d t + ∂ z ∂ y ⋅ d y d t \frac{dz}{dt} = \frac{\partial z}{\partial x} \cdot \frac{dx}{dt} + \frac{\partial z}{\partial y} \cdot \frac{dy}{dt} dtdz=xzdtdx+yzdtdy

例子 z = x 2 + y 2 z = x^2 + y^2 z=x2+y2 x = sin ⁡ t x = \sin t x=sint y = cos ⁡ t y = \cos t y=cost,求 d z / d t dz/dt dz/dt

  • ∂ z ∂ x = 2 x \frac{\partial z}{\partial x} = 2x xz=2x ∂ z ∂ y = 2 y \frac{\partial z}{\partial y} = 2y yz=2y
  • d x d t = cos ⁡ t \frac{dx}{dt} = \cos t dtdx=cost d y d t = − sin ⁡ t \frac{dy}{dt} = -\sin t dtdy=sint
  • 代入得 d z d t = 2 x cos ⁡ t + 2 y ( − sin ⁡ t ) = 2 sin ⁡ t cos ⁡ t − 2 cos ⁡ t sin ⁡ t = 0 \frac{dz}{dt} = 2x \cos t + 2y (-\sin t) = 2\sin t \cos t - 2\cos t \sin t = 0 dtdz=2xcost+2y(sint)=2sintcost2costsint=0
  • 事实上 z = sin ⁡ 2 t + cos ⁡ 2 t = 1 z = \sin^2 t + \cos^2 t = 1 z=sin2t+cos2t=1,导数为0,一致。
2) 多个自变量,多个中间变量

更一般地,设 z = f ( u 1 , u 2 , … , u m ) z = f(u_1, u_2, \dots, u_m) z=f(u1,u2,,um),每个 u i u_i ui x 1 , x 2 , … , x n x_1, x_2, \dots, x_n x1,x2,,xn 的函数,则 z z z x j x_j xj 的偏导数为:
∂ z ∂ x j = ∑ i = 1 m ∂ z ∂ u i ⋅ ∂ u i ∂ x j \frac{\partial z}{\partial x_j} = \sum_{i=1}^m \frac{\partial z}{\partial u_i} \cdot \frac{\partial u_i}{\partial x_j} xjz=i=1muizxjui

这就是多变量链式法则的通用形式。


7.3 向量形式的链式法则

在机器学习中,我们常处理向量和矩阵。这时链式法则用雅可比矩阵表示。

y = f ( u ) \mathbf{y} = \mathbf{f}(\mathbf{u}) y=f(u) u = g ( x ) \mathbf{u} = \mathbf{g}(\mathbf{x}) u=g(x),则复合函数 y = f ( g ( x ) ) \mathbf{y} = \mathbf{f}(\mathbf{g}(\mathbf{x})) y=f(g(x)) 的雅可比矩阵为:
∂ y ∂ x = ∂ y ∂ u ⋅ ∂ u ∂ x \frac{\partial \mathbf{y}}{\partial \mathbf{x}} = \frac{\partial \mathbf{y}}{\partial \mathbf{u}} \cdot \frac{\partial \mathbf{u}}{\partial \mathbf{x}} xy=uyxu
其中 ∂ y ∂ u \frac{\partial \mathbf{y}}{\partial \mathbf{u}} uy f \mathbf{f} f 的雅可比矩阵, ∂ u ∂ x \frac{\partial \mathbf{u}}{\partial \mathbf{x}} xu g \mathbf{g} g 的雅可比矩阵,乘积是矩阵乘法。

在反向传播中,我们通常计算标量损失 L L L 对参数的梯度。若 L L L 是标量, z \mathbf{z} z 是中间向量,则:
∂ L ∂ x = ∂ L ∂ z ⋅ ∂ z ∂ x \frac{\partial L}{\partial \mathbf{x}} = \frac{\partial L}{\partial \mathbf{z}} \cdot \frac{\partial \mathbf{z}}{\partial \mathbf{x}} xL=zLxz
这里的 ∂ L ∂ z \frac{\partial L}{\partial \mathbf{z}} zL 是行向量(梯度), ∂ z ∂ x \frac{\partial \mathbf{z}}{\partial \mathbf{x}} xz 是雅可比矩阵,结果也是行向量(梯度)。


7.4 链式法则在机器学习中的应用:反向传播

反向传播算法是训练神经网络的核心,它利用链式法则高效计算损失函数对每个参数的梯度。

1) 一个简单网络

考虑一个三层神经网络的一部分:

  • 输入 $ x $
  • 第一层: z 1 = W 1 x + b 1 z_1 = W_1 x + b_1 z1=W1x+b1 a 1 = σ ( z 1 ) a_1 = \sigma(z_1) a1=σ(z1)
  • 第二层: z 2 = W 2 a 1 + b 2 z_2 = W_2 a_1 + b_2 z2=W2a1+b2 a 2 = σ ( z 2 ) a_2 = \sigma(z_2) a2=σ(z2)
  • 损失 L = Loss ( a 2 , y ) L = \text{Loss}(a_2, y) L=Loss(a2,y)

我们想求 ∂ L ∂ W 2 \frac{\partial L}{\partial W_2} W2L。根据链式法则:
∂ L ∂ W 2 = ∂ L ∂ a 2 ⋅ ∂ a 2 ∂ z 2 ⋅ ∂ z 2 ∂ W 2 \frac{\partial L}{\partial W_2} = \frac{\partial L}{\partial a_2} \cdot \frac{\partial a_2}{\partial z_2} \cdot \frac{\partial z_2}{\partial W_2} W2L=a2Lz2a2W2z2

  • ∂ L ∂ a 2 \frac{\partial L}{\partial a_2} a2L:损失对输出的梯度。
  • ∂ a 2 ∂ z 2 \frac{\partial a_2}{\partial z_2} z2a2:激活函数的导数(如sigmoid的导数为 σ ( z ) ( 1 − σ ( z ) ) \sigma(z)(1-\sigma(z)) σ(z)(1σ(z)))。
  • ∂ z 2 ∂ W 2 \frac{\partial z_2}{\partial W_2} W2z2:因为 z 2 = W 2 a 1 + b 2 z_2 = W_2 a_1 + b_2 z2=W2a1+b2,所以 ∂ z 2 ∂ W 2 = a 1 ⊤ \frac{\partial z_2}{\partial W_2} = a_1^\top W2z2=a1(考虑维度,通常得到梯度矩阵)。

类似地, ∂ L ∂ W 1 \frac{\partial L}{\partial W_1} W1L 需要继续向后传播:
∂ L ∂ W 1 = ∂ L ∂ a 2 ⋅ ∂ a 2 ∂ z 2 ⋅ ∂ z 2 ∂ a 1 ⋅ ∂ a 1 ∂ z 1 ⋅ ∂ z 1 ∂ W 1 \frac{\partial L}{\partial W_1} = \frac{\partial L}{\partial a_2} \cdot \frac{\partial a_2}{\partial z_2} \cdot \frac{\partial z_2}{\partial a_1} \cdot \frac{\partial a_1}{\partial z_1} \cdot \frac{\partial z_1}{\partial W_1} W1L=a2Lz2a2a1z2z1a1W1z1
每一层都是链式法则的乘积。

2) 计算图与自动微分

现代深度学习框架(如PyTorch、TensorFlow)使用计算图来记录运算,然后自动进行反向传播。它们正是基于链式法则,将每个操作的局部导数相乘并累加。


7.5 常见错误与注意事项

  1. 变量混淆:在应用链式法则时,必须清楚变量之间的依赖关系。
  2. 维度匹配:在多变量情形,注意偏导数的维度,确保矩阵乘法合法。
  3. 全导数与偏导数:当中间变量依赖于多个自变量时,需用求和符号将所有路径相加。

7.6 总结

  • 链式法则是求复合函数导数的法则,形式简洁但应用广泛。
  • 单变量形式 d y d x = d y d u ⋅ d u d x \frac{dy}{dx} = \frac{dy}{du} \cdot \frac{du}{dx} dxdy=dudydxdu
  • 多变量形式 ∂ z ∂ t = ∂ z ∂ x d x d t + ∂ z ∂ y d y d t \frac{\partial z}{\partial t} = \frac{\partial z}{\partial x} \frac{dx}{dt} + \frac{\partial z}{\partial y} \frac{dy}{dt} tz=xzdtdx+yzdtdy
  • 在机器学习中,它是反向传播算法的数学基础,使得神经网络能够高效训练。

掌握链式法则,就掌握了理解复杂模型梯度的钥匙。无论是简单的线性回归还是深度神经网络,链式法则都是不可或缺的工具。

上一章 机器学习微积分–(6)阶乘与阶数

下一章 机器学习微积分–(8)小结

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐