深度学习的数学原理(三十五)—— Transformer 完整训练流程数值计算
写在前面
在前面的文章中,我们分别完成了编码器(第 33 篇)和解码器(第 34 篇)的前向传播数值计算。但前向传播只是模型推理的一半,真正的学习发生在反向传播过程中——当我们把预测结果与真实标签比较,求出误差关于每个参数的梯度,然后沿着梯度方向更新参数,模型的预测才会一步步变准。
本文的目标是:以 ["我", "爱", "深"] → ["i", "love", "deep"] 的机器翻译为例,首次在专栏中完整手算 Transformer 的一轮训练迭代(前向 → 损失 → 反向 → 参数更新 → 再前向)。
本文与之前文章的关系
| 前序文章 | 内容 | 本文使用方式 |
|---|---|---|
| 第 33 篇 | 编码器前向数值计算 | 直接引用其编码器输出 XencX_{\text{enc}}Xenc |
| 第 34 篇 | 解码器前向数值计算 | 直接引用其解码器输出 DDD |
| 第 25-32 篇 | 各组件原理 | 略去原理推导,专注数值计算 |
说明:本文聚焦于训练全流程的数值验证,所有组件的数学原理已在第 25-34 篇详细阐述,此处不再重复推导。代码实现将放在之后的内容。
一、问题设定
1.1 翻译任务
| 项目 | 内容 |
|---|---|
| 源序列 | ["我", "爱", "深"](取自"我爱深度学习",3 tokens) |
| 目标序列 | ["i", "love", "deep"](3 tokens) |
| 目标输入(右移) | ["<sos>", "i", "love"] |
| 任务 | 给定源序列,解码器依次预测 "i" "love" "deep" |
1.2 模型配置
为便于手算,采用极小型配置:
| 参数 | 值 | 说明 |
|---|---|---|
| dmodeld_{\text{model}}dmodel | 4 | 模型维度(所有层的统一特征维度) |
| hhh | 1 | 单头注意力 |
| dkd_kdk | 4 | 每个头的维度 |
| dffd_{ff}dff | 8 | FFN 隐藏层维度 |
| NencN_{\text{enc}}Nenc | 1 | 编码器层数 |
| NdecN_{\text{dec}}Ndec | 1 | 解码器层数 |
| 源词表 | {<pad>:0, 我:1, 爱:2, 深:3} |
4 个 token |
| 目标词表 | {<pad>:0, <sos>:1, <eos>:2, i:3, love:4, deep:5} |
6 个 token |
1.3 权重初始化策略
所有注意力 QKV 投影矩阵初始化为单位矩阵,使得前向计算时注意力分数直接由输入相似度决定,便于手算追踪。
输出投影层 Wproj∈R6×4W_{\text{proj}} \in \mathbb{R}^{6 \times 4}Wproj∈R6×4 是本文唯一训练的权重。我们有意初始化为使模型做出错误预测的值,从而清晰展示梯度下降如何修正错误。
二、轮次 1 前向传播
2.1 编码器前向(引用第 33 篇结果)
编码器输入 ["我", "爱", "深"],经过嵌入 + 位置编码后:
X0=(101001011100)←"我"←"爱"←"深" X_0 = \begin{pmatrix} 1 & 0 & 1 & 0 \\ 0 & 1 & 0 & 1 \\ 1 & 1 & 0 & 0 \end{pmatrix} \begin{aligned} &\leftarrow \text{"我"} \\ &\leftarrow \text{"爱"} \\ &\leftarrow \text{"深"} \end{aligned} X0= 101011100010 ←"我"←"爱"←"深"
经过第 33 篇详细计算的单层编码器前向传播(自注意力 + FFN + 残差 LN),得到编码器输出:
Xenc=(1.279−0.7830.655−1.151−0.7831.279−1.1510.6551.0001.000−1.000−1.000)←"我"←"爱"←"深" X_{\text{enc}} = \begin{pmatrix} 1.279 & -0.783 & 0.655 & -1.151 \\ -0.783 & 1.279 & -1.151 & 0.655 \\ 1.000 & 1.000 & -1.000 & -1.000 \end{pmatrix} \begin{aligned} &\leftarrow \text{"我"} \\ &\leftarrow \text{"爱"} \\ &\leftarrow \text{"深"} \end{aligned} Xenc= 1.279−0.7831.000−0.7831.2791.0000.655−1.151−1.000−1.1510.655−1.000 ←"我"←"爱"←"深"
编码器输出的物理意义:每一行是源序列对应 token 的上下文感知特征向量。这三个向量各不相同,说明编码器成功区分了三个源词并编码了它们之间的关系。
2.2 解码器前向(引用第 34 篇结果)
解码器输入 ["<sos>", "i", "love"],经过嵌入 + 位置编码:
X0dec=(0.10.20.10.20.60.10.20.30.20.70.10.4)←"<sos>"←"i"←"love" X_0^{\text{dec}} = \begin{pmatrix} 0.1 & 0.2 & 0.1 & 0.2 \\ 0.6 & 0.1 & 0.2 & 0.3 \\ 0.2 & 0.7 & 0.1 & 0.4 \end{pmatrix} \begin{aligned} &\leftarrow \text{"<sos>"} \\ &\leftarrow \text{"i"} \\ &\leftarrow \text{"love"} \end{aligned} X0dec= 0.10.60.20.20.10.70.10.20.10.20.30.4 ←"<sos>"←"i"←"love"
经过第 34 篇详细计算的单层解码器前向传播(因果掩码自注意力 + 交叉注意力 + FFN + 残差 LN),得到解码器输出:
D=(1.0000.0001.0000.0000.0001.0000.0001.0001.0001.0000.0000.000)←位置 0(应预测"i")←位置 1(应预测"love")←位置 2(应预测"deep") D = \begin{pmatrix} 1.000 & 0.000 & 1.000 & 0.000 \\ 0.000 & 1.000 & 0.000 & 1.000 \\ 1.000 & 1.000 & 0.000 & 0.000 \end{pmatrix} \begin{aligned} &\leftarrow \text{位置 0(应预测"i")} \\ &\leftarrow \text{位置 1(应预测"love")} \\ &\leftarrow \text{位置 2(应预测"deep")} \end{aligned} D= 1.0000.0001.0000.0001.0001.0001.0000.0000.0000.0001.0000.000 ←位置 0(应预测"i")←位置 1(应预测"love")←位置 2(应预测"deep")
注意:这里 DDD 的取值与第 33 篇编码器输入 X0X_0X0 恰好相同(均为精心设计的序列表示),这是刻意为之——用简洁的数值使后文的梯度计算清晰可追踪。
2.3 输出投影与 Logits
输出投影矩阵 Wproj∈R6×4W_{\text{proj}} \in \mathbb{R}^{6 \times 4}Wproj∈R6×4 初始化为:
Wproj=(0.10.10.00.00.00.10.00.10.00.00.10.10.30.00.20.00.80.00.10.00.00.70.10.0)←<pad>←<sos>←<eos>←i ← 关注 dim0 和 dim2←love ← 过度关注 dim0←deep ← 过度关注 dim1 W_{\text{proj}} = \begin{pmatrix} 0.1 & 0.1 & 0.0 & 0.0 \\ 0.0 & 0.1 & 0.0 & 0.1 \\ 0.0 & 0.0 & 0.1 & 0.1 \\ \mathbf{0.3} & \mathbf{0.0} & \mathbf{0.2} & \mathbf{0.0} \\ \mathbf{0.8} & \mathbf{0.0} & \mathbf{0.1} & \mathbf{0.0} \\ \mathbf{0.0} & \mathbf{0.7} & \mathbf{0.1} & \mathbf{0.0} \end{pmatrix} \begin{aligned} &\leftarrow \text{<pad>} \\ &\leftarrow \text{<sos>} \\ &\leftarrow \text{<eos>} \\ &\leftarrow \text{i ← 关注 dim0 和 dim2} \\ &\leftarrow \text{love ← 过度关注 dim0} \\ &\leftarrow \text{deep ← 过度关注 dim1} \end{aligned} Wproj= 0.10.00.00.30.80.00.10.10.00.00.00.70.00.00.10.20.10.10.00.10.10.00.00.0 ←<pad>←<sos>←<eos>←i ← 关注 dim0 和 dim2←love ← 过度关注 dim0←deep ← 过度关注 dim1
设计意图:注意加粗的三行——「i」行的权重 0.3 和 0.2 偏低,「love」行在 dim0 有 0.8(过高),「deep」行在 dim1 有 0.7(过高)。这会导致初始预测全部错误。
计算 Logits:Logits=D⋅Wproj⊤\text{Logits} = D \cdot W_{\text{proj}}^\topLogits=D⋅Wproj⊤
对于位置 ppp 和词表中第 jjj 个 token:
Logits[p,j]=∑k=03D[p,k]×Wproj[j,k] \text{Logits}[p, j] = \sum_{k=0}^{3} D[p, k] \times W_{\text{proj}}[j, k] Logits[p,j]=k=0∑3D[p,k]×Wproj[j,k]
逐位置计算:
位置 0(<sos> → 应预测 i,token 3):
D[0]=[1.0,0.0,1.0,0.0]D[0] = [1.0, 0.0, 1.0, 0.0]D[0]=[1.0,0.0,1.0,0.0]
| Token | 计算过程 | Logit |
|---|---|---|
<pad> |
1.0×0.1+0.0×0.1+1.0×0.0+0.0×0.0=0.11.0\times0.1 + 0.0\times0.1 + 1.0\times0.0 + 0.0\times0.0 = 0.11.0×0.1+0.0×0.1+1.0×0.0+0.0×0.0=0.1 | 0.1 |
<sos> |
1.0×0.0+0.0×0.1+1.0×0.0+0.0×0.1=0.01.0\times0.0 + 0.0\times0.1 + 1.0\times0.0 + 0.0\times0.1 = 0.01.0×0.0+0.0×0.1+1.0×0.0+0.0×0.1=0.0 | 0.0 |
<eos> |
1.0×0.0+0.0×0.0+1.0×0.1+0.0×0.1=0.11.0\times0.0 + 0.0\times0.0 + 1.0\times0.1 + 0.0\times0.1 = 0.11.0×0.0+0.0×0.0+1.0×0.1+0.0×0.1=0.1 | 0.1 |
| i | 1.0×0.3+0.0×0.0+1.0×0.2+0.0×0.0=0.51.0\times0.3 + 0.0\times0.0 + 1.0\times0.2 + 0.0\times0.0 = \mathbf{0.5}1.0×0.3+0.0×0.0+1.0×0.2+0.0×0.0=0.5 | 0.5 |
| love | 1.0×0.8+0.0×0.0+1.0×0.1+0.0×0.0=0.91.0\times\mathbf{0.8} + 0.0\times0.0 + 1.0\times0.1 + 0.0\times0.0 = \mathbf{0.9}1.0×0.8+0.0×0.0+1.0×0.1+0.0×0.0=0.9 | 0.9 ← ❌ 最高 |
| deep | 1.0×0.0+0.0×0.7+1.0×0.1+0.0×0.0=0.11.0\times0.0 + 0.0\times0.7 + 1.0\times0.1 + 0.0\times0.0 = 0.11.0×0.0+0.0×0.7+1.0×0.1+0.0×0.0=0.1 | 0.1 |
❌ 位置 0 预测 love(logit=0.9),应为 i(logit=0.5)
位置 1(i → 应预测 love,token 4):
D[1]=[0.0,1.0,0.0,1.0]D[1] = [0.0, 1.0, 0.0, 1.0]D[1]=[0.0,1.0,0.0,1.0]
| Token | 计算过程 | Logit |
|---|---|---|
<pad> |
0.0×0.1+1.0×0.1+0.0×0.0+1.0×0.0=0.10.0\times0.1 + 1.0\times0.1 + 0.0\times0.0 + 1.0\times0.0 = 0.10.0×0.1+1.0×0.1+0.0×0.0+1.0×0.0=0.1 | 0.1 |
<sos> |
0.0×0.0+1.0×0.1+0.0×0.0+1.0×0.1=0.20.0\times0.0 + 1.0\times0.1 + 0.0\times0.0 + 1.0\times0.1 = 0.20.0×0.0+1.0×0.1+0.0×0.0+1.0×0.1=0.2 | 0.2 |
<eos> |
0.0×0.0+1.0×0.0+0.0×0.1+1.0×0.1=0.10.0\times0.0 + 1.0\times0.0 + 0.0\times0.1 + 1.0\times0.1 = 0.10.0×0.0+1.0×0.0+0.0×0.1+1.0×0.1=0.1 | 0.1 |
| i | 0.0×0.3+1.0×0.0+0.0×0.2+1.0×0.0=0.00.0\times0.3 + 1.0\times0.0 + 0.0\times0.2 + 1.0\times0.0 = 0.00.0×0.3+1.0×0.0+0.0×0.2+1.0×0.0=0.0 | 0.0 |
| love | 0.0×0.8+1.0×0.0+0.0×0.1+1.0×0.0=0.00.0\times0.8 + 1.0\times0.0 + 0.0\times0.1 + 1.0\times0.0 = 0.00.0×0.8+1.0×0.0+0.0×0.1+1.0×0.0=0.0 | 0.0 |
| deep | 0.0×0.0+1.0×0.7+0.0×0.1+1.0×0.0=0.70.0\times0.0 + 1.0\times\mathbf{0.7} + 0.0\times0.1 + 1.0\times0.0 = \mathbf{0.7}0.0×0.0+1.0×0.7+0.0×0.1+1.0×0.0=0.7 | 0.7 ← ❌ 最高 |
❌ 位置 1 预测 deep(logit=0.7),应为 love(logit=0.0)
位置 2(love → 应预测 deep,token 5):
D[2]=[1.0,1.0,0.0,0.0]D[2] = [1.0, 1.0, 0.0, 0.0]D[2]=[1.0,1.0,0.0,0.0]
| Token | 计算过程 | Logit |
|---|---|---|
<pad> |
1.0×0.1+1.0×0.1+0.0×0.0+0.0×0.0=0.21.0\times0.1 + 1.0\times0.1 + 0.0\times0.0 + 0.0\times0.0 = 0.21.0×0.1+1.0×0.1+0.0×0.0+0.0×0.0=0.2 | 0.2 |
<sos> |
1.0×0.0+1.0×0.1+0.0×0.0+0.0×0.1=0.11.0\times0.0 + 1.0\times0.1 + 0.0\times0.0 + 0.0\times0.1 = 0.11.0×0.0+1.0×0.1+0.0×0.0+0.0×0.1=0.1 | 0.1 |
<eos> |
1.0×0.0+1.0×0.0+0.0×0.1+0.0×0.1=0.01.0\times0.0 + 1.0\times0.0 + 0.0\times0.1 + 0.0\times0.1 = 0.01.0×0.0+1.0×0.0+0.0×0.1+0.0×0.1=0.0 | 0.0 |
| i | 1.0×0.3+1.0×0.0+0.0×0.2+0.0×0.0=0.31.0\times0.3 + 1.0\times0.0 + 0.0\times0.2 + 0.0\times0.0 = 0.31.0×0.3+1.0×0.0+0.0×0.2+0.0×0.0=0.3 | 0.3 |
| love | 1.0×0.8+1.0×0.0+0.0×0.1+0.0×0.0=0.81.0\times\mathbf{0.8} + 1.0\times0.0 + 0.0\times0.1 + 0.0\times0.0 = \mathbf{0.8}1.0×0.8+1.0×0.0+0.0×0.1+0.0×0.0=0.8 | 0.8 ← ❌ 最高 |
| deep | 1.0×0.0+1.0×0.7+0.0×0.1+0.0×0.0=0.71.0\times0.0 + 1.0\times\mathbf{0.7} + 0.0\times0.1 + 0.0\times0.0 = \mathbf{0.7}1.0×0.0+1.0×0.7+0.0×0.1+0.0×0.0=0.7 | 0.7 |
❌ 位置 2 预测 love(logit=0.8),应为 deep(logit=0.7)
初始预测总结:三个位置全部预测错误。这正是我们想要的效果——初始权重是"坏的",下一步梯度更新才能展示"变好"的过程。
2.4 Softmax 与交叉熵损失
Softmax 概率
位置 0:logits = [0.1,0.0,0.1,0.5,0.9,0.1][0.1, 0.0, 0.1, 0.5, 0.9, 0.1][0.1,0.0,0.1,0.5,0.9,0.1]
elogit=[1.105,1.000,1.105,1.649,2.460,1.105],sum=8.424 e^{\text{logit}} = [1.105, 1.000, 1.105, 1.649, 2.460, 1.105], \quad \text{sum}=8.424 elogit=[1.105,1.000,1.105,1.649,2.460,1.105],sum=8.424
P0=[0.131,0.119,0.131,0.196,0.292,0.131] P_0 = [0.131, 0.119, 0.131, 0.196, 0.292, 0.131] P0=[0.131,0.119,0.131,0.196,0.292,0.131]
位置 1:logits = [0.1,0.2,0.1,0.0,0.0,0.7][0.1, 0.2, 0.1, 0.0, 0.0, 0.7][0.1,0.2,0.1,0.0,0.0,0.7]
elogit=[1.105,1.221,1.105,1.000,1.000,2.014],sum=7.445 e^{\text{logit}} = [1.105, 1.221, 1.105, 1.000, 1.000, 2.014], \quad \text{sum}=7.445 elogit=[1.105,1.221,1.105,1.000,1.000,2.014],sum=7.445
P1=[0.148,0.164,0.148,0.134,0.134,0.270] P_1 = [0.148, 0.164, 0.148, 0.134, 0.134, 0.270] P1=[0.148,0.164,0.148,0.134,0.134,0.270]
位置 2:logits = [0.2,0.1,0.0,0.3,0.8,0.7][0.2, 0.1, 0.0, 0.3, 0.8, 0.7][0.2,0.1,0.0,0.3,0.8,0.7]
elogit=[1.221,1.105,1.000,1.350,2.226,2.014],sum=8.916 e^{\text{logit}} = [1.221, 1.105, 1.000, 1.350, 2.226, 2.014], \quad \text{sum}=8.916 elogit=[1.221,1.105,1.000,1.350,2.226,2.014],sum=8.916
P2=[0.137,0.124,0.112,0.151,0.250,0.226] P_2 = [0.137, 0.124, 0.112, 0.151, 0.250, 0.226] P2=[0.137,0.124,0.112,0.151,0.250,0.226]
交叉熵损失
对于每个位置,损失为 Lp=−log(Pp[targetp])L_p = -\log(P_p[\text{target}_p])Lp=−log(Pp[targetp]):
| 位置 | 目标 token | 正确概率 | 损失 | 说明 |
|---|---|---|---|---|
| 0 | i (token 3) | 0.196 | −log(0.196)=1.630-\log(0.196)=1.630−log(0.196)=1.630 | 模型对"i"的置信度低 |
| 1 | love (token 4) | 0.134 | −log(0.134)=2.010-\log(0.134)=2.010−log(0.134)=2.010 | 模型对"love"的置信度极低 |
| 2 | deep (token 5) | 0.226 | −log(0.226)=1.487-\log(0.226)=1.487−log(0.226)=1.487 | 模型对"deep"的置信度中等 |
L=13(1.630+2.010+1.487)=5.1273=1.709 \mathcal{L} = \frac{1}{3}(1.630 + 2.010 + 1.487) = \frac{5.127}{3} = 1.709 L=31(1.630+2.010+1.487)=35.127=1.709
这是初始损失值。梯度下降的目标就是将这个值降下去。
2.5 一轮前向传播的损失图景
源序列 ["我", "爱", "深"] ──→ 编码器 ──→ X_enc (3×4)
│
目标 ["<sos>", "i", "love"] ──→ 解码器 ──→ D (3×4) ──→ W_proj ──→ Logits (3×6) ──→ Softmax ──→ Loss = 1.709
↑ ↑
交叉注意力 K,V 只有此权重可训练
三、反向传播:输出投影层梯度(完整手算)
这是本文的核心。我们将逐行计算损失 L\mathcal{L}L 关于输出投影矩阵 WprojW_{\text{proj}}Wproj 的梯度。
3.1 梯度计算的基本公式
对于输出投影层 Logits=D⋅Wproj⊤\text{Logits} = D \cdot W_{\text{proj}}^\topLogits=D⋅Wproj⊤,使用交叉熵损失 L=13∑p−log(Softmax(Logits[p,:])[tp])\mathcal{L} = \frac{1}{3}\sum_p -\log(\text{Softmax}(\text{Logits}[p,:])[t_p])L=31∑p−log(Softmax(Logits[p,:])[tp]),梯度为:
∂L∂Wproj[j,k]=13∑p=02(Pp[j]−1[tp=j])⏟位置 p 在 token j 上的 logit 梯度⋅D[p,k]⏟解码器输出 \frac{\partial \mathcal{L}}{\partial W_{\text{proj}}[j,k]} = \frac{1}{3} \sum_{p=0}^{2} \underbrace{\left( P_p[j] - \mathbb{1}[t_p = j] \right)}_{\text{位置 p 在 token j 上的 logit 梯度}} \cdot \underbrace{D[p,k]}_{\text{解码器输出}} ∂Wproj[j,k]∂L=31p=0∑2位置 p 在 token j 上的 logit 梯度 (Pp[j]−1[tp=j])⋅解码器输出 D[p,k]
其中 Pp[j]=Softmax(Logits[p,:])[j]P_p[j] = \text{Softmax}(\text{Logits}[p,:])[j]Pp[j]=Softmax(Logits[p,:])[j],1[tp=j]\mathbb{1}[t_p = j]1[tp=j] 为指示函数(当目标= jjj 时为 1,否则为 0)。
物理意义:每个位置的 logit 梯度 = (预测概率 - 真实目标指示),这个值乘以解码器输出后,沿着序列位置求和,就得到每个权重参数的梯度。
3.2 各位置的 Logit 梯度
首先计算 δp[j]=Pp[j]−1[tp=j]\delta_p[j] = P_p[j] - \mathbb{1}[t_p = j]δp[j]=Pp[j]−1[tp=j]:
| token jjj | P0[j]P_0[j]P0[j] | t0=jt_0=jt0=j? | δ0[j]\delta_0[j]δ0[j] | P1[j]P_1[j]P1[j] | t1=jt_1=jt1=j? | δ1[j]\delta_1[j]δ1[j] | P2[j]P_2[j]P2[j] | t2=jt_2=jt2=j? | δ2[j]\delta_2[j]δ2[j] |
|---|---|---|---|---|---|---|---|---|---|
<pad> |
0.131 | 否 | +0.131 | 0.148 | 否 | +0.148 | 0.137 | 否 | +0.137 |
<sos> |
0.119 | 否 | +0.119 | 0.164 | 否 | +0.164 | 0.124 | 否 | +0.124 |
<eos> |
0.131 | 否 | +0.131 | 0.148 | 否 | +0.148 | 0.112 | 否 | +0.112 |
| i | 0.196 | 是 | -0.804 | 0.134 | 否 | +0.134 | 0.151 | 否 | +0.151 |
| love | 0.292 | 否 | +0.292 | 0.134 | 是 | -0.866 | 0.250 | 否 | +0.250 |
| deep | 0.131 | 否 | +0.131 | 0.270 | 否 | +0.270 | 0.226 | 是 | -0.774 |
关键观察:对于正确的目标 token,δ=P−1<0\delta = P - 1 < 0δ=P−1<0,梯度方向告诉模型要提高该 token 的 logit;对于错误 token,δ=P>0\delta = P > 0δ=P>0,梯度方向告诉模型要降低该 token 的 logit。
3.3 解码器输出矩阵
D=(1.00.01.00.00.01.00.01.01.01.00.00.0) D = \begin{pmatrix} 1.0 & 0.0 & 1.0 & 0.0 \\ 0.0 & 1.0 & 0.0 & 1.0 \\ 1.0 & 1.0 & 0.0 & 0.0 \end{pmatrix} D= 1.00.01.00.01.01.01.00.00.00.01.00.0
3.4 梯度计算:以 token i(行 3)为例
∂L∂Wproj[3,k]=13∑p=02δp[3]⋅D[p,k] \frac{\partial \mathcal{L}}{\partial W_{\text{proj}}[3,k]} = \frac{1}{3} \sum_{p=0}^{2} \delta_p[3] \cdot D[p,k] ∂Wproj[3,k]∂L=31p=0∑2δp[3]⋅D[p,k]
δp[3]\delta_p[3]δp[3] 的值:p=0:−0.804, p=1:+0.134, p=2:+0.151p=0: -0.804,\; p=1: +0.134,\; p=2: +0.151p=0:−0.804,p=1:+0.134,p=2:+0.151
对于 k=0k=0k=0(dim0 权重):
∂L∂Wproj[3,0]=13((−0.804)×1.0+0.134×0.0+0.151×1.0)=13(−0.804+0+0.151)=−0.6533=−0.218 \frac{\partial \mathcal{L}}{\partial W_{\text{proj}}[3,0]} = \frac{1}{3}\left( (-0.804)\times1.0 + 0.134\times0.0 + 0.151\times1.0 \right) = \frac{1}{3}\left(-0.804 + 0 + 0.151\right) = \frac{-0.653}{3} = -0.218 ∂Wproj[3,0]∂L=31((−0.804)×1.0+0.134×0.0+0.151×1.0)=31(−0.804+0+0.151)=3−0.653=−0.218
梯度方向为负 → 权重应增加(负梯度指引向损失下降方向)。
对于 k=1k=1k=1(dim1 权重):
∂L∂Wproj[3,1]=13((−0.804)×0.0+0.134×1.0+0.151×1.0)=13(0+0.134+0.151)=0.2853=+0.095 \frac{\partial \mathcal{L}}{\partial W_{\text{proj}}[3,1]} = \frac{1}{3}\left( (-0.804)\times0.0 + 0.134\times1.0 + 0.151\times1.0 \right) = \frac{1}{3}\left(0 + 0.134 + 0.151\right) = \frac{0.285}{3} = +0.095 ∂Wproj[3,1]∂L=31((−0.804)×0.0+0.134×1.0+0.151×1.0)=31(0+0.134+0.151)=30.285=+0.095
梯度方向为正 → 权重应减小。
对于 k=2k=2k=2(dim2 权重):
∂L∂Wproj[3,2]=13((−0.804)×1.0+0.134×0.0+0.151×0.0)=13(−0.804+0+0)=−0.268 \frac{\partial \mathcal{L}}{\partial W_{\text{proj}}[3,2]} = \frac{1}{3}\left( (-0.804)\times1.0 + 0.134\times0.0 + 0.151\times0.0 \right) = \frac{1}{3}\left(-0.804 + 0 + 0\right) = -0.268 ∂Wproj[3,2]∂L=31((−0.804)×1.0+0.134×0.0+0.151×0.0)=31(−0.804+0+0)=−0.268
对于 k=3k=3k=3(dim3 权重):
∂L∂Wproj[3,3]=13((−0.804)×0.0+0.134×1.0+0.151×0.0)=13(0+0.134+0)=+0.045 \frac{\partial \mathcal{L}}{\partial W_{\text{proj}}[3,3]} = \frac{1}{3}\left( (-0.804)\times0.0 + 0.134\times1.0 + 0.151\times0.0 \right) = \frac{1}{3}\left(0 + 0.134 + 0\right) = +0.045 ∂Wproj[3,3]∂L=31((−0.804)×0.0+0.134×1.0+0.151×0.0)=31(0+0.134+0)=+0.045
梯度向量:
∇Wproj[3,:]=[−0.218, +0.095, −0.268, +0.045] \nabla_{W_{\text{proj}}[3,:]} = [-0.218, \; +0.095, \; -0.268, \; +0.045] ∇Wproj[3,:]=[−0.218,+0.095,−0.268,+0.045]
梯度解释:
| 维度 | 梯度 | 含义 |
|---|---|---|
| dim0 | −0.218-0.218−0.218 | 应增加 Wproj[3,0]W_{\text{proj}}[3,0]Wproj[3,0]——因为位置 0 需要"i"的 logit 提高(-0.804)且 D[0,0]=1.0 |
| dim1 | +0.095+0.095+0.095 | 应减小 Wproj[3,1]W_{\text{proj}}[3,1]Wproj[3,1]——因为位置 2 不需要"i"(+0.151)且 D[2,1]=1.0 |
| dim2 | −0.268-0.268−0.268 | 应增加 Wproj[3,2]W_{\text{proj}}[3,2]Wproj[3,2]——因为位置 0 需要"i"的 logit 提高 |
| dim3 | +0.045+0.045+0.045 | 应减小 Wproj[3,3]W_{\text{proj}}[3,3]Wproj[3,3]——位置 1 处"i"的 logit 过高(+0.134) |
梯度竞争:注意 dim0 和 dim2 的梯度告诉我们要增大权重(位置 0 需要更多"i"),而 dim1 和 dim3 的梯度告诉我们要减小权重(位置 1、2 需要更少"i")。这正是多任务训练的缩影——同一个 token 的权重需要在不同位置间找到平衡。
3.5 全部 token 的梯度(汇总表)
按照同样的方法,可以算出 WprojW_{\text{proj}}Wproj 所有 6 行 × 4 列的梯度:
核心计算公式:
gj,k=13∑p=02δp[j]⋅D[p,k] g_{j,k} = \frac{1}{3} \sum_{p=0}^{2} \delta_p[j] \cdot D[p,k] gj,k=31p=0∑2δp[j]⋅D[p,k]
以 <pad>(j=0j=0j=0)的 dim0 为例:g0,0=13(+0.131×1.0+0.148×0.0+0.137×1.0)=0.2683=+0.089g_{0,0} = \frac{1}{3}(+0.131\times1.0 + 0.148\times0.0 + 0.137\times1.0) = \frac{0.268}{3} = +0.089g0,0=31(+0.131×1.0+0.148×0.0+0.137×1.0)=30.268=+0.089
全部梯度:
| Token | δ0\delta_0δ0 | δ1\delta_1δ1 | δ2\delta_2δ2 | ∇dim0\nabla_{\text{dim0}}∇dim0 | ∇dim1\nabla_{\text{dim1}}∇dim1 | ∇dim2\nabla_{\text{dim2}}∇dim2 | ∇dim3\nabla_{\text{dim3}}∇dim3 |
|---|---|---|---|---|---|---|---|
<pad> |
+0.131 | +0.148 | +0.137 | +0.089+0.089+0.089 | +0.095+0.095+0.095 | +0.044+0.044+0.044 | +0.049+0.049+0.049 |
<sos> |
+0.119 | +0.164 | +0.124 | +0.081+0.081+0.081 | +0.096+0.096+0.096 | +0.040+0.040+0.040 | +0.055+0.055+0.055 |
<eos> |
+0.131 | +0.148 | +0.112 | +0.081+0.081+0.081 | +0.087+0.087+0.087 | +0.044+0.044+0.044 | +0.049+0.049+0.049 |
| i | -0.804 | +0.134 | +0.151 | −0.218-0.218−0.218 | +0.095+0.095+0.095 | −0.268-0.268−0.268 | +0.045+0.045+0.045 |
| love | +0.292 | -0.866 | +0.250 | +0.181+0.181+0.181 | −0.205-0.205−0.205 | +0.097+0.097+0.097 | −0.289-0.289−0.289 |
| deep | +0.131 | +0.270 | -0.774 | −0.214-0.214−0.214 | +0.012+0.012+0.012 | +0.044+0.044+0.044 | +0.090+0.090+0.090 |
完整梯度矩阵:
∇L=(+0.089+0.095+0.044+0.049+0.081+0.096+0.040+0.055+0.081+0.087+0.044+0.049−0.218+0.095−0.268+0.045+0.181−0.205+0.097−0.289−0.214+0.012+0.044+0.090) \nabla \mathcal{L} = \begin{pmatrix} +0.089 & +0.095 & +0.044 & +0.049 \\ +0.081 & +0.096 & +0.040 & +0.055 \\ +0.081 & +0.087 & +0.044 & +0.049 \\ -0.218 & +0.095 & -0.268 & +0.045 \\ +0.181 & -0.205 & +0.097 & -0.289 \\ -0.214 & +0.012 & +0.044 & +0.090 \end{pmatrix} ∇L= +0.089+0.081+0.081−0.218+0.181−0.214+0.095+0.096+0.087+0.095−0.205+0.012+0.044+0.040+0.044−0.268+0.097+0.044+0.049+0.055+0.049+0.045−0.289+0.090
模式观察:
- 目标 token(i, love, deep)对应的行有大数值梯度——因为它们的 δ=P−1\delta = P-1δ=P−1 远大于非目标 token 的 δ=P\delta = Pδ=P
- 非目标 token 的梯度较小且为正——说明它们的 logit 普遍偏高,需要降低
- 梯度的正负取决于 δ\deltaδ 和 DDD 的乘积在三个位置上的综合结果
四、参数更新(SGD)
使用学习率 η=0.5\eta = 0.5η=0.5,SGD 更新公式:
Wprojnew=Wprojold−η⋅∇L W_{\text{proj}}^{\text{new}} = W_{\text{proj}}^{\text{old}} - \eta \cdot \nabla \mathcal{L} Wprojnew=Wprojold−η⋅∇L
4.1 逐行更新
Token <pad>(行 0):
Wprojnew[0,:]=[0.1,0.1,0.0,0.0]−0.5×[+0.089,+0.095,+0.044,+0.049] W_{\text{proj}}^{\text{new}}[0,:] = [0.1, 0.1, 0.0, 0.0] - 0.5 \times [+0.089, +0.095, +0.044, +0.049] Wprojnew[0,:]=[0.1,0.1,0.0,0.0]−0.5×[+0.089,+0.095,+0.044,+0.049]
=[0.1−0.045, 0.1−0.048, 0.0−0.022, 0.0−0.025] = [0.1 - 0.045, \; 0.1 - 0.048, \; 0.0 - 0.022, \; 0.0 - 0.025] =[0.1−0.045,0.1−0.048,0.0−0.022,0.0−0.025]
=[0.055,0.052,−0.022,−0.025] = [0.055, 0.052, -0.022, -0.025] =[0.055,0.052,−0.022,−0.025]
Token <sos>(行 1):
Wprojnew[1,:]=[0.0,0.1,0.0,0.1]−0.5×[+0.081,+0.096,+0.040,+0.055] W_{\text{proj}}^{\text{new}}[1,:] = [0.0, 0.1, 0.0, 0.1] - 0.5 \times [+0.081, +0.096, +0.040, +0.055] Wprojnew[1,:]=[0.0,0.1,0.0,0.1]−0.5×[+0.081,+0.096,+0.040,+0.055]
=[0.0−0.041, 0.1−0.048, 0.0−0.020, 0.1−0.028] = [0.0 - 0.041, \; 0.1 - 0.048, \; 0.0 - 0.020, \; 0.1 - 0.028] =[0.0−0.041,0.1−0.048,0.0−0.020,0.1−0.028]
=[−0.041,0.052,−0.020,0.072] = [-0.041, 0.052, -0.020, 0.072] =[−0.041,0.052,−0.020,0.072]
Token <eos>(行 2):
Wprojnew[2,:]=[0.0,0.0,0.1,0.1]−0.5×[+0.081,+0.087,+0.044,+0.049] W_{\text{proj}}^{\text{new}}[2,:] = [0.0, 0.0, 0.1, 0.1] - 0.5 \times [+0.081, +0.087, +0.044, +0.049] Wprojnew[2,:]=[0.0,0.0,0.1,0.1]−0.5×[+0.081,+0.087,+0.044,+0.049]
=[0.0−0.041, 0.0−0.044, 0.1−0.022, 0.1−0.025] = [0.0 - 0.041, \; 0.0 - 0.044, \; 0.1 - 0.022, \; 0.1 - 0.025] =[0.0−0.041,0.0−0.044,0.1−0.022,0.1−0.025]
=[−0.041,−0.044,0.078,0.075] = [-0.041, -0.044, 0.078, 0.075] =[−0.041,−0.044,0.078,0.075]
Token i(行 3)——核心更新:
Wprojnew[3,:]=[0.3,0.0,0.2,0.0]−0.5×[−0.218,+0.095,−0.268,+0.045] W_{\text{proj}}^{\text{new}}[3,:] = [0.3, 0.0, 0.2, 0.0] - 0.5 \times [-0.218, +0.095, -0.268, +0.045] Wprojnew[3,:]=[0.3,0.0,0.2,0.0]−0.5×[−0.218,+0.095,−0.268,+0.045]
=[0.3−(−0.109), 0.0−0.048, 0.2−(−0.134), 0.0−0.023] = [0.3 - (-0.109), \; 0.0 - 0.048, \; 0.2 - (-0.134), \; 0.0 - 0.023] =[0.3−(−0.109),0.0−0.048,0.2−(−0.134),0.0−0.023]
=[0.3+0.109, −0.048, 0.2+0.134, −0.023] = [0.3 + 0.109, \; -0.048, \; 0.2 + 0.134, \; -0.023] =[0.3+0.109,−0.048,0.2+0.134,−0.023]
=[0.409, −0.048, 0.334, −0.023] = [0.409, \; -0.048, \; 0.334, \; -0.023] =[0.409,−0.048,0.334,−0.023]
✨ 关键变化:dim0 权重从 0.3 提升到 0.409,dim2 权重从 0.2 提升到 0.334。这两个维度在位置 0 处有 D[0]=[1,0,1,0]D[0] = [1, 0, 1, 0]D[0]=[1,0,1,0],提升它们直接提高了位置 0 对"i"的 logit。
Token love(行 4):
Wprojnew[4,:]=[0.8,0.0,0.1,0.0]−0.5×[+0.181,−0.205,+0.097,−0.289] W_{\text{proj}}^{\text{new}}[4,:] = [0.8, 0.0, 0.1, 0.0] - 0.5 \times [+0.181, -0.205, +0.097, -0.289] Wprojnew[4,:]=[0.8,0.0,0.1,0.0]−0.5×[+0.181,−0.205,+0.097,−0.289]
=[0.8−0.091, 0.0−(−0.103), 0.1−0.049, 0.0−(−0.145)] = [0.8 - 0.091, \; 0.0 - (-0.103), \; 0.1 - 0.049, \; 0.0 - (-0.145)] =[0.8−0.091,0.0−(−0.103),0.1−0.049,0.0−(−0.145)]
=[0.709,0.103,0.051,0.145] = [0.709, 0.103, 0.051, 0.145] =[0.709,0.103,0.051,0.145]
✨ 关键变化:dim0 权重从 0.8 降低到 0.709(初始值过高导致误预测),dim1 从 0.0 升到 0.103,dim3 从 0.0 升到 0.145——后两个维度的提升帮助位置 1(应预测"love")获得更高的 logit。
Token deep(行 5):
Wprojnew[5,:]=[0.0,0.7,0.1,0.0]−0.5×[−0.214,+0.012,+0.044,+0.090] W_{\text{proj}}^{\text{new}}[5,:] = [0.0, 0.7, 0.1, 0.0] - 0.5 \times [-0.214, +0.012, +0.044, +0.090] Wprojnew[5,:]=[0.0,0.7,0.1,0.0]−0.5×[−0.214,+0.012,+0.044,+0.090]
=[0.0−(−0.107), 0.7−0.006, 0.1−0.022, 0.0−0.045] = [0.0 - (-0.107), \; 0.7 - 0.006, \; 0.1 - 0.022, \; 0.0 - 0.045] =[0.0−(−0.107),0.7−0.006,0.1−0.022,0.0−0.045]
=[0.107,0.694,0.078,−0.045] = [0.107, 0.694, 0.078, -0.045] =[0.107,0.694,0.078,−0.045]
✨ 关键变化:dim0 从 0.0 升到 0.107(帮助位置 2 对"deep"的 logit),dim1 从 0.7 微降到 0.694(因为位置 1 的"deep"logit 过高)。
4.2 更新后的权重矩阵
Wprojnew=(0.0550.052−0.022−0.025−0.0410.052−0.0200.072−0.041−0.0440.0780.0750.409−0.0480.334−0.0230.7090.1030.0510.1450.1070.6940.078−0.045)←<pad>←<sos>←<eos>←i ↑dim0, ↑dim2←love ↓dim0, ↑dim1, ↑dim3←deep ↑dim0, ↓dim1 W_{\text{proj}}^{\text{new}} = \begin{pmatrix} 0.055 & 0.052 & -0.022 & -0.025 \\ -0.041 & 0.052 & -0.020 & 0.072 \\ -0.041 & -0.044 & 0.078 & 0.075 \\ \mathbf{0.409} & \mathbf{-0.048} & \mathbf{0.334} & \mathbf{-0.023} \\ \mathbf{0.709} & \mathbf{0.103} & \mathbf{0.051} & \mathbf{0.145} \\ \mathbf{0.107} & \mathbf{0.694} & \mathbf{0.078} & \mathbf{-0.045} \end{pmatrix} \begin{aligned} &\leftarrow \text{<pad>} \\ &\leftarrow \text{<sos>} \\ &\leftarrow \text{<eos>} \\ &\leftarrow \mathbf{\text{i ↑dim0, ↑dim2}} \\ &\leftarrow \mathbf{\text{love ↓dim0, ↑dim1, ↑dim3}} \\ &\leftarrow \mathbf{\text{deep ↑dim0, ↓dim1}} \end{aligned} Wprojnew= 0.055−0.041−0.0410.4090.7090.1070.0520.052−0.044−0.0480.1030.694−0.022−0.0200.0780.3340.0510.078−0.0250.0720.075−0.0230.145−0.045 ←<pad>←<sos>←<eos>←i ↑dim0, ↑dim2←love ↓dim0, ↑dim1, ↑dim3←deep ↑dim0, ↓dim1
五、轮次 2 前向传播(验证改进)
使用更新后的权重 WprojnewW_{\text{proj}}^{\text{new}}Wprojnew 重新前向传播,验证损失是否下降。
5.1 重新计算 Logits
位置 0(D[0]=[1.0,0.0,1.0,0.0]D[0] = [1.0, 0.0, 1.0, 0.0]D[0]=[1.0,0.0,1.0,0.0]):
| Token | 计算过程 | Logit |
|---|---|---|
<pad> |
1.0×0.055+0.0×0.052+1.0×(−0.022)+0.0×(−0.025)=0.0331.0\times0.055 + 0.0\times0.052 + 1.0\times(-0.022) + 0.0\times(-0.025) = 0.0331.0×0.055+0.0×0.052+1.0×(−0.022)+0.0×(−0.025)=0.033 | 0.033 |
<sos> |
1.0×(−0.041)+0.0×0.052+1.0×(−0.020)+0.0×0.072=−0.0611.0\times(-0.041) + 0.0\times0.052 + 1.0\times(-0.020) + 0.0\times0.072 = -0.0611.0×(−0.041)+0.0×0.052+1.0×(−0.020)+0.0×0.072=−0.061 | -0.061 |
<eos> |
1.0×(−0.041)+0.0×(−0.044)+1.0×0.078+0.0×0.075=0.0371.0\times(-0.041) + 0.0\times(-0.044) + 1.0\times0.078 + 0.0\times0.075 = 0.0371.0×(−0.041)+0.0×(−0.044)+1.0×0.078+0.0×0.075=0.037 | 0.037 |
| i | 1.0×0.409+0.0×(−0.048)+1.0×0.334+0.0×(−0.023)=0.7431.0\times\mathbf{0.409} + 0.0\times(-0.048) + 1.0\times\mathbf{0.334} + 0.0\times(-0.023) = \mathbf{0.743}1.0×0.409+0.0×(−0.048)+1.0×0.334+0.0×(−0.023)=0.743 | 0.743 ✅ |
| love | 1.0×0.709+0.0×0.103+1.0×0.051+0.0×0.145=0.7601.0\times0.709 + 0.0\times0.103 + 1.0\times0.051 + 0.0\times0.145 = 0.7601.0×0.709+0.0×0.103+1.0×0.051+0.0×0.145=0.760 | 0.760 → ❌ |
| deep | 1.0×0.107+0.0×0.694+1.0×0.078+0.0×(−0.045)=0.1851.0\times0.107 + 0.0\times0.694 + 1.0\times0.078 + 0.0\times(-0.045) = 0.1851.0×0.107+0.0×0.694+1.0×0.078+0.0×(−0.045)=0.185 | 0.185 |
✅ 位置 0 改进:"i"的 logit 从 0.5 → 0.743(大幅提升!)。但"love"仍有 0.760(由于 Wproj[4,0]W_{\text{proj}}[4,0]Wproj[4,0] 从 0.8 只降到 0.709,仍偏高)。两者非常接近,概率几乎持平——这是进步,但还需更多轮迭代。
位置 1(D[1]=[0.0,1.0,0.0,1.0]D[1] = [0.0, 1.0, 0.0, 1.0]D[1]=[0.0,1.0,0.0,1.0]):
| Token | 计算过程 | Logit |
|---|---|---|
<pad> |
0.0×0.055+1.0×0.052+0.0×(−0.022)+1.0×(−0.025)=0.0270.0\times0.055 + 1.0\times0.052 + 0.0\times(-0.022) + 1.0\times(-0.025) = 0.0270.0×0.055+1.0×0.052+0.0×(−0.022)+1.0×(−0.025)=0.027 | 0.027 |
<sos> |
0.0×(−0.041)+1.0×0.052+0.0×(−0.020)+1.0×0.072=0.1240.0\times(-0.041) + 1.0\times0.052 + 0.0\times(-0.020) + 1.0\times0.072 = 0.1240.0×(−0.041)+1.0×0.052+0.0×(−0.020)+1.0×0.072=0.124 | 0.124 |
<eos> |
0.0×(−0.041)+1.0×(−0.044)+0.0×0.078+1.0×0.075=0.0310.0\times(-0.041) + 1.0\times(-0.044) + 0.0\times0.078 + 1.0\times0.075 = 0.0310.0×(−0.041)+1.0×(−0.044)+0.0×0.078+1.0×0.075=0.031 | 0.031 |
| i | 0.0×0.409+1.0×(−0.048)+0.0×0.334+1.0×(−0.023)=−0.0710.0\times0.409 + 1.0\times(-0.048) + 0.0\times0.334 + 1.0\times(-0.023) = -0.0710.0×0.409+1.0×(−0.048)+0.0×0.334+1.0×(−0.023)=−0.071 | -0.071 |
| love | 0.0×0.709+1.0×0.103+0.0×0.051+1.0×0.145=0.2480.0\times0.709 + 1.0\times\mathbf{0.103} + 0.0\times0.051 + 1.0\times\mathbf{0.145} = \mathbf{0.248}0.0×0.709+1.0×0.103+0.0×0.051+1.0×0.145=0.248 | 0.248 ✅ |
| deep | 0.0×0.107+1.0×0.694+0.0×0.078+1.0×(−0.045)=0.6490.0\times0.107 + 1.0\times0.694 + 0.0\times0.078 + 1.0\times(-0.045) = 0.6490.0×0.107+1.0×0.694+0.0×0.078+1.0×(−0.045)=0.649 | 0.649 → ❌ |
✅ 位置 1 改进:"love"的 logit 从 0.0 → 0.248(从无到有!)。但"deep"仍有 0.649(Wproj[5,1]=0.694W_{\text{proj}}[5,1]=0.694Wproj[5,1]=0.694 仍偏高)。趋势正确,还需继续训练。
位置 2(D[2]=[1.0,1.0,0.0,0.0]D[2] = [1.0, 1.0, 0.0, 0.0]D[2]=[1.0,1.0,0.0,0.0]):
| Token | 计算过程 | Logit |
|---|---|---|
<pad> |
1.0×0.055+1.0×0.052+0.0×(−0.022)+0.0×(−0.025)=0.1071.0\times0.055 + 1.0\times0.052 + 0.0\times(-0.022) + 0.0\times(-0.025) = 0.1071.0×0.055+1.0×0.052+0.0×(−0.022)+0.0×(−0.025)=0.107 | 0.107 |
<sos> |
1.0×(−0.041)+1.0×0.052+0.0×(−0.020)+0.0×0.072=0.0111.0\times(-0.041) + 1.0\times0.052 + 0.0\times(-0.020) + 0.0\times0.072 = 0.0111.0×(−0.041)+1.0×0.052+0.0×(−0.020)+0.0×0.072=0.011 | 0.011 |
<eos> |
1.0×(−0.041)+1.0×(−0.044)+0.0×0.078+0.0×0.075=−0.0851.0\times(-0.041) + 1.0\times(-0.044) + 0.0\times0.078 + 0.0\times0.075 = -0.0851.0×(−0.041)+1.0×(−0.044)+0.0×0.078+0.0×0.075=−0.085 | -0.085 |
| i | 1.0×0.409+1.0×(−0.048)+0.0×0.334+0.0×(−0.023)=0.3611.0\times0.409 + 1.0\times(-0.048) + 0.0\times0.334 + 0.0\times(-0.023) = 0.3611.0×0.409+1.0×(−0.048)+0.0×0.334+0.0×(−0.023)=0.361 | 0.361 |
| love | 1.0×0.709+1.0×0.103+0.0×0.051+0.0×0.145=0.8121.0\times0.709 + 1.0\times0.103 + 0.0\times0.051 + 0.0\times0.145 = 0.8121.0×0.709+1.0×0.103+0.0×0.051+0.0×0.145=0.812 | 0.812 → ❌ |
| deep | 1.0×0.107+1.0×0.694+0.0×0.078+0.0×(−0.045)=0.8011.0\times\mathbf{0.107} + 1.0\times\mathbf{0.694} + 0.0\times0.078 + 0.0\times(-0.045) = \mathbf{0.801}1.0×0.107+1.0×0.694+0.0×0.078+0.0×(−0.045)=0.801 | 0.801 ✅ |
✅ 位置 2 改进:"deep"的 logit 从 0.7 → 0.801(显著提升!),"love"从 0.8 → 0.812(几乎不变)。两者非常接近。
5.2 重新计算损失
位置 0:logits = [0.033,−0.061,0.037,0.743,0.760,0.185][0.033, -0.061, 0.037, 0.743, 0.760, 0.185][0.033,−0.061,0.037,0.743,0.760,0.185]
elogit=[1.034,0.941,1.038,2.102,2.138,1.203],sum=8.456 e^{\text{logit}} = [1.034, 0.941, 1.038, 2.102, 2.138, 1.203], \quad \text{sum}=8.456 elogit=[1.034,0.941,1.038,2.102,2.138,1.203],sum=8.456
P0=[0.122,0.111,0.123,0.249,0.253,0.142] P_0 = [0.122, 0.111, 0.123, 0.249, 0.253, 0.142] P0=[0.122,0.111,0.123,0.249,0.253,0.142]
目标 = i(token 3):P0[3]=0.249P_0[3] = 0.249P0[3]=0.249,L0new=−log(0.249)=1.390L_0^{\text{new}} = -\log(0.249) = 1.390L0new=−log(0.249)=1.390
位置 1:logits = [0.027,0.124,0.031,−0.071,0.248,0.649][0.027, 0.124, 0.031, -0.071, 0.248, 0.649][0.027,0.124,0.031,−0.071,0.248,0.649]
elogit=[1.027,1.132,1.031,0.931,1.281,1.913],sum=7.315 e^{\text{logit}} = [1.027, 1.132, 1.031, 0.931, 1.281, 1.913], \quad \text{sum}=7.315 elogit=[1.027,1.132,1.031,0.931,1.281,1.913],sum=7.315
P1=[0.140,0.155,0.141,0.127,0.175,0.262] P_1 = [0.140, 0.155, 0.141, 0.127, 0.175, 0.262] P1=[0.140,0.155,0.141,0.127,0.175,0.262]
目标 = love(token 4):P1[4]=0.175P_1[4] = 0.175P1[4]=0.175,L1new=−log(0.175)=1.743L_1^{\text{new}} = -\log(0.175) = 1.743L1new=−log(0.175)=1.743
位置 2:logits = [0.107,0.011,−0.085,0.361,0.812,0.801][0.107, 0.011, -0.085, 0.361, 0.812, 0.801][0.107,0.011,−0.085,0.361,0.812,0.801]
elogit=[1.113,1.011,0.919,1.435,2.252,2.228],sum=8.958 e^{\text{logit}} = [1.113, 1.011, 0.919, 1.435, 2.252, 2.228], \quad \text{sum}=8.958 elogit=[1.113,1.011,0.919,1.435,2.252,2.228],sum=8.958
P2=[0.124,0.113,0.103,0.160,0.251,0.249] P_2 = [0.124, 0.113, 0.103, 0.160, 0.251, 0.249] P2=[0.124,0.113,0.103,0.160,0.251,0.249]
目标 = deep(token 5):P2[5]=0.249P_2[5] = 0.249P2[5]=0.249,L2new=−log(0.249)=1.391L_2^{\text{new}} = -\log(0.249) = 1.391L2new=−log(0.249)=1.391
总损失:
Lnew=13(1.390+1.743+1.391)=4.5243=1.508 \mathcal{L}^{\text{new}} = \frac{1}{3}(1.390 + 1.743 + 1.391) = \frac{4.524}{3} = 1.508 Lnew=31(1.390+1.743+1.391)=34.524=1.508
5.3 损失变化对比
| 位置 | 正确 token | 轮次 1 logit | 轮次 2 logit | 轮次 1 概率 | 轮次 2 概率 | 轮次 1 损失 | 轮次 2 损失 |
|---|---|---|---|---|---|---|---|
| 0 | i | 0.5 | 0.743 ↑ | 0.196 | 0.249 ↑ | 1.630 | 1.390 ↓ |
| 1 | love | 0.0 | 0.248 ↑ | 0.134 | 0.175 ↑ | 2.010 | 1.743 ↓ |
| 2 | deep | 0.7 | 0.801 ↑ | 0.226 | 0.249 ↑ | 1.487 | 1.391 ↓ |
| 总损失 | 1.709 | 1.508 ↓ |
✅ 验证成功!一轮 SGD 更新后,总损失从 1.709 降至 1.508,下降了 11.8%。三个位置的目标 token 概率均有提升。 这证明梯度下降正确地驱动了参数朝减少损失的方向移动。
5.4 继续训练的趋势分析
如果继续训练多轮,可以预期:
| Token | 问题 | 多轮后的趋势 |
|---|---|---|
| i | 位置 0 与 love 竞争 | dim0 权重持续上升(需要更高 logit),降低 love 的 dim0 |
| love | 位置 1 被 deep 压制 | dim1/dim3 权重继续上升,降低 deep 的 dim1 |
| deep | 位置 2 与 love 竞争 | dim0 上升 + love 的 dim0 下降,最终 deep 胜出 |
经过约 10-20 轮迭代,模型应该能做出正确预测。
六、梯度回传到解码器层(概述)
以上我们详细计算了输出投影层 WprojW_{\text{proj}}Wproj 的梯度。在实际训练中,梯度还会继续向后传播,经过解码器 FFN、交叉注意力、掩码自注意力,最终到达编码器。这里简要概述梯度回传的路径。
6.1 损失对解码器输出的梯度
∂L∂D[p,k]=∑j=05∂L∂Logits[p,j]⋅Wproj[j,k]=13∑j=05δp[j]⋅Wproj[j,k] \frac{\partial \mathcal{L}}{\partial D[p,k]} = \sum_{j=0}^{5} \frac{\partial \mathcal{L}}{\partial \text{Logits}[p,j]} \cdot W_{\text{proj}}[j,k] = \frac{1}{3} \sum_{j=0}^{5} \delta_p[j] \cdot W_{\text{proj}}[j,k] ∂D[p,k]∂L=j=0∑5∂Logits[p,j]∂L⋅Wproj[j,k]=31j=0∑5δp[j]⋅Wproj[j,k]
这是损失回传到解码器输出的梯度,它将通过解码器各子层继续反向传播。
6.2 梯度回传路径
Loss (标量)
↓ dL/d(Logits) [3×6] ← 输出投影层梯度(已详算)
↓ ↓
dL/d(D) [3×4] ← 损失对解码器输出的梯度
↓
├─→ dL/d(W_FFN2) FFN 降维层
├─→ dL/d(W_FFN1) FFN 升维层
↓
dL/d(X_cross) [3×4] ← 损失对交叉注意力输入的梯度
↓
├─→ dL/d(W_V_cross) K,V 来自编码器
├─→ dL/d(W_K_cross)
├─→ dL/d(W_Q_cross) Q 来自解码器
↓
dL/d(X_self) [3×4] ← 损失对掩码自注意力的梯度
↓
├─→ dL/d(W_V_self)
├─→ dL/d(W_K_self)
├─→ dL/d(W_Q_self)
↓
dL/d(X_enc) [3×4] ← 梯度回传到编码器输出
↓
├─→ dL/d(W_FFN2_enc) 编码器 FFN
├─→ dL/d(W_FFN1_enc)
↓
dL/d(X_src) [3×4] ← 梯度回传到编码器输入
虽然解码器内部的链式法则展开很长(涉及多个矩阵乘法和 LN 的求导),但其数学本质与第 30-32 篇中介绍的各类操作的可微性完全一致——所有操作都是可微的,梯度通过链式法则逐层回传。
七、数值实例总结
7.1 核心发现
通过这一轮完整的前向→反向→更新→再前向的数值验证,我们直观地看到了:
- 为什么需要训练?:初始权重导致全部三个位置预测错误,损失高达 1.709
- 梯度下降如何工作?:每个权重参数得到 ∂L/∂W\partial \mathcal{L}/\partial W∂L/∂W 的梯度信号,正梯度告知参数应减小,负梯度告知应增大
- 梯度竞争现象:同一个 token 的权重在不同位置受到方向相反的拉力,最终趋于折中
- 一轮训练足够了?:一轮 SGD 使损失下降 11.8%,三个目标 token 的 logit 全面上升——趋势正确,但还不足以完全纠正预测
7.2 关键数值一览
| 指标 | 轮次 1 | 轮次 2 | 变化 |
|---|---|---|---|
| 总损失 L\mathcal{L}L | 1.709 | 1.508 | ↓ 11.8% |
| 位置 0 正确 logit(i) | 0.5 | 0.743 | ↑ 48.6% |
| 位置 1 正确 logit(love) | 0.0 | 0.248 | 从无到有 |
| 位置 2 正确 logit(deep) | 0.7 | 0.801 | ↑ 14.4% |
7.3 梯度下降的可视化理解
损失曲面(示意)
▲
1.8│ ★ (1.709) ← 初始位置
1.6│ ↘
1.4│ ★ (1.508) ← 一步 SGD 后
1.2│
1.0│
└─────────────────────► 参数空间
W_proj 沿梯度方向更新
梯度下降的本质:在损失曲面上找到下降速度最快的方向(负梯度方向),沿此方向迈出一小步(学习率 η\etaη 控制步长),进入损失更低的区域。
八、总结
8.1 后续内容
本文用纯数值计算验证了训练流程。接下来将提供完整的 PyTorch 代码实现,包括:
- 与本文完全对齐的模型定义(使用相同初始权重)
- 自动梯度计算(
loss.backward())验证本文手工计算 - 多轮训练循环,展示损失持续下降到收敛
- 最终推理结果
8.2 关于更大的模型
本文用 dmodel=4d_{\text{model}}=4dmodel=4 的极小型模型手算。在真实的 Transformer 中(dmodel=512d_{\text{model}}=512dmodel=512),计算流程完全相同,只是矩阵规模大了 128 倍,但:
- 梯度公式 ∂L∂Wproj[j,k]∝∑p(Pp[j]−1[tp=j])⋅D[p,k]\frac{\partial \mathcal{L}}{\partial W_{\text{proj}}[j,k]} \propto \sum_p (P_p[j] - \mathbb{1}[t_p=j]) \cdot D[p,k]∂Wproj[j,k]∂L∝∑p(Pp[j]−1[tp=j])⋅D[p,k] 保持不变
- SGD 更新公式 W=W−η∇LW = W - \eta \nabla \mathcal{L}W=W−η∇L 保持不变
- 反向传播的链式法则路径保持不变
这就是数学抽象的威力——理解小规模的手算,就理解了大规模训练的全部本质。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐



所有评论(0)