【算法解析】Transformer架构深度剖析与实现指南
【算法解析】Transformer架构深度剖析与实现指南
引言
Transformer架构自2017年Google团队在论文《Attention is All You Need》中提出以来,彻底改变了自然语言处理领域的格局。它摒弃了传统的循环神经网络结构,完全基于注意力机制构建,在机器翻译、文本生成等任务上取得了革命性的突破。本文将深入剖析Transformer的核心原理,并结合PyTorch实现完整的Transformer模型。
一、Transformer架构总览
1.1 整体架构
Transformer由编码器(Encoder)和解码器(Decoder)两部分组成:
┌─────────────────────────────────────┐
│ Encoder Stack │
│ ┌─────────────────────────────┐ │
│ │ Multi-Head Attention │ │
│ └─────────────┬─────────────┘ │
│ │ │
│ ┌─────────────▼─────────────┐ │
│ │ Feed Forward Network │ │
│ └─────────────────────────────┘ │
└───────────────────┬───────────────┘
│
┌───────────────────▼───────────────┐
│ Decoder Stack │
│ ┌─────────────────────────────┐ │
│ │ Masked Multi-Head Attention│ │
│ └─────────────┬─────────────┘ │
│ │ │
│ ┌─────────────▼─────────────┐ │
│ │ Encoder-Decoder Attention│ │
│ └─────────────┬─────────────┘ │
│ │ │
│ ┌─────────────▼─────────────┐ │
│ │ Feed Forward Network │ │
│ └─────────────────────────────┘ │
└───────────────────┬───────────────┘
│
┌───────────────────▼───────────────┐
│ Linear + Softmax │
└───────────────────────────────────┘
1.2 核心组件
Transformer的核心组件包括:
- 多头注意力机制(Multi-Head Attention)
- 位置编码(Positional Encoding)
- 前馈神经网络(Feed Forward Network)
- 层归一化(Layer Normalization)
- 残差连接(Residual Connection)
二、自注意力机制详解
2.1 注意力机制的本质
注意力机制允许模型在处理序列时,根据当前位置动态地关注序列的不同部分。自注意力机制则是让序列中的每个位置都关注整个序列。
2.2 Scaled Dot-Product Attention
自注意力机制的核心是Scaled Dot-Product Attention:
def scaled_dot_product_attention(q, k, v, mask=None):
"""
Args:
q: shape (batch_size, n_heads, seq_len, d_k)
k: shape (batch_size, n_heads, seq_len, d_k)
v: shape (batch_size, n_heads, seq_len, d_v)
mask: shape (batch_size, 1, seq_len, seq_len)
Returns:
output: shape (batch_size, n_heads, seq_len, d_v)
attn_weights: shape (batch_size, n_heads, seq_len, seq_len)
"""
d_k = q.size(-1)
# 计算注意力分数
scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
# 应用mask(可选)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# 计算注意力权重
attn_weights = torch.softmax(scores, dim=-1)
# 计算输出
output = torch.matmul(attn_weights, v)
return output, attn_weights
2.3 为什么需要Scaled?
当d_k较大时,点积的结果可能很大,导致softmax函数的梯度变得非常小(趋于饱和)。通过除以√d_k,可以将分数的方差归一化到1,避免梯度消失问题。
三、多头注意力机制
3.1 多头注意力的设计思想
多头注意力通过多个并行的注意力头,从不同角度捕捉输入序列的依赖关系:
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
# 线性变换层
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model)
def split_heads(self, x, batch_size):
"""将输入按头数分割"""
x = x.view(batch_size, -1, self.n_heads, self.d_k)
return x.transpose(1, 2) # (batch_size, n_heads, seq_len, d_k)
def forward(self, q, k, v, mask=None):
batch_size = q.size(0)
# 线性变换
q = self.w_q(q)
k = self.w_k(k)
v = self.w_v(v)
# 分割多头
q = self.split_heads(q, batch_size)
k = self.split_heads(k, batch_size)
v = self.split_heads(v, batch_size)
# 计算注意力
attn_output, attn_weights = scaled_dot_product_attention(q, k, v, mask)
# 拼接多头输出
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, -1, self.d_model)
# 最终线性变换
output = self.w_o(attn_output)
return output, attn_weights
3.2 多头注意力的优势
- 不同的注意力头学习不同的注意力模式
- 增加模型的表达能力
- 捕捉多种类型的依赖关系
四、位置编码
4.1 位置编码的必要性
由于Transformer没有循环结构,无法自动获取序列的位置信息。位置编码通过在输入嵌入中添加位置信息,让模型能够理解序列的顺序。
4.2 正弦余弦位置编码
论文中使用的位置编码方式:
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
# 计算位置编码
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x):
"""
Args:
x: shape (seq_len, batch_size, d_model)
Returns:
x + positional encoding
"""
x = x + self.pe[:x.size(0)]
return self.dropout(x)
4.3 位置编码的特性
- 周期性:相同位置间隔的位置编码具有相似性
- 可扩展性:可以处理训练时未见过的序列长度
- 数学性质:相对位置信息可以通过三角函数的性质得到
五、前馈神经网络
5.1 FFN的结构
每个编码器和解码器层都包含一个前馈神经网络:
class FeedForwardNetwork(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.1):
super().__init__()
self.fc1 = nn.Linear(d_model, d_ff)
self.fc2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
"""
Args:
x: shape (batch_size, seq_len, d_model)
Returns:
output: shape (batch_size, seq_len, d_model)
"""
x = self.fc1(x)
x = F.relu(x)
x = self.dropout(x)
x = self.fc2(x)
return x
5.2 FFN的作用
前馈神经网络在每个位置上独立地进行非线性变换,增加模型的表达能力。
六、编码器实现
6.1 单个编码器层
class EncoderLayer(nn.Module):
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, n_heads)
self.ffn = FeedForwardNetwork(d_model, d_ff, dropout)
# 层归一化
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
# Dropout
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x, mask=None):
"""
Args:
x: shape (batch_size, seq_len, d_model)
mask: shape (batch_size, 1, 1, seq_len)
Returns:
output: shape (batch_size, seq_len, d_model)
"""
# 自注意力
attn_output, _ = self.self_attn(x, x, x, mask)
x = x + self.dropout1(attn_output)
x = self.norm1(x)
# 前馈网络
ffn_output = self.ffn(x)
x = x + self.dropout2(ffn_output)
x = self.norm2(x)
return x
6.2 编码器栈
class Encoder(nn.Module):
def __init__(self, vocab_size, d_model, n_layers, n_heads, d_ff, max_len=5000, dropout=0.1):
super().__init__()
self.d_model = d_model
# 词嵌入
self.embedding = nn.Embedding(vocab_size, d_model)
# 位置编码
self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
# 编码器层栈
self.layers = nn.ModuleList([
EncoderLayer(d_model, n_heads, d_ff, dropout)
for _ in range(n_layers)
])
def forward(self, x, mask=None):
"""
Args:
x: shape (batch_size, seq_len)
mask: shape (batch_size, 1, 1, seq_len)
Returns:
output: shape (batch_size, seq_len, d_model)
"""
# 词嵌入并乘以√d_model
x = self.embedding(x) * math.sqrt(self.d_model)
# 添加位置编码
x = self.pos_encoding(x.transpose(0, 1)).transpose(0, 1)
# 通过编码器层
for layer in self.layers:
x = layer(x, mask)
return x
七、解码器实现
7.1 单个解码器层
class DecoderLayer(nn.Module):
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
super().__init__()
self.masked_self_attn = MultiHeadAttention(d_model, n_heads)
self.enc_dec_attn = MultiHeadAttention(d_model, n_heads)
self.ffn = FeedForwardNetwork(d_model, d_ff, dropout)
# 层归一化
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
# Dropout
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
"""
Args:
x: shape (batch_size, tgt_seq_len, d_model)
enc_output: shape (batch_size, src_seq_len, d_model)
src_mask: shape (batch_size, 1, 1, src_seq_len)
tgt_mask: shape (batch_size, 1, tgt_seq_len, tgt_seq_len)
Returns:
output: shape (batch_size, tgt_seq_len, d_model)
"""
# Masked自注意力(防止前瞻)
attn_output, _ = self.masked_self_attn(x, x, x, tgt_mask)
x = x + self.dropout1(attn_output)
x = self.norm1(x)
# 编码器-解码器注意力
attn_output, _ = self.enc_dec_attn(x, enc_output, enc_output, src_mask)
x = x + self.dropout2(attn_output)
x = self.norm2(x)
# 前馈网络
ffn_output = self.ffn(x)
x = x + self.dropout3(ffn_output)
x = self.norm3(x)
return x
7.2 解码器栈
class Decoder(nn.Module):
def __init__(self, vocab_size, d_model, n_layers, n_heads, d_ff, max_len=5000, dropout=0.1):
super().__init__()
self.d_model = d_model
# 词嵌入
self.embedding = nn.Embedding(vocab_size, d_model)
# 位置编码
self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
# 解码器层栈
self.layers = nn.ModuleList([
DecoderLayer(d_model, n_heads, d_ff, dropout)
for _ in range(n_layers)
])
def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
"""
Args:
x: shape (batch_size, tgt_seq_len)
enc_output: shape (batch_size, src_seq_len, d_model)
src_mask: shape (batch_size, 1, 1, src_seq_len)
tgt_mask: shape (batch_size, 1, tgt_seq_len, tgt_seq_len)
Returns:
output: shape (batch_size, tgt_seq_len, d_model)
"""
# 词嵌入并乘以√d_model
x = self.embedding(x) * math.sqrt(self.d_model)
# 添加位置编码
x = self.pos_encoding(x.transpose(0, 1)).transpose(0, 1)
# 通过解码器层
for layer in self.layers:
x = layer(x, enc_output, src_mask, tgt_mask)
return x
八、完整Transformer模型
8.1 模型组装
class Transformer(nn.Module):
def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, n_layers=6,
n_heads=8, d_ff=2048, max_len=5000, dropout=0.1):
super().__init__()
self.encoder = Encoder(src_vocab_size, d_model, n_layers, n_heads, d_ff, max_len, dropout)
self.decoder = Decoder(tgt_vocab_size, d_model, n_layers, n_heads, d_ff, max_len, dropout)
# 最终线性层
self.fc = nn.Linear(d_model, tgt_vocab_size)
def forward(self, src, tgt, src_mask=None, tgt_mask=None):
"""
Args:
src: shape (batch_size, src_seq_len)
tgt: shape (batch_size, tgt_seq_len)
src_mask: shape (batch_size, 1, 1, src_seq_len)
tgt_mask: shape (batch_size, 1, tgt_seq_len, tgt_seq_len)
Returns:
output: shape (batch_size, tgt_seq_len, tgt_vocab_size)
"""
# 编码
enc_output = self.encoder(src, src_mask)
# 解码
dec_output = self.decoder(tgt, enc_output, src_mask, tgt_mask)
# 最终预测
output = self.fc(dec_output)
return output
8.2 Mask的构建
def create_padding_mask(seq):
"""创建padding mask"""
mask = (seq == 0).unsqueeze(1).unsqueeze(2) # (batch_size, 1, 1, seq_len)
return mask
def create_look_ahead_mask(size):
"""创建look-ahead mask(防止前瞻)"""
mask = torch.triu(torch.ones(size, size), diagonal=1)
mask = mask.masked_fill(mask == 1, float('-inf'))
return mask
九、Transformer的训练与优化
9.1 损失函数
criterion = nn.CrossEntropyLoss(ignore_index=0) # 忽略padding token
def compute_loss(output, target):
"""
Args:
output: shape (batch_size, seq_len, vocab_size)
target: shape (batch_size, seq_len)
Returns:
loss: scalar
"""
output = output.contiguous().view(-1, output.size(-1))
target = target.contiguous().view(-1)
loss = criterion(output, target)
return loss
9.2 优化器
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
9.3 学习率调度
class WarmupScheduler:
def __init__(self, optimizer, d_model, warmup_steps=4000):
self.optimizer = optimizer
self.d_model = d_model
self.warmup_steps = warmup_steps
self.step_num = 0
def step(self):
self.step_num += 1
lr = (self.d_model ** -0.5) * min(self.step_num ** -0.5, self.step_num * self.warmup_steps ** -1.5)
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
十、Transformer的应用场景
10.1 机器翻译
Transformer最初就是为机器翻译任务设计的,在多个翻译任务上取得了state-of-the-art的结果。
10.2 文本生成
通过调整解码器的结构,可以将Transformer应用于文本生成任务,如GPT系列模型。
10.3 文本分类
可以使用Transformer的编码器部分进行文本分类任务。
10.4 问答系统
Transformer可以用于构建问答系统,处理上下文理解和答案生成。
十一、总结
Transformer架构通过注意力机制彻底改变了序列建模的方式,具有以下优点:
- 并行计算:相比RNN,Transformer可以并行处理整个序列
- 长距离依赖:注意力机制可以直接捕捉任意位置之间的依赖
- 灵活建模:多头注意力可以学习多种类型的依赖关系
同时也存在一些挑战:
- 计算复杂度:自注意力的复杂度为O(n²),对于长序列计算成本较高
- 内存占用:多头注意力需要存储大量的注意力权重
Transformer已经成为NLP领域的基础架构,后续的BERT、GPT等模型都是基于Transformer发展而来。理解Transformer的原理对于深入学习现代NLP技术至关重要。
#Transformer #NLP #深度学习 #注意力机制
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐


所有评论(0)