【深度学习精通】第14章 | 注意力机制革命 - 从Seq2Seq到Self-Attention
环境声明
- Python版本:Python 3.10+
- PyTorch版本:PyTorch 2.0+
- 推荐开发工具:PyCharm / VS Code / Jupyter Notebook
- 操作系统:Windows / macOS / Linux(通用)
学习目标和摘要
摘要:本章将深入探讨注意力机制的发展历程,从2014年Bahdanau等人开创性的工作开始,到2017年Transformer中Self-Attention的横空出世,再到近年来各种高效注意力变体的涌现。你将理解注意力机制的数学本质,掌握Query-Key-Value的计算范式,并学会实现多头注意力机制。
学习目标:
- 理解注意力机制的生物学启发和直观意义
- 掌握Seq2Seq架构中注意力机制的工作原理
- 深入理解Self-Attention的数学形式和计算过程
- 学会实现多头注意力机制和位置编码
- 了解高效注意力变体(Linear Attention、Sparse Attention等)
- 掌握注意力可视化的方法
1. 注意力机制的生物学启发
1.1 人类视觉注意力系统
想象你正在一个拥挤的火车站寻找你的朋友。你的眼睛不会同时清晰地看到所有人和物体,而是会快速扫视,将注意力集中在可能与朋友相关的特征上——比如相似的身高、穿着的颜色、发型等。这种选择性关注的能力就是注意力的本质。
核心比喻:注意力机制就像是给神经网络装上了一个"聚光灯",让它能够在处理大量信息时,动态地选择关注最重要的部分。
人类大脑处理视觉信息时,存在两种注意力机制:
- 自下而上的注意力:由外界刺激驱动,例如突然的响声或明亮的闪光会自动吸引注意
- 自上而下的注意力:由目标和任务驱动,例如主动寻找特定物体
深度学习中的注意力机制主要模拟的是自上而下的注意力——根据当前任务目标,动态调整对不同输入部分的关注程度。
1.2 从RNN到注意力:为什么需要变革
在注意力机制出现之前,序列到序列(Seq2Seq)模型使用固定长度的上下文向量来编码整个输入序列。这就像试图用一张小纸条记录整本书的内容——信息损失不可避免。
关键问题:当输入序列很长时,编码器必须将所有信息压缩到一个固定维度的向量中,导致信息瓶颈(Information Bottleneck)。
注意力机制通过允许解码器在生成每个输出时"回看"输入序列的不同部分,彻底解决了这个问题。
2. Seq2Seq与编码器-解码器架构
2.1 基础Seq2Seq架构
Seq2Seq(Sequence to Sequence)模型是深度学习中处理序列转换任务的基石,典型应用包括机器翻译、文本摘要、语音识别等。
架构组成:
import torch
import torch.nn as nn
class Encoder(nn.Module):
"""
基础编码器:将输入序列编码为上下文向量
"""
def __init__(self, vocab_size, embed_dim, hidden_dim):
super(Encoder, self).__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
def forward(self, x):
# x: (batch_size, seq_len)
embedded = self.embedding(x) # (batch_size, seq_len, embed_dim)
outputs, (hidden, cell) = self.lstm(embedded)
# outputs: (batch_size, seq_len, hidden_dim)
# hidden: (1, batch_size, hidden_dim)
return outputs, hidden, cell
class Decoder(nn.Module):
"""
基础解码器:从上下文向量生成输出序列
"""
def __init__(self, vocab_size, embed_dim, hidden_dim):
super(Decoder, self).__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
self.fc = nn.Linear(hidden_dim, vocab_size)
def forward(self, x, hidden, cell):
# x: (batch_size, 1) - 单个词
embedded = self.embedding(x) # (batch_size, 1, embed_dim)
output, (hidden, cell) = self.lstm(embedded, (hidden, cell))
prediction = self.fc(output.squeeze(1))
return prediction, hidden, cell
2.2 信息瓶颈问题
在基础Seq2Seq中,编码器将变长输入序列压缩为固定长度的上下文向量。对于长序列,这种压缩会导致严重的信息损失。
一句话总结:注意力机制让模型学会"该看哪里",而不是被迫记住所有信息。
3. 注意力机制的数学形式
3.1 Query、Key、Value范式
注意力机制的核心思想可以抽象为三个概念:
- Query(查询):当前需要关注什么,代表"我要找什么"
- Key(键):输入序列中各位置的标识,代表"我是什么"
- Value(值):输入序列中各位置的实际内容,代表"我有什么信息"
计算过程类比:想象你在图书馆找书(Query),每本书都有书名标签(Key)和实际内容(Value)。你通过比较Query和Key的相似度,决定从哪些Value中获取信息。
3.2 注意力计算的一般形式
注意力函数可以描述为将一个Query和一组Key-Value对映射到输出,其中输出是Value的加权和,权重由Query与对应Key的相似度计算得到。
Attention(Q, K, V) = softmax(similarity(Q, K)) * V
其中similarity函数可以有多种形式:
| 注意力类型 | 相似度计算方式 | 特点 |
|---|---|---|
| 加性注意力 | v^T * tanh(W_qQ + W_kK) | 灵活性强,可学习参数多 |
| 点积注意力 | Q * K^T | 计算简单,速度快 |
| 缩放点积注意力 | (Q * K^T) / sqrt(d_k) | 防止softmax梯度消失 |
4. 加性注意力与点积注意力
4.1 加性注意力(Additive Attention)
Bahdanau等人在2014年提出的加性注意力使用一个前馈网络来计算相似度:
class AdditiveAttention(nn.Module):
"""
加性注意力机制(Bahdanau Attention)
适用于Query和Key维度不同的情况
"""
def __init__(self, query_dim, key_dim, hidden_dim):
super(AdditiveAttention, self).__init__()
self.W_query = nn.Linear(query_dim, hidden_dim, bias=False)
self.W_key = nn.Linear(key_dim, hidden_dim, bias=False)
self.v = nn.Linear(hidden_dim, 1, bias=False)
def forward(self, query, keys, values, mask=None):
"""
Args:
query: (batch_size, query_dim) 或 (batch_size, num_queries, query_dim)
keys: (batch_size, seq_len, key_dim)
values: (batch_size, seq_len, value_dim)
mask: (batch_size, seq_len) 可选的掩码
Returns:
context: (batch_size, value_dim) 加权后的上下文向量
attention_weights: (batch_size, seq_len) 注意力权重
"""
# 扩展query维度以便广播
if query.dim() == 2:
query = query.unsqueeze(1) # (batch_size, 1, query_dim)
# 计算加性分数: v^T * tanh(W_q*Q + W_k*K)
# query_transformed: (batch_size, 1, hidden_dim)
query_transformed = self.W_query(query)
# keys_transformed: (batch_size, seq_len, hidden_dim)
keys_transformed = self.W_key(keys)
# 广播相加: (batch_size, seq_len, hidden_dim)
combined = torch.tanh(query_transformed + keys_transformed)
# 计算分数: (batch_size, seq_len, 1) -> (batch_size, seq_len)
scores = self.v(combined).squeeze(-1)
# 应用掩码(如果提供)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Softmax归一化
attention_weights = torch.softmax(scores, dim=-1) # (batch_size, seq_len)
# 加权求和: (batch_size, 1, seq_len) @ (batch_size, seq_len, value_dim)
context = torch.bmm(attention_weights.unsqueeze(1), values).squeeze(1)
return context, attention_weights
4.2 点积注意力与缩放点积注意力
点积注意力计算更加直接高效:
class ScaledDotProductAttention(nn.Module):
"""
缩放点积注意力机制(Scaled Dot-Product Attention)
Transformer中使用的标准注意力机制
"""
def __init__(self, dropout=0.1):
super(ScaledDotProductAttention, self).__init__()
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, mask=None):
"""
Args:
query: (batch_size, num_heads, seq_len_q, d_k)
key: (batch_size, num_heads, seq_len_k, d_k)
value: (batch_size, num_heads, seq_len_v, d_v)
mask: (batch_size, 1, seq_len_q, seq_len_k) 可选
Returns:
output: (batch_size, num_heads, seq_len_q, d_v)
attention_weights: (batch_size, num_heads, seq_len_q, seq_len_k)
"""
d_k = query.size(-1)
# 计算点积: Q @ K^T / sqrt(d_k)
# (batch, heads, seq_q, d_k) @ (batch, heads, d_k, seq_k)
# = (batch, heads, seq_q, seq_k)
scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
# 应用掩码
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Softmax归一化
attention_weights = torch.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
# 与Value相乘: (batch, heads, seq_q, seq_k) @ (batch, heads, seq_v, d_v)
# = (batch, heads, seq_q, d_v)
output = torch.matmul(attention_weights, value)
return output, attention_weights
缩放因子的重要性:除以sqrt(d_k)可以防止点积结果过大导致softmax梯度消失。当d_k较大时,点积的方差会增大,导致softmax进入饱和区。
5. Self-Attention:注意力的新范式
5.1 从Cross-Attention到Self-Attention
传统的注意力机制(Cross-Attention)中,Query来自解码器,Key和Value来自编码器。而Self-Attention则让Query、Key、Value都来自同一个序列——序列中的每个位置都能"看到"其他所有位置。
核心思想:Self-Attention允许模型直接建模序列中任意两个位置之间的关系,无论它们相距多远。这克服了RNN中远距离依赖难以捕捉的问题。
5.2 Self-Attention的直观理解
考虑句子:“The animal didn’t cross the street because it was too tired.”
这里的"it"指代什么?人类读者很容易理解"it"指的是"animal"而不是"street"。Self-Attention通过计算"it"与句子中所有词的关联度,让模型学会这种指代消解。
5.3 Self-Attention的完整实现
class SelfAttention(nn.Module):
"""
自注意力机制的实现
"""
def __init__(self, d_model, dropout=0.1):
super(SelfAttention, self).__init__()
self.d_model = d_model
self.W_query = nn.Linear(d_model, d_model)
self.W_key = nn.Linear(d_model, d_model)
self.W_value = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
"""
Args:
x: (batch_size, seq_len, d_model)
mask: (batch_size, seq_len, seq_len) 可选的注意力掩码
Returns:
output: (batch_size, seq_len, d_model)
attention_weights: (batch_size, seq_len, seq_len)
"""
batch_size, seq_len, _ = x.size()
# 生成Q、K、V
Q = self.W_query(x) # (batch, seq, d_model)
K = self.W_key(x)
V = self.W_value(x)
# 计算注意力分数
scores = torch.bmm(Q, K.transpose(1, 2)) / torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32))
# scores: (batch, seq, seq)
# 应用掩码
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Softmax和加权
attention_weights = torch.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
output = torch.bmm(attention_weights, V) # (batch, seq, d_model)
return output, attention_weights
6. 多头注意力机制(Multi-Head Attention)
6.1 为什么需要多头
单一的注意力机制只能捕捉一种类型的关系。但在自然语言中,词与词之间可能存在多种关系:语法关系、语义关系、指代关系等。
核心思想:多头注意力使用多组独立的Q、K、V投影,让模型在不同的"表示子空间"中并行学习不同类型的依赖关系。
6.2 多头注意力的数学表达
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) * W^O
where head_i = Attention(Q*W_i^Q, K*W_i^K, V*W_i^V)
6.3 多头注意力完整实现
class MultiHeadAttention(nn.Module):
"""
多头注意力机制的完整实现
"""
def __init__(self, d_model, num_heads, dropout=0.1):
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0, "d_model必须能被num_heads整除"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads # 每个头的维度
# 线性投影层
self.W_query = nn.Linear(d_model, d_model)
self.W_key = nn.Linear(d_model, d_model)
self.W_value = nn.Linear(d_model, d_model)
self.W_output = nn.Linear(d_model, d_model)
self.attention = ScaledDotProductAttention(dropout)
self.dropout = nn.Dropout(dropout)
def split_heads(self, x, batch_size):
"""
将输入分割成多个头
x: (batch_size, seq_len, d_model)
return: (batch_size, num_heads, seq_len, d_k)
"""
seq_len = x.size(1)
x = x.view(batch_size, seq_len, self.num_heads, self.d_k)
return x.transpose(1, 2) # (batch, heads, seq, d_k)
def forward(self, query, key, value, mask=None):
"""
Args:
query, key, value: (batch_size, seq_len, d_model)
mask: (batch_size, 1, seq_len, seq_len) 或兼容形状
Returns:
output: (batch_size, seq_len, d_model)
attention_weights: (batch_size, num_heads, seq_len, seq_len)
"""
batch_size = query.size(0)
# 线性投影并分割多头
Q = self.split_heads(self.W_query(query), batch_size) # (batch, heads, seq_q, d_k)
K = self.split_heads(self.W_key(key), batch_size) # (batch, heads, seq_k, d_k)
V = self.split_heads(self.W_value(value), batch_size) # (batch, heads, seq_v, d_v)
# 调整mask形状以适配多头
if mask is not None:
if mask.dim() == 3:
mask = mask.unsqueeze(1) # (batch, 1, seq_q, seq_k)
# 计算注意力
attn_output, attention_weights = self.attention(Q, K, V, mask)
# attn_output: (batch, heads, seq_q, d_k)
# 合并多头
attn_output = attn_output.transpose(1, 2).contiguous() # (batch, seq_q, heads, d_k)
attn_output = attn_output.view(batch_size, -1, self.d_model) # (batch, seq_q, d_model)
# 最终线性投影
output = self.W_output(attn_output)
output = self.dropout(output)
return output, attention_weights
7. 注意力可视化与可解释性
7.1 注意力权重的意义
注意力权重矩阵直观地展示了模型在处理序列时"关注"了哪些位置。通过可视化这些权重,我们可以:
- 理解模型的决策过程
- 发现模型学到的语言规律
- 诊断模型的问题
7.2 注意力热力图可视化代码
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
def visualize_attention(attention_weights, tokens=None, title="Attention Heatmap"):
"""
可视化注意力权重热力图
Args:
attention_weights: numpy数组,形状为 (seq_len, seq_len) 或 (num_heads, seq_len, seq_len)
tokens: 可选,token列表用于坐标轴标注
title: 图表标题
"""
if isinstance(attention_weights, torch.Tensor):
attention_weights = attention_weights.detach().cpu().numpy()
# 如果是多头注意力,取平均
if attention_weights.ndim == 3:
attention_weights = attention_weights.mean(axis=0)
seq_len = attention_weights.shape[0]
# 如果没有提供tokens,使用索引
if tokens is None:
tokens = [f"{i}" for i in range(seq_len)]
plt.figure(figsize=(10, 8))
sns.heatmap(
attention_weights,
xticklabels=tokens,
yticklabels=tokens,
cmap="YlOrRd",
cbar_kws={'label': 'Attention Weight'},
square=True,
linewidths=0.5
)
plt.title(title, fontsize=14, fontweight='bold')
plt.xlabel('Key Position', fontsize=12)
plt.ylabel('Query Position', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()
def visualize_multihead_attention(attention_weights, tokens=None, num_heads_to_show=4):
"""
可视化多头注意力中不同头的注意力模式
Args:
attention_weights: (num_heads, seq_len, seq_len)
tokens: token列表
num_heads_to_show: 要展示的头数
"""
if isinstance(attention_weights, torch.Tensor):
attention_weights = attention_weights.detach().cpu().numpy()
num_heads = attention_weights.shape[0]
seq_len = attention_weights.shape[1]
if tokens is None:
tokens = [f"{i}" for i in range(seq_len)]
# 选择要展示的头
heads_to_show = min(num_heads_to_show, num_heads)
fig, axes = plt.subplots(1, heads_to_show, figsize=(4*heads_to_show, 4))
if heads_to_show == 1:
axes = [axes]
for idx, ax in enumerate(axes):
sns.heatmap(
attention_weights[idx],
xticklabels=tokens if idx == heads_to_show-1 else [],
yticklabels=tokens if idx == 0 else [],
cmap="YlOrRd",
ax=ax,
square=True,
cbar=False
)
ax.set_title(f'Head {idx+1}', fontsize=10)
plt.suptitle('Multi-Head Attention Visualization', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()
# 示例:创建演示数据并可视化
def demo_attention_visualization():
"""
演示注意力可视化功能
"""
# 模拟一个简单句子的注意力权重
sentence = ["The", "cat", "sat", "on", "the", "mat", "."]
seq_len = len(sentence)
# 创建模拟的注意力权重(模拟对角线和局部关注)
np.random.seed(42)
attention = np.random.rand(seq_len, seq_len) * 0.1
# 增强对角线(自注意力通常关注自身)
for i in range(seq_len):
attention[i, i] = 0.5
# 增强相邻词的关注
if i > 0:
attention[i, i-1] = 0.2
if i < seq_len - 1:
attention[i, i+1] = 0.2
# 归一化
attention = attention / attention.sum(axis=1, keepdims=True)
# 可视化
visualize_attention(attention, sentence, "Self-Attention Pattern Example")
# 多头注意力演示
num_heads = 8
multihead_attention = np.random.rand(num_heads, seq_len, seq_len)
# 为不同头设置不同模式
for h in range(num_heads):
if h % 2 == 0:
# 偶数头:关注局部
for i in range(seq_len):
for j in range(max(0, i-2), min(seq_len, i+3)):
multihead_attention[h, i, j] += 0.3
else:
# 奇数头:关注全局(模拟长距离依赖)
multihead_attention[h, :, :] += 0.1
# 归一化
for h in range(num_heads):
multihead_attention[h] = multihead_attention[h] / multihead_attention[h].sum(axis=1, keepdims=True)
visualize_multihead_attention(multihead_attention, sentence, num_heads_to_show=4)
# 运行演示
if __name__ == "__main__":
demo_attention_visualization()
7.3 注意力模式分析
通过观察注意力热力图,我们可以发现一些有趣的模式:
| 注意力模式 | 描述 | 示例 |
|---|---|---|
| 对角线模式 | 主要关注当前位置或相邻位置 | 局部特征提取 |
| 垂直/水平条纹 | 某些位置被广泛关注(如标点、特殊token) | [CLS]、[SEP] token |
| 块状结构 | 关注特定短语或句子片段 | 名词短语识别 |
| 稀疏分散 | 长距离依赖关系 | 指代消解 |
8. 高效注意力变体
8.1 标准注意力的计算复杂度问题
标准Self-Attention的计算复杂度为O(n^2),其中n是序列长度。对于长序列(如长文档、高分辨率图像),这成为严重的性能瓶颈。
8.2 Linear Attention
Linear Attention通过核技巧将复杂度从O(n^2)降低到O(n):
class LinearAttention(nn.Module):
"""
Linear Attention实现
通过核技巧将O(n^2)复杂度降低到O(n)
参考: "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention"
"""
def __init__(self, d_model, num_heads, dropout=0.1):
super(LinearAttention, self).__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_query = nn.Linear(d_model, d_model)
self.W_key = nn.Linear(d_model, d_model)
self.W_value = nn.Linear(d_model, d_model)
self.W_output = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
seq_len = query.size(1)
# 投影
Q = self.W_query(query).view(batch_size, seq_len, self.num_heads, self.d_k)
K = self.W_key(key).view(batch_size, seq_len, self.num_heads, self.d_k)
V = self.W_value(value).view(batch_size, seq_len, self.num_heads, self.d_k)
# 应用核函数(elu+1)
Q = torch.nn.functional.elu(Q) + 1
K = torch.nn.functional.elu(K) + 1
# 转置为 (batch, heads, seq, d_k)
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
# Linear Attention核心: (Q @ K^T) @ V = Q @ (K^T @ V)
# 先计算 K^T @ V: (batch, heads, d_k, d_k)
KV = torch.matmul(K.transpose(-2, -1), V)
# 再计算 Q @ KV: (batch, heads, seq, d_k)
Z = 1 / (torch.matmul(Q, K.sum(dim=2).unsqueeze(-1)) + 1e-6)
output = torch.matmul(Q, KV) * Z
# 合并头
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
output = self.W_output(output)
return output, None # Linear Attention不直接产生可解释的权重
8.3 Sparse Attention
Sparse Attention通过限制每个位置只能关注部分位置来降低复杂度:
class SparseAttention(nn.Module):
"""
稀疏注意力实现 - Strided Pattern
每个位置只关注固定间隔的位置
"""
def __init__(self, d_model, num_heads, stride=4, dropout=0.1):
super(SparseAttention, self).__init__()
self.d_model = d_model
self.num_heads = num_heads
self.stride = stride
self.d_k = d_model // num_heads
self.W_query = nn.Linear(d_model, d_model)
self.W_key = nn.Linear(d_model, d_model)
self.W_value = nn.Linear(d_model, d_model)
self.W_output = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def create_sparse_mask(self, seq_len, device):
"""创建稀疏注意力掩码"""
mask = torch.zeros(seq_len, seq_len, device=device)
# 每个位置关注:自身、局部窗口、固定间隔位置
for i in range(seq_len):
# 局部窗口(前后各2个)
for j in range(max(0, i-2), min(seq_len, i+3)):
mask[i, j] = 1
# 固定间隔位置
for j in range(0, seq_len, self.stride):
mask[i, j] = 1
return mask
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
seq_len = query.size(1)
# 投影
Q = self.W_query(query).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_key(key).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_value(value).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
# 创建稀疏掩码
sparse_mask = self.create_sparse_mask(seq_len, query.device)
sparse_mask = sparse_mask.unsqueeze(0).unsqueeze(0) # (1, 1, seq, seq)
# 计算注意力分数
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float32))
# 应用稀疏掩码
scores = scores.masked_fill(sparse_mask == 0, float('-inf'))
# Softmax和dropout
attn_weights = torch.softmax(scores, dim=-1)
attn_weights = attn_weights.masked_fill(torch.isnan(attn_weights), 0)
attn_weights = self.dropout(attn_weights)
# 加权求和
output = torch.matmul(attn_weights, V)
# 合并头
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
output = self.W_output(output)
return output, attn_weights
8.4 高效注意力变体对比
| 变体名称 | 时间复杂度 | 空间复杂度 | 核心思想 | 适用场景 |
|---|---|---|---|---|
| 标准Attention | O(n^2) | O(n^2) | 全连接注意力 | 短序列 |
| Linear Attention | O(n) | O(n) | 核技巧重排序 | 长序列生成 |
| Sparse Attention | O(n*sqrt(n)) | O(n*sqrt(n)) | 稀疏连接模式 | 长文档处理 |
| Linformer | O(n) | O(n) | 低秩近似 | 长序列分类 |
| Performer | O(n) | O(n) | 正交随机特征 | 超长序列 |
| Flash Attention | O(n^2) | O(1) | IO感知的分块计算 | 硬件优化 |
9. 完整Transformer编码器层实现
class TransformerEncoderLayer(nn.Module):
"""
完整的Transformer编码器层
包含:多头注意力 + 前馈网络 + 残差连接 + LayerNorm
"""
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super(TransformerEncoderLayer, self).__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# 多头自注意力子层
attn_output, attn_weights = self.self_attn(x, x, x, mask)
x = self.norm1(x + self.dropout(attn_output)) # 残差连接 + LayerNorm
# 前馈网络子层
ff_output = self.feed_forward(x)
x = self.norm2(x + ff_output) # 残差连接 + LayerNorm
return x, attn_weights
class PositionalEncoding(nn.Module):
"""
位置编码实现
"""
def __init__(self, d_model, max_seq_length=5000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_seq_length, d_model)
position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() *
(-torch.log(torch.tensor(10000.0)) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
return x + self.pe[:, :x.size(1), :]
10. 避坑小贴士
10.1 常见错误与解决方案
问题1:注意力权重全为NaN
原因:输入数值过大导致softmax溢出,或mask使用不当。
解决方案:
# 确保使用缩放因子
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
# 检查mask是否正确应用
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
问题2:多头维度不匹配
原因:d_model不能被num_heads整除。
解决方案:
assert d_model % num_heads == 0, f"d_model ({d_model}) 必须能被 num_heads ({num_heads}) 整除"
self.d_k = d_model // num_heads
问题3:注意力可视化时权重归一化错误
原因:对已经softmax过的权重再次归一化。
解决方案:注意力权重已经是概率分布(每行和为1),直接可视化即可,不需要额外归一化。
10.2 性能优化建议
- 使用Flash Attention:对于长序列,使用Flash Attention可以显著减少内存占用并加速计算
- 梯度检查点:对于深层Transformer,使用gradient checkpointing节省显存
- 混合精度训练:使用torch.cuda.amp进行FP16训练,加速并节省显存
11. 本章小结和知识点回顾
核心概念回顾
-
注意力机制的本质:动态加权机制,让模型学会"该看哪里"
-
Q-K-V范式:
- Query:查询向量,代表当前要寻找的信息
- Key:键向量,代表输入各位置的标识
- Value:值向量,代表输入各位置的实际信息
-
主要注意力类型:
- 加性注意力:灵活,适合Q/K维度不同
- 点积注意力:计算高效,是Transformer的标准选择
- 缩放点积注意力:通过除以sqrt(d_k)防止梯度消失
-
多头注意力:在多个子空间并行学习不同类型的依赖关系
-
高效注意力变体:Linear Attention、Sparse Attention等解决长序列问题
关键公式总结
缩放点积注意力: Attention(Q,K,V) = softmax(QK^T / sqrt(d_k))V
多头注意力: MultiHead(Q,K,V) = Concat(head_1,...,head_h)W^O
一句话总结
注意力机制让神经网络拥有了"选择性关注"的能力,而Self-Attention和Multi-Head Attention的出现,让模型能够直接建模序列中任意位置之间的关系,彻底改变了深度学习处理序列数据的方式。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐

所有评论(0)