【算法解析】Transformer架构深度剖析与实现指南

引言

Transformer架构自2017年Google团队在论文《Attention is All You Need》中提出以来,彻底改变了自然语言处理领域的格局。它摒弃了传统的循环神经网络结构,完全基于注意力机制构建,在机器翻译、文本生成等任务上取得了革命性的突破。本文将深入剖析Transformer的核心原理,并结合PyTorch实现完整的Transformer模型。

一、Transformer架构总览

1.1 整体架构

Transformer由编码器(Encoder)和解码器(Decoder)两部分组成:

              ┌─────────────────────────────────────┐
              │         Encoder Stack               │
              │  ┌─────────────────────────────┐   │
              │  │  Multi-Head Attention      │   │
              │  └─────────────┬─────────────┘   │
              │                │                   │
              │  ┌─────────────▼─────────────┐   │
              │  │  Feed Forward Network     │   │
              │  └─────────────────────────────┘   │
              └───────────────────┬───────────────┘
                                  │
              ┌───────────────────▼───────────────┐
              │         Decoder Stack               │
              │  ┌─────────────────────────────┐   │
              │  │  Masked Multi-Head Attention│   │
              │  └─────────────┬─────────────┘   │
              │                │                   │
              │  ┌─────────────▼─────────────┐   │
              │  │  Encoder-Decoder Attention│   │
              │  └─────────────┬─────────────┘   │
              │                │                   │
              │  ┌─────────────▼─────────────┐   │
              │  │  Feed Forward Network     │   │
              │  └─────────────────────────────┘   │
              └───────────────────┬───────────────┘
                                  │
              ┌───────────────────▼───────────────┐
              │          Linear + Softmax         │
              └───────────────────────────────────┘

1.2 核心组件

Transformer的核心组件包括:

  • 多头注意力机制(Multi-Head Attention)
  • 位置编码(Positional Encoding)
  • 前馈神经网络(Feed Forward Network)
  • 层归一化(Layer Normalization)
  • 残差连接(Residual Connection)

二、自注意力机制详解

2.1 注意力机制的本质

注意力机制允许模型在处理序列时,根据当前位置动态地关注序列的不同部分。自注意力机制则是让序列中的每个位置都关注整个序列。

2.2 Scaled Dot-Product Attention

自注意力机制的核心是Scaled Dot-Product Attention:

def scaled_dot_product_attention(q, k, v, mask=None):
    """
    Args:
        q: shape (batch_size, n_heads, seq_len, d_k)
        k: shape (batch_size, n_heads, seq_len, d_k)
        v: shape (batch_size, n_heads, seq_len, d_v)
        mask: shape (batch_size, 1, seq_len, seq_len)
    
    Returns:
        output: shape (batch_size, n_heads, seq_len, d_v)
        attn_weights: shape (batch_size, n_heads, seq_len, seq_len)
    """
    d_k = q.size(-1)
    # 计算注意力分数
    scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
    
    # 应用mask(可选)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    
    # 计算注意力权重
    attn_weights = torch.softmax(scores, dim=-1)
    
    # 计算输出
    output = torch.matmul(attn_weights, v)
    
    return output, attn_weights

2.3 为什么需要Scaled?

当d_k较大时,点积的结果可能很大,导致softmax函数的梯度变得非常小(趋于饱和)。通过除以√d_k,可以将分数的方差归一化到1,避免梯度消失问题。

三、多头注意力机制

3.1 多头注意力的设计思想

多头注意力通过多个并行的注意力头,从不同角度捕捉输入序列的依赖关系:

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # 线性变换层
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)
    
    def split_heads(self, x, batch_size):
        """将输入按头数分割"""
        x = x.view(batch_size, -1, self.n_heads, self.d_k)
        return x.transpose(1, 2)  # (batch_size, n_heads, seq_len, d_k)
    
    def forward(self, q, k, v, mask=None):
        batch_size = q.size(0)
        
        # 线性变换
        q = self.w_q(q)
        k = self.w_k(k)
        v = self.w_v(v)
        
        # 分割多头
        q = self.split_heads(q, batch_size)
        k = self.split_heads(k, batch_size)
        v = self.split_heads(v, batch_size)
        
        # 计算注意力
        attn_output, attn_weights = scaled_dot_product_attention(q, k, v, mask)
        
        # 拼接多头输出
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, -1, self.d_model)
        
        # 最终线性变换
        output = self.w_o(attn_output)
        
        return output, attn_weights

3.2 多头注意力的优势

  • 不同的注意力头学习不同的注意力模式
  • 增加模型的表达能力
  • 捕捉多种类型的依赖关系

四、位置编码

4.1 位置编码的必要性

由于Transformer没有循环结构,无法自动获取序列的位置信息。位置编码通过在输入嵌入中添加位置信息,让模型能够理解序列的顺序。

4.2 正弦余弦位置编码

论文中使用的位置编码方式:

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        # 计算位置编码
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        """
        Args:
            x: shape (seq_len, batch_size, d_model)
        
        Returns:
            x + positional encoding
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

4.3 位置编码的特性

  • 周期性:相同位置间隔的位置编码具有相似性
  • 可扩展性:可以处理训练时未见过的序列长度
  • 数学性质:相对位置信息可以通过三角函数的性质得到

五、前馈神经网络

5.1 FFN的结构

每个编码器和解码器层都包含一个前馈神经网络:

class FeedForwardNetwork(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        """
        Args:
            x: shape (batch_size, seq_len, d_model)
        
        Returns:
            output: shape (batch_size, seq_len, d_model)
        """
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

5.2 FFN的作用

前馈神经网络在每个位置上独立地进行非线性变换,增加模型的表达能力。

六、编码器实现

6.1 单个编码器层

class EncoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, n_heads)
        self.ffn = FeedForwardNetwork(d_model, d_ff, dropout)
        
        # 层归一化
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        # Dropout
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        """
        Args:
            x: shape (batch_size, seq_len, d_model)
            mask: shape (batch_size, 1, 1, seq_len)
        
        Returns:
            output: shape (batch_size, seq_len, d_model)
        """
        # 自注意力
        attn_output, _ = self.self_attn(x, x, x, mask)
        x = x + self.dropout1(attn_output)
        x = self.norm1(x)
        
        # 前馈网络
        ffn_output = self.ffn(x)
        x = x + self.dropout2(ffn_output)
        x = self.norm2(x)
        
        return x

6.2 编码器栈

class Encoder(nn.Module):
    def __init__(self, vocab_size, d_model, n_layers, n_heads, d_ff, max_len=5000, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        
        # 词嵌入
        self.embedding = nn.Embedding(vocab_size, d_model)
        
        # 位置编码
        self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
        
        # 编码器层栈
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, n_heads, d_ff, dropout)
            for _ in range(n_layers)
        ])
    
    def forward(self, x, mask=None):
        """
        Args:
            x: shape (batch_size, seq_len)
            mask: shape (batch_size, 1, 1, seq_len)
        
        Returns:
            output: shape (batch_size, seq_len, d_model)
        """
        # 词嵌入并乘以√d_model
        x = self.embedding(x) * math.sqrt(self.d_model)
        
        # 添加位置编码
        x = self.pos_encoding(x.transpose(0, 1)).transpose(0, 1)
        
        # 通过编码器层
        for layer in self.layers:
            x = layer(x, mask)
        
        return x

七、解码器实现

7.1 单个解码器层

class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.masked_self_attn = MultiHeadAttention(d_model, n_heads)
        self.enc_dec_attn = MultiHeadAttention(d_model, n_heads)
        self.ffn = FeedForwardNetwork(d_model, d_ff, dropout)
        
        # 层归一化
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        
        # Dropout
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)
    
    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        """
        Args:
            x: shape (batch_size, tgt_seq_len, d_model)
            enc_output: shape (batch_size, src_seq_len, d_model)
            src_mask: shape (batch_size, 1, 1, src_seq_len)
            tgt_mask: shape (batch_size, 1, tgt_seq_len, tgt_seq_len)
        
        Returns:
            output: shape (batch_size, tgt_seq_len, d_model)
        """
        # Masked自注意力(防止前瞻)
        attn_output, _ = self.masked_self_attn(x, x, x, tgt_mask)
        x = x + self.dropout1(attn_output)
        x = self.norm1(x)
        
        # 编码器-解码器注意力
        attn_output, _ = self.enc_dec_attn(x, enc_output, enc_output, src_mask)
        x = x + self.dropout2(attn_output)
        x = self.norm2(x)
        
        # 前馈网络
        ffn_output = self.ffn(x)
        x = x + self.dropout3(ffn_output)
        x = self.norm3(x)
        
        return x

7.2 解码器栈

class Decoder(nn.Module):
    def __init__(self, vocab_size, d_model, n_layers, n_heads, d_ff, max_len=5000, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        
        # 词嵌入
        self.embedding = nn.Embedding(vocab_size, d_model)
        
        # 位置编码
        self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
        
        # 解码器层栈
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, n_heads, d_ff, dropout)
            for _ in range(n_layers)
        ])
    
    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        """
        Args:
            x: shape (batch_size, tgt_seq_len)
            enc_output: shape (batch_size, src_seq_len, d_model)
            src_mask: shape (batch_size, 1, 1, src_seq_len)
            tgt_mask: shape (batch_size, 1, tgt_seq_len, tgt_seq_len)
        
        Returns:
            output: shape (batch_size, tgt_seq_len, d_model)
        """
        # 词嵌入并乘以√d_model
        x = self.embedding(x) * math.sqrt(self.d_model)
        
        # 添加位置编码
        x = self.pos_encoding(x.transpose(0, 1)).transpose(0, 1)
        
        # 通过解码器层
        for layer in self.layers:
            x = layer(x, enc_output, src_mask, tgt_mask)
        
        return x

八、完整Transformer模型

8.1 模型组装

class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, n_layers=6, 
                 n_heads=8, d_ff=2048, max_len=5000, dropout=0.1):
        super().__init__()
        
        self.encoder = Encoder(src_vocab_size, d_model, n_layers, n_heads, d_ff, max_len, dropout)
        self.decoder = Decoder(tgt_vocab_size, d_model, n_layers, n_heads, d_ff, max_len, dropout)
        
        # 最终线性层
        self.fc = nn.Linear(d_model, tgt_vocab_size)
    
    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        """
        Args:
            src: shape (batch_size, src_seq_len)
            tgt: shape (batch_size, tgt_seq_len)
            src_mask: shape (batch_size, 1, 1, src_seq_len)
            tgt_mask: shape (batch_size, 1, tgt_seq_len, tgt_seq_len)
        
        Returns:
            output: shape (batch_size, tgt_seq_len, tgt_vocab_size)
        """
        # 编码
        enc_output = self.encoder(src, src_mask)
        
        # 解码
        dec_output = self.decoder(tgt, enc_output, src_mask, tgt_mask)
        
        # 最终预测
        output = self.fc(dec_output)
        
        return output

8.2 Mask的构建

def create_padding_mask(seq):
    """创建padding mask"""
    mask = (seq == 0).unsqueeze(1).unsqueeze(2)  # (batch_size, 1, 1, seq_len)
    return mask

def create_look_ahead_mask(size):
    """创建look-ahead mask(防止前瞻)"""
    mask = torch.triu(torch.ones(size, size), diagonal=1)
    mask = mask.masked_fill(mask == 1, float('-inf'))
    return mask

九、Transformer的训练与优化

9.1 损失函数

criterion = nn.CrossEntropyLoss(ignore_index=0)  # 忽略padding token

def compute_loss(output, target):
    """
    Args:
        output: shape (batch_size, seq_len, vocab_size)
        target: shape (batch_size, seq_len)
    
    Returns:
        loss: scalar
    """
    output = output.contiguous().view(-1, output.size(-1))
    target = target.contiguous().view(-1)
    loss = criterion(output, target)
    return loss

9.2 优化器

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

9.3 学习率调度

class WarmupScheduler:
    def __init__(self, optimizer, d_model, warmup_steps=4000):
        self.optimizer = optimizer
        self.d_model = d_model
        self.warmup_steps = warmup_steps
        self.step_num = 0
    
    def step(self):
        self.step_num += 1
        lr = (self.d_model ** -0.5) * min(self.step_num ** -0.5, self.step_num * self.warmup_steps ** -1.5)
        
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

十、Transformer的应用场景

10.1 机器翻译

Transformer最初就是为机器翻译任务设计的,在多个翻译任务上取得了state-of-the-art的结果。

10.2 文本生成

通过调整解码器的结构,可以将Transformer应用于文本生成任务,如GPT系列模型。

10.3 文本分类

可以使用Transformer的编码器部分进行文本分类任务。

10.4 问答系统

Transformer可以用于构建问答系统,处理上下文理解和答案生成。

十一、总结

Transformer架构通过注意力机制彻底改变了序列建模的方式,具有以下优点:

  1. 并行计算:相比RNN,Transformer可以并行处理整个序列
  2. 长距离依赖:注意力机制可以直接捕捉任意位置之间的依赖
  3. 灵活建模:多头注意力可以学习多种类型的依赖关系

同时也存在一些挑战:

  1. 计算复杂度:自注意力的复杂度为O(n²),对于长序列计算成本较高
  2. 内存占用:多头注意力需要存储大量的注意力权重

Transformer已经成为NLP领域的基础架构,后续的BERT、GPT等模型都是基于Transformer发展而来。理解Transformer的原理对于深入学习现代NLP技术至关重要。

#Transformer #NLP #深度学习 #注意力机制

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐