写在前面

在前面的文章中,我们分别完成了编码器(第 33 篇)和解码器(第 34 篇)的前向传播数值计算。但前向传播只是模型推理的一半,真正的学习发生在反向传播过程中——当我们把预测结果与真实标签比较,求出误差关于每个参数的梯度,然后沿着梯度方向更新参数,模型的预测才会一步步变准。

本文的目标是:以 ["我", "爱", "深"] → ["i", "love", "deep"] 的机器翻译为例,首次在专栏中完整手算 Transformer 的一轮训练迭代(前向 → 损失 → 反向 → 参数更新 → 再前向)

本文与之前文章的关系

前序文章 内容 本文使用方式
第 33 篇 编码器前向数值计算 直接引用其编码器输出 XencX_{\text{enc}}Xenc
第 34 篇 解码器前向数值计算 直接引用其解码器输出 DDD
第 25-32 篇 各组件原理 略去原理推导,专注数值计算

说明:本文聚焦于训练全流程的数值验证,所有组件的数学原理已在第 25-34 篇详细阐述,此处不再重复推导。代码实现将放在之后的内容。

一、问题设定

1.1 翻译任务

项目 内容
源序列 ["我", "爱", "深"](取自"我爱深度学习",3 tokens)
目标序列 ["i", "love", "deep"](3 tokens)
目标输入(右移) ["<sos>", "i", "love"]
任务 给定源序列,解码器依次预测 "i" "love" "deep"

1.2 模型配置

为便于手算,采用极小型配置:

参数 说明
dmodeld_{\text{model}}dmodel 4 模型维度(所有层的统一特征维度)
hhh 1 单头注意力
dkd_kdk 4 每个头的维度
dffd_{ff}dff 8 FFN 隐藏层维度
NencN_{\text{enc}}Nenc 1 编码器层数
NdecN_{\text{dec}}Ndec 1 解码器层数
源词表 {<pad>:0, 我:1, 爱:2, 深:3} 4 个 token
目标词表 {<pad>:0, <sos>:1, <eos>:2, i:3, love:4, deep:5} 6 个 token

1.3 权重初始化策略

所有注意力 QKV 投影矩阵初始化为单位矩阵,使得前向计算时注意力分数直接由输入相似度决定,便于手算追踪。

输出投影层 Wproj∈R6×4W_{\text{proj}} \in \mathbb{R}^{6 \times 4}WprojR6×4 是本文唯一训练的权重。我们有意初始化为使模型做出错误预测的值,从而清晰展示梯度下降如何修正错误。


二、轮次 1 前向传播

2.1 编码器前向(引用第 33 篇结果)

编码器输入 ["我", "爱", "深"],经过嵌入 + 位置编码后:

X0=(101001011100)←"我"←"爱"←"深" X_0 = \begin{pmatrix} 1 & 0 & 1 & 0 \\ 0 & 1 & 0 & 1 \\ 1 & 1 & 0 & 0 \end{pmatrix} \begin{aligned} &\leftarrow \text{"我"} \\ &\leftarrow \text{"爱"} \\ &\leftarrow \text{"深"} \end{aligned} X0= 101011100010 """"""

经过第 33 篇详细计算的单层编码器前向传播(自注意力 + FFN + 残差 LN),得到编码器输出:

Xenc=(1.279−0.7830.655−1.151−0.7831.279−1.1510.6551.0001.000−1.000−1.000)←"我"←"爱"←"深" X_{\text{enc}} = \begin{pmatrix} 1.279 & -0.783 & 0.655 & -1.151 \\ -0.783 & 1.279 & -1.151 & 0.655 \\ 1.000 & 1.000 & -1.000 & -1.000 \end{pmatrix} \begin{aligned} &\leftarrow \text{"我"} \\ &\leftarrow \text{"爱"} \\ &\leftarrow \text{"深"} \end{aligned} Xenc= 1.2790.7831.0000.7831.2791.0000.6551.1511.0001.1510.6551.000 """"""

编码器输出的物理意义:每一行是源序列对应 token 的上下文感知特征向量。这三个向量各不相同,说明编码器成功区分了三个源词并编码了它们之间的关系。

2.2 解码器前向(引用第 34 篇结果)

解码器输入 ["<sos>", "i", "love"],经过嵌入 + 位置编码:

X0dec=(0.10.20.10.20.60.10.20.30.20.70.10.4)←"<sos>"←"i"←"love" X_0^{\text{dec}} = \begin{pmatrix} 0.1 & 0.2 & 0.1 & 0.2 \\ 0.6 & 0.1 & 0.2 & 0.3 \\ 0.2 & 0.7 & 0.1 & 0.4 \end{pmatrix} \begin{aligned} &\leftarrow \text{"<sos>"} \\ &\leftarrow \text{"i"} \\ &\leftarrow \text{"love"} \end{aligned} X0dec= 0.10.60.20.20.10.70.10.20.10.20.30.4 "<sos>""i""love"

经过第 34 篇详细计算的单层解码器前向传播(因果掩码自注意力 + 交叉注意力 + FFN + 残差 LN),得到解码器输出:

D=(1.0000.0001.0000.0000.0001.0000.0001.0001.0001.0000.0000.000)←位置 0(应预测"i")←位置 1(应预测"love")←位置 2(应预测"deep") D = \begin{pmatrix} 1.000 & 0.000 & 1.000 & 0.000 \\ 0.000 & 1.000 & 0.000 & 1.000 \\ 1.000 & 1.000 & 0.000 & 0.000 \end{pmatrix} \begin{aligned} &\leftarrow \text{位置 0(应预测"i")} \\ &\leftarrow \text{位置 1(应预测"love")} \\ &\leftarrow \text{位置 2(应预测"deep")} \end{aligned} D= 1.0000.0001.0000.0001.0001.0001.0000.0000.0000.0001.0000.000 位置 0(应预测"i"位置 1(应预测"love"位置 2(应预测"deep"

注意:这里 DDD 的取值与第 33 篇编码器输入 X0X_0X0 恰好相同(均为精心设计的序列表示),这是刻意为之——用简洁的数值使后文的梯度计算清晰可追踪。

2.3 输出投影与 Logits

输出投影矩阵 Wproj∈R6×4W_{\text{proj}} \in \mathbb{R}^{6 \times 4}WprojR6×4 初始化为:

Wproj=(0.10.10.00.00.00.10.00.10.00.00.10.10.30.00.20.00.80.00.10.00.00.70.10.0)←<pad>←<sos>←<eos>←i ← 关注 dim0 和 dim2←love ← 过度关注 dim0←deep ← 过度关注 dim1 W_{\text{proj}} = \begin{pmatrix} 0.1 & 0.1 & 0.0 & 0.0 \\ 0.0 & 0.1 & 0.0 & 0.1 \\ 0.0 & 0.0 & 0.1 & 0.1 \\ \mathbf{0.3} & \mathbf{0.0} & \mathbf{0.2} & \mathbf{0.0} \\ \mathbf{0.8} & \mathbf{0.0} & \mathbf{0.1} & \mathbf{0.0} \\ \mathbf{0.0} & \mathbf{0.7} & \mathbf{0.1} & \mathbf{0.0} \end{pmatrix} \begin{aligned} &\leftarrow \text{<pad>} \\ &\leftarrow \text{<sos>} \\ &\leftarrow \text{<eos>} \\ &\leftarrow \text{i ← 关注 dim0 和 dim2} \\ &\leftarrow \text{love ← 过度关注 dim0} \\ &\leftarrow \text{deep ← 过度关注 dim1} \end{aligned} Wproj= 0.10.00.00.30.80.00.10.10.00.00.00.70.00.00.10.20.10.10.00.10.10.00.00.0 <pad><sos><eos>i ← 关注 dim0  dim2love ← 过度关注 dim0deep ← 过度关注 dim1

设计意图:注意加粗的三行——「i」行的权重 0.3 和 0.2 偏低,「love」行在 dim0 有 0.8(过高),「deep」行在 dim1 有 0.7(过高)。这会导致初始预测全部错误

计算 LogitsLogits=D⋅Wproj⊤\text{Logits} = D \cdot W_{\text{proj}}^\topLogits=DWproj

对于位置 ppp 和词表中第 jjj 个 token:

Logits[p,j]=∑k=03D[p,k]×Wproj[j,k] \text{Logits}[p, j] = \sum_{k=0}^{3} D[p, k] \times W_{\text{proj}}[j, k] Logits[p,j]=k=03D[p,k]×Wproj[j,k]

逐位置计算:

位置 0<sos> → 应预测 i,token 3):

D[0]=[1.0,0.0,1.0,0.0]D[0] = [1.0, 0.0, 1.0, 0.0]D[0]=[1.0,0.0,1.0,0.0]

Token 计算过程 Logit
<pad> 1.0×0.1+0.0×0.1+1.0×0.0+0.0×0.0=0.11.0\times0.1 + 0.0\times0.1 + 1.0\times0.0 + 0.0\times0.0 = 0.11.0×0.1+0.0×0.1+1.0×0.0+0.0×0.0=0.1 0.1
<sos> 1.0×0.0+0.0×0.1+1.0×0.0+0.0×0.1=0.01.0\times0.0 + 0.0\times0.1 + 1.0\times0.0 + 0.0\times0.1 = 0.01.0×0.0+0.0×0.1+1.0×0.0+0.0×0.1=0.0 0.0
<eos> 1.0×0.0+0.0×0.0+1.0×0.1+0.0×0.1=0.11.0\times0.0 + 0.0\times0.0 + 1.0\times0.1 + 0.0\times0.1 = 0.11.0×0.0+0.0×0.0+1.0×0.1+0.0×0.1=0.1 0.1
i 1.0×0.3+0.0×0.0+1.0×0.2+0.0×0.0=0.51.0\times0.3 + 0.0\times0.0 + 1.0\times0.2 + 0.0\times0.0 = \mathbf{0.5}1.0×0.3+0.0×0.0+1.0×0.2+0.0×0.0=0.5 0.5
love 1.0×0.8+0.0×0.0+1.0×0.1+0.0×0.0=0.91.0\times\mathbf{0.8} + 0.0\times0.0 + 1.0\times0.1 + 0.0\times0.0 = \mathbf{0.9}1.0×0.8+0.0×0.0+1.0×0.1+0.0×0.0=0.9 0.9 ← ❌ 最高
deep 1.0×0.0+0.0×0.7+1.0×0.1+0.0×0.0=0.11.0\times0.0 + 0.0\times0.7 + 1.0\times0.1 + 0.0\times0.0 = 0.11.0×0.0+0.0×0.7+1.0×0.1+0.0×0.0=0.1 0.1

位置 0 预测 love(logit=0.9),应为 i(logit=0.5)

位置 1i → 应预测 love,token 4):

D[1]=[0.0,1.0,0.0,1.0]D[1] = [0.0, 1.0, 0.0, 1.0]D[1]=[0.0,1.0,0.0,1.0]

Token 计算过程 Logit
<pad> 0.0×0.1+1.0×0.1+0.0×0.0+1.0×0.0=0.10.0\times0.1 + 1.0\times0.1 + 0.0\times0.0 + 1.0\times0.0 = 0.10.0×0.1+1.0×0.1+0.0×0.0+1.0×0.0=0.1 0.1
<sos> 0.0×0.0+1.0×0.1+0.0×0.0+1.0×0.1=0.20.0\times0.0 + 1.0\times0.1 + 0.0\times0.0 + 1.0\times0.1 = 0.20.0×0.0+1.0×0.1+0.0×0.0+1.0×0.1=0.2 0.2
<eos> 0.0×0.0+1.0×0.0+0.0×0.1+1.0×0.1=0.10.0\times0.0 + 1.0\times0.0 + 0.0\times0.1 + 1.0\times0.1 = 0.10.0×0.0+1.0×0.0+0.0×0.1+1.0×0.1=0.1 0.1
i 0.0×0.3+1.0×0.0+0.0×0.2+1.0×0.0=0.00.0\times0.3 + 1.0\times0.0 + 0.0\times0.2 + 1.0\times0.0 = 0.00.0×0.3+1.0×0.0+0.0×0.2+1.0×0.0=0.0 0.0
love 0.0×0.8+1.0×0.0+0.0×0.1+1.0×0.0=0.00.0\times0.8 + 1.0\times0.0 + 0.0\times0.1 + 1.0\times0.0 = 0.00.0×0.8+1.0×0.0+0.0×0.1+1.0×0.0=0.0 0.0
deep 0.0×0.0+1.0×0.7+0.0×0.1+1.0×0.0=0.70.0\times0.0 + 1.0\times\mathbf{0.7} + 0.0\times0.1 + 1.0\times0.0 = \mathbf{0.7}0.0×0.0+1.0×0.7+0.0×0.1+1.0×0.0=0.7 0.7 ← ❌ 最高

位置 1 预测 deep(logit=0.7),应为 love(logit=0.0)

位置 2love → 应预测 deep,token 5):

D[2]=[1.0,1.0,0.0,0.0]D[2] = [1.0, 1.0, 0.0, 0.0]D[2]=[1.0,1.0,0.0,0.0]

Token 计算过程 Logit
<pad> 1.0×0.1+1.0×0.1+0.0×0.0+0.0×0.0=0.21.0\times0.1 + 1.0\times0.1 + 0.0\times0.0 + 0.0\times0.0 = 0.21.0×0.1+1.0×0.1+0.0×0.0+0.0×0.0=0.2 0.2
<sos> 1.0×0.0+1.0×0.1+0.0×0.0+0.0×0.1=0.11.0\times0.0 + 1.0\times0.1 + 0.0\times0.0 + 0.0\times0.1 = 0.11.0×0.0+1.0×0.1+0.0×0.0+0.0×0.1=0.1 0.1
<eos> 1.0×0.0+1.0×0.0+0.0×0.1+0.0×0.1=0.01.0\times0.0 + 1.0\times0.0 + 0.0\times0.1 + 0.0\times0.1 = 0.01.0×0.0+1.0×0.0+0.0×0.1+0.0×0.1=0.0 0.0
i 1.0×0.3+1.0×0.0+0.0×0.2+0.0×0.0=0.31.0\times0.3 + 1.0\times0.0 + 0.0\times0.2 + 0.0\times0.0 = 0.31.0×0.3+1.0×0.0+0.0×0.2+0.0×0.0=0.3 0.3
love 1.0×0.8+1.0×0.0+0.0×0.1+0.0×0.0=0.81.0\times\mathbf{0.8} + 1.0\times0.0 + 0.0\times0.1 + 0.0\times0.0 = \mathbf{0.8}1.0×0.8+1.0×0.0+0.0×0.1+0.0×0.0=0.8 0.8 ← ❌ 最高
deep 1.0×0.0+1.0×0.7+0.0×0.1+0.0×0.0=0.71.0\times0.0 + 1.0\times\mathbf{0.7} + 0.0\times0.1 + 0.0\times0.0 = \mathbf{0.7}1.0×0.0+1.0×0.7+0.0×0.1+0.0×0.0=0.7 0.7

位置 2 预测 love(logit=0.8),应为 deep(logit=0.7)

初始预测总结:三个位置全部预测错误。这正是我们想要的效果——初始权重是"坏的",下一步梯度更新才能展示"变好"的过程。

2.4 Softmax 与交叉熵损失

Softmax 概率

位置 0:logits = [0.1,0.0,0.1,0.5,0.9,0.1][0.1, 0.0, 0.1, 0.5, 0.9, 0.1][0.1,0.0,0.1,0.5,0.9,0.1]

elogit=[1.105,1.000,1.105,1.649,2.460,1.105],sum=8.424 e^{\text{logit}} = [1.105, 1.000, 1.105, 1.649, 2.460, 1.105], \quad \text{sum}=8.424 elogit=[1.105,1.000,1.105,1.649,2.460,1.105],sum=8.424

P0=[0.131,0.119,0.131,0.196,0.292,0.131] P_0 = [0.131, 0.119, 0.131, 0.196, 0.292, 0.131] P0=[0.131,0.119,0.131,0.196,0.292,0.131]

位置 1:logits = [0.1,0.2,0.1,0.0,0.0,0.7][0.1, 0.2, 0.1, 0.0, 0.0, 0.7][0.1,0.2,0.1,0.0,0.0,0.7]

elogit=[1.105,1.221,1.105,1.000,1.000,2.014],sum=7.445 e^{\text{logit}} = [1.105, 1.221, 1.105, 1.000, 1.000, 2.014], \quad \text{sum}=7.445 elogit=[1.105,1.221,1.105,1.000,1.000,2.014],sum=7.445

P1=[0.148,0.164,0.148,0.134,0.134,0.270] P_1 = [0.148, 0.164, 0.148, 0.134, 0.134, 0.270] P1=[0.148,0.164,0.148,0.134,0.134,0.270]

位置 2:logits = [0.2,0.1,0.0,0.3,0.8,0.7][0.2, 0.1, 0.0, 0.3, 0.8, 0.7][0.2,0.1,0.0,0.3,0.8,0.7]

elogit=[1.221,1.105,1.000,1.350,2.226,2.014],sum=8.916 e^{\text{logit}} = [1.221, 1.105, 1.000, 1.350, 2.226, 2.014], \quad \text{sum}=8.916 elogit=[1.221,1.105,1.000,1.350,2.226,2.014],sum=8.916

P2=[0.137,0.124,0.112,0.151,0.250,0.226] P_2 = [0.137, 0.124, 0.112, 0.151, 0.250, 0.226] P2=[0.137,0.124,0.112,0.151,0.250,0.226]

交叉熵损失

对于每个位置,损失为 Lp=−log⁡(Pp[targetp])L_p = -\log(P_p[\text{target}_p])Lp=log(Pp[targetp])

位置 目标 token 正确概率 损失 说明
0 i (token 3) 0.196 −log⁡(0.196)=1.630-\log(0.196)=1.630log(0.196)=1.630 模型对"i"的置信度低
1 love (token 4) 0.134 −log⁡(0.134)=2.010-\log(0.134)=2.010log(0.134)=2.010 模型对"love"的置信度极低
2 deep (token 5) 0.226 −log⁡(0.226)=1.487-\log(0.226)=1.487log(0.226)=1.487 模型对"deep"的置信度中等

L=13(1.630+2.010+1.487)=5.1273=1.709 \mathcal{L} = \frac{1}{3}(1.630 + 2.010 + 1.487) = \frac{5.127}{3} = 1.709 L=31(1.630+2.010+1.487)=35.127=1.709

这是初始损失值。梯度下降的目标就是将这个值降下去。

2.5 一轮前向传播的损失图景

源序列 ["我", "爱", "深"] ──→ 编码器 ──→ X_enc (3×4)
                                              │
目标 ["<sos>", "i", "love"] ──→ 解码器 ──→ D (3×4) ──→ W_proj ──→ Logits (3×6) ──→ Softmax ──→ Loss = 1.709
                                              ↑                    ↑
                                        交叉注意力 K,V         只有此权重可训练

三、反向传播:输出投影层梯度(完整手算)

这是本文的核心。我们将逐行计算损失 L\mathcal{L}L 关于输出投影矩阵 WprojW_{\text{proj}}Wproj 的梯度。

3.1 梯度计算的基本公式

对于输出投影层 Logits=D⋅Wproj⊤\text{Logits} = D \cdot W_{\text{proj}}^\topLogits=DWproj,使用交叉熵损失 L=13∑p−log⁡(Softmax(Logits[p,:])[tp])\mathcal{L} = \frac{1}{3}\sum_p -\log(\text{Softmax}(\text{Logits}[p,:])[t_p])L=31plog(Softmax(Logits[p,:])[tp]),梯度为:

∂L∂Wproj[j,k]=13∑p=02(Pp[j]−1[tp=j])⏟位置 p 在 token j 上的 logit 梯度⋅D[p,k]⏟解码器输出 \frac{\partial \mathcal{L}}{\partial W_{\text{proj}}[j,k]} = \frac{1}{3} \sum_{p=0}^{2} \underbrace{\left( P_p[j] - \mathbb{1}[t_p = j] \right)}_{\text{位置 p 在 token j 上的 logit 梯度}} \cdot \underbrace{D[p,k]}_{\text{解码器输出}} Wproj[j,k]L=31p=02位置 p  token j 上的 logit 梯度 (Pp[j]1[tp=j])解码器输出 D[p,k]

其中 Pp[j]=Softmax(Logits[p,:])[j]P_p[j] = \text{Softmax}(\text{Logits}[p,:])[j]Pp[j]=Softmax(Logits[p,:])[j]1[tp=j]\mathbb{1}[t_p = j]1[tp=j] 为指示函数(当目标= jjj 时为 1,否则为 0)。

物理意义:每个位置的 logit 梯度 = (预测概率 - 真实目标指示),这个值乘以解码器输出后,沿着序列位置求和,就得到每个权重参数的梯度。

3.2 各位置的 Logit 梯度

首先计算 δp[j]=Pp[j]−1[tp=j]\delta_p[j] = P_p[j] - \mathbb{1}[t_p = j]δp[j]=Pp[j]1[tp=j]

token jjj P0[j]P_0[j]P0[j] t0=jt_0=jt0=j? δ0[j]\delta_0[j]δ0[j] P1[j]P_1[j]P1[j] t1=jt_1=jt1=j? δ1[j]\delta_1[j]δ1[j] P2[j]P_2[j]P2[j] t2=jt_2=jt2=j? δ2[j]\delta_2[j]δ2[j]
<pad> 0.131 +0.131 0.148 +0.148 0.137 +0.137
<sos> 0.119 +0.119 0.164 +0.164 0.124 +0.124
<eos> 0.131 +0.131 0.148 +0.148 0.112 +0.112
i 0.196 -0.804 0.134 +0.134 0.151 +0.151
love 0.292 +0.292 0.134 -0.866 0.250 +0.250
deep 0.131 +0.131 0.270 +0.270 0.226 -0.774

关键观察:对于正确的目标 token,δ=P−1<0\delta = P - 1 < 0δ=P1<0,梯度方向告诉模型要提高该 token 的 logit;对于错误 token,δ=P>0\delta = P > 0δ=P>0,梯度方向告诉模型要降低该 token 的 logit。

3.3 解码器输出矩阵

D=(1.00.01.00.00.01.00.01.01.01.00.00.0) D = \begin{pmatrix} 1.0 & 0.0 & 1.0 & 0.0 \\ 0.0 & 1.0 & 0.0 & 1.0 \\ 1.0 & 1.0 & 0.0 & 0.0 \end{pmatrix} D= 1.00.01.00.01.01.01.00.00.00.01.00.0

3.4 梯度计算:以 token i(行 3)为例

∂L∂Wproj[3,k]=13∑p=02δp[3]⋅D[p,k] \frac{\partial \mathcal{L}}{\partial W_{\text{proj}}[3,k]} = \frac{1}{3} \sum_{p=0}^{2} \delta_p[3] \cdot D[p,k] Wproj[3,k]L=31p=02δp[3]D[p,k]

δp[3]\delta_p[3]δp[3] 的值:p=0:−0.804,  p=1:+0.134,  p=2:+0.151p=0: -0.804,\; p=1: +0.134,\; p=2: +0.151p=0:0.804,p=1:+0.134,p=2:+0.151

对于 k=0k=0k=0(dim0 权重)

∂L∂Wproj[3,0]=13((−0.804)×1.0+0.134×0.0+0.151×1.0)=13(−0.804+0+0.151)=−0.6533=−0.218 \frac{\partial \mathcal{L}}{\partial W_{\text{proj}}[3,0]} = \frac{1}{3}\left( (-0.804)\times1.0 + 0.134\times0.0 + 0.151\times1.0 \right) = \frac{1}{3}\left(-0.804 + 0 + 0.151\right) = \frac{-0.653}{3} = -0.218 Wproj[3,0]L=31((0.804)×1.0+0.134×0.0+0.151×1.0)=31(0.804+0+0.151)=30.653=0.218

梯度方向为负 → 权重应增加(负梯度指引向损失下降方向)。

对于 k=1k=1k=1(dim1 权重)

∂L∂Wproj[3,1]=13((−0.804)×0.0+0.134×1.0+0.151×1.0)=13(0+0.134+0.151)=0.2853=+0.095 \frac{\partial \mathcal{L}}{\partial W_{\text{proj}}[3,1]} = \frac{1}{3}\left( (-0.804)\times0.0 + 0.134\times1.0 + 0.151\times1.0 \right) = \frac{1}{3}\left(0 + 0.134 + 0.151\right) = \frac{0.285}{3} = +0.095 Wproj[3,1]L=31((0.804)×0.0+0.134×1.0+0.151×1.0)=31(0+0.134+0.151)=30.285=+0.095

梯度方向为正 → 权重应减小。

对于 k=2k=2k=2(dim2 权重)

∂L∂Wproj[3,2]=13((−0.804)×1.0+0.134×0.0+0.151×0.0)=13(−0.804+0+0)=−0.268 \frac{\partial \mathcal{L}}{\partial W_{\text{proj}}[3,2]} = \frac{1}{3}\left( (-0.804)\times1.0 + 0.134\times0.0 + 0.151\times0.0 \right) = \frac{1}{3}\left(-0.804 + 0 + 0\right) = -0.268 Wproj[3,2]L=31((0.804)×1.0+0.134×0.0+0.151×0.0)=31(0.804+0+0)=0.268

对于 k=3k=3k=3(dim3 权重)

∂L∂Wproj[3,3]=13((−0.804)×0.0+0.134×1.0+0.151×0.0)=13(0+0.134+0)=+0.045 \frac{\partial \mathcal{L}}{\partial W_{\text{proj}}[3,3]} = \frac{1}{3}\left( (-0.804)\times0.0 + 0.134\times1.0 + 0.151\times0.0 \right) = \frac{1}{3}\left(0 + 0.134 + 0\right) = +0.045 Wproj[3,3]L=31((0.804)×0.0+0.134×1.0+0.151×0.0)=31(0+0.134+0)=+0.045

梯度向量

∇Wproj[3,:]=[−0.218,  +0.095,  −0.268,  +0.045] \nabla_{W_{\text{proj}}[3,:]} = [-0.218, \; +0.095, \; -0.268, \; +0.045] Wproj[3,:]=[0.218,+0.095,0.268,+0.045]

梯度解释

维度 梯度 含义
dim0 −0.218-0.2180.218 应增加 Wproj[3,0]W_{\text{proj}}[3,0]Wproj[3,0]——因为位置 0 需要"i"的 logit 提高(-0.804)且 D[0,0]=1.0
dim1 +0.095+0.095+0.095 应减小 Wproj[3,1]W_{\text{proj}}[3,1]Wproj[3,1]——因为位置 2 不需要"i"(+0.151)且 D[2,1]=1.0
dim2 −0.268-0.2680.268 应增加 Wproj[3,2]W_{\text{proj}}[3,2]Wproj[3,2]——因为位置 0 需要"i"的 logit 提高
dim3 +0.045+0.045+0.045 应减小 Wproj[3,3]W_{\text{proj}}[3,3]Wproj[3,3]——位置 1 处"i"的 logit 过高(+0.134)

梯度竞争:注意 dim0 和 dim2 的梯度告诉我们要增大权重(位置 0 需要更多"i"),而 dim1 和 dim3 的梯度告诉我们要减小权重(位置 1、2 需要更少"i")。这正是多任务训练的缩影——同一个 token 的权重需要在不同位置间找到平衡。

3.5 全部 token 的梯度(汇总表)

按照同样的方法,可以算出 WprojW_{\text{proj}}Wproj 所有 6 行 × 4 列的梯度:

核心计算公式

gj,k=13∑p=02δp[j]⋅D[p,k] g_{j,k} = \frac{1}{3} \sum_{p=0}^{2} \delta_p[j] \cdot D[p,k] gj,k=31p=02δp[j]D[p,k]

<pad>j=0j=0j=0)的 dim0 为例:g0,0=13(+0.131×1.0+0.148×0.0+0.137×1.0)=0.2683=+0.089g_{0,0} = \frac{1}{3}(+0.131\times1.0 + 0.148\times0.0 + 0.137\times1.0) = \frac{0.268}{3} = +0.089g0,0=31(+0.131×1.0+0.148×0.0+0.137×1.0)=30.268=+0.089

全部梯度:

Token δ0\delta_0δ0 δ1\delta_1δ1 δ2\delta_2δ2 ∇dim0\nabla_{\text{dim0}}dim0 ∇dim1\nabla_{\text{dim1}}dim1 ∇dim2\nabla_{\text{dim2}}dim2 ∇dim3\nabla_{\text{dim3}}dim3
<pad> +0.131 +0.148 +0.137 +0.089+0.089+0.089 +0.095+0.095+0.095 +0.044+0.044+0.044 +0.049+0.049+0.049
<sos> +0.119 +0.164 +0.124 +0.081+0.081+0.081 +0.096+0.096+0.096 +0.040+0.040+0.040 +0.055+0.055+0.055
<eos> +0.131 +0.148 +0.112 +0.081+0.081+0.081 +0.087+0.087+0.087 +0.044+0.044+0.044 +0.049+0.049+0.049
i -0.804 +0.134 +0.151 −0.218-0.2180.218 +0.095+0.095+0.095 −0.268-0.2680.268 +0.045+0.045+0.045
love +0.292 -0.866 +0.250 +0.181+0.181+0.181 −0.205-0.2050.205 +0.097+0.097+0.097 −0.289-0.2890.289
deep +0.131 +0.270 -0.774 −0.214-0.2140.214 +0.012+0.012+0.012 +0.044+0.044+0.044 +0.090+0.090+0.090

完整梯度矩阵

∇L=(+0.089+0.095+0.044+0.049+0.081+0.096+0.040+0.055+0.081+0.087+0.044+0.049−0.218+0.095−0.268+0.045+0.181−0.205+0.097−0.289−0.214+0.012+0.044+0.090) \nabla \mathcal{L} = \begin{pmatrix} +0.089 & +0.095 & +0.044 & +0.049 \\ +0.081 & +0.096 & +0.040 & +0.055 \\ +0.081 & +0.087 & +0.044 & +0.049 \\ -0.218 & +0.095 & -0.268 & +0.045 \\ +0.181 & -0.205 & +0.097 & -0.289 \\ -0.214 & +0.012 & +0.044 & +0.090 \end{pmatrix} L= +0.089+0.081+0.0810.218+0.1810.214+0.095+0.096+0.087+0.0950.205+0.012+0.044+0.040+0.0440.268+0.097+0.044+0.049+0.055+0.049+0.0450.289+0.090

模式观察

  • 目标 token(i, love, deep)对应的行有大数值梯度——因为它们的 δ=P−1\delta = P-1δ=P1 远大于非目标 token 的 δ=P\delta = Pδ=P
  • 非目标 token 的梯度较小且为正——说明它们的 logit 普遍偏高,需要降低
  • 梯度的正负取决于 δ\deltaδDDD 的乘积在三个位置上的综合结果

四、参数更新(SGD)

使用学习率 η=0.5\eta = 0.5η=0.5,SGD 更新公式:

Wprojnew=Wprojold−η⋅∇L W_{\text{proj}}^{\text{new}} = W_{\text{proj}}^{\text{old}} - \eta \cdot \nabla \mathcal{L} Wprojnew=WprojoldηL

4.1 逐行更新

Token <pad>(行 0)

Wprojnew[0,:]=[0.1,0.1,0.0,0.0]−0.5×[+0.089,+0.095,+0.044,+0.049] W_{\text{proj}}^{\text{new}}[0,:] = [0.1, 0.1, 0.0, 0.0] - 0.5 \times [+0.089, +0.095, +0.044, +0.049] Wprojnew[0,:]=[0.1,0.1,0.0,0.0]0.5×[+0.089,+0.095,+0.044,+0.049]

=[0.1−0.045,  0.1−0.048,  0.0−0.022,  0.0−0.025] = [0.1 - 0.045, \; 0.1 - 0.048, \; 0.0 - 0.022, \; 0.0 - 0.025] =[0.10.045,0.10.048,0.00.022,0.00.025]

=[0.055,0.052,−0.022,−0.025] = [0.055, 0.052, -0.022, -0.025] =[0.055,0.052,0.022,0.025]

Token <sos>(行 1)

Wprojnew[1,:]=[0.0,0.1,0.0,0.1]−0.5×[+0.081,+0.096,+0.040,+0.055] W_{\text{proj}}^{\text{new}}[1,:] = [0.0, 0.1, 0.0, 0.1] - 0.5 \times [+0.081, +0.096, +0.040, +0.055] Wprojnew[1,:]=[0.0,0.1,0.0,0.1]0.5×[+0.081,+0.096,+0.040,+0.055]

=[0.0−0.041,  0.1−0.048,  0.0−0.020,  0.1−0.028] = [0.0 - 0.041, \; 0.1 - 0.048, \; 0.0 - 0.020, \; 0.1 - 0.028] =[0.00.041,0.10.048,0.00.020,0.10.028]

=[−0.041,0.052,−0.020,0.072] = [-0.041, 0.052, -0.020, 0.072] =[0.041,0.052,0.020,0.072]

Token <eos>(行 2)

Wprojnew[2,:]=[0.0,0.0,0.1,0.1]−0.5×[+0.081,+0.087,+0.044,+0.049] W_{\text{proj}}^{\text{new}}[2,:] = [0.0, 0.0, 0.1, 0.1] - 0.5 \times [+0.081, +0.087, +0.044, +0.049] Wprojnew[2,:]=[0.0,0.0,0.1,0.1]0.5×[+0.081,+0.087,+0.044,+0.049]

=[0.0−0.041,  0.0−0.044,  0.1−0.022,  0.1−0.025] = [0.0 - 0.041, \; 0.0 - 0.044, \; 0.1 - 0.022, \; 0.1 - 0.025] =[0.00.041,0.00.044,0.10.022,0.10.025]

=[−0.041,−0.044,0.078,0.075] = [-0.041, -0.044, 0.078, 0.075] =[0.041,0.044,0.078,0.075]

Token i(行 3)——核心更新:

Wprojnew[3,:]=[0.3,0.0,0.2,0.0]−0.5×[−0.218,+0.095,−0.268,+0.045] W_{\text{proj}}^{\text{new}}[3,:] = [0.3, 0.0, 0.2, 0.0] - 0.5 \times [-0.218, +0.095, -0.268, +0.045] Wprojnew[3,:]=[0.3,0.0,0.2,0.0]0.5×[0.218,+0.095,0.268,+0.045]

=[0.3−(−0.109),  0.0−0.048,  0.2−(−0.134),  0.0−0.023] = [0.3 - (-0.109), \; 0.0 - 0.048, \; 0.2 - (-0.134), \; 0.0 - 0.023] =[0.3(0.109),0.00.048,0.2(0.134),0.00.023]

=[0.3+0.109,  −0.048,  0.2+0.134,  −0.023] = [0.3 + 0.109, \; -0.048, \; 0.2 + 0.134, \; -0.023] =[0.3+0.109,0.048,0.2+0.134,0.023]

=[0.409,  −0.048,  0.334,  −0.023] = [0.409, \; -0.048, \; 0.334, \; -0.023] =[0.409,0.048,0.334,0.023]

✨ 关键变化:dim0 权重从 0.3 提升到 0.409,dim2 权重从 0.2 提升到 0.334。这两个维度在位置 0 处有 D[0]=[1,0,1,0]D[0] = [1, 0, 1, 0]D[0]=[1,0,1,0],提升它们直接提高了位置 0 对"i"的 logit。

Token love(行 4)

Wprojnew[4,:]=[0.8,0.0,0.1,0.0]−0.5×[+0.181,−0.205,+0.097,−0.289] W_{\text{proj}}^{\text{new}}[4,:] = [0.8, 0.0, 0.1, 0.0] - 0.5 \times [+0.181, -0.205, +0.097, -0.289] Wprojnew[4,:]=[0.8,0.0,0.1,0.0]0.5×[+0.181,0.205,+0.097,0.289]

=[0.8−0.091,  0.0−(−0.103),  0.1−0.049,  0.0−(−0.145)] = [0.8 - 0.091, \; 0.0 - (-0.103), \; 0.1 - 0.049, \; 0.0 - (-0.145)] =[0.80.091,0.0(0.103),0.10.049,0.0(0.145)]

=[0.709,0.103,0.051,0.145] = [0.709, 0.103, 0.051, 0.145] =[0.709,0.103,0.051,0.145]

✨ 关键变化:dim0 权重从 0.8 降低0.709(初始值过高导致误预测),dim1 从 0.0 升到 0.103,dim3 从 0.0 升到 0.145——后两个维度的提升帮助位置 1(应预测"love")获得更高的 logit。

Token deep(行 5)

Wprojnew[5,:]=[0.0,0.7,0.1,0.0]−0.5×[−0.214,+0.012,+0.044,+0.090] W_{\text{proj}}^{\text{new}}[5,:] = [0.0, 0.7, 0.1, 0.0] - 0.5 \times [-0.214, +0.012, +0.044, +0.090] Wprojnew[5,:]=[0.0,0.7,0.1,0.0]0.5×[0.214,+0.012,+0.044,+0.090]

=[0.0−(−0.107),  0.7−0.006,  0.1−0.022,  0.0−0.045] = [0.0 - (-0.107), \; 0.7 - 0.006, \; 0.1 - 0.022, \; 0.0 - 0.045] =[0.0(0.107),0.70.006,0.10.022,0.00.045]

=[0.107,0.694,0.078,−0.045] = [0.107, 0.694, 0.078, -0.045] =[0.107,0.694,0.078,0.045]

✨ 关键变化:dim0 从 0.0 升到 0.107(帮助位置 2 对"deep"的 logit),dim1 从 0.7 微降到 0.694(因为位置 1 的"deep"logit 过高)。

4.2 更新后的权重矩阵

Wprojnew=(0.0550.052−0.022−0.025−0.0410.052−0.0200.072−0.041−0.0440.0780.0750.409−0.0480.334−0.0230.7090.1030.0510.1450.1070.6940.078−0.045)←<pad>←<sos>←<eos>←i ↑dim0, ↑dim2←love ↓dim0, ↑dim1, ↑dim3←deep ↑dim0, ↓dim1 W_{\text{proj}}^{\text{new}} = \begin{pmatrix} 0.055 & 0.052 & -0.022 & -0.025 \\ -0.041 & 0.052 & -0.020 & 0.072 \\ -0.041 & -0.044 & 0.078 & 0.075 \\ \mathbf{0.409} & \mathbf{-0.048} & \mathbf{0.334} & \mathbf{-0.023} \\ \mathbf{0.709} & \mathbf{0.103} & \mathbf{0.051} & \mathbf{0.145} \\ \mathbf{0.107} & \mathbf{0.694} & \mathbf{0.078} & \mathbf{-0.045} \end{pmatrix} \begin{aligned} &\leftarrow \text{<pad>} \\ &\leftarrow \text{<sos>} \\ &\leftarrow \text{<eos>} \\ &\leftarrow \mathbf{\text{i ↑dim0, ↑dim2}} \\ &\leftarrow \mathbf{\text{love ↓dim0, ↑dim1, ↑dim3}} \\ &\leftarrow \mathbf{\text{deep ↑dim0, ↓dim1}} \end{aligned} Wprojnew= 0.0550.0410.0410.4090.7090.1070.0520.0520.0440.0480.1030.6940.0220.0200.0780.3340.0510.0780.0250.0720.0750.0230.1450.045 <pad><sos><eos>i ↑dim0, ↑dim2love ↓dim0, ↑dim1, ↑dim3deep ↑dim0, ↓dim1


五、轮次 2 前向传播(验证改进)

使用更新后的权重 WprojnewW_{\text{proj}}^{\text{new}}Wprojnew 重新前向传播,验证损失是否下降。

5.1 重新计算 Logits

位置 0D[0]=[1.0,0.0,1.0,0.0]D[0] = [1.0, 0.0, 1.0, 0.0]D[0]=[1.0,0.0,1.0,0.0]):

Token 计算过程 Logit
<pad> 1.0×0.055+0.0×0.052+1.0×(−0.022)+0.0×(−0.025)=0.0331.0\times0.055 + 0.0\times0.052 + 1.0\times(-0.022) + 0.0\times(-0.025) = 0.0331.0×0.055+0.0×0.052+1.0×(0.022)+0.0×(0.025)=0.033 0.033
<sos> 1.0×(−0.041)+0.0×0.052+1.0×(−0.020)+0.0×0.072=−0.0611.0\times(-0.041) + 0.0\times0.052 + 1.0\times(-0.020) + 0.0\times0.072 = -0.0611.0×(0.041)+0.0×0.052+1.0×(0.020)+0.0×0.072=0.061 -0.061
<eos> 1.0×(−0.041)+0.0×(−0.044)+1.0×0.078+0.0×0.075=0.0371.0\times(-0.041) + 0.0\times(-0.044) + 1.0\times0.078 + 0.0\times0.075 = 0.0371.0×(0.041)+0.0×(0.044)+1.0×0.078+0.0×0.075=0.037 0.037
i 1.0×0.409+0.0×(−0.048)+1.0×0.334+0.0×(−0.023)=0.7431.0\times\mathbf{0.409} + 0.0\times(-0.048) + 1.0\times\mathbf{0.334} + 0.0\times(-0.023) = \mathbf{0.743}1.0×0.409+0.0×(0.048)+1.0×0.334+0.0×(0.023)=0.743 0.743
love 1.0×0.709+0.0×0.103+1.0×0.051+0.0×0.145=0.7601.0\times0.709 + 0.0\times0.103 + 1.0\times0.051 + 0.0\times0.145 = 0.7601.0×0.709+0.0×0.103+1.0×0.051+0.0×0.145=0.760 0.760 → ❌
deep 1.0×0.107+0.0×0.694+1.0×0.078+0.0×(−0.045)=0.1851.0\times0.107 + 0.0\times0.694 + 1.0\times0.078 + 0.0\times(-0.045) = 0.1851.0×0.107+0.0×0.694+1.0×0.078+0.0×(0.045)=0.185 0.185

位置 0 改进:"i"的 logit 从 0.5 → 0.743(大幅提升!)。但"love"仍有 0.760(由于 Wproj[4,0]W_{\text{proj}}[4,0]Wproj[4,0] 从 0.8 只降到 0.709,仍偏高)。两者非常接近,概率几乎持平——这是进步,但还需更多轮迭代。

位置 1D[1]=[0.0,1.0,0.0,1.0]D[1] = [0.0, 1.0, 0.0, 1.0]D[1]=[0.0,1.0,0.0,1.0]):

Token 计算过程 Logit
<pad> 0.0×0.055+1.0×0.052+0.0×(−0.022)+1.0×(−0.025)=0.0270.0\times0.055 + 1.0\times0.052 + 0.0\times(-0.022) + 1.0\times(-0.025) = 0.0270.0×0.055+1.0×0.052+0.0×(0.022)+1.0×(0.025)=0.027 0.027
<sos> 0.0×(−0.041)+1.0×0.052+0.0×(−0.020)+1.0×0.072=0.1240.0\times(-0.041) + 1.0\times0.052 + 0.0\times(-0.020) + 1.0\times0.072 = 0.1240.0×(0.041)+1.0×0.052+0.0×(0.020)+1.0×0.072=0.124 0.124
<eos> 0.0×(−0.041)+1.0×(−0.044)+0.0×0.078+1.0×0.075=0.0310.0\times(-0.041) + 1.0\times(-0.044) + 0.0\times0.078 + 1.0\times0.075 = 0.0310.0×(0.041)+1.0×(0.044)+0.0×0.078+1.0×0.075=0.031 0.031
i 0.0×0.409+1.0×(−0.048)+0.0×0.334+1.0×(−0.023)=−0.0710.0\times0.409 + 1.0\times(-0.048) + 0.0\times0.334 + 1.0\times(-0.023) = -0.0710.0×0.409+1.0×(0.048)+0.0×0.334+1.0×(0.023)=0.071 -0.071
love 0.0×0.709+1.0×0.103+0.0×0.051+1.0×0.145=0.2480.0\times0.709 + 1.0\times\mathbf{0.103} + 0.0\times0.051 + 1.0\times\mathbf{0.145} = \mathbf{0.248}0.0×0.709+1.0×0.103+0.0×0.051+1.0×0.145=0.248 0.248
deep 0.0×0.107+1.0×0.694+0.0×0.078+1.0×(−0.045)=0.6490.0\times0.107 + 1.0\times0.694 + 0.0\times0.078 + 1.0\times(-0.045) = 0.6490.0×0.107+1.0×0.694+0.0×0.078+1.0×(0.045)=0.649 0.649 → ❌

位置 1 改进:"love"的 logit 从 0.0 → 0.248(从无到有!)。但"deep"仍有 0.649(Wproj[5,1]=0.694W_{\text{proj}}[5,1]=0.694Wproj[5,1]=0.694 仍偏高)。趋势正确,还需继续训练。

位置 2D[2]=[1.0,1.0,0.0,0.0]D[2] = [1.0, 1.0, 0.0, 0.0]D[2]=[1.0,1.0,0.0,0.0]):

Token 计算过程 Logit
<pad> 1.0×0.055+1.0×0.052+0.0×(−0.022)+0.0×(−0.025)=0.1071.0\times0.055 + 1.0\times0.052 + 0.0\times(-0.022) + 0.0\times(-0.025) = 0.1071.0×0.055+1.0×0.052+0.0×(0.022)+0.0×(0.025)=0.107 0.107
<sos> 1.0×(−0.041)+1.0×0.052+0.0×(−0.020)+0.0×0.072=0.0111.0\times(-0.041) + 1.0\times0.052 + 0.0\times(-0.020) + 0.0\times0.072 = 0.0111.0×(0.041)+1.0×0.052+0.0×(0.020)+0.0×0.072=0.011 0.011
<eos> 1.0×(−0.041)+1.0×(−0.044)+0.0×0.078+0.0×0.075=−0.0851.0\times(-0.041) + 1.0\times(-0.044) + 0.0\times0.078 + 0.0\times0.075 = -0.0851.0×(0.041)+1.0×(0.044)+0.0×0.078+0.0×0.075=0.085 -0.085
i 1.0×0.409+1.0×(−0.048)+0.0×0.334+0.0×(−0.023)=0.3611.0\times0.409 + 1.0\times(-0.048) + 0.0\times0.334 + 0.0\times(-0.023) = 0.3611.0×0.409+1.0×(0.048)+0.0×0.334+0.0×(0.023)=0.361 0.361
love 1.0×0.709+1.0×0.103+0.0×0.051+0.0×0.145=0.8121.0\times0.709 + 1.0\times0.103 + 0.0\times0.051 + 0.0\times0.145 = 0.8121.0×0.709+1.0×0.103+0.0×0.051+0.0×0.145=0.812 0.812 → ❌
deep 1.0×0.107+1.0×0.694+0.0×0.078+0.0×(−0.045)=0.8011.0\times\mathbf{0.107} + 1.0\times\mathbf{0.694} + 0.0\times0.078 + 0.0\times(-0.045) = \mathbf{0.801}1.0×0.107+1.0×0.694+0.0×0.078+0.0×(0.045)=0.801 0.801

位置 2 改进:"deep"的 logit 从 0.7 → 0.801(显著提升!),"love"从 0.8 → 0.812(几乎不变)。两者非常接近。

5.2 重新计算损失

位置 0:logits = [0.033,−0.061,0.037,0.743,0.760,0.185][0.033, -0.061, 0.037, 0.743, 0.760, 0.185][0.033,0.061,0.037,0.743,0.760,0.185]

elogit=[1.034,0.941,1.038,2.102,2.138,1.203],sum=8.456 e^{\text{logit}} = [1.034, 0.941, 1.038, 2.102, 2.138, 1.203], \quad \text{sum}=8.456 elogit=[1.034,0.941,1.038,2.102,2.138,1.203],sum=8.456

P0=[0.122,0.111,0.123,0.249,0.253,0.142] P_0 = [0.122, 0.111, 0.123, 0.249, 0.253, 0.142] P0=[0.122,0.111,0.123,0.249,0.253,0.142]

目标 = i(token 3):P0[3]=0.249P_0[3] = 0.249P0[3]=0.249L0new=−log⁡(0.249)=1.390L_0^{\text{new}} = -\log(0.249) = 1.390L0new=log(0.249)=1.390

位置 1:logits = [0.027,0.124,0.031,−0.071,0.248,0.649][0.027, 0.124, 0.031, -0.071, 0.248, 0.649][0.027,0.124,0.031,0.071,0.248,0.649]

elogit=[1.027,1.132,1.031,0.931,1.281,1.913],sum=7.315 e^{\text{logit}} = [1.027, 1.132, 1.031, 0.931, 1.281, 1.913], \quad \text{sum}=7.315 elogit=[1.027,1.132,1.031,0.931,1.281,1.913],sum=7.315

P1=[0.140,0.155,0.141,0.127,0.175,0.262] P_1 = [0.140, 0.155, 0.141, 0.127, 0.175, 0.262] P1=[0.140,0.155,0.141,0.127,0.175,0.262]

目标 = love(token 4):P1[4]=0.175P_1[4] = 0.175P1[4]=0.175L1new=−log⁡(0.175)=1.743L_1^{\text{new}} = -\log(0.175) = 1.743L1new=log(0.175)=1.743

位置 2:logits = [0.107,0.011,−0.085,0.361,0.812,0.801][0.107, 0.011, -0.085, 0.361, 0.812, 0.801][0.107,0.011,0.085,0.361,0.812,0.801]

elogit=[1.113,1.011,0.919,1.435,2.252,2.228],sum=8.958 e^{\text{logit}} = [1.113, 1.011, 0.919, 1.435, 2.252, 2.228], \quad \text{sum}=8.958 elogit=[1.113,1.011,0.919,1.435,2.252,2.228],sum=8.958

P2=[0.124,0.113,0.103,0.160,0.251,0.249] P_2 = [0.124, 0.113, 0.103, 0.160, 0.251, 0.249] P2=[0.124,0.113,0.103,0.160,0.251,0.249]

目标 = deep(token 5):P2[5]=0.249P_2[5] = 0.249P2[5]=0.249L2new=−log⁡(0.249)=1.391L_2^{\text{new}} = -\log(0.249) = 1.391L2new=log(0.249)=1.391

总损失

Lnew=13(1.390+1.743+1.391)=4.5243=1.508 \mathcal{L}^{\text{new}} = \frac{1}{3}(1.390 + 1.743 + 1.391) = \frac{4.524}{3} = 1.508 Lnew=31(1.390+1.743+1.391)=34.524=1.508

5.3 损失变化对比

位置 正确 token 轮次 1 logit 轮次 2 logit 轮次 1 概率 轮次 2 概率 轮次 1 损失 轮次 2 损失
0 i 0.5 0.743 0.196 0.249 1.630 1.390
1 love 0.0 0.248 0.134 0.175 2.010 1.743
2 deep 0.7 0.801 0.226 0.249 1.487 1.391
总损失 1.709 1.508

✅ 验证成功!一轮 SGD 更新后,总损失从 1.709 降至 1.508,下降了 11.8%。三个位置的目标 token 概率均有提升。 这证明梯度下降正确地驱动了参数朝减少损失的方向移动。

5.4 继续训练的趋势分析

如果继续训练多轮,可以预期:

Token 问题 多轮后的趋势
i 位置 0 与 love 竞争 dim0 权重持续上升(需要更高 logit),降低 love 的 dim0
love 位置 1 被 deep 压制 dim1/dim3 权重继续上升,降低 deep 的 dim1
deep 位置 2 与 love 竞争 dim0 上升 + love 的 dim0 下降,最终 deep 胜出

经过约 10-20 轮迭代,模型应该能做出正确预测。


六、梯度回传到解码器层(概述)

以上我们详细计算了输出投影层 WprojW_{\text{proj}}Wproj 的梯度。在实际训练中,梯度还会继续向后传播,经过解码器 FFN、交叉注意力、掩码自注意力,最终到达编码器。这里简要概述梯度回传的路径。

6.1 损失对解码器输出的梯度

∂L∂D[p,k]=∑j=05∂L∂Logits[p,j]⋅Wproj[j,k]=13∑j=05δp[j]⋅Wproj[j,k] \frac{\partial \mathcal{L}}{\partial D[p,k]} = \sum_{j=0}^{5} \frac{\partial \mathcal{L}}{\partial \text{Logits}[p,j]} \cdot W_{\text{proj}}[j,k] = \frac{1}{3} \sum_{j=0}^{5} \delta_p[j] \cdot W_{\text{proj}}[j,k] D[p,k]L=j=05Logits[p,j]LWproj[j,k]=31j=05δp[j]Wproj[j,k]

这是损失回传到解码器输出的梯度,它将通过解码器各子层继续反向传播。

6.2 梯度回传路径

Loss (标量)
  ↓ dL/d(Logits)          [3×6]  ← 输出投影层梯度(已详算)
  ↓                        ↓
dL/d(D)                  [3×4]  ← 损失对解码器输出的梯度
  ↓
  ├─→ dL/d(W_FFN2)       FFN 降维层
  ├─→ dL/d(W_FFN1)       FFN 升维层
  ↓
dL/d(X_cross)            [3×4]  ← 损失对交叉注意力输入的梯度
  ↓
  ├─→ dL/d(W_V_cross)    K,V 来自编码器
  ├─→ dL/d(W_K_cross)    
  ├─→ dL/d(W_Q_cross)    Q 来自解码器
  ↓
dL/d(X_self)             [3×4]  ← 损失对掩码自注意力的梯度
  ↓
  ├─→ dL/d(W_V_self)     
  ├─→ dL/d(W_K_self)     
  ├─→ dL/d(W_Q_self)     
  ↓
dL/d(X_enc)              [3×4]  ← 梯度回传到编码器输出
  ↓
  ├─→ dL/d(W_FFN2_enc)   编码器 FFN
  ├─→ dL/d(W_FFN1_enc)   
  ↓
dL/d(X_src)              [3×4]  ← 梯度回传到编码器输入

虽然解码器内部的链式法则展开很长(涉及多个矩阵乘法和 LN 的求导),但其数学本质与第 30-32 篇中介绍的各类操作的可微性完全一致——所有操作都是可微的,梯度通过链式法则逐层回传。


七、数值实例总结

7.1 核心发现

通过这一轮完整的前向→反向→更新→再前向的数值验证,我们直观地看到了:

  1. 为什么需要训练?:初始权重导致全部三个位置预测错误,损失高达 1.709
  2. 梯度下降如何工作?:每个权重参数得到 ∂L/∂W\partial \mathcal{L}/\partial WL/W 的梯度信号,正梯度告知参数应减小,负梯度告知应增大
  3. 梯度竞争现象:同一个 token 的权重在不同位置受到方向相反的拉力,最终趋于折中
  4. 一轮训练足够了?:一轮 SGD 使损失下降 11.8%,三个目标 token 的 logit 全面上升——趋势正确,但还不足以完全纠正预测

7.2 关键数值一览

指标 轮次 1 轮次 2 变化
总损失 L\mathcal{L}L 1.709 1.508 ↓ 11.8%
位置 0 正确 logit(i) 0.5 0.743 ↑ 48.6%
位置 1 正确 logit(love) 0.0 0.248 从无到有
位置 2 正确 logit(deep) 0.7 0.801 ↑ 14.4%

7.3 梯度下降的可视化理解

                         损失曲面(示意)
                         ▲
                      1.8│  ★ (1.709)  ← 初始位置
                      1.6│     ↘
                      1.4│       ★ (1.508) ← 一步 SGD 后
                      1.2│
                      1.0│
                         └─────────────────────► 参数空间
                          W_proj 沿梯度方向更新

梯度下降的本质:在损失曲面上找到下降速度最快的方向(负梯度方向),沿此方向迈出一小步(学习率 η\etaη 控制步长),进入损失更低的区域。


八、总结

8.1 后续内容

本文用纯数值计算验证了训练流程。接下来将提供完整的 PyTorch 代码实现,包括:

  • 与本文完全对齐的模型定义(使用相同初始权重)
  • 自动梯度计算(loss.backward())验证本文手工计算
  • 多轮训练循环,展示损失持续下降到收敛
  • 最终推理结果

8.2 关于更大的模型

本文用 dmodel=4d_{\text{model}}=4dmodel=4 的极小型模型手算。在真实的 Transformer 中(dmodel=512d_{\text{model}}=512dmodel=512),计算流程完全相同,只是矩阵规模大了 128 倍,但:

  • 梯度公式 ∂L∂Wproj[j,k]∝∑p(Pp[j]−1[tp=j])⋅D[p,k]\frac{\partial \mathcal{L}}{\partial W_{\text{proj}}[j,k]} \propto \sum_p (P_p[j] - \mathbb{1}[t_p=j]) \cdot D[p,k]Wproj[j,k]Lp(Pp[j]1[tp=j])D[p,k] 保持不变
  • SGD 更新公式 W=W−η∇LW = W - \eta \nabla \mathcal{L}W=WηL 保持不变
  • 反向传播的链式法则路径保持不变

这就是数学抽象的威力——理解小规模的手算,就理解了大规模训练的全部本质。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐