注意力机制
注意力机制 (Attention Mechanism)
核心思想
让模型在处理信息时,学会“聚焦”于输入序列中与当前任务最相关的部分,而不是同等地看待所有信息。它模拟了人类视觉的注意力机制——我们看一张图或读一句话时,不会平均分配精力,而是关注重点。
🤔为什么需要注意力?
在注意力机制出现之前,主流的序列模型(如 RNN、LSTM、GRU)主要依赖编码器-解码器架构来处理序列到序列的任务(如机器翻译)。
- 瓶颈问题:编码器需要将整个输入序列压缩成一个固定长度的上下文向量。如果句子很长,这个固定长度的向量很难完美保存所有信息,导致“长距离依赖”丢失。
- 计算效率:RNN 必须按顺序处理序列( t1,t2,...t1,t2,...t1,t2,...),无法并行计算,训练速度慢。
注意力机制的诞生:它允许解码器在生成每一个输出词时,直接“回头”查看编码器的所有隐藏状态,从中挑选出最相关的信息,而不是只依赖最后一个状态。
直观类比
我们可以把注意力机制想象成一个字典查找的过程:
- 你有一个查询词(比如“苹果”)。
- 你有一本字典(输入序列的所有单词)。
- 你拿着“苹果”去字典里比对,发现和“水果”、“红色”这两个词条最相关。
- 你根据相关程度(权重),提取出这两个词条对应的解释(值)。
在深度学习中,这个过程被形式化为 Query (查询), Key (键), Value (值) 的交互。
核心数学原理 (Q, K, V)
注意力机制的本质是:根据 Query 和 Key 的相似度计算权重,然后对 Value 进行加权求和。
- Query ( QQQ ):当前时刻我想找什么?(例如:解码器当前的状态)
- Key ( KKK ):输入序列里有什么?(例如:编码器所有时刻的隐藏状态)
- Value ( VVV ):输入序列里对应的具体内容是什么?(通常与 KKK 相同,但在自注意力中可能不同)
计算步骤:
- 计算相似度:计算 QQQ 和每一个 KKK 的点积(或余弦相似度)。
- 缩放:除以 dk\sqrt{d_k}dk (防止点积过大导致梯度消失)。
- 归一化:使用
Softmax函数将分数转化为概率分布(权重 ααα )。 - 加权求和:用权重 ααα 乘以对应的 VVV 。
数学公式(缩放点积注意力):
Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)VAttention(Q,K,V)=softmax(dkQKT)V
其中:
- dkd_kdk 是键向量的维度。
- 1dk\frac{1}{\sqrt{d_k}}dk1 是缩放因子。
Hard vs. Soft Attention
Hard Attention (硬注意力)
Hard Attention 是一种“非此即彼”的机制。它在某一时刻只选择输入序列中的一个或几个特定位置进行关注,完全忽略其他部分。它是不可微的(离散的),通常需要特殊的训练技巧。
核心概念
如果说 Soft Attention 是“看着整张图,但眼神聚焦在鸟身上”,那么 Hard Attention 就是“用剪刀把鸟剪下来,只看这一块”。
-
机制:它是一个离散的随机过程或确定性过程。
- 输入序列长度为LLL 。
- 模型输出一个位置索引 iii (或者一组索引)。
- 输出仅包含位置 iii 的信息,其他位置的信息被直接丢弃(权重为0)。
-
数学表达:
不同于 Soft Attention 的加权求和 ∑αivi\sumα_iv_i∑αivi,Hard Attention 的输出通常是:c=viwhere i∼P(i∣q,K)c = v_i \quad \text{where } i \sim P(i|q, K)c=viwhere i∼P(i∣q,K)
或者通过一个二进制掩码(Mask) mmm 来实现,其中 mmm 只有一个位置是 1,其余是 0。
🤔为什么 Hard Attention 很难训练?
这是 Hard Attention 最关键的知识点。
-
不可微性:
标准的神经网络训练依赖反向传播,这要求计算图中的所有操作都是可微的(即可以求导数)。
- Soft Attention 使用了 Softmax,它是连续且可微的。
- Hard Attention 涉及“采样”或“取最大值索引”(Argmax),这是一个阶跃函数。阶跃函数的导数在几乎所有地方都是 0。
-
后果:
如果使用普通的梯度下降,梯度无法穿过 Hard Attention 层传回到前面的网络。网络不知道“选错了位置”是因为特征提取没做好,还是因为运气不好。
如何训练 Hard Attention?
既然不能直接求导,研究人员提出了几种解决方案:
- 强化学习 (Reinforcement Learning, RL):
- 把注意力位置的选择看作是一个“动作”(Action)。
- 把最终的模型输出准确率(或损失函数的负值)看作“奖励”(Reward)。
- 使用 策略梯度 (Policy Gradient) 或 REINFORCE 算法 来更新参数。简单来说,就是告诉网络:“刚才你选了这个位置,结果导致得分高了/低了,下次要多/少选这种位置。”
- Gumbel-Softmax 技巧:
- 这是一种“欺骗”梯度的技巧。它用一种特殊的重参数化方法,生成近似离散的样本,但在反向传播时又假装它是连续可微的。
- 直通估计器 (Straight-Through Estimator, STE):
- 在前向传播时,执行硬性的 Argmax(真的选一个)。
- 在反向传播时,假装梯度可以直接传过去(通常直接复制梯度或使用 Softmax 的梯度)。
典型应用场景
虽然 Soft Attention 更流行,但 Hard Attention 在以下领域有独特优势:
- 视觉问答 (VQA) - DeepMind 的 HAN 模型:
- 在回答“图片里有几只狗?”时,模型不需要关注背景里的树。Hard Attention 可以根据特征向量的范数(L2-norm),直接选出包含“狗”特征的区域进行计算,忽略背景。这不仅提高了准确率,还大大减少了计算量。
- 图像描述生成 (Show, Attend and Tell):
- 早期的研究对比了 Soft 和 Hard。Hard Attention 模型在生成单词 “bird” 时,会真正地把注意力“移动”到鸟所在的图像块上,生成的注意力图(Attention Map)非常清晰锐利。
- 克服灾难性遗忘:
- 在持续学习中,Hard Attention 可以用来“冻结”网络中对旧任务重要的神经元,强制模型只修改与新任务相关的部分参数。
PyTorch代码
import torch
def hard_attention_sample(probs):
"""
probs: 概率分布张量 [batch_size, seq_len]
返回: 选中的索引
"""
# 1. 根据概率分布进行多项式采样 (Multinomial Sampling)
# 这是一个不可微的操作
selected_indices = torch.multinomial(probs, num_samples=1)
return selected_indices
# 模拟数据
# 假设模型预测第3个位置的概率最高
probs = torch.tensor([[0.05, 0.05, 0.80, 0.05, 0.05]])
# 执行硬注意力选择
idx = hard_attention_sample(probs)
print(f"选中的位置索引: {idx.item()}")
# 输出可能是 2 (对应概率0.80),但也可能是其他(因为有随机性)
# 注意:在反向传播时,这个操作会阻断梯度,
# 除非使用 Gumbel-Softmax 或强化学习库来处理。
Soft Attention (软注意力)
Soft Attention 是一种“雨露均沾”但“重点突出”的机制。它不硬性选择某一个输入,而是给所有输入分配一个权重(概率分布),通过加权求和来提取信息。它是可微的,因此可以直接通过反向传播进行端到端训练。
核心概念
为了理解 Soft Attention,我们需要对比它的对立面——Hard Attention。
- Hard Attention (硬注意力):
- 行为:像探照灯。在某一时刻,只关注输入序列中的某一个特定位置,完全忽略其他位置。
- 机制:通常是离散的操作(例如:采样或取 argmax)。
- 缺点:因为是不可导的离散操作,无法直接使用梯度下降,通常需要强化学习(如策略梯度)或变分推断来训练,训练难度大且不稳定。
- Soft Attention (软注意力):
- 行为:像调节亮度的灯光。关注所有位置,但相关的部分亮(权重高),不相关的部分暗(权重接近0)。
- 机制:连续的数学运算(加权求和)。
- 优点:完全可微。梯度可以顺畅地流过注意力层,直接更新网络参数。
数学原理与计算流程
Soft Attention 的核心在于利用 Softmax 函数将原始的注意力分数转化为概率分布。
假设我们有一个查询向量 qqq (Query),以及一组输入信息的键值对 (k1,v1),(k2,v2),...,(kn,vn){(k_1,v_1),(k_2,v_2),...,(k_n,v_n)}(k1,v1),(k2,v2),...,(kn,vn)。
计算步骤如下:
-
打分 (Scoring):
计算查询 qqq 与每一个键 kik_iki 的相似度(分数 eie_iei )。常用的打分函数有点积、加性模型等。ei=score(q,ki)e_i=score(q,k_i)ei=score(q,ki)
-
归一化 (Softmax):
使用 Softmax 将分数 eie_iei 转化为注意力权重 αiα_iαi 。这确保了所有权重之和为 1,且都在 (0, 1) 之间。
αi=exp(ei)∑j=1nexp(ej)\alpha_i = \frac{\exp(e_i)}{\sum_{j=1}^{n} \exp(e_j)}αi=∑j=1nexp(ej)exp(ei)
-
加权求和 (Weighted Sum):
利用计算出的权重 αiα_iαi ,对值向量 viv_ivi 进行加权求和,得到最终的上下文向量 ccc 。c=∑i=1nαivic=\sum_{i=1}^{n}α_iv_ic=∑i=1nαivi
直观理解:
如果 qqq 和 k3k_3k3 最相关,那么 e3e_3e3 会很大,经过 Softmax 后 α3α_3α3 会接近 1,而其他的 ααα 会接近 0。最终结果 ccc 就会主要由 v3v_3v3 决定,但也包含了一点点其他 vvv 的信息。
数学公式:
Attention(Query,Source)=∑i=1LxSimilarity(Query,Keyi)∗ValueiAttention(Query, Source) = \sum_{i=1}^{L_x} Similarity(Query, Key_i) * Value_iAttention(Query,Source)=∑i=1LxSimilarity(Query,Keyi)∗Valuei
经典应用场景
Soft Attention 是许多经典模型的基石:
-
机器翻译 (Seq2Seq + Attention):
在生成目标句子的每一个单词时,模型计算当前解码器状态(Query)与源句子所有单词隐藏状态(Keys)的 Soft Attention。这使得模型在翻译 “bank” 时,能根据上下文给予 “river” 或 “money” 不同的权重。 -
图像描述生成 (Show, Attend and Tell):
这是 Soft Attention 在计算机视觉中的经典应用。
- 输入:CNN 提取的图像特征图(被切分为 LLL 个区域)。
- 过程:RNN 在生成每一个描述词(如 “bird”)时,计算 Soft Attention 权重。
- 结果:模型会给图像中包含“鸟”的区域分配高权重,给背景(如“水”、“天空”)分配低权重,从而生成准确的描述。
-
Transformer (Self-Attention):
Transformer 中的核心组件 Scaled Dot-Product Attention 本质上就是 Soft Attention 的一种高效实现形式。
进阶:关于 Softmax 的讨论
虽然 Soft Attention 依赖 Softmax,但在处理超长序列时,Softmax 带来的 O(N2)O(N^2)O(N2)计算复杂度是一个瓶颈。
- FlashAttention:这是一种优化算法,它不改变 Soft Attention 的数学结果,但通过分块计算和重计算技术,极大地减少了 GPU 显存读写(HBM 访问),使得 Soft Attention 在处理长文本时更快、更省显存。
- Linear Attention / SOFT:有些研究试图去掉 Softmax(例如使用线性核函数代替),以实现线性复杂度,但这通常属于更前沿的探索领域。目前主流的 Soft Attention 依然牢牢占据统治地位。
PyTorch代码
import torch
import torch.nn as nn
import torch.nn.functional as F
class SoftAttention(nn.Module):
def __init__(self, query_dim, key_dim, value_dim):
super(SoftAttention, self).__init__()
# 定义线性层用于计算分数 (这里使用加性模型作为示例)
self.attn = nn.Linear(query_dim + key_dim, key_dim)
self.v = nn.Parameter(torch.rand(key_dim)) # 可学习参数向量
def forward(self, query, keys, values):
# query: [batch, query_dim]
# keys: [batch, seq_len, key_dim]
# values: [batch, seq_len, value_dim]
batch_size, seq_len, _ = keys.size()
# 1. 计算能量分数 (Energy/Score)
# 这里演示广播机制:将 query 扩展以匹配 keys 的序列长度
query_expanded = query.unsqueeze(1).expand(-1, seq_len, -1)
energy = torch.tanh(self.attn(torch.cat((query_expanded, keys), dim=2)))
scores = torch.matmul(energy, self.v) # [batch, seq_len]
# 2. Softmax 归一化 -> 得到 Soft Attention 权重
# 这就是 "Soft" 的核心:概率分布
attn_weights = F.softmax(scores, dim=1)
# 3. 加权求和
context = torch.bmm(attn_weights.unsqueeze(1), values).squeeze(1)
return context, attn_weights
Soft vs. Hard Attention 对比表
| 特性 | Soft Attention | Hard Attention |
|---|---|---|
| 关注范围 | 全局关注(所有位置),权重不同 | 局部关注(只选一个或少数位置) |
| 可微性 | 可微 (Differentiable) | 不可微 (Non-differentiable) |
| 训练方法 | 标准反向传播 (Backpropagation) | 强化学习 (REINFORCE) 或 变分法 |
| 确定性 | 确定性输出 (给定输入,输出固定) | 随机性输出 (通常涉及采样) |
| 计算效率 | 较高 (并行矩阵运算) | 较低 (通常涉及循环或采样) |
| 主要用途 | 绝大多数现代 NLP/CV 模型 | 需要明确“定位”或“裁剪”的任务 |
注意力机制的演变
以下都属于软注意力,这是目前深度学习中最主流、最常用的注意力机制。
| 类型 | 描述 | 典型应用 |
|---|---|---|
| Bahdanau Attention | 最早提出的注意力机制。使用加法计算能量(Additive/MLP)。计算量稍大,但在某些任务上效果好。 | 神经机器翻译 (RNN) |
| Luong Attention | 简化了计算,使用点积或矩阵乘法。效率更高。 | 神经机器翻译 (RNN) |
| Self-Attention | Q, K, V 都来自同一个输入序列。用于捕捉句子内部单词之间的关联(例如代词指代)。 | Transformer, BERT, GPT |
| Multi-Head Attention | 将 Q, K, V 映射到不同的子空间,并行计算多次注意力,最后拼接。让模型关注不同位置的不同信息。 | Transformer |
加性注意力 (Additive)
加性注意力(也叫 Bahdanau Attention)是注意力机制的开山之作。它通过一个**小型前馈神经网络(MLP)**来计算查询(Query)和键(Key)的相似度,而不是简单的向量点积。
1.解决的问题
在早期的 RNN 机器翻译中,编码器必须将整个句子压缩成一个固定向量。加性注意力允许解码器在生成翻译时,动态地“回头看”源句子的所有隐藏状态。
2.核心数学原理
与“点积注意力”直接将两个向量相乘不同,加性注意力的核心在于拼接和非线性变换。
计算公式:
score(q,k)=vTtanh(Wqq+Wkk)score(q,k)=v^Ttanh(W_qq+W_kk)score(q,k)=vTtanh(Wqq+Wkk)
符号解释:
- qqq :查询向量(例如解码器当前的隐藏状态)。
- kkk :键向量(例如编码器某个时刻的隐藏状态)。
- Wq,WkW_q,W_kWq,Wk :可学习的权重矩阵。它们的作用是将 qqq 和 kkk 投影到同一个维度(隐藏维度 hhh )。
- vvv :可学习的权重向量(通常是一维向量)。它的作用是将经过tanhtanhtanh 激活后的向量再次映射为一个标量分数。
- tanhtanhtanh :非线性激活函数,引入了非线性能力。
计算流程:
- 线性变换:分别用矩阵 WqW_qWq 和 WkW_kWk 处理 qqq 和 kkk 。
- 相加:将变换后的两个向量相加(Element-wise addition)。
- 激活:通过 tanhtanhtanh 函数。
- 打分:与向量 vvv 做点积,得到最终的相似度分数。
- 归一化:最后通过 Softmax 得到注意力权重。
3.为什么叫“加性”?
这个名字来源于其计算过程的核心操作是向量的相加( Wqq+WkkW_qq+W_kkWqq+Wkk)。
- 对比点积:点积注意力是 q⋅kq⋅kq⋅k (乘法)。
- 对比拼接:有些教材也将其视为一种特殊的拼接(Concatenation),因为 Wqq+WkkW_qq+W_kkWqq+Wkk 在数学上等价于先拼接 [q;k][q;k][q;k] 再乘一个大矩阵,但“加性”这个术语更强调其将两个向量映射到同一空间后相加的特性。
PyTorch代码实现
import torch
import torch.nn as nn
class AdditiveAttention(nn.Module):
def __init__(self, query_dim, key_dim, hidden_dim):
super(AdditiveAttention, self).__init__()
# 1. 定义线性变换层
self.W_q = nn.Linear(query_dim, hidden_dim, bias=False)
self.W_k = nn.Linear(key_dim, hidden_dim, bias=False)
# 2. 定义最后的打分向量 v (这里用 Linear 模拟向量点积)
self.v = nn.Linear(hidden_dim, 1, bias=False)
# 3. 激活函数
self.activation = nn.Tanh()
def forward(self, query, keys):
# query: [batch_size, query_dim]
# keys: [batch_size, seq_len, key_dim]
batch_size, seq_len, _ = keys.size()
# 扩展 query 以匹配 keys 的序列长度
# query_expanded: [batch_size, seq_len, query_dim]
query_expanded = query.unsqueeze(1).expand(-1, seq_len, -1)
# 1. 线性变换并相加
# energy: [batch_size, seq_len, hidden_dim]
energy = self.W_q(query_expanded) + self.W_k(keys)
# 2. 激活函数
energy = self.activation(energy)
# 3. 与 v 做点积,得到分数
# scores: [batch_size, seq_len, 1] -> squeeze -> [batch_size, seq_len]
scores = self.v(energy).squeeze(-1)
return scores
# --- 测试 ---
# 假设 query 维度 64, key 维度 64, 隐藏层维度 32
attn = AdditiveAttention(64, 64, 32)
q = torch.randn(2, 64)
k = torch.randn(2, 10, 64) # 10 是序列长度
scores = attn(q, k)
print(f"分数形状: {scores.shape}") # [2, 10]
总结:
历史地位:它是注意力机制的鼻祖,由 Bahdanau 提出。
核心机制:使用 MLP + Tanh 计算相似度,公式为 vTtanh(Wqq+Wkk)v^Ttanh(W_qq+W_kk)vTtanh(Wqq+Wkk)。
优缺点:优点是灵活、表达能力强;缺点是计算慢,不适合超长序列。
现状:虽然被点积注意力取代成为主流,但在追求线性复杂度和移动端高效推理的新型架构中,加性注意力正在“文艺复兴”。
点积注意力 (Dot-Product)
点积注意力(也叫乘性注意力)通过计算查询向量(Query)和键向量(Key)的点积来衡量相似度。它是 Transformer 架构的核心计算单元,利用高度优化的矩阵乘法,比加性注意力更快、更节省空间。
1. 核心数学原理
点积注意力的核心思想非常直观:两个向量越相似,它们的点积就越大。
基础公式(未缩放):
score(q,k)=qTkscore(q,k)=q^Tkscore(q,k)=qTk
完整计算流程(矩阵形式):
给定查询矩阵 QQQ 、键矩阵 KKK 和值矩阵 VVV :
-
计算相似度:计算 QQ 和 KK 的转置的矩阵乘法。
Scores=QKTScores=QK^TScores=QKT
这一步会得到一个注意力分数矩阵,每个元素代表一个 Query 和一个 Key 的匹配程度。
-
缩放 (Scaling):除以 dk\sqrt{d_k}dk ( dkd_kdk 是向量的维度)。
Scaled Scores=QKTdk\text{Scaled Scores}=\frac{QK^T}{\sqrt{d_k}}Scaled Scores=dkQKT
-
归一化:使用 Softmax 将分数转化为概率(权重)。
Weights=softmax(Scaled Scores)Weights=softmax(\text{Scaled Scores})Weights=softmax(Scaled Scores)
-
加权求和:用权重乘以 VVV 。
Output=Weights⋅VOutput=Weights⋅VOutput=Weights⋅V
最终公式:
Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)VAttention(Q,K,V)=softmax(dkQKT)V
2.为什么要“缩放”?(关键点)
为什么公式里要除以 dk\sqrt{d_k}dk ?
-
问题根源:
当向量维度 dk\sqrt{d_k}dk 很大时(例如 512 或 1024), QQQ 和 KKK 的点积结果会变得非常大(数值范围会随维度线性增长)。 -
Softmax 的特性:
Softmax 函数在输入数值过大时,会进入梯度极小的区域(饱和区)。
- 如果输入很大,Softmax 的输出会接近 One-hot 分布(一个接近 1,其他接近 0)。
- 在这种极端分布下,反向传播的梯度会接近于 0(梯度消失),导致模型无法有效训练。
-
缩放的作用:
除以 dkd_kdk 可以将点积结果的方差控制在 1 左右,把数值拉回到 Softmax 的敏感区域(梯度较大区域),保证模型训练的稳定性和收敛速度。
3.PyTorch代码
import torch
import torch.nn as nn
import math
class DotProductAttention(nn.Module):
def __init__(self, dropout=0.1):
super(DotProductAttention, self).__init__()
self.dropout = nn.Dropout(dropout)
def forward(self, Q, K, V, mask=None):
# Q: [batch_size, seq_len_q, d_k]
# K: [batch_size, seq_len_k, d_k]
# V: [batch_size, seq_len_v, d_v]
d_k = Q.size(-1)
# 1. 计算 Q 和 K 的点积 (矩阵乘法)
# scores: [batch_size, seq_len_q, seq_len_k]
scores = torch.matmul(Q, K.transpose(-2, -1))
# 2. 缩放 (关键步骤!)
scores = scores / math.sqrt(d_k)
# 3. 掩码处理 (可选,用于 Padding 或 防止看未来)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# 4. Softmax 归一化
attn_weights = self.dropout(torch.softmax(scores, dim=-1))
# 5. 加权求和
# output: [batch_size, seq_len_q, d_v]
output = torch.matmul(attn_weights, V)
return output, attn_weights
# --- 测试 ---
# 假设 batch=2, 序列长度=5, 维度=64
Q = torch.randn(2, 5, 64)
K = torch.randn(2, 5, 64)
V = torch.randn(2, 5, 64)
attn = DotProductAttention()
out, weights = attn(Q, K, V)
print(f"输出形状: {out.shape}") # [2, 5, 64]
4.几何意义
从几何角度看,点积注意力是在计算向量之间的夹角:
- 点积定义: A⋅B=∣A∣∣B∣cos(θ)A⋅B=∣A∣∣B∣cos(θ)A⋅B=∣A∣∣B∣cos(θ)
- 归一化后:如果 QQQ 和 KKK 经过了归一化(长度为 1),点积就是 cos(θ)cos(θ)cos(θ) 。
- 含义:
- 点积越大 → 夹角越小 → 向量方向越一致 → 语义越相似。
- 点积为 0 → 正交 → 无关。
- 点积为负 → 方向相反 → 语义相斥。
总结:
地位:它是现代 AI 的引擎,支撑起了 Transformer 和所有大语言模型。
核心优势:快。利用 GPU 对矩阵乘法的极致优化,实现了并行计算。
关键细节:缩放因子 dk\sqrt{d_k}dk 必不可少,它是防止梯度消失、保证模型收敛的数学保障。
直观理解:通过计算向量夹角的余弦值,快速筛选出语义最相关的信息。
点积 vs. 加性 (Dot-Product vs. Additive)
| 特性 | 点积注意力 (Luong/Transformer) | 加性注意力 (Bahdanau) |
|---|---|---|
| 计算方式 | qTkq^TkqTk (向量内积) | vTtanh(Wqq+Wkk)v^Ttanh(W_qq+W_kk)vTtanh(Wqq+Wkk) (MLP) |
| 计算速度 | 极快 (利用矩阵乘法优化) | 较慢 (涉及逐元素加法和激活函数) |
| 空间效率 | 高 (参数少) | 较低 (需要额外的 WWW 和 vvv 参数) |
| 维度要求 | QQQ 和 KKK 维度必须相同 | QQQ 和 KKK 维度可以不同 |
| 主要应用 | Transformer, GPT, BERT | 早期 RNN 翻译模型 |
自注意力机制 (Self-Attention)
自注意力机制是一种**“自我反思”的机制。它允许序列中的每一个元素(如单词)都与序列中的所有其他元素**进行交互,从而根据上下文动态地更新自己的表示。
1. 核心概念:从“寻找”到“自省”
在之前的注意力机制(如 RNN 中的 Attention)中,Query 来自解码器,Key/Value 来自编码器,这是一种“跨序列”的关注。
而 Self-Attention 不同:
- 输入:只有一个序列(例如一句话)。
- 机制:序列中的每个词都会变成 Query,去查询序列中所有词(作为 Key/Value)。
- 目的:为了捕捉句子内部的依赖关系(如指代关系、语法结构)。
直观类比:全员会议
- RNN (旧方式):像击鼓传花。信息从第一个人传到第二个人,再传到第三个……传到后面时,前面的信息可能已经模糊了(长距离依赖问题)。
- Self-Attention (新方式):像全员视频会议。每个人(词)都可以直接看到和听到其他所有人。当你(比如“它”这个词)想理解自己指代谁时,你可以直接看“猫”和“垫子”,并瞬间建立联系。
2. 为什么需要 Self-Attention?
在 Transformer 出现之前,RNN/LSTM 面临两大瓶颈:
- 长距离依赖 (Long-term Dependencies):句子太长时,开头的信息传到结尾会衰减。例如:“The animal didn’t cross the street because it was too tired.” RNN 处理到“it”时,可能已经忘了“animal”。Self-Attention 可以直接计算“it”和“animal”的关联,距离无关。
- 无法并行 (Sequential Computation):RNN 必须算完 t−1t−1t−1 才能算 ttt,训练速度慢。Self-Attention 可以同时计算所有位置的关联,极大提升了训练效率。
3. 数学原理:五步计算法
Self-Attention 的核心是 Query (Q), Key (K), Value (V) 的交互。
假设输入矩阵为 XXX (包含句子中所有词的向量)。
Step 1: 线性变换 (生成 Q, K, V)
通过三个可学习的权重矩阵WQ,WK,WVW_Q,W_K,W_VWQ,WK,WV ,将输入 XXX 映射到三个空间:
Q=XWQ,K=XWK,V=XWVQ=X_WQ,K=X_WK,V=X_WVQ=XWQ,K=XWK,V=XWV
- Q (查询):我在找什么?(例如:“坐”这个词在找主语)
- K (键):我包含什么特征?(例如:“猫”的特征是动物、主语)
- V (值):我实际的内容是什么?(例如:“猫”的具体语义向量)
Step 2: 计算相似度 (点积)
计算 QQQ 和 KKK 的点积,衡量词与词之间的相关性:
Scores=QKTScores=QK^TScores=QKT
Step 3: 缩放 (Scaling)
除以 dk\sqrt{d_k}dk ( dkd_kdk 是向量维度),防止点积过大导致 Softmax 梯度消失:
Scaled Scores=QKTdk\text{Scaled Scores}=\frac{QK^T}{\sqrt{d_k}}Scaled Scores=dkQKT
Step 4: 归一化 (Softmax)
将分数转化为概率分布(权重 ααα ),所有权重之和为 1:
Weights=softmax(Scaled Scores)Weights=softmax(\text{Scaled Scores})Weights=softmax(Scaled Scores)
Step 5: 加权求和
根据权重,从 VVV 中提取信息,生成新的上下文向量:
Output=Weights⋅VOutput=Weights⋅VOutput=Weights⋅V
最终公式:
Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)VAttention(Q,K,V)=softmax(dkQKT)V
4. 实例解析:“猫坐在垫子上”
当模型处理 “坐” 这个词时,Self-Attention 会发生什么?
- 生成 Q:“坐”生成一个查询向量,意思是“我在找动作的发出者(主语)和承受者(宾语)”。
- 匹配 K:
- “坐”的 Q 与“猫”的 K 点积很高(主语匹配)。
- “坐”的 Q 与“垫子”的 K 点积较高(宾语匹配)。
- “坐”的 Q 与“在”的 K 点积较低(虚词,关系不大)。
- 加权求和 V:
- 最终“坐”的新向量 = 30%的“猫”的信息 + 25%的“坐”自身 + 30%的“垫子”的信息 + 其他。
- 结果:原本单纯的“坐”向量,现在融合了“谁在坐”和“坐在哪”的信息,语义变得极其丰富。
5. 进阶:多头注意力 (Multi-Head Attention)
如果只有一个 Self-Attention,模型只能关注一种关系(比如只看语法)。为了从不同角度理解句子,我们使用多头机制。
- 原理:将 Q,K,VQ,K,VQ,K,V拆分成 hhh 个头(例如 8 个头),每个头独立计算 Self-Attention。
- 比喻:
- 头1:关注语法结构(主谓宾)。
- 头2:关注指代关系(“它”指“猫”)。
- 头3:关注词性特征。
- 融合:最后将所有头的输出拼接(Concat)并线性变换,得到最终结果。这让模型的理解更加立体。
6.PyTorch代码
import torch
import torch.nn as nn
import math
class SelfAttention(nn.Module):
def __init__(self, embed_size, heads):
super(SelfAttention, self).__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads
# 确保维度能被头数整除
assert (self.head_dim * heads == embed_size), "Embed size needs to be div by heads"
# 定义线性变换层
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
def forward(self, values, keys, query, mask=None):
N = query.shape[0] # Batch size
value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
# 1. 拆分多头 (Split embedding into heads)
values = values.reshape(N, value_len, self.heads, self.head_dim)
keys = keys.reshape(N, key_len, self.heads, self.head_dim)
queries = query.reshape(N, query_len, self.heads, self.head_dim)
# 2. 计算能量 (Q * K^T)
# einsum 是高效的张量运算,这里执行矩阵乘法
energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
# 3. 缩放
energy = energy / (self.embed_size ** (1/2))
# 4. 掩码 (可选,用于防止看未来或处理填充)
if mask is not None:
energy = energy.masked_fill(mask == 0, float("-1e20"))
# 5. Softmax 和 加权求和
attention = torch.softmax(energy, dim=3)
out = torch.einsum("nhqk,nvhd->nqhd", [attention, values]).reshape(
N, query_len, self.heads * self.head_dim
)
# 6. 线性变换
out = self.fc_out(out)
return out
Self-Attention 的本质:它让模型不再通过“记忆”来传递信息,而是通过“观察”来全局捕捉特征。这是 BERT、GPT、LLaMA 等大模型能够理解人类语言逻辑的基石。
多头注意力 (Multi-Head Attention)
多头注意力机制通过并行计算多个独立的注意力头,让模型能够同时关注输入序列中不同位置、不同类型的信息(如语法、语义、指代关系),最后将这些信息融合。它解决了单头注意力“顾此失彼”的问题,极大地增强了模型的表达能力。
1. 为什么需要“多头”?(单头的局限)
在单头注意力(Single-Head Attention)中,模型只能计算一种注意力分布。这就像你只能用一种颜色的眼镜看世界,或者只能带一把钥匙开所有的锁。
举个例子:
句子:“The animal didn’t cross the street because it was too tired.”
对于单词 “it”,单头注意力很难同时处理以下所有关系:
- 指代关系:“it” 指代 “animal”。
- 语法关系:“it” 是 “was” 的主语。
- 语义修饰:“tired” 是形容 “it” 的状态。
如果只用一个头,这些不同的信息会被混合在一个向量空间里,导致“平均化”,丢失细节。
多头的解决方案:
这就好比开了一个“专家会诊”:
- 头 1:专门关注指代(发现 it = animal)。
- 头 2:专门关注句法结构(发现 it 是主语)。
- 头 3:专门关注词义修饰(发现 tired 修饰 it)。
最后,模型将这些专家的意见汇总,得到一个全面、精准的“it”的表示。
2. 核心工作流程
多头注意力的计算过程可以分为四个步骤:投影 -> 并行计算 -> 拼接 -> 融合。
假设输入向量的维度是 dmodeld_{model}dmodel (例如 512),头的数量是 hhh (例如 8)。
Step 1: 线性投影 (Projection)
首先,我们将输入 XXX 通过三组不同的线性变换矩阵( WQ,WK,WVW_Q,W_K,W_VWQ,WK,WV),映射到 hhh 个不同的子空间。
- 每个头只负责处理 dk=dmodel/hd_k=d_{model}/hdk=dmodel/h 维度的信息(例如 512/8=64512/8=64512/8=64 维)。
- 这意味着每个头都在一个更小的、独立的子空间里学习特定的特征。
Step 2: 并行计算 (Parallel Attention)
在每个子空间内,独立计算缩放点积注意力:
headi=Attention(QWiQ,KWiK,VWiV)head_i=Attention(QW_i^Q,KW_i^K,VW_i^V)headi=Attention(QWiQ,KWiK,VWiV)
这一步是高度并行的,8 个头同时计算,互不干扰。
Step 3: 拼接 (Concatenation)
将所有头的输出向量拼在一起,恢复成原始维度:
Concat=[head1,head2,...,headh]Concat=[head_1,head_2,...,head_h]Concat=[head1,head2,...,headh]
Step 4: 最终融合 (Final Projection)
最后,通过一个线性层( WOW^OWO )将拼接后的向量再次投影,混合所有头的信息:
MultiHead(Q,K,V)=Concat⋅WOMultiHead(Q,K,V)=Concat⋅W^OMultiHead(Q,K,V)=Concat⋅WO
完整公式:
MultiHead(Q,K,V)=Concat(head1,...,headh)WOMultiHead(Q,K,V)=Concat(head_1,...,head_h)W^OMultiHead(Q,K,V)=Concat(head1,...,headh)WO
3. 直观类比:RGB 图像
你可以把多头注意力想象成处理一张 RGB 彩色图片:
- 单头注意力:就像只看黑白照片(灰度图),你只能看到亮度信息,丢失了色彩细节。
- 多头注意力:
- 头 1 ®:关注红色通道(比如捕捉情感色彩)。
- 头 2 (G):关注绿色通道(比如捕捉语法结构)。
- 头 3 (B):关注蓝色通道(比如捕捉实体指代)。
- 最后将 R、G、B 三个通道叠加,还原出一张色彩丰富、信息完整的图片。
4.PyTorch代码
这是 Transformer 中 nn.MultiheadAttention 的底层逻辑实现,展示了如何拆分维度和并行计算。
import torch
import torch.nn as nn
import math
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, dropout=0.1):
super(MultiHeadAttention, self).__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads # 每个头的维度
# 1. 定义线性层
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model) # 输出投影
self.dropout = nn.Dropout(dropout)
def forward(self, Q, K, V, mask=None):
batch_size = Q.size(0)
# 2. 线性变换并拆分多头
# 形状变化: [Batch, Seq_Len, d_model] -> [Batch, Seq_Len, d_model]
Q = self.W_q(Q)
K = self.W_k(K)
V = self.W_v(V)
# 形状变化: [Batch, Seq_Len, num_heads, d_k] -> [Batch, num_heads, Seq_Len, d_k]
# 这一步非常关键,将维度拆分并转置,让 heads 变成独立的一维
Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# 3. 缩放点积注意力计算 (并行处理所有头)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn_weights = torch.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
# 加权求和
out = torch.matmul(attn_weights, V) # [Batch, num_heads, Seq_Len, d_k]
# 4. 拼接与最终投影
# 转置回 [Batch, Seq_Len, num_heads, d_k]
out = out.transpose(1, 2).contiguous()
# 拼接: [Batch, Seq_Len, d_model]
out = out.view(batch_size, -1, self.d_model)
# 最后通过线性层融合信息
out = self.W_o(out)
return out, attn_weights
5. 关键参数设置
在实际应用中,如何选择头的数量( hhh )?
| 参数 | 典型值 | 说明 |
|---|---|---|
| dmodeld_{model}dmodel | 512 / 768 / 1024 | 模型的总维度。 |
| h(Heads)h (Heads)h(Heads) | 8 / 12 / 16 | 头的数量。通常 dmodeld_{model}dmodel 必须能被 hhh 整除。 |
| dkd_kdk | 64 | 每个头的维度。如果 dk{d_k}dk太小,模型学不到东西;太大则计算量爆炸。经验表明 64 是一个性价比极高的数值。 |
总结:
核心作用:多视角特征提取。它允许模型在同一时间关注不同位置的不同类型的依赖关系。
计算优势:虽然看起来计算量变大了(8个头),但因为每个头的维度变小了(1/8 ),且所有头是并行计算的,所以总的时间复杂度与单头注意力几乎相同。
地位:它是 Transformer 区别于 RNN 的关键创新之一,赋予了模型强大的上下文理解能力和并行计算能力。
注意力机制的优缺点
优点:
- 解决长距离依赖:无论序列多长,任意两个位置之间的距离都是 1(直接计算),信息传递路径最短。
- 并行化:矩阵运算可以高度并行,训练速度远超 RNN。
- 可解释性:通过可视化注意力权重矩阵,我们可以看到模型在生成某个词时“关注”了输入中的哪些词(例如翻译 “Bank” 时关注了 “River”)。
缺点:
- 计算复杂度:标准注意力机制的计算复杂度是 O(n2)O(n^2)O(n2),其中 nnn 是序列长度。处理超长文本(如整本书)时显存消耗巨大。
- 位置信息丢失:纯粹的注意力机制是“排列不变”的(即打乱输入顺序,输出不变)。因此必须引入位置编码来补充位置信息。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐


所有评论(0)