重练AI小号:大模型小白的重修日记(归一化)
文章目录
一.归一化
在深度学习的各个归一化方法中,其核心数学逻辑几乎是一致的。假设我们有一组待归一化的特征数据 x x x,归一化通常包含两个步骤:
第一步:标准化(Standardization)计算这组数据的均值 μ \mu μ 和方差 σ 2 \sigma^2 σ2,然后将数据转化为均值为 0 0 0、方差为 1 1 1 的分布: x ^ = x − μ σ 2 + ϵ \hat{x} = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} x^=σ2+ϵx−μ注: ϵ \epsilon ϵ 是一个极小的常量(如 10 − 5 10^{-5} 10−5),用于防止分母为零。
第二步:仿射变换(Affine Transformation / Scale and Shift)由于强制将数据限制在标准正态分布可能会破坏模型原本学到的特征表达(例如破坏激活函数如 ReLU 的非线性),我们需要引入两个可学习的参数:缩放因子 γ \gamma γ(Scale)和平移因子 β \beta β(Shift): y = γ x ^ + β y = \gamma \hat{x} + \beta y=γx^+β通过反向传播,模型可以自己学习决定是否要保留归一化,或者将分布恢复到归一化前的状态。
1.1 为啥要进行归一化
1.优化损失地形,加速模型收敛:

损失地形(Loss Landscape)描述的是:在当前固定的数据集下,模型参数取不同值时,整体损失(Loss)的高低变化。
对数据 x x x 归一化,为啥能优化参数 w w w 的损失地形?
答案藏在前向传播的乘法和反向传播的链式法则里。我们用图里的线性模型 y = w 1 x 1 + w 2 x 2 y = w_1x_1 + w_2x_2 y=w1x1+w2x2 来一步步拆解:
阶段一:未归一化时。 x 1 x_1 x1 代表距离,数值很大,比如 x 1 = 10000 x_1 = 10000 x1=10000 x 2 x_2 x2 代表比例,数值很小,比如 x 2 = 0.1 x_2 = 0.1 x2=0.1现在我们要求损失函数 L L L 对参数 w w w 的梯度(也就是地形的坡度)。根据链式法则:
(1) ∂ L ∂ w 1 = ∂ L ∂ y ⋅ ∂ y ∂ w 1 = ∂ L ∂ y ⋅ x 1 \frac{\partial L}{\partial w_1} = \frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial w_1} = \frac{\partial L}{\partial y} \cdot x_1 ∂w1∂L=∂y∂L⋅∂w1∂y=∂y∂L⋅x1
(2) ∂ L ∂ w 2 = ∂ L ∂ y ⋅ ∂ y ∂ w 2 = ∂ L ∂ y ⋅ x 2 \frac{\partial L}{\partial w_2} = \frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial w_2} = \frac{\partial L}{\partial y} \cdot x_2 ∂w2∂L=∂y∂L⋅∂w2∂y=∂y∂L⋅x2
参数的梯度大小,直接挂钩于对应输入数据的数值大小。因为 x 1 x_1 x1 特别大(10000),所以 ∂ L ∂ w 1 \frac{\partial L}{\partial w_1} ∂w1∂L 也会特别大。这意味着, w 1 w_1 w1 哪怕只改变一点点,Loss 就会剧烈飙升或骤降。在地形图上, w 1 w_1 w1 方向非常陡峭。因为 x 2 x_2 x2 特别小(0.1),所以 ∂ L ∂ w 2 \frac{\partial L}{\partial w_2} ∂w2∂L 会特别小。这意味着,你把 w 2 w_2 w2 怎么调,Loss 都懒得动。在地形图上, w 2 w_2 w2 方向非常平缓。
阶段二:归一化后。当我们对数据进行了归一化,把 x 1 x_1 x1 和 x 2 x_2 x2 都强行拉回了均值为 0、方差为 1 的分布。此时,在大部分时候, x 1 x_1 x1 和 x 2 x_2 x2 的数值大小处于同一个量级(比如都在 -2 到 2 之间)。我们再看梯度公式:
(1) ∂ L ∂ w 1 ≈ ∂ L ∂ y ⋅ ( 量级为 1 的数 ) \frac{\partial L}{\partial w_1} \approx \frac{\partial L}{\partial y} \cdot (\text{量级为 1 的数}) ∂w1∂L≈∂y∂L⋅(量级为 1 的数)
(2) ∂ L ∂ w 2 ≈ ∂ L ∂ y ⋅ ( 量级为 1 的数 ) \frac{\partial L}{\partial w_2} \approx \frac{\partial L}{\partial y} \cdot (\text{量级为 1 的数}) ∂w2∂L≈∂y∂L⋅(量级为 1 的数)
现在, ∂ L ∂ w 1 \frac{\partial L}{\partial w_1} ∂w1∂L 和 ∂ L ∂ w 2 \frac{\partial L}{\partial w_2} ∂w2∂L 的大小变得差不多了。这意味着,无论你在 w 1 w_1 w1 方向还是 w 2 w_2 w2 方向迈出一步,Loss 的变化幅度是均匀的、对称的。地形不再有极端的陡峭和平缓之分,被重新“捏”成了一个对称的圆形碗状
1.1.1 量级为1(数据vs参数)
“量级为 1” 指的就是“个位数级别”的数字。比如 0.5、1.2、-0.8、2.1、-1.5 等等。这些数字的共同特点是:它们的绝对值既不是成百上千的巨大数字,也不是零点零零几的微小数字,它们都在“1”的周边徘徊。
当我们对数据进行标准化(Standardization,即减去均值,除以标准差),我们强制把数据的分布拉成了一个均值为 0,方差为 1 的状态(通常近似于标准正态分布)。
当你在训练神经网络时,无论你最初的数据是几万(比如房价),还是零点零几(比如某种微量元素的比例),只要经过了归一化层,你随便抓取一个输出的特征 x ^ i \hat{x}_i x^i,它的绝对值大概率就是 0点几、1点几、最多 2点几。所有的特征,统统被压缩到了“个位数”这个尺度上,这就是所谓的“量级为 1”。
既然所有维度的数据都被压成了差不多的大小(都在 1 左右),那模型怎么区分“重要特征”和“垃圾特征”呢?靠的就是不受严格量级限制的参数(权重 W)。
假设输入了一张图片,经过归一化: x 1 x_1 x1(代表猫耳朵特征的激活值)被拉到了 1.5 x 2 x_2 x2(代表背景杂草特征的激活值)被拉到了 1.2它们在数据量级上大差不差。但是,神经网络在训练过程中会发现猫耳朵更重要,于是它把对应的权重更新为 w 1 = 10 w_1 = 10 w1=10,把背景杂草的权重更新为 w 2 = 0.01 w_2 = 0.01 w2=0.01。当数据乘上参数后( 1.5 × 10 1.5 \times 10 1.5×10 与 1.2 × 0.01 1.2 \times 0.01 1.2×0.01),区分度瞬间就拉开了!参数的不同,代表了模型对不同特征“关注度”的不同。
总结:数据负责稳定流动(量级为 1),参数负责拉开差距(提供区分度)。
2.归一化能缓解梯度消失与梯度爆炸
这主要与深度神经网络中的链式法则和激活函数有关。在深度网络中,数据要经过多层矩阵乘法: H = W n … W 2 W 1 X H = W_n \dots W_2 W_1 X H=Wn…W2W1X。如果初始化权重矩阵的值普遍大于 1,多层相乘后,数据的值会呈指数级放大(导致梯度爆炸)。如果权重矩阵的值普遍小于 1,多层相乘后,数据的值会呈指数级缩小,趋近于 0(导致梯度消失)。
更致命的是激活函数(如 Sigmoid 或 Tanh)。以 Sigmoid 为例,只有当输入值在 [ − 3 , 3 ] [-3, 3] [−3,3] 这个区间(也就是 0 的附近)时,它的导数才比较大(梯度明显)。如果输入值变得特别大(比如 100)或特别小(比如 -100),Sigmoid 函数的输出会非常接近 1 或 0,此时它的导数几乎为 0(这被称为饱和区)。根据链式法则,一旦有一层的导数是 0,传回来的梯度乘上 0,前面的网络层就再也接收不到梯度了,权重直接停止更新(彻底死掉)。
在数据进入激活函数之前,归一化强制把数据拉回到均值为 0,方差为 1 的分布。这就相当于把数据精准地“摁”在了激活函数导数最大的那段非饱和区(敏感区)。这样一来,每次反向传播时,梯度都能顺畅地流过激活函数,既不会因为数值太大而死掉(消失),也被方差限制住了狂飙的势头(爆炸)。
1.2 归一化是否影响信息准确性
(1)信息的本质是“相对关系”,而不是“绝对数值”
在深度学习(尤其是像自然语言处理中的词向量)中,特征真正有价值的部分,通常不是那个孤立的绝对数字,而是特征与特征之间的相对关系(比如大小顺序、距离远近、分布形状)。
假设你有一组温度数据代表今天的天气变化:摄氏度 [10, 20, 30]。如果你把它们转换成华氏度 [50, 68, 86],虽然数值完全变了,但这三个时间点“越来越热”的物理规律、以及它们之间的温差比例,被完美保留了下来。标准化操作 x ^ = x − μ σ \hat{x} = \frac{x - \mu}{\sigma} x^=σx−μ 本质上是一个线性变换(平移和缩放)这种线性操作绝不会改变数据的拓扑结构。原来在一维直线上排在左边的数据,归一化后依然在左边;原来在空间中聚集成一团的词向量,归一化后依然聚在一起。
(2)在高维空间中,“方向”比“长度”更重要
归一化(如 NLP 中常用的 Layer Normalization)主要是在调整向量各个维度的比例或者整体的缩放尺度。虽然向量变“短”或变“长”了,但它们指向的相对空间方向和语义特征模式并没有被抹平。下一层的神经网络依然可以通过权重矩阵的乘法,轻松捕捉到这些语义方向。
(3)最关键的“后悔药”:可训练参数 γ \gamma γ 和 β \beta β
归一化层还自带了两个可训练的参数:缩放因子 γ \gamma γ 和 平移因子 β \beta β。 y = γ x ^ + β y = \gamma \hat{x} + \beta y=γx^+β如果标准化对当前特征有益: 网络会保留标准化后的结果,此时 γ \gamma γ 接近 1 1 1, β \beta β 接近 0 0 0。如果标准化破坏了特征: 网络可以学习出特定的 γ \gamma γ 和 β \beta β,直接把分布还原回去!具体在数学上,如果网络让: γ = σ 2 + ϵ \gamma = \sqrt{\sigma^2 + \epsilon} γ=σ2+ϵ β = μ \beta = \mu β=μ代入公式 y = γ x − μ σ 2 + ϵ + β y = \gamma \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta y=γσ2+ϵx−μ+β 后,你会发现 y = x y = x y=x。这意味着,通过引入这两个参数,归一化层至少能够保证前向传播时不丢失原有的信息(即实现恒等映射),这就保证了插入归一化层后,模型的表达能力下限不会变差。
1.3 层归一化(Layer Normalization)

假设我们有一个输入数据 X X X,它是一个三维张量,形状为 [N, L, D],分别代表: N N N: Batch Size(批次大小) L L L: Sequence Length(序列长度,比如句子的词数) D D D: Hidden Dimension(隐藏层维度,也就是词向量的长度)。标准化后,我们得到矩阵 X ^ \hat{X} X^: X ^ = [ x ^ 11 x ^ 12 … x ^ 1 D ← 第 1 个词的所有特征 (用 μ 1 , σ 1 归一化) x ^ 21 x ^ 22 … x ^ 2 D ← 第 2 个词的所有特征 (用 μ 2 , σ 2 归一化) ⋮ ⋮ ⋱ ⋮ x ^ L 1 x ^ L 2 … x ^ L D ← 第 L 个词的所有特征 (用 μ L , σ L 归一化) ] \hat{X} = \begin{bmatrix} \hat{x}_{11} & \hat{x}_{12} & \dots & \hat{x}_{1D} \\ \leftarrow \text{第 1 个词的所有特征 (用 } \mu_1, \sigma_1 \text{ 归一化)} \\ \hat{x}_{21} & \hat{x}_{22} & \dots & \hat{x}_{2D} \\ \leftarrow \text{第 2 个词的所有特征 (用 } \mu_2, \sigma_2 \text{ 归一化)} \\ \vdots & \vdots & \ddots & \vdots \\ \hat{x}_{L1} & \hat{x}_{L2} & \dots & \hat{x}_{LD} \\ \leftarrow \text{第 L 个词的所有特征 (用 } \mu_L, \sigma_L \text{ 归一化)} \end{bmatrix} X^=
x^11←第 1 个词的所有特征 (用 μ1,σ1 归一化)x^21←第 2 个词的所有特征 (用 μ2,σ2 归一化)⋮x^L1←第 L 个词的所有特征 (用 μL,σL 归一化)x^12x^22⋮x^L2……⋱…x^1Dx^2D⋮x^LD
在 LN 中, γ \gamma γ 和 β \beta β 依然是长度为 D D D 的一维向量。 γ = [ γ 1 , γ 2 , … , γ D ] \gamma = [\gamma_1, \gamma_2, \dots, \gamma_D] γ=[γ1,γ2,…,γD] β = [ β 1 , β 2 , … , β D ] \beta = [\beta_1, \beta_2, \dots, \beta_D] β=[β1,β2,…,βD]
现在我们执行 Y = γ ⊙ X ^ + β Y = \gamma \odot \hat{X} + \beta Y=γ⊙X^+β。通过广播机制,长度为 D D D 的向量 γ \gamma γ 和 β \beta β 会“复制” L L L 份,向下扩展去和每一行相乘相加。展开后的矩阵是这样的: Y = [ γ 1 x ^ 11 + β 1 γ 2 x ^ 12 + β 2 … γ D x ^ 1 D + β D γ 1 x ^ 21 + β 1 γ 2 x ^ 22 + β 2 … γ D x ^ 2 D + β D ⋮ ⋮ ⋱ ⋮ γ 1 x ^ L 1 + β 1 γ 2 x ^ L 2 + β 2 … γ D x ^ L D + β D ] Y = \begin{bmatrix} \gamma_1 \hat{x}_{11} + \beta_1 & \gamma_2 \hat{x}_{12} + \beta_2 & \dots & \gamma_D \hat{x}_{1D} + \beta_D \\ \gamma_1 \hat{x}_{21} + \beta_1 & \gamma_2 \hat{x}_{22} + \beta_2 & \dots & \gamma_D \hat{x}_{2D} + \beta_D \\ \vdots & \vdots & \ddots & \vdots \\ \gamma_1 \hat{x}_{L1} + \beta_1 & \gamma_2 \hat{x}_{L2} + \beta_2 & \dots & \gamma_D \hat{x}_{LD} + \beta_D \end{bmatrix} Y=
γ1x^11+β1γ1x^21+β1⋮γ1x^L1+β1γ2x^12+β2γ2x^22+β2⋮γ2x^L2+β2……⋱…γDx^1D+βDγDx^2D+βD⋮γDx^LD+βD
1.4 批归一化(Batch Normalization)

在计算机视觉中,我们的输入数据(特征图)通常是一个四维张量,形状为 [N, C, H, W]。BN 的核心法则:在“通道(Channel)”维度上保持独立,在另外三个维度( N , H , W N, H, W N,H,W)上进行大锅饭式的统计。数学表达如下(针对第 c c c 个通道): μ c = 1 N × H × W ∑ n = 1 N ∑ h = 1 H ∑ w = 1 W x n , c , h , w \mu_c = \frac{1}{N \times H \times W} \sum_{n=1}^{N} \sum_{h=1}^{H} \sum_{w=1}^{W} x_{n,c,h,w} μc=N×H×W1n=1∑Nh=1∑Hw=1∑Wxn,c,h,w算出统计量后,依然是我们熟悉的配方——标准化、然后乘以可训练的缩放参数 γ c \gamma_c γc 和平移参数 β c \beta_c βc: y n , c , h , w = γ c ( x n , c , h , w − μ c σ c 2 + ϵ ) + β c y_{n,c,h,w} = \gamma_c \left( \frac{x_{n,c,h,w} - \mu_c}{\sqrt{\sigma_c^2 + \epsilon}} \right) + \beta_c yn,c,h,w=γc(σc2+ϵxn,c,h,w−μc)+βc
1.5 LN.vs BN.
为什么语言模型用LN,视觉用BN?
在计算机视觉的 CNN 中,数据形状是 ( N , C , H , W ) (N, C, H, W) (N,C,H,W)。BN 是在 ( N , H , W ) (N, H, W) (N,H,W) 上求均值,这意味着它把整个 Batch 里所有图片的所有像素点,在同一个通道 C C C 上的值混在一起算了一个分布。CNN 中,一个特定的通道(Channel)代表一种特定的滤波器(Filter / 特征提取器)。例如,通道 1 专门提取“边缘”,通道 2 专门提取“红色纹理”。“边缘”或“红色纹理”这种低级/中级视觉特征,在所有的图片中(无论这是猫的图片还是狗的图片),它们的统计分布是极其相似且稳定的。因此,跨越整个 Batch 去计算一个“边缘特征”的全局均值和方差,是非常合理的,这能给模型提供一个非常稳定且具有代表性的基准。
在语言模型(如 Transformer)中,数据形状是 ( N , L , D ) (N, L, D) (N,L,D)。如果用 BN,就是在 ( N , L ) (N, L) (N,L) 上求均值,即把整个 Batch 里所有句子的所有词,在同一个特征维度 D D D 上的值混在一起算分布。这在 NLP 中是灾难性的。语言具有极强的高阶语义和语境依赖。在 Batch 中,句子 A 讲的是“量子力学”,句子 B 讲的是“今天中午吃什么”。词汇的特征分布随语境剧烈震荡。如果你把“量子”和“大白菜”在某个特征维度上强行平均,算出的 μ \mu μ 和 σ \sigma σ 是毫无统计学意义的“缝合怪”。 LN 是在特征维度 D D D 上求均值(即对单独的一个词的 D D D 个特征求均值)。它放弃了跨样本寻找共性,而是强制每个词自身内部的能量(特征方差)保持稳定。这保证了无论这个词是“量子”还是“白菜”,它的特征向量在进入下一层前,都不会因为绝对数值过大或过小而导致梯度异常。
并且,语言是参差不齐的“变长序列(Variable-length Sequence)。在 NLP 中,句子的长度是不一样的。为了凑成一个 Batch 的矩阵 ( N , L , D ) (N, L, D) (N,L,D),我们必须以最长的句子为准,对短句子进行大量的填充(Padding,通常用 0 补齐)。如果使用 BN 跨样本计算均值和方差,那些大量毫无意义的 Padding(0) 会严重污染真实的统计量。Batch 里面短句子越多,算出来的均值就越被拉低,方差也就越畸形。 LN 只在每一个词自身内部(特征维度 D D D)进行计算。真正的词算自己的 μ \mu μ 和 σ \sigma σ,Padding 算 Padding 的(后续通过掩码 Mask 忽略即可)。词与词之间互不干扰,完全免疫了长短不一带来的问题。
用一句话来概括:计算机视觉(CV) 处理的是空间上结构对齐、特征语义具备全局一致性的数据,因此 BN 可以跨样本收集丰满的统计信息。语言模型(NLP) 处理的是时序上长度多变、特征语义高度依赖上下文且推理过程动态生长的数据,因此只能采用“各自为政、互不干涉”的 LN 来保证单个 Token 特征的稳定。
1.6.RMSNorm (Root Mean Square Normalization,均方根归一化)
RMSNorm 是 LN 的“极简加速版”。 在当今的大语言模型(LLM)时代,像 LLaMA (1/2/3)、Mistral、Qwen 等几乎所有主流的开源大模型,都已经抛弃了传统的 LN,全面拥抱了 RMSNorm。
RMSNorm 的作者(Biao Zhang 等人在 2019 年提出)敏锐地发现:LN 里面计算均值(Mean)并且减去均值这一步,其实是个“鸡肋”,既消耗算力,对模型性能的提升又没啥大用。于是,RMSNorm 大刀阔斧地砍掉了与“均值”和“平移”相关的所有操作:不计算均值 μ \mu μ(直接强制假设均值为 0)。不减去均值,直接计算数据的均方根(RMS)。不要平移参数 β \beta β,只保留缩放参数 γ \gamma γ。
均方根(Root Mean Square, RMS)的定义: RMS ( x ) = 1 D ∑ i = 1 D x i 2 + ϵ \text{RMS}(x) = \sqrt{\frac{1}{D} \sum_{i=1}^{D} x_i^2 + \epsilon} RMS(x)=D1i=1∑Dxi2+ϵ(注:你看,公式里根本没有 μ \mu μ,直接就是每个元素的平方求和算平均,然后再开根号。)RMSNorm 的最终计算公式: y = x RMS ( x ) ⋅ γ y = \frac{x}{\text{RMS}(x)} \cdot \gamma y=RMS(x)x⋅γ在 GPU/TPU 的底层执行中,计算均值需要遍历一次数据,计算方差又要遍历一次(因为要先算出均值才能算方差)。RMSNorm 直接算平方和,只需要遍历一次数据,极大地减少了内存的读取操作(Memory Access)。实验表明,RMSNorm 比 LN 的前向和反向传播速度快了约 10% 到 50%。
二.归一化层内部的反向传播
2.1 反向传播中的参数梯度与数据梯度
最终被优化器更新的,只有模型的参数(比如归一化层里的 γ \gamma γ 和 β \beta β,或者卷积层/全连接层里的权重 W W W 和偏置 b b b)。数据(也就是每层的输入/输出特征 x i x_i xi)是绝对不会被更新的。既然数据不需要更新,为什么我们还要辛辛苦苦推导并计算损失函数关于数据 x i x_i xi 的梯度( ∂ L ∂ x i \frac{\partial L}{\partial x_i} ∂xi∂L)呢?
答案就四个字:为了传代(链式法则)。
在归一化层中, x i x_i xi 是它接收到的输入数据。同时,这个 x i x_i xi 也是上一层(卷积层)的输出数据。归一化层计算出 ∂ L ∂ x i \frac{\partial L}{\partial x_i} ∂xi∂L 后,并不是拿它来更新 x i x_i xi,而是把它当作“接力棒”,原封不动地传给上一层。卷积层里有它自己的参数(权重 W W W)。卷积层想要更新自己的参数 W W W,就必须知道损失函数 L L L 对 W W W 的梯度( ∂ L ∂ W \frac{\partial L}{\partial W} ∂W∂L)。根据微积分的链式法则,卷积层计算 ∂ L ∂ W \frac{\partial L}{\partial W} ∂W∂L 的公式是: ∂ L ∂ W = ∂ L ∂ x i ⋅ ∂ x i ∂ W \frac{\partial L}{\partial W} = \frac{\partial L}{\partial x_i} \cdot \frac{\partial x_i}{\partial W} ∂W∂L=∂xi∂L⋅∂W∂xi ∂ x i ∂ W \frac{\partial x_i}{\partial W} ∂W∂xi:这是卷积层自己内部可以算出来的(输出对权重的导数)。 ∂ L ∂ x i \frac{\partial L}{\partial x_i} ∂xi∂L:这就是它必须从下一层(归一化层)要过来的梯度!
总结:在任何一个网络层内部,反向传播都承担着两个完全不同的任务:对参数求导(为了自己):计算 ∂ L ∂ 参数 \frac{\partial L}{\partial \text{参数}} ∂参数∂L,交给优化器(如 SGD, Adam),用来更新本层的参数。对输入求导(为了别人):计算 ∂ L ∂ 输入 \frac{\partial L}{\partial \text{输入}} ∂输入∂L,作为误差信号传给上一层,让上一层能够利用链式法则去更新上一层的参数。
2.2 归一化层内部的反向传播数学细节
1.变量定义与前向传播计算图
假设网络最终输出的标量损失函数为 L ∈ R L \in \mathbb{R} L∈R。给定 LN 层的输入向量 x ∈ R D x \in \mathbb{R}^D x∈RD,其前向传播构成了如下依赖关系的计算图:
均值节点 μ \mu μ: μ = 1 D ∑ k = 1 D x k \mu = \frac{1}{D} \sum_{k=1}^D x_k μ=D1k=1∑Dxk方差节点 σ 2 \sigma^2 σ2: σ 2 = 1 D ∑ k = 1 D ( x k − μ ) 2 \sigma^2 = \frac{1}{D} \sum_{k=1}^D (x_k - \mu)^2 σ2=D1k=1∑D(xk−μ)2标准化节点 x ^ i \hat{x}_i x^i (对于 i = 1 , 2 , … , D i = 1, 2, \dots, D i=1,2,…,D): x ^ i = x i − μ σ 2 + ϵ \hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}} x^i=σ2+ϵxi−μ仿射变换节点 y i y_i yi (对于 i = 1 , 2 , … , D i = 1, 2, \dots, D i=1,2,…,D): y i = γ x ^ i + β y_i = \gamma \hat{x}_i + \beta yi=γx^i+β
已知条件: 算法已通过下游网络计算出损失函数对当前层输出的偏导数,即梯度向量 ∇ y L \nabla_y L ∇yL,其各个分量为 ∂ L ∂ y i \frac{\partial L}{\partial y_i} ∂yi∂L。
推导目标: 计算损失函数对本层输入以及可学习参数的偏导数 ∂ L ∂ x i \frac{\partial L}{\partial x_i} ∂xi∂L、 ∂ L ∂ γ \frac{\partial L}{\partial \gamma} ∂γ∂L、 ∂ L ∂ β \frac{\partial L}{\partial \beta} ∂β∂L。
2. 规范推导过程
第一步:计算对仿射参数 γ \gamma γ 和 β \beta β 的偏导数
根据多元函数求导法则,由于所有的 y i y_i yi 都依赖于全局共享参数 γ \gamma γ 和 β \beta β,需对其在特征维度 D D D 上进行全微分求和: ∂ L ∂ γ = ∑ i = 1 D ∂ L ∂ y i ∂ y i ∂ γ = ∑ i = 1 D ∂ L ∂ y i x ^ i \frac{\partial L}{\partial \gamma} = \sum_{i=1}^D \frac{\partial L}{\partial y_i} \frac{\partial y_i}{\partial \gamma} = \sum_{i=1}^D \frac{\partial L}{\partial y_i} \hat{x}_i ∂γ∂L=i=1∑D∂yi∂L∂γ∂yi=i=1∑D∂yi∂Lx^i ∂ L ∂ β = ∑ i = 1 D ∂ L ∂ y i ∂ y i ∂ β = ∑ i = 1 D ∂ L ∂ y i \frac{\partial L}{\partial \beta} = \sum_{i=1}^D \frac{\partial L}{\partial y_i} \frac{\partial y_i}{\partial \beta} = \sum_{i=1}^D \frac{\partial L}{\partial y_i} ∂β∂L=i=1∑D∂yi∂L∂β∂yi=i=1∑D∂yi∂L
第二步:计算对标准化中间变量 x ^ i \hat{x}_i x^i 的偏导数
由于每个 y i y_i yi 仅由对应的 x ^ i \hat{x}_i x^i 决定(无交叉依赖),此步为单变量链式法则: ∂ L ∂ x ^ i = ∂ L ∂ y i ∂ y i ∂ x ^ i = ∂ L ∂ y i γ \frac{\partial L}{\partial \hat{x}_i} = \frac{\partial L}{\partial y_i} \frac{\partial y_i}{\partial \hat{x}_i} = \frac{\partial L}{\partial y_i} \gamma ∂x^i∂L=∂yi∂L∂x^i∂yi=∂yi∂Lγ第三步:计算对方差 σ 2 \sigma^2 σ2 的偏导数
计算图显示,方差 σ 2 \sigma^2 σ2 是所有标准化变量 x ^ j \hat{x}_j x^j ( j = 1 , … , D j = 1, \dots, D j=1,…,D) 的分母部分。根据多元链式法则,需对所有路径求和: ∂ L ∂ σ 2 = ∑ j = 1 D ∂ L ∂ x ^ j ∂ x ^ j ∂ σ 2 \frac{\partial L}{\partial \sigma^2} = \sum_{j=1}^D \frac{\partial L}{\partial \hat{x}_j} \frac{\partial \hat{x}_j}{\partial \sigma^2} ∂σ2∂L=j=1∑D∂x^j∂L∂σ2∂x^j对 x ^ j \hat{x}_j x^j 结合 σ 2 \sigma^2 σ2 求偏导: ∂ x ^ j ∂ σ 2 = ∂ ∂ σ 2 [ ( x j − μ ) ( σ 2 + ϵ ) − 1 2 ] = − 1 2 ( x j − μ ) ( σ 2 + ϵ ) − 3 2 \frac{\partial \hat{x}_j}{\partial \sigma^2} = \frac{\partial}{\partial \sigma^2} \left[ (x_j - \mu)(\sigma^2 + \epsilon)^{-\frac{1}{2}} \right] = -\frac{1}{2}(x_j - \mu)(\sigma^2 + \epsilon)^{-\frac{3}{2}} ∂σ2∂x^j=∂σ2∂[(xj−μ)(σ2+ϵ)−21]=−21(xj−μ)(σ2+ϵ)−23代入求和公式: ∂ L ∂ σ 2 = − 1 2 ( σ 2 + ϵ ) − 3 2 ∑ j = 1 D ∂ L ∂ x ^ j ( x j − μ ) \frac{\partial L}{\partial \sigma^2} = -\frac{1}{2} (\sigma^2 + \epsilon)^{-\frac{3}{2}} \sum_{j=1}^D \frac{\partial L}{\partial \hat{x}_j} (x_j - \mu) ∂σ2∂L=−21(σ2+ϵ)−23j=1∑D∂x^j∂L(xj−μ)第四步:计算对均值 μ \mu μ 的偏导数
在计算图中,均值 μ \mu μ 是 σ 2 \sigma^2 σ2 的输入,同时也是所有 x ^ j \hat{x}_j x^j 的直接输入。其偏导数包含两部分: ∂ L ∂ μ = ∑ j = 1 D ∂ L ∂ x ^ j ∂ x ^ j ∂ μ + ∂ L ∂ σ 2 ∂ σ 2 ∂ μ \frac{\partial L}{\partial \mu} = \sum_{j=1}^D \frac{\partial L}{\partial \hat{x}_j} \frac{\partial \hat{x}_j}{\partial \mu} + \frac{\partial L}{\partial \sigma^2} \frac{\partial \sigma^2}{\partial \mu} ∂μ∂L=j=1∑D∂x^j∂L∂μ∂x^j+∂σ2∂L∂μ∂σ2分别计算两个局部偏导数: ∂ x ^ j ∂ μ = ∂ ∂ μ [ x j − μ σ 2 + ϵ ] = − 1 σ 2 + ϵ \frac{\partial \hat{x}_j}{\partial \mu} = \frac{\partial}{\partial \mu} \left[ \frac{x_j - \mu}{\sqrt{\sigma^2 + \epsilon}} \right] = -\frac{1}{\sqrt{\sigma^2 + \epsilon}} ∂μ∂x^j=∂μ∂[σ2+ϵxj−μ]=−σ2+ϵ1 ∂ σ 2 ∂ μ = ∂ ∂ μ [ 1 D ∑ k = 1 D ( x k − μ ) 2 ] = 1 D ∑ k = 1 D − 2 ( x k − μ ) = − 2 ( 1 D ∑ k = 1 D x k − μ ) = − 2 ( μ − μ ) = 0 \frac{\partial \sigma^2}{\partial \mu} = \frac{\partial}{\partial \mu} \left[ \frac{1}{D} \sum_{k=1}^D (x_k - \mu)^2 \right] = \frac{1}{D} \sum_{k=1}^D -2(x_k - \mu) = -2 \left( \frac{1}{D} \sum_{k=1}^D x_k - \mu \right) = -2(\mu - \mu) = 0 ∂μ∂σ2=∂μ∂[D1k=1∑D(xk−μ)2]=D1k=1∑D−2(xk−μ)=−2(D1k=1∑Dxk−μ)=−2(μ−μ)=0由于 ∂ σ 2 ∂ μ = 0 \frac{\partial \sigma^2}{\partial \mu} = 0 ∂μ∂σ2=0,第二项消去,结果简化为: ∂ L ∂ μ = − 1 σ 2 + ϵ ∑ j = 1 D ∂ L ∂ x ^ j \frac{\partial L}{\partial \mu} = -\frac{1}{\sqrt{\sigma^2 + \epsilon}} \sum_{j=1}^D \frac{\partial L}{\partial \hat{x}_j} ∂μ∂L=−σ2+ϵ1j=1∑D∂x^j∂L第五步:计算对输入 x i x_i xi 的偏导数(最终反向传播输出)
根据有向无环图,输入 x i x_i xi 到最终损失 L L L 有三条有效的前向路径:直接到 x ^ i \hat{x}_i x^i、通过 μ \mu μ、通过 σ 2 \sigma^2 σ2。应用全微分公式: ∂ L ∂ x i = ∂ L ∂ x ^ i ∂ x ^ i ∂ x i ∣ direct + ∂ L ∂ σ 2 ∂ σ 2 ∂ x i ∣ direct + ∂ L ∂ μ ∂ μ ∂ x i \frac{\partial L}{\partial x_i} = \frac{\partial L}{\partial \hat{x}_i} \frac{\partial \hat{x}_i}{\partial x_i} \bigg|_{\text{direct}} + \frac{\partial L}{\partial \sigma^2} \frac{\partial \sigma^2}{\partial x_i} \bigg|_{\text{direct}} + \frac{\partial L}{\partial \mu} \frac{\partial \mu}{\partial x_i} ∂xi∂L=∂x^i∂L∂xi∂x^i
direct+∂σ2∂L∂xi∂σ2
direct+∂μ∂L∂xi∂μ分别计算节点间的直接偏导数: ∂ x ^ i ∂ x i ∣ direct = 1 σ 2 + ϵ \frac{\partial \hat{x}_i}{\partial x_i}\bigg|_{\text{direct}} = \frac{1}{\sqrt{\sigma^2 + \epsilon}} ∂xi∂x^i
direct=σ2+ϵ1 ∂ σ 2 ∂ x i ∣ direct = 2 ( x i − μ ) D \frac{\partial \sigma^2}{\partial x_i}\bigg|_{\text{direct}} = \frac{2(x_i - \mu)}{D} ∂xi∂σ2
direct=D2(xi−μ) ∂ μ ∂ x i = 1 D \frac{\partial \mu}{\partial x_i} = \frac{1}{D} ∂xi∂μ=D1将前三步求得的 ∂ L ∂ x ^ i \frac{\partial L}{\partial \hat{x}_i} ∂x^i∂L、 ∂ L ∂ σ 2 \frac{\partial L}{\partial \sigma^2} ∂σ2∂L、 ∂ L ∂ μ \frac{\partial L}{\partial \mu} ∂μ∂L 与上述结果代入总公式: ∂ L ∂ x i = ∂ L ∂ x ^ i ( 1 σ 2 + ϵ ) + ( − 1 2 ( σ 2 + ϵ ) − 3 2 ∑ j = 1 D ∂ L ∂ x ^ j ( x j − μ ) ) ( 2 ( x i − μ ) D ) + ( − 1 σ 2 + ϵ ∑ j = 1 D ∂ L ∂ x ^ j ) ( 1 D ) \frac{\partial L}{\partial x_i} = \frac{\partial L}{\partial \hat{x}_i} \left( \frac{1}{\sqrt{\sigma^2 + \epsilon}} \right) + \left( -\frac{1}{2} (\sigma^2 + \epsilon)^{-\frac{3}{2}} \sum_{j=1}^D \frac{\partial L}{\partial \hat{x}_j} (x_j - \mu) \right) \left( \frac{2(x_i - \mu)}{D} \right) + \left( -\frac{1}{\sqrt{\sigma^2 + \epsilon}} \sum_{j=1}^D \frac{\partial L}{\partial \hat{x}_j} \right) \left( \frac{1}{D} \right) ∂xi∂L=∂x^i∂L(σ2+ϵ1)+(−21(σ2+ϵ)−23j=1∑D∂x^j∂L(xj−μ))(D2(xi−μ))+(−σ2+ϵ1j=1∑D∂x^j∂L)(D1)3. 公式代数化简 (至标准 CUDA 算子实现形式)
为降低计算图节点的冗余计算,对上式进行代数整理。合并同类项 1 σ 2 + ϵ \frac{1}{\sqrt{\sigma^2 + \epsilon}} σ2+ϵ1: ∂ L ∂ x i = 1 σ 2 + ϵ [ ∂ L ∂ x ^ i − x i − μ D ( σ 2 + ϵ ) ∑ j = 1 D ∂ L ∂ x ^ j ( x j − μ ) − 1 D ∑ j = 1 D ∂ L ∂ x ^ j ] \frac{\partial L}{\partial x_i} = \frac{1}{\sqrt{\sigma^2 + \epsilon}} \left[ \frac{\partial L}{\partial \hat{x}_i} - \frac{x_i - \mu}{D(\sigma^2 + \epsilon)} \sum_{j=1}^D \frac{\partial L}{\partial \hat{x}_j} (x_j - \mu) - \frac{1}{D} \sum_{j=1}^D \frac{\partial L}{\partial \hat{x}_j} \right] ∂xi∂L=σ2+ϵ1[∂x^i∂L−D(σ2+ϵ)xi−μj=1∑D∂x^j∂L(xj−μ)−D1j=1∑D∂x^j∂L]利用标准化定义式逆向替换:已知 x ^ j = x j − μ σ 2 + ϵ \hat{x}_j = \frac{x_j - \mu}{\sqrt{\sigma^2 + \epsilon}} x^j=σ2+ϵxj−μ,可推导出 ( x j − μ ) = x ^ j σ 2 + ϵ (x_j - \mu) = \hat{x}_j \sqrt{\sigma^2 + \epsilon} (xj−μ)=x^jσ2+ϵ。同理, x i − μ σ 2 + ϵ = x ^ i σ 2 + ϵ \frac{x_i - \mu}{\sigma^2 + \epsilon} = \frac{\hat{x}_i}{\sqrt{\sigma^2 + \epsilon}} σ2+ϵxi−μ=σ2+ϵx^i。将该关系代入中间项的求和式中: x i − μ D ( σ 2 + ϵ ) ∑ j = 1 D ∂ L ∂ x ^ j ( x j − μ ) = x ^ i D σ 2 + ϵ ∑ j = 1 D ∂ L ∂ x ^ j ( x ^ j σ 2 + ϵ ) = x ^ i D ∑ j = 1 D ∂ L ∂ x ^ j x ^ j \frac{x_i - \mu}{D(\sigma^2 + \epsilon)} \sum_{j=1}^D \frac{\partial L}{\partial \hat{x}_j} (x_j - \mu) = \frac{\hat{x}_i}{D\sqrt{\sigma^2 + \epsilon}} \sum_{j=1}^D \frac{\partial L}{\partial \hat{x}_j} \left( \hat{x}_j \sqrt{\sigma^2 + \epsilon} \right) = \frac{\hat{x}_i}{D} \sum_{j=1}^D \frac{\partial L}{\partial \hat{x}_j} \hat{x}_j D(σ2+ϵ)xi−μj=1∑D∂x^j∂L(xj−μ)=Dσ2+ϵx^ij=1∑D∂x^j∂L(x^jσ2+ϵ)=Dx^ij=1∑D∂x^j∂Lx^j将化简后的中间项替换回原式,并整理各项顺序,最终得到规范的 Layer Normalization 反向传播解析解: ∂ L ∂ x i = 1 σ 2 + ϵ ( ∂ L ∂ x ^ i − 1 D ∑ j = 1 D ∂ L ∂ x ^ j − x ^ i D ∑ j = 1 D ∂ L ∂ x ^ j x ^ j ) \frac{\partial L}{\partial x_i} = \frac{1}{\sqrt{\sigma^2 + \epsilon}} \left( \frac{\partial L}{\partial \hat{x}_i} - \frac{1}{D} \sum_{j=1}^D \frac{\partial L}{\partial \hat{x}_j} - \frac{\hat{x}_i}{D} \sum_{j=1}^D \frac{\partial L}{\partial \hat{x}_j} \hat{x}_j \right) ∂xi∂L=σ2+ϵ1(∂x^i∂L−D1j=1∑D∂x^j∂L−Dx^ij=1∑D∂x^j∂Lx^j)所有带有下标 i i i 的,都特指当前正在计算的那个维度的局部值。所有带有下标 j j j 的(都在 Σ \Sigma Σ 求和符号后面),都是为了计算出某种全局的统计量。当 j j j 遍历求和结束后,带有 j j j 的部分就变成了一个常数标量,跟具体的维度无关了。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐

所有评论(0)