注意力机制 (Attention Mechanism)

核心思想

让模型在处理信息时,学会“聚焦”于输入序列中与当前任务最相关的部分,而不是同等地看待所有信息。它模拟了人类视觉的注意力机制——我们看一张图或读一句话时,不会平均分配精力,而是关注重点。

🤔为什么需要注意力?

在注意力机制出现之前,主流的序列模型(如 RNN、LSTM、GRU)主要依赖编码器-解码器架构来处理序列到序列的任务(如机器翻译)。

  • 瓶颈问题:编码器需要将整个输入序列压缩成一个固定长度的上下文向量。如果句子很长,这个固定长度的向量很难完美保存所有信息,导致“长距离依赖”丢失。
  • 计算效率:RNN 必须按顺序处理序列( t1,t2,...t1,t2,...t1,t2,...),无法并行计算,训练速度慢。

注意力机制的诞生:它允许解码器在生成每一个输出词时,直接“回头”查看编码器的所有隐藏状态,从中挑选出最相关的信息,而不是只依赖最后一个状态。

直观类比

我们可以把注意力机制想象成一个字典查找的过程:

  1. 你有一个查询词(比如“苹果”)。
  2. 你有一本字典(输入序列的所有单词)。
  3. 你拿着“苹果”去字典里比对,发现和“水果”、“红色”这两个词条最相关。
  4. 你根据相关程度(权重),提取出这两个词条对应的解释(值)。

在深度学习中,这个过程被形式化为 Query (查询), Key (键), Value (值) 的交互。

核心数学原理 (Q, K, V)

注意力机制的本质是:根据 QueryKey 的相似度计算权重,然后对 Value 进行加权求和。

  • Query ( QQQ ):当前时刻我想找什么?(例如:解码器当前的状态)
  • Key ( KKK ):输入序列里有什么?(例如:编码器所有时刻的隐藏状态)
  • Value ( VVV ):输入序列里对应的具体内容是什么?(通常与 KKK 相同,但在自注意力中可能不同)
计算步骤:
  1. 计算相似度:计算 QQQ 和每一个 KKK 的点积(或余弦相似度)。
  2. 缩放:除以 dk\sqrt{d_k}dk (防止点积过大导致梯度消失)。
  3. 归一化:使用 Softmax 函数将分数转化为概率分布(权重 ααα )。
  4. 加权求和:用权重 ααα 乘以对应的 VVV
数学公式(缩放点积注意力):

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)VAttention(Q,K,V)=softmax(dk QKT)V

其中:

  • dkd_kdk 是键向量的维度。
  • 1dk\frac{1}{\sqrt{d_k}}dk 1 是缩放因子。

Hard vs. Soft Attention

Hard Attention (硬注意力)

Hard Attention 是一种“非此即彼”的机制。它在某一时刻只选择输入序列中的一个或几个特定位置进行关注,完全忽略其他部分。它是不可微的(离散的),通常需要特殊的训练技巧。

核心概念

如果说 Soft Attention 是“看着整张图,但眼神聚焦在鸟身上”,那么 Hard Attention 就是“用剪刀把鸟剪下来,只看这一块”。

  • 机制:它是一个离散的随机过程或确定性过程。

    • 输入序列长度为LLL
    • 模型输出一个位置索引 iii (或者一组索引)。
    • 输出仅包含位置 iii 的信息,其他位置的信息被直接丢弃(权重为0)。
  • 数学表达
    不同于 Soft Attention 的加权求和 ∑αivi\sumα_iv_iαivi,Hard Attention 的输出通常是:

    c=viwhere i∼P(i∣q,K)c = v_i \quad \text{where } i \sim P(i|q, K)c=viwhere iP(iq,K)

    或者通过一个二进制掩码(Mask) mmm 来实现,其中 mmm 只有一个位置是 1,其余是 0。

🤔为什么 Hard Attention 很难训练?

这是 Hard Attention 最关键的知识点。

  • 不可微性:

    标准的神经网络训练依赖反向传播,这要求计算图中的所有操作都是可微的(即可以求导数)。

    • Soft Attention 使用了 Softmax,它是连续且可微的。
    • Hard Attention 涉及“采样”或“取最大值索引”(Argmax),这是一个阶跃函数。阶跃函数的导数在几乎所有地方都是 0
  • 后果
    如果使用普通的梯度下降,梯度无法穿过 Hard Attention 层传回到前面的网络。网络不知道“选错了位置”是因为特征提取没做好,还是因为运气不好。

如何训练 Hard Attention?

既然不能直接求导,研究人员提出了几种解决方案:

  • 强化学习 (Reinforcement Learning, RL):
    • 把注意力位置的选择看作是一个“动作”(Action)。
    • 把最终的模型输出准确率(或损失函数的负值)看作“奖励”(Reward)。
    • 使用 策略梯度 (Policy Gradient)REINFORCE 算法 来更新参数。简单来说,就是告诉网络:“刚才你选了这个位置,结果导致得分高了/低了,下次要多/少选这种位置。”
  • Gumbel-Softmax 技巧:
    • 这是一种“欺骗”梯度的技巧。它用一种特殊的重参数化方法,生成近似离散的样本,但在反向传播时又假装它是连续可微的。
  • 直通估计器 (Straight-Through Estimator, STE):
    • 在前向传播时,执行硬性的 Argmax(真的选一个)。
    • 在反向传播时,假装梯度可以直接传过去(通常直接复制梯度或使用 Softmax 的梯度)。
典型应用场景

虽然 Soft Attention 更流行,但 Hard Attention 在以下领域有独特优势:

  • 视觉问答 (VQA) - DeepMind 的 HAN 模型:
    • 在回答“图片里有几只狗?”时,模型不需要关注背景里的树。Hard Attention 可以根据特征向量的范数(L2-norm),直接选出包含“狗”特征的区域进行计算,忽略背景。这不仅提高了准确率,还大大减少了计算量。
  • 图像描述生成 (Show, Attend and Tell):
    • 早期的研究对比了 Soft 和 Hard。Hard Attention 模型在生成单词 “bird” 时,会真正地把注意力“移动”到鸟所在的图像块上,生成的注意力图(Attention Map)非常清晰锐利。
  • 克服灾难性遗忘:
    • 在持续学习中,Hard Attention 可以用来“冻结”网络中对旧任务重要的神经元,强制模型只修改与新任务相关的部分参数。
PyTorch代码
import torch

def hard_attention_sample(probs):
    """
    probs: 概率分布张量 [batch_size, seq_len]
    返回: 选中的索引
    """
    # 1. 根据概率分布进行多项式采样 (Multinomial Sampling)
    # 这是一个不可微的操作
    selected_indices = torch.multinomial(probs, num_samples=1)
    
    return selected_indices

# 模拟数据
# 假设模型预测第3个位置的概率最高
probs = torch.tensor([[0.05, 0.05, 0.80, 0.05, 0.05]]) 

# 执行硬注意力选择
idx = hard_attention_sample(probs)
print(f"选中的位置索引: {idx.item()}") 
# 输出可能是 2 (对应概率0.80),但也可能是其他(因为有随机性)

# 注意:在反向传播时,这个操作会阻断梯度,
# 除非使用 Gumbel-Softmax 或强化学习库来处理。
Soft Attention (软注意力)

Soft Attention 是一种“雨露均沾”但“重点突出”的机制。它不硬性选择某一个输入,而是给所有输入分配一个权重(概率分布),通过加权求和来提取信息。它是可微的,因此可以直接通过反向传播进行端到端训练。

核心概念

为了理解 Soft Attention,我们需要对比它的对立面——Hard Attention。

  • Hard Attention (硬注意力):
    • 行为:像探照灯。在某一时刻,只关注输入序列中的某一个特定位置,完全忽略其他位置。
    • 机制:通常是离散的操作(例如:采样或取 argmax)。
    • 缺点:因为是不可导的离散操作,无法直接使用梯度下降,通常需要强化学习(如策略梯度)或变分推断来训练,训练难度大且不稳定。
  • Soft Attention (软注意力):
    • 行为:像调节亮度的灯光。关注所有位置,但相关的部分亮(权重高),不相关的部分暗(权重接近0)
    • 机制:连续的数学运算(加权求和)。
    • 优点完全可微。梯度可以顺畅地流过注意力层,直接更新网络参数。
数学原理与计算流程

Soft Attention 的核心在于利用 Softmax 函数将原始的注意力分数转化为概率分布。

假设我们有一个查询向量 qqq (Query),以及一组输入信息的键值对 (k1,v1),(k2,v2),...,(kn,vn){(k_1,v_1),(k_2,v_2),...,(k_n,v_n)}(k1,v1),(k2,v2),...,(kn,vn)

计算步骤如下:

  1. 打分 (Scoring)
    计算查询 qqq 与每一个键 kik_iki 的相似度(分数 eie_iei​ )。常用的打分函数有点积、加性模型等。

    ei=score(q,ki)e_i=score(q,k_i)ei=score(q,ki)

  2. 归一化 (Softmax)

    使用 Softmax 将分数 eie_iei 转化为注意力权重 αiα_iαi 。这确保了所有权重之和为 1,且都在 (0, 1) 之间。

    αi=exp⁡(ei)∑j=1nexp⁡(ej)\alpha_i = \frac{\exp(e_i)}{\sum_{j=1}^{n} \exp(e_j)}αi=j=1nexp(ej)exp(ei)

  3. 加权求和 (Weighted Sum)
    利用计算出的权重 αiα_iαi​ ,对值向量 viv_ivi 进行加权求和,得到最终的上下文向量 ccc

    c=∑i=1nαivic=\sum_{i=1}^{n}α_iv_ic=i=1nαivi

直观理解

如果 qqqk3k_3k3 最相关,那么 e3e_3e3 会很大,经过 Softmax 后 α3α_3α3 会接近 1,而其他的 ααα 会接近 0。最终结果 ccc 就会主要由 v3v_3v3 决定,但也包含了一点点其他 vvv 的信息。

数学公式

Attention(Query,Source)=∑i=1LxSimilarity(Query,Keyi)∗ValueiAttention(Query, Source) = \sum_{i=1}^{L_x} Similarity(Query, Key_i) * Value_iAttention(Query,Source)=i=1LxSimilarity(Query,Keyi)Valuei

经典应用场景

Soft Attention 是许多经典模型的基石:

  • 机器翻译 (Seq2Seq + Attention)
    在生成目标句子的每一个单词时,模型计算当前解码器状态(Query)与源句子所有单词隐藏状态(Keys)的 Soft Attention。这使得模型在翻译 “bank” 时,能根据上下文给予 “river” 或 “money” 不同的权重。

  • 图像描述生成 (Show, Attend and Tell):

    这是 Soft Attention 在计算机视觉中的经典应用。

    • 输入:CNN 提取的图像特征图(被切分为 LLL 个区域)。
    • 过程:RNN 在生成每一个描述词(如 “bird”)时,计算 Soft Attention 权重。
    • 结果:模型会给图像中包含“鸟”的区域分配高权重,给背景(如“水”、“天空”)分配低权重,从而生成准确的描述。
  • Transformer (Self-Attention)
    Transformer 中的核心组件 Scaled Dot-Product Attention 本质上就是 Soft Attention 的一种高效实现形式。

进阶:关于 Softmax 的讨论

虽然 Soft Attention 依赖 Softmax,但在处理超长序列时,Softmax 带来的 O(N2)O(N^2)O(N2)计算复杂度是一个瓶颈。

  • FlashAttention:这是一种优化算法,它不改变 Soft Attention 的数学结果,但通过分块计算和重计算技术,极大地减少了 GPU 显存读写(HBM 访问),使得 Soft Attention 在处理长文本时更快、更省显存。
  • Linear Attention / SOFT:有些研究试图去掉 Softmax(例如使用线性核函数代替),以实现线性复杂度,但这通常属于更前沿的探索领域。目前主流的 Soft Attention 依然牢牢占据统治地位。
PyTorch代码
import torch
import torch.nn as nn
import torch.nn.functional as F

class SoftAttention(nn.Module):
    def __init__(self, query_dim, key_dim, value_dim):
        super(SoftAttention, self).__init__()
        # 定义线性层用于计算分数 (这里使用加性模型作为示例)
        self.attn = nn.Linear(query_dim + key_dim, key_dim)
        self.v = nn.Parameter(torch.rand(key_dim)) # 可学习参数向量

    def forward(self, query, keys, values):
        # query: [batch, query_dim]
        # keys: [batch, seq_len, key_dim]
        # values: [batch, seq_len, value_dim]
        
        batch_size, seq_len, _ = keys.size()
        
        # 1. 计算能量分数 (Energy/Score)
        # 这里演示广播机制:将 query 扩展以匹配 keys 的序列长度
        query_expanded = query.unsqueeze(1).expand(-1, seq_len, -1)
        energy = torch.tanh(self.attn(torch.cat((query_expanded, keys), dim=2)))
        scores = torch.matmul(energy, self.v) # [batch, seq_len]
        
        # 2. Softmax 归一化 -> 得到 Soft Attention 权重
        # 这就是 "Soft" 的核心:概率分布
        attn_weights = F.softmax(scores, dim=1) 
        
        # 3. 加权求和
        context = torch.bmm(attn_weights.unsqueeze(1), values).squeeze(1)
        
        return context, attn_weights
Soft vs. Hard Attention 对比表
特性 Soft Attention Hard Attention
关注范围 全局关注(所有位置),权重不同 局部关注(只选一个或少数位置)
可微性 可微 (Differentiable) 不可微 (Non-differentiable)
训练方法 标准反向传播 (Backpropagation) 强化学习 (REINFORCE) 或 变分法
确定性 确定性输出 (给定输入,输出固定) 随机性输出 (通常涉及采样)
计算效率 较高 (并行矩阵运算) 较低 (通常涉及循环或采样)
主要用途 绝大多数现代 NLP/CV 模型 需要明确“定位”或“裁剪”的任务

注意力机制的演变

以下都属于软注意力,这是目前深度学习中最主流、最常用的注意力机制。

类型 描述 典型应用
Bahdanau Attention 最早提出的注意力机制。使用加法计算能量(Additive/MLP)。计算量稍大,但在某些任务上效果好。 神经机器翻译 (RNN)
Luong Attention 简化了计算,使用点积或矩阵乘法。效率更高。 神经机器翻译 (RNN)
Self-Attention Q, K, V 都来自同一个输入序列。用于捕捉句子内部单词之间的关联(例如代词指代)。 Transformer, BERT, GPT
Multi-Head Attention 将 Q, K, V 映射到不同的子空间,并行计算多次注意力,最后拼接。让模型关注不同位置的不同信息。 Transformer

加性注意力 (Additive)

加性注意力(也叫 Bahdanau Attention)是注意力机制的开山之作。它通过一个**小型前馈神经网络(MLP)**来计算查询(Query)和键(Key)的相似度,而不是简单的向量点积。

1.解决的问题

在早期的 RNN 机器翻译中,编码器必须将整个句子压缩成一个固定向量。加性注意力允许解码器在生成翻译时,动态地“回头看”源句子的所有隐藏状态。

2.核心数学原理

与“点积注意力”直接将两个向量相乘不同,加性注意力的核心在于拼接非线性变换

计算公式:

score(q,k)=vTtanh⁡(Wqq+Wkk)score(q,k)=v^Ttanh⁡(W_qq+W_kk)score(q,k)=vTtanh(Wqq+Wkk)

符号解释:

  • qqq :查询向量(例如解码器当前的隐藏状态)。
  • kkk :键向量(例如编码器某个时刻的隐藏状态)。
  • Wq,WkW_q,W_kWq,Wk :可学习的权重矩阵。它们的作用是将 qqqkkk 投影到同一个维度(隐藏维度 hhh )。
  • vvv :可学习的权重向量(通常是一维向量)。它的作用是将经过tanhtanhtanh 激活后的向量再次映射为一个标量分数。
  • ⁡tanh⁡tanhtanh :非线性激活函数,引入了非线性能力。

计算流程:

  1. 线性变换:分别用矩阵 WqW_qWqWkW_kWk 处理 qqqkkk
  2. 相加:将变换后的两个向量相加(Element-wise addition)。
  3. 激活:通过 ⁡tanh⁡tanhtanh 函数。
  4. 打分:与向量 vvv 做点积,得到最终的相似度分数。
  5. 归一化:最后通过 Softmax 得到注意力权重。
3.为什么叫“加性”?

这个名字来源于其计算过程的核心操作是向量的相加Wqq+WkkW_qq+W_kkWqq+Wkk)。

  • 对比点积:点积注意力是 q⋅kq⋅kqk (乘法)。
  • 对比拼接:有些教材也将其视为一种特殊的拼接(Concatenation),因为 Wqq+WkkW_qq+W_kkWqq+Wkk 在数学上等价于先拼接 [q;k][q;k][q;k] 再乘一个大矩阵,但“加性”这个术语更强调其将两个向量映射到同一空间后相加的特性。
PyTorch代码实现
import torch
import torch.nn as nn

class AdditiveAttention(nn.Module):
    def __init__(self, query_dim, key_dim, hidden_dim):
        super(AdditiveAttention, self).__init__()
        # 1. 定义线性变换层
        self.W_q = nn.Linear(query_dim, hidden_dim, bias=False)
        self.W_k = nn.Linear(key_dim, hidden_dim, bias=False)
        # 2. 定义最后的打分向量 v (这里用 Linear 模拟向量点积)
        self.v = nn.Linear(hidden_dim, 1, bias=False)
        # 3. 激活函数
        self.activation = nn.Tanh()

    def forward(self, query, keys):
        # query: [batch_size, query_dim]
        # keys: [batch_size, seq_len, key_dim]
        
        batch_size, seq_len, _ = keys.size()
        
        # 扩展 query 以匹配 keys 的序列长度
        # query_expanded: [batch_size, seq_len, query_dim]
        query_expanded = query.unsqueeze(1).expand(-1, seq_len, -1)
        
        # 1. 线性变换并相加
        # energy: [batch_size, seq_len, hidden_dim]
        energy = self.W_q(query_expanded) + self.W_k(keys)
        
        # 2. 激活函数
        energy = self.activation(energy)
        
        # 3. 与 v 做点积,得到分数
        # scores: [batch_size, seq_len, 1] ->  squeeze -> [batch_size, seq_len]
        scores = self.v(energy).squeeze(-1)
        
        return scores

# --- 测试 ---
# 假设 query 维度 64, key 维度 64, 隐藏层维度 32
attn = AdditiveAttention(64, 64, 32)
q = torch.randn(2, 64)
k = torch.randn(2, 10, 64) # 10 是序列长度

scores = attn(q, k)
print(f"分数形状: {scores.shape}") # [2, 10]

总结

历史地位:它是注意力机制的鼻祖,由 Bahdanau 提出。

核心机制:使用 MLP + Tanh 计算相似度,公式为 vTtanh⁡(Wqq+Wkk)v^Ttanh⁡(W_qq+W_kk)vTtanh(Wqq+Wkk)

优缺点:优点是灵活、表达能力强;缺点是计算慢,不适合超长序列。

现状:虽然被点积注意力取代成为主流,但在追求线性复杂度移动端高效推理的新型架构中,加性注意力正在“文艺复兴”。

点积注意力 (Dot-Product)

点积注意力(也叫乘性注意力)通过计算查询向量(Query)和键向量(Key)的点积来衡量相似度。它是 Transformer 架构的核心计算单元,利用高度优化的矩阵乘法,比加性注意力更快、更节省空间。

1. 核心数学原理

点积注意力的核心思想非常直观:两个向量越相似,它们的点积就越大。

基础公式(未缩放):

score(q,k)=qTkscore(q,k)=q^Tkscore(q,k)=qTk

完整计算流程(矩阵形式):

给定查询矩阵 QQQ 、键矩阵 KKK 和值矩阵 VVV

  1. 计算相似度:计算 QQ 和 KK 的转置的矩阵乘法。

    Scores=QKTScores=QK^TScores=QKT

    这一步会得到一个注意力分数矩阵,每个元素代表一个 Query 和一个 Key 的匹配程度。

  2. 缩放 (Scaling):除以 dk\sqrt{d_k}dk dkd_kdk 是向量的维度)。

    Scaled Scores=QKTdk\text{Scaled Scores}=\frac{QK^T}{\sqrt{d_k}}Scaled Scores=dk QKT

  3. 归一化:使用 Softmax 将分数转化为概率(权重)。

    Weights=softmax(Scaled Scores)Weights=softmax(\text{Scaled Scores})Weights=softmax(Scaled Scores)

  4. 加权求和:用权重乘以 VVV

    Output=Weights⋅VOutput=Weights⋅VOutput=WeightsV

最终公式:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)VAttention(Q,K,V)=softmax(dk QKT)V

2.为什么要“缩放”?(关键点)

为什么公式里要除以 dk\sqrt{d_k}dk

  • 问题根源
    当向量维度 dk\sqrt{d_k}dk 很大时(例如 512 或 1024), QQQKKK 的点积结果会变得非常大(数值范围会随维度线性增长)。

  • Softmax 的特性

    Softmax 函数在输入数值过大时,会进入梯度极小的区域(饱和区)。

    • 如果输入很大,Softmax 的输出会接近 One-hot 分布(一个接近 1,其他接近 0)。
    • 在这种极端分布下,反向传播的梯度会接近于 0(梯度消失),导致模型无法有效训练。
  • 缩放的作用
    除以 dkd_kdk​ 可以将点积结果的方差控制在 1 左右,把数值拉回到 Softmax 的敏感区域(梯度较大区域),保证模型训练的稳定性和收敛速度。

3.PyTorch代码
import torch
import torch.nn as nn
import math

class DotProductAttention(nn.Module):
    def __init__(self, dropout=0.1):
        super(DotProductAttention, self).__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, Q, K, V, mask=None):
        # Q: [batch_size, seq_len_q, d_k]
        # K: [batch_size, seq_len_k, d_k]
        # V: [batch_size, seq_len_v, d_v]
        
        d_k = Q.size(-1)
        
        # 1. 计算 Q 和 K 的点积 (矩阵乘法)
        # scores: [batch_size, seq_len_q, seq_len_k]
        scores = torch.matmul(Q, K.transpose(-2, -1))
        
        # 2. 缩放 (关键步骤!)
        scores = scores / math.sqrt(d_k)
        
        # 3. 掩码处理 (可选,用于 Padding 或 防止看未来)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        # 4. Softmax 归一化
        attn_weights = self.dropout(torch.softmax(scores, dim=-1))
        
        # 5. 加权求和
        # output: [batch_size, seq_len_q, d_v]
        output = torch.matmul(attn_weights, V)
        
        return output, attn_weights

# --- 测试 ---
# 假设 batch=2, 序列长度=5, 维度=64
Q = torch.randn(2, 5, 64)
K = torch.randn(2, 5, 64)
V = torch.randn(2, 5, 64)

attn = DotProductAttention()
out, weights = attn(Q, K, V)
print(f"输出形状: {out.shape}") # [2, 5, 64]
4.几何意义

从几何角度看,点积注意力是在计算向量之间的夹角

  • 点积定义A⋅B=∣A∣∣B∣cos⁡(θ)A⋅B=∣A∣∣B∣cos⁡(θ)AB=∣A∣∣Bcos(θ)
  • 归一化后:如果 QQQKKK 经过了归一化(长度为 1),点积就是 cos⁡(θ)cos⁡(θ)cos(θ)
  • 含义:
    • 点积越大 → 夹角越小 → 向量方向越一致 → 语义越相似
    • 点积为 0 → 正交 → 无关
    • 点积为负 → 方向相反 → 语义相斥

总结

地位:它是现代 AI 的引擎,支撑起了 Transformer 和所有大语言模型。

核心优势。利用 GPU 对矩阵乘法的极致优化,实现了并行计算。

关键细节缩放因子 dk\sqrt{d_k}dk 必不可少,它是防止梯度消失、保证模型收敛的数学保障。

直观理解:通过计算向量夹角的余弦值,快速筛选出语义最相关的信息。

点积 vs. 加性 (Dot-Product vs. Additive)

特性 点积注意力 (Luong/Transformer) 加性注意力 (Bahdanau)
计算方式 qTkq^TkqTk (向量内积) vTtanh(Wqq+Wkk)v^Ttanh(W_qq+W_kk)vTtanh(Wqq+Wkk) (MLP)
计算速度 极快 (利用矩阵乘法优化) 较慢 (涉及逐元素加法和激活函数)
空间效率 高 (参数少) 较低 (需要额外的 WWWvvv 参数)
维度要求 QQQKKK 维度必须相同 QQQKKK 维度可以不同
主要应用 Transformer, GPT, BERT 早期 RNN 翻译模型

自注意力机制 (Self-Attention)

自注意力机制是一种**“自我反思”的机制。它允许序列中的每一个元素(如单词)都与序列中的所有其他元素**进行交互,从而根据上下文动态地更新自己的表示。

1. 核心概念:从“寻找”到“自省”

在之前的注意力机制(如 RNN 中的 Attention)中,Query 来自解码器,Key/Value 来自编码器,这是一种“跨序列”的关注。

Self-Attention 不同:

  • 输入:只有一个序列(例如一句话)。
  • 机制:序列中的每个词都会变成 Query,去查询序列中所有词(作为 Key/Value)。
  • 目的:为了捕捉句子内部的依赖关系(如指代关系、语法结构)。

直观类比:全员会议

  • RNN (旧方式):像击鼓传花。信息从第一个人传到第二个人,再传到第三个……传到后面时,前面的信息可能已经模糊了(长距离依赖问题)。
  • Self-Attention (新方式):像全员视频会议。每个人(词)都可以直接看到和听到其他所有人。当你(比如“它”这个词)想理解自己指代谁时,你可以直接看“猫”和“垫子”,并瞬间建立联系。
2. 为什么需要 Self-Attention?

在 Transformer 出现之前,RNN/LSTM 面临两大瓶颈:

  1. 长距离依赖 (Long-term Dependencies):句子太长时,开头的信息传到结尾会衰减。例如:“The animal didn’t cross the street because it was too tired.” RNN 处理到“it”时,可能已经忘了“animal”。Self-Attention 可以直接计算“it”和“animal”的关联,距离无关。
  2. 无法并行 (Sequential Computation):RNN 必须算完 t−1t−1t1 才能算 ttt,训练速度慢。Self-Attention 可以同时计算所有位置的关联,极大提升了训练效率。
3. 数学原理:五步计算法

Self-Attention 的核心是 Query (Q), Key (K), Value (V) 的交互。

假设输入矩阵为 XXX (包含句子中所有词的向量)。

Step 1: 线性变换 (生成 Q, K, V)
通过三个可学习的权重矩阵WQ,WK,WVW_Q,W_K,W_VWQ,WK,WV ,将输入 XXX 映射到三个空间:

Q=XWQ,K=XWK,V=XWVQ=X_WQ,K=X_WK,V=X_WVQ=XWQ,K=XWK,V=XWV

  • Q (查询):我在找什么?(例如:“坐”这个词在找主语)
  • K (键):我包含什么特征?(例如:“猫”的特征是动物、主语)
  • V (值):我实际的内容是什么?(例如:“猫”的具体语义向量)

Step 2: 计算相似度 (点积)
计算 QQQKKK 的点积,衡量词与词之间的相关性:

Scores=QKT​Scores=QK^T​Scores=QKT

Step 3: 缩放 (Scaling)
除以 dk\sqrt{d_k}dk dkd_kdk 是向量维度),防止点积过大导致 Softmax 梯度消失:

Scaled Scores=QKTdk\text{Scaled Scores}=\frac{QK^T}{\sqrt{d_k}}Scaled Scores=dk QKT

Step 4: 归一化 (Softmax)
将分数转化为概率分布(权重 ααα ),所有权重之和为 1:

Weights=softmax(Scaled Scores)Weights=softmax(\text{Scaled Scores})Weights=softmax(Scaled Scores)

Step 5: 加权求和
根据权重,从 VVV 中提取信息,生成新的上下文向量:

Output=Weights⋅VOutput=Weights⋅VOutput=WeightsV

最终公式:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)VAttention(Q,K,V)=softmax(dk QKT)V

4. 实例解析:“猫坐在垫子上”

当模型处理 “坐” 这个词时,Self-Attention 会发生什么?

  1. 生成 Q:“坐”生成一个查询向量,意思是“我在找动作的发出者(主语)和承受者(宾语)”。
  2. 匹配 K:
    • “坐”的 Q 与“猫”的 K 点积很高(主语匹配)。
    • “坐”的 Q 与“垫子”的 K 点积较高(宾语匹配)。
    • “坐”的 Q 与“在”的 K 点积较低(虚词,关系不大)。
  3. 加权求和 V:
    • 最终“坐”的新向量 = 30%的“猫”的信息 + 25%的“坐”自身 + 30%的“垫子”的信息 + 其他。
  4. 结果:原本单纯的“坐”向量,现在融合了“谁在坐”和“坐在哪”的信息,语义变得极其丰富。
5. 进阶:多头注意力 (Multi-Head Attention)

如果只有一个 Self-Attention,模型只能关注一种关系(比如只看语法)。为了从不同角度理解句子,我们使用多头机制

  • 原理:将 Q,K,VQ,K,VQ,K,V拆分成 hhh 个头(例如 8 个头),每个头独立计算 Self-Attention。
  • 比喻:
    • 头1:关注语法结构(主谓宾)。
    • 头2:关注指代关系(“它”指“猫”)。
    • 头3:关注词性特征。
  • 融合:最后将所有头的输出拼接(Concat)并线性变换,得到最终结果。这让模型的理解更加立体。
6.PyTorch代码
import torch
import torch.nn as nn
import math

class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads
        
        # 确保维度能被头数整除
        assert (self.head_dim * heads == embed_size), "Embed size needs to be div by heads"
        
        # 定义线性变换层
        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask=None):
        N = query.shape[0] # Batch size
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
        
        # 1. 拆分多头 (Split embedding into heads)
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)
        
        # 2. 计算能量 (Q * K^T)
        # einsum 是高效的张量运算,这里执行矩阵乘法
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        
        # 3. 缩放
        energy = energy / (self.embed_size ** (1/2))
        
        # 4. 掩码 (可选,用于防止看未来或处理填充)
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))
        
        # 5. Softmax 和 加权求和
        attention = torch.softmax(energy, dim=3)
        out = torch.einsum("nhqk,nvhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )
        
        # 6. 线性变换
        out = self.fc_out(out)
        return out

Self-Attention 的本质:它让模型不再通过“记忆”来传递信息,而是通过“观察”来全局捕捉特征。这是 BERT、GPT、LLaMA 等大模型能够理解人类语言逻辑的基石。

多头注意力 (Multi-Head Attention)

多头注意力机制通过并行计算多个独立的注意力头,让模型能够同时关注输入序列中不同位置、不同类型的信息(如语法、语义、指代关系),最后将这些信息融合。它解决了单头注意力“顾此失彼”的问题,极大地增强了模型的表达能力。

1. 为什么需要“多头”?(单头的局限)

在单头注意力(Single-Head Attention)中,模型只能计算一种注意力分布。这就像你只能用一种颜色的眼镜看世界,或者只能带一把钥匙开所有的锁。

举个例子:
句子:“The animal didn’t cross the street because it was too tired.”

对于单词 “it”,单头注意力很难同时处理以下所有关系:

  1. 指代关系:“it” 指代 “animal”。
  2. 语法关系:“it” 是 “was” 的主语。
  3. 语义修饰:“tired” 是形容 “it” 的状态。

如果只用一个头,这些不同的信息会被混合在一个向量空间里,导致“平均化”,丢失细节。

多头的解决方案:
这就好比开了一个“专家会诊”:

  • 头 1:专门关注指代(发现 it = animal)。
  • 头 2:专门关注句法结构(发现 it 是主语)。
  • 头 3:专门关注词义修饰(发现 tired 修饰 it)。
    最后,模型将这些专家的意见汇总,得到一个全面、精准的“it”的表示。
2. 核心工作流程

多头注意力的计算过程可以分为四个步骤:投影 -> 并行计算 -> 拼接 -> 融合

假设输入向量的维度是 dmodeld_{model}dmodel (例如 512),头的数量是 hhh (例如 8)。

Step 1: 线性投影 (Projection)
首先,我们将输入 XXX 通过三组不同的线性变换矩阵( WQ,WK,WVW_Q,W_K,W_VWQ,WK,WV),映射到 hhh 个不同的子空间。

  • 每个头只负责处理 dk=dmodel/hd_k=d_{model}/hdk=dmodel/h 维度的信息(例如 512/8=64512/8=64512/8=64 维)。
  • 这意味着每个头都在一个更小的、独立的子空间里学习特定的特征。

Step 2: 并行计算 (Parallel Attention)
在每个子空间内,独立计算缩放点积注意力:

headi=Attention(QWiQ,KWiK,VWiV)head_i=Attention(QW_i^Q,KW_i^K,VW_i^V)headi=Attention(QWiQ,KWiK,VWiV)

这一步是高度并行的,8 个头同时计算,互不干扰。

Step 3: 拼接 (Concatenation)
将所有头的输出向量拼在一起,恢复成原始维度:

Concat=[head1,head2,...,headh]Concat=[head_1,head_2,...,head_h]Concat=[head1,head2,...,headh]

Step 4: 最终融合 (Final Projection)
最后,通过一个线性层( WOW^OWO )将拼接后的向量再次投影,混合所有头的信息:

MultiHead(Q,K,V)=Concat⋅WOMultiHead(Q,K,V)=Concat⋅W^OMultiHead(Q,K,V)=ConcatWO

完整公式:

MultiHead(Q,K,V)=Concat(head1,...,headh)WOMultiHead(Q,K,V)=Concat(head_1,...,head_h)W^OMultiHead(Q,K,V)=Concat(head1,...,headh)WO

3. 直观类比:RGB 图像

你可以把多头注意力想象成处理一张 RGB 彩色图片

  • 单头注意力:就像只看黑白照片(灰度图),你只能看到亮度信息,丢失了色彩细节。
  • 多头注意力:
    • 头 1 ®:关注红色通道(比如捕捉情感色彩)。
    • 头 2 (G):关注绿色通道(比如捕捉语法结构)。
    • 头 3 (B):关注蓝色通道(比如捕捉实体指代)。
    • 最后将 R、G、B 三个通道叠加,还原出一张色彩丰富、信息完整的图片。
4.PyTorch代码

这是 Transformer 中 nn.MultiheadAttention 的底层逻辑实现,展示了如何拆分维度和并行计算。

import torch
import torch.nn as nn
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # 每个头的维度
        
        # 1. 定义线性层
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model) # 输出投影
        
        self.dropout = nn.Dropout(dropout)

    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)
        
        # 2. 线性变换并拆分多头
        # 形状变化: [Batch, Seq_Len, d_model] -> [Batch, Seq_Len, d_model]
        Q = self.W_q(Q)
        K = self.W_k(K)
        V = self.W_v(V)
        
        # 形状变化: [Batch, Seq_Len, num_heads, d_k] -> [Batch, num_heads, Seq_Len, d_k]
        # 这一步非常关键,将维度拆分并转置,让 heads 变成独立的一维
        Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        
        # 3. 缩放点积注意力计算 (并行处理所有头)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
            
        attn_weights = torch.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # 加权求和
        out = torch.matmul(attn_weights, V) # [Batch, num_heads, Seq_Len, d_k]
        
        # 4. 拼接与最终投影
        # 转置回 [Batch, Seq_Len, num_heads, d_k]
        out = out.transpose(1, 2).contiguous()
        # 拼接: [Batch, Seq_Len, d_model]
        out = out.view(batch_size, -1, self.d_model)
        
        # 最后通过线性层融合信息
        out = self.W_o(out)
        
        return out, attn_weights
5. 关键参数设置

在实际应用中,如何选择头的数量( hhh )?

参数 典型值 说明
dmodeld_{model}dmodel 512 / 768 / 1024 模型的总维度。
h(Heads)h (Heads)h(Heads) 8 / 12 / 16 头的数量。通常 dmodeld_{model}dmodel 必须能被 hhh 整除。
dkd_kdk 64 每个头的维度。如果 dk{d_k}dk太小,模型学不到东西;太大则计算量爆炸。经验表明 64 是一个性价比极高的数值。

总结:

核心作用多视角特征提取。它允许模型在同一时间关注不同位置的不同类型的依赖关系。

计算优势:虽然看起来计算量变大了(8个头),但因为每个头的维度变小了(1/8 ),且所有头是并行计算的,所以总的时间复杂度与单头注意力几乎相同。

地位:它是 Transformer 区别于 RNN 的关键创新之一,赋予了模型强大的上下文理解能力和并行计算能力。

注意力机制的优缺点

优点:

  • 解决长距离依赖:无论序列多长,任意两个位置之间的距离都是 1(直接计算),信息传递路径最短。
  • 并行化:矩阵运算可以高度并行,训练速度远超 RNN。
  • 可解释性:通过可视化注意力权重矩阵,我们可以看到模型在生成某个词时“关注”了输入中的哪些词(例如翻译 “Bank” 时关注了 “River”)。

缺点:

  • 计算复杂度:标准注意力机制的计算复杂度是 O(n2)O(n^2)O(n2),其中 nnn 是序列长度。处理超长文本(如整本书)时显存消耗巨大。
  • 位置信息丢失:纯粹的注意力机制是“排列不变”的(即打乱输入顺序,输出不变)。因此必须引入位置编码来补充位置信息。
Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐