《Attention Is All You Need》阅读笔记
在如今的神经网络中,我们大量讨论的莫过于Transformer架构,这个2017年由《Attention Is All You Need》带来的变革之作。在这篇具有突破意义的论文中首次提出了Transformer这种神经网络架构,其完全基于注意力机制,摒弃了传统的卷积操作。在自注意力机制下,Transformer能够有效捕捉输入序列中的长距离依赖关系,使得模型在处理长文本时更为高效和准确。而Transformer逐渐也从翻译走向了问答任务,我们所熟知的GPT就是根源于Transformer架构。如今作为学习者再读这篇经典论文,希望有新的理解。
一、研究背景与问题提出
1. 传统模型的局限性
此前序列建模和转导任务(如机器翻译)的SOTA模型基于循环神经网络(RNN/LSTM/GRU)或卷积神经网络(CNN),并结合注意力机制(通过注意力机制实现编码器与解码器之间的连接),但存在一些核心问题:首先是并行性差,RNN类模型按序列位置逐步计算,隐藏状态依赖前一位置结果,无法在训练样本内并行计算,长序列下效率极低;其次是长距离依赖捕捉困难,RNN的梯度消失问题会导致长序列依赖学习难度大,CNN 则需堆叠多层才能捕捉长距离依赖,路径长度随序列增长呈对数增加;最后注意力仅作为辅助,注意力机制多与 RNN/CNN 结合使用,未成为模型核心。
2. 现有改进的不足
部分工作尝试用 CNN 实现并行计算,但捕捉任意两个位置的依赖所需操作数随位置距离增长,仍不如直接的注意力机制高效。而当时主流的序列转导模型均基于结构复杂的循环神经网络或卷积神经网络,且都包含编码器和解码器模块。
3. 研究问题
因此作者思考能否设计完全基于注意力机制的序列转导模型,既保留注意力捕捉长距离依赖的优势,又实现高度并行化,同时在效果和训练效率上超越传统模型。
二、核心概念:自注意力(Self-Attention)
自注意力是模型的核心,指单个序列内不同位置之间的注意力机制,通过计算序列中各位置的关联权重,生成序列的表示,而我们所最熟知的需要该方法的就是阅读理解等,早期工作再过长位置下很难让机器理解。与传统注意力(编码器-解码器注意力)相比,自注意力无需跨序列输入或输出,仅在自身序列内建模依赖,是Transformer实现全局依赖捕捉的关键。
三、Transformer模型架构
Transformer仍采用编码器-解码器的经典序列转导结构,通过叠加自注意力机制和逐点全连接的层来构建编码器和解码器,但所有层均由自注意力和逐位置前馈网络构成,整体架构分为编码器栈、解码器栈及连接组件三部分。
3.1 编码器栈(Encoder Stack)
结构:由6个完全相同的层堆叠而成,每层包含 2 个子层:多头自注意力机制+逐点相连的全连接前馈网络;
残差连接与层归一化:每个子层均采用残差连接+层归一化的结构,即LayerNorm(x+Sublayer(x)),所有子层和嵌入层的输出维度均为dmodel=512,保证残差连接的维度匹配;
特性:编码器的自注意力为双向自注意力,每个位置可关注输入序列的所有位置,实现全局依赖建模。
3.2 解码器栈(Decoder Stack)
结构:同样由6个完全相同的层堆叠而成,每层包含3个子层:带掩码的多头自注意力 +编码器-解码器多头注意力+逐位置前馈网络(与解码器不同在于加入专门对编码器的输出进行多头自注意力处理的子层);
改进 1:掩码自注意力:通过掩码(将 softmax 输入的非法连接设为−∞)阻止位置关注后续位置,保证模型的自回归特性(预测第t个位置仅依赖前t−1个位置);
改进 2:编码器-解码器注意力:查询(query)来自解码器前一层,键(key)和值(value)来自编码器输出,使解码器每个位置能关注输入序列的所有位置,实现传统的编码器 - 解码器依赖建模;(在注意力机制中,输入数据被分成三部分:查询(query)、键(key)和值(value),q表示当前关注的问题,k表示输入数据的不同特征,v表示与k相关的回答)
残差与归一化:与编码器一致,采用残差连接+层归一化。
3.3 核心组件1:注意力机制的具体实现
Transformer设计了缩放点积注意力和多头注意力,解决了传统点积注意力的缺陷并提升了模型的表达能力。
(1)缩放点积注意力(Scaled Dot-Product Attention)
缩放的必要性:当d_k(查询和键的维度)较大时,点积结果的方差会增大,导致softmax 进入梯度极小的饱和区域,除以d_k可将点积结果的方差归一化,缓解该问题;
优势:相比加性注意力(通过单隐层前馈网络计算兼容性),点积注意力可通过高度优化的矩阵乘法实现,速度更快、空间效率更高。
(2)多头注意力(Multi-Head Attention)
核心思想:将Q、K、V通过不同的可学习线性投影,分成h组低维子空间,在每个子空间独立计算缩放点积注意力,最后将结果拼接并投影得到最终输出;
优势:使模型能同时关注不同子空间、不同位置的信息,单头注意力的平均操作会丢失这种细粒度的依赖关系,多头注意力的表达能力更强。
(3)注意力的三种应用场景
Transformer 在不同子层中使用多头注意力,分别实现不同的依赖建模:
编码器自注意力:Q/K/V 均来自编码器前一层,双向关注输入序列所有位置;
解码器掩码自注意力:Q/K/V 均来自解码器前一层,单向关注解码器前t−1个位置;
编码器-解码器注意力:Q 来自解码器,K/V 来自编码器,实现解码器对输入序列的关注。
3.4 核心组件2:逐位置前馈网络
特性:对序列中每个位置的向量独立执行相同的两层线性变换+ ReLU激活,等价于 1×1 的卷积操作;
维度:输入输出维度为dmodel=512,中间层维度为dff=2048,实现特征的非线性变换和维度扩展。
3.5 核心组件3:嵌入层与位置编码(Embeddings & Positional Encoding)
Transformer由于无循环和卷积结构,无法天然捕捉序列的位置信息,因此必须通过额外方式注入位置特征,同时嵌入层实现词符到向量的转换。
(1)嵌入层
将输入和输出词符转换为维向量,且共享输入和输出嵌入层的权重矩阵,并与预 softmax的线性变换权重共享;嵌入向量需乘以dmodel,平衡嵌入层和位置编码的尺度。
(2)位置编码
核心要求:与嵌入层维度相同(512 维),可与嵌入向量逐元素相加;能捕捉相对位置和绝对位置信息;可泛化到训练中未见过的长序列;
优势:对于任意固定偏移k,PE(pos+k)可表示为PE(pos)的线性组合,使模型能学习相对位置依赖;同时正弦余弦函数的周期性可支持任意长度的序列,泛化性优于可学习的位置嵌入。
3.6 输出层
解码器的最终输出通过线性变换将dmodel维向量映射到词汇表维度,再通过softmax生成下一个词符的概率分布,实现自回归生成。
四、Self-Attention 的优势
该论文从计算复杂度、并行化能力、长距离依赖捕捉三个维度,将自注意力与RNN、CNN进行对比(n为序列长度,d为表示维度,k为 CNN 核大小,r为受限自注意力的邻域大小)
4.1 并行化能力最优
自注意力和 CNN 的顺序操作数为O(1),可实现样本内的完全并行,远优于RNN的O(n)。
4.2 长距离依赖捕捉更高效
自注意力的最大路径长度为O(1),即任意两个位置的依赖可通过一次注意力操作直接建模;而RNN的路径长度为O(n),CNN为O(logkn),路径越短,梯度传播越顺畅,长距离依赖学习越容易。
4.3 计算复杂度可控
当n≪d时,自注意力的复杂度O(n2d)低于RNN的O(nd2);对于极长序列,可采用受限自注意力(仅关注邻域r内的位置),将复杂度降至O(rnd),兼顾效率和长距离依赖。
4.4 可解释性更强
自注意力的权重可直接可视化,实验发现不同的注意力头会学习到不同的语言特征(如句法依赖、指代消解、长距离短语关联),使模型具有一定的可解释性。
五、模型训练细节
5.1 训练数据与批处理
论文所使用的机器翻译数据集为WMT 2014英德翻译数据集(450 万句对,字节对编码,词汇量 37k)、WMT 2014英法翻译数据集(3600 万句对,词片编码,词汇量 32k);
批处理策略:按序列长度近似分组,每个批次包含约25000个源端词符和25000个目标端词符,提升训练效率。
5.2 正则化策略
残差 Dropout:子层输出和嵌入+位置编码的结果均施加 Dropout,基础模型dropout 率0.1;
标签平滑(Label Smoothing):平滑系数ϵls=0.1,牺牲困惑度但提升 BLEU 值,避免模型过度自信。
六、实验结果与分析
6.1 机器翻译任务
在 WMT 2014 英德、英法翻译任务上,Transformer 大幅超越当时的SOTA模型,且训练成本显著降低。
(1)英德翻译(WMT 2014)
Transformer-big 取得28.4 BLEU,超越此前所有模型(包括集成模型)超过2.0 BLEU,创当时 SOTA;基础模型也达到 27.3 BLEU,优于所有传统模型,且训练成本仅为 GNMT 的 1/7。
(2)英法翻译(WMT 2014)
Transformer-big 取得41.8 BLEU,成为首个单模型突破 41 BLEU 的模型,训练成本仅为此前 SOTA 模型的 1/4;即使是基础模型,也在远低于传统模型的训练成本下取得接近 SOTA 的效果。
6.2 模型变体实验(消融实验)
为验证各组件的重要性,论文在英德翻译开发集(newstest2013)上做了大量变体实验,核心结论:
-
多头注意力的必要性:单头注意力 BLEU 值下降 0.9,头数过多(32)也会导致效果下降,8 头为最优选择;
-
注意力键维度dk的影响:减小dk会显著降低模型效果,说明点积兼容性函数需要足够的维度来捕捉依赖;
-
模型规模的影响:增加编码器 / 解码器层数、扩大dmodel/dff,模型效果持续提升(大模型效果最优);
-
Dropout 的作用:无 Dropout 时模型过拟合,BLEU 值下降,适当的 Dropout(0.1)是必要的;
-
位置编码的鲁棒性:可学习的位置嵌入与正弦余弦位置编码效果几乎一致,证明位置编码的核心是注入位置信息,而非具体实现方式。
6.3 跨任务泛化:英语句法分析
为验证模型的通用性,将 Transformer 应用于英语成分句法分析(WSJ 语料),结果表明:小数据场景(仅 40k 句):4 层 Transformer 取得 91.3 F1,超越 RNN 序列转导模型,仅略低于 RNN 语法模型(91.7 F1);半监督场景(17M 句):取得 92.7 F1,超越所有此前的半监督模型,证明 Transformer 具有良好的跨任务泛化能力。
七、结论与未来展望
7.1 核心结论
-
Transformer 是首个完全基于注意力机制的序列转导模型,摒弃了循环和卷积结构,实现了高度并行化,训练速度远快于传统 RNN/CNN 模型;
-
在机器翻译任务上,Transformer 取得了远超当时 SOTA 的效果,即使是基础模型也能在低训练成本下超越集成模型;
-
自注意力机制在长距离依赖捕捉、并行化、可解释性上均优于 RNN/CNN,是序列建模的更优选择;
-
Transformer 具有良好的跨任务泛化能力,可成功应用于句法分析等非翻译任务。
7.2 未来研究方向
论文提出了 Transformer 的后续研究方向,均成为后续 NLP/AI 领域的研究热点:
-
将 Transformer 扩展到非文本模态(图像、音频、视频),实现多模态建模;
-
设计局部 / 受限自注意力,高效处理极长序列(如文档、视频帧);
-
降低生成任务的自回归性,提升生成速度;
-
探索注意力机制在更多任务中的应用(如阅读理解、摘要、问答)。
八、论文的里程碑意义
-
开启注意力时代:Transformer 的提出彻底改变了 NLP 的模型架构,取代 RNN/CNN 成为 NLP 的基础模型,后续的 BERT、GPT、T5 等大语言模型均基于 Transformer;
-
并行化训练:解决了 RNN 的并行化瓶颈,为大模型的训练奠定了基础,使海量数据和大参数量模型的训练成为可能;
-
长距离依赖建模:自注意力机制为长序列建模提供了高效方法,适用于文档理解、长文本生成等任务;
-
跨领域扩展:Transformer 不仅在 NLP 领域称霸,还被扩展到计算机视觉(ViT)、语音识别(Audio Transformer)、强化学习等领域,成为通用的深度学习架构
尝试的复现
受制于性能只做出了结构上的还原并没有使用数据集进行测试
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.autograd import Variable
# 掩码生成 + 缩放点积注意力
def create_padding_mask(seq):
"""生成掩码"""
mask = (seq == 0).unsqueeze(1).unsqueeze(2)
return mask # (batch_size, 1, 1, seq_len)
def create_subsequent_mask(seq):
"""生成解码器未来信息掩码"""
seq_len = seq.size(1)
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
mask = mask.masked_fill(mask == 1, float('-inf'))
return mask.unsqueeze(0).unsqueeze(1) # (1, 1, seq_len, seq_len)
def scaled_dot_product_attention(q, k, v, mask=None):
"""
缩放点积注意力
"""
d_k = q.size(-1)
# Q*K^T / sqrt(d_k)
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
# 掩码
if mask is not None:
attn_scores += mask
# softmax
attn_weights = F.softmax(attn_scores, dim=-1)
# 加权求和
output = torch.matmul(attn_weights, v)
return output, attn_weights
# 多头注意力
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, h):
super().__init__()
self.d_model = d_model
self.h = h
self.d_k = d_model // h
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model)
def split_heads(self, x):
"""拆分为多头"""
batch_size = x.size(0)
return x.view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
def forward(self, q, k, v, mask=None):
batch_size = q.size(0)
# 线性投影
q = self.w_q(q)
k = self.w_k(k)
v = self.w_v(v)
q = self.split_heads(q)
k = self.split_heads(k)
v = self.split_heads(v)
# 缩放点积注意力
attn_output, attn_weights = scaled_dot_product_attention(q, k, v, mask)
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
output = self.w_o(attn_output)
return output, attn_weights
# 逐位置前馈网络
class PositionWiseFeedForward(nn.Module):
def __init__(self, d_model, d_ff):
super().__init__()
self.fc1 = nn.Linear(d_model, d_ff)
self.fc2 = nn.Linear(d_ff, d_model)
def forward(self, x):
"""FFN(x) = max(0, xW1 + b1)W2 + b2"""
return self.fc2(F.relu(self.fc1(x)))
# 位置编码
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
"""x: (batch_size, seq_len, d_model)"""
x = x + self.pe[:, :x.size(1), :]
return self.dropout(x)
# 编码器层 + 编码器栈N=6
class EncoderLayer(nn.Module):
def __init__(self, d_model, h, d_ff, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, h)
self.ffn = PositionWiseFeedForward(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x, mask):
# 自注意力 + 残差 + 归一化
attn_out, _ = self.self_attn(x, x, x, mask)
x = self.norm1(x + self.dropout1(attn_out))
# FFN + 残差 + 归一化
ffn_out = self.ffn(x)
x = self.norm2(x + self.dropout2(ffn_out))
return x
class Encoder(nn.Module):
def __init__(self, d_model, h, d_ff, N, vocab_size, max_len, dropout=0.1):
super().__init__()
self.d_model = d_model
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
self.layers = nn.ModuleList([EncoderLayer(d_model, h, d_ff, dropout) for _ in range(N)])
def forward(self, x, mask):
x = self.embedding(x) * math.sqrt(self.d_model)
x = self.pos_encoding(x)
# N 层编码
for layer in self.layers:
x = layer(x, mask)
return x
# 解码器层 + 解码器栈N=6
class DecoderLayer(nn.Module):
def __init__(self, d_model, h, d_ff, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, h)
self.cross_attn = MultiHeadAttention(d_model, h)
self.ffn = PositionWiseFeedForward(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
def forward(self, x, enc_out, src_mask, tgt_mask):
# 掩码自注意力
attn_out, _ = self.self_attn(x, x, x, tgt_mask)
x = self.norm1(x + self.dropout1(attn_out))
# 编码器-解码器注意力
cross_out, _ = self.cross_attn(x, enc_out, enc_out, src_mask)
x = self.norm2(x + self.dropout2(cross_out))
# FFN
ffn_out = self.ffn(x)
x = self.norm3(x + self.dropout3(ffn_out))
return x
class Decoder(nn.Module):
def __init__(self, d_model, h, d_ff, N, vocab_size, max_len, dropout=0.1):
super().__init__()
self.d_model = d_model
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
self.layers = nn.ModuleList([DecoderLayer(d_model, h, d_ff, dropout) for _ in range(N)])
def forward(self, x, enc_out, src_mask, tgt_mask):
x = self.embedding(x) * math.sqrt(self.d_model)
x = self.pos_encoding(x)
for layer in self.layers:
x = layer(x, enc_out, src_mask, tgt_mask)
return x
class Transformer(nn.Module):
def __init__(
self,
src_vocab_size,
tgt_vocab_size,
d_model=512,
h=8,
d_ff=2048,
N=6,
max_len=5000,
dropout=0.1
):
super().__init__()
self.encoder = Encoder(d_model, h, d_ff, N, src_vocab_size, max_len, dropout)
self.decoder = Decoder(d_model, h, d_ff, N, tgt_vocab_size, max_len, dropout)
self.fc = nn.Linear(d_model, tgt_vocab_size)
def forward(self, src, tgt):
# 生成掩码
src_mask = create_padding_mask(src)
tgt_pad_mask = create_padding_mask(tgt)
tgt_sub_mask = create_subsequent_mask(tgt)
tgt_mask = torch.min(tgt_pad_mask, tgt_sub_mask)
# 编码 + 解码
enc_out = self.encoder(src, src_mask)
dec_out = self.decoder(tgt, enc_out, src_mask, tgt_mask)
# 输出层
output = self.fc(dec_out)
return output
class AdamOptimizer:
def __init__(self, d_model, warmup_steps=4000, optimizer=None):
self.d_model = d_model
self.warmup_steps = warmup_steps
self.optimizer = optimizer
self.step_num = 0
def step(self):
self.step_num += 1
lr = self.calc_lr()
for p in self.optimizer.param_groups:
p['lr'] = lr
self.optimizer.step()
def calc_lr(self):
"""学习率公式"""
return self.d_model ** (-0.5) * min(
self.step_num ** (-0.5),
self.step_num * self.warmup_steps ** (-1.5)
)
if __name__ == "__main__":
SRC_VOCAB_SIZE = 37000
TGT_VOCAB_SIZE = 37000
D_MODEL = 512
H = 8
D_FF = 2048
N = 6
DROPOUT = 0.1
model = Transformer(
src_vocab_size=SRC_VOCAB_SIZE,
tgt_vocab_size=TGT_VOCAB_SIZE,
d_model=D_MODEL,
h=H,
d_ff=D_FF,
N=N,
dropout=DROPOUT
)
src = torch.randint(1, SRC_VOCAB_SIZE, (32, 20))
tgt = torch.randint(1, TGT_VOCAB_SIZE, (32, 20))
# 前向传播
pred = model(src, tgt)
print(f"模型输出形状: {pred.shape}") # 应输出: torch.Size([32, 20, 37000])
print("运行成功!")
optimizer = torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-9)
opt = AdamOptimizer(d_model=D_MODEL, optimizer=optimizer)
print("初始化成功!")
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐


所有评论(0)