从零实现Transformer:第 7 部分 - 集成组件搭建完整模型(Integrating Components to Assemble the Transformer Model)

flyfish

完整代码在文末

在前 6 部分构建了所有基础组件后,本部分将集成嵌入层、位置编码、编码器、解码器和投影层,实现完整 Transformer 模型。
看下图 3个注意力已经搞清楚了,如果概念还是不清楚,可以点击 这里理清概念

请添加图片描述

注意看Encoder 和 Decoder 怎么连接起来的

Encoder 和 Decoder没有直接拼接,而是通过 Decoder 里的 Cross-Attention 层做桥梁,经过Encoder (多层 Self-Attention + Feed-Forward)处理后,得到丰富的上下文表示(Encoder Output)。Decoder 的每一层都拿着这个 Encoder Output 当 K/V,和自己的 Q 做注意力计算,从而把源序列的信息传递到解码过程中

左侧(橙色框):Encoder(编码器)
右侧(绿色框):Decoder(解码器)

Encoder 负责理解输入序列(提取特征),Decoder 负责生成输出序列。

请添加图片描述
论文《Attention Is All You Need》在 3.2.3 Applications of Attention in our Model 这一节写道:

In “encoder-decoder attention” layers, the queries come from the previous decoder layer, and the memory keys and values come from the output of the encoder.

编码器的输出也可以用如下词表示
the output of the encoder
memory
encoder hidden states
states
encoder outputs
encoder representations 例如下图用states在这里插入图片描述
再详细一些涉及到层了
叫法可以是 EncoderBlock ,Decoder Block 或者 Encoder Layer,Decoder Layer
Encoder 的输出(States),直接连到了 Decoder 中每一层 Cross-Attention 的 K(Key)和 V(Value)输入。每一层 Decoder 的 Cross-Attention 使用的 K 和 V 都是相同的。
这份 States 会同时、直接、完全相同地 连接到 Decoder 每一层 的 Multi-Headed Cross-Attention 的 K(Key) 和 V(Value) 输入上。
也就是说:
Decoder 第 1 层 Cross-Attention 的 K、V = Encoder States
Decoder 第 2 层 Cross-Attention 的 K、V = 还是同一个 Encoder States
……(如果是 N=6 层,也全部一样)
在这里插入图片描述
Encoder 是串行堆叠(N 层顺序执行)
Encoder 1 → Encoder 2 → Encoder 3 → … → Encoder N
Decoder 也是串行堆叠(N 层顺序执行)
Decoder 1 → Decoder 2 → Decoder 3 → … → Decoder N
只有 Encoder 的最后一层(Encoder N) 的输出会连接到 Decoder 的每一层 的 Cross-Attention 上。
在这里插入图片描述

组件1:最终的投影层(Projection Layer)

作用

Decoder 输出的是 [batch, seq_len, d_model] 的特征向量,但需要预测词汇表中每个 token 的概率。投影层完成维度转换:
ProjectionLayer 是 Transformer Decoder 的最后一步 —— 输出投影层(Output Projection)。

把 Decoder 输出的隐状态向量(hidden states)转换为词汇表上的 logits(预测分数)。logits 不是概率。它是模型输出的原始未归一化分数。只有经过 softmax(logits) 之后,才会变成真正的概率。

d_model 维特征 → vocab_size 维 logits

它就是一个线性层(Linear Layer):nn.Linear(d_model, vocab_size)
把每个位置的 d_model 维向量,线性映射到整个词汇表的大小。
输出结果称为 logits(未归一化的概率分数)。
后续通常会接 softmax(训练时用 CrossEntropyLoss 会自动处理)来得到每个 token 的概率分布。

PyTorch 实现

class ProjectionLayer(nn.Module):
    def __init__(self, d_model: int, vocab_size: int):
        super().__init__()
        self.proj = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        return self.proj(x)

加注释

import torch
import torch.nn as nn

class ProjectionLayer(nn.Module):
    """
    最终线性投影层:将 Decoder 输出映射到词汇表空间
    """
    def __init__(self, d_model: int, vocab_size: int):
        """
        Args:
            d_model: 模型特征维度(输入)
            vocab_size: 目标词汇表大小(输出)
        """
        super().__init__()
        # 简单线性层:d_model → vocab_size
        self.proj = nn.Linear(d_model, vocab_size)

    def forward(self, x: torch.Tensor):
        """
        Args:
            x: Decoder 输出 [batch, seq_len, d_model]
        
        Returns:
            logits: [batch, seq_len, vocab_size],未归一化的分数
        """
        return self.proj(x)  # 线性投影

组件2:完整 Transformer 类

数据流向总览

源序列 src (token IDs)
    │
    ▼
┌─────────────────┐
│ src_embed       │ ← 词嵌入 [B, S, d_model]
└─────────────────┘
    │
    ▼
┌─────────────────┐
│ src_pos         │ ← + 位置编码 [B, S, d_model]
└─────────────────┘
    │
    ▼
┌─────────────────┐
│ Encoder × N     │ ← 使用 src_mask [B, S, d_model]
└─────────────────┘
    │
    ▼  encoder_output
    │
目标序列 tgt (token IDs)
    │
    ▼
┌─────────────────┐
│ tgt_embed       │ ← 词嵌入 [B, T, d_model]
└─────────────────┘
    │
    ▼
┌─────────────────┐
│ tgt_pos         │ ← + 位置编码 [B, T, d_model]
└─────────────────┘
    │
    ▼
┌─────────────────┐
│ Decoder × N     │ ← 使用 encoder_output + src_mask + tgt_mask
└─────────────────┘
    │
    ▼  decoder_output [B, T, d_model]
    │
    ▼
┌─────────────────┐
│ Projection      │ ← 线性投影 [B, T, vocab_size]
└─────────────────┘
    │
    ▼
logits → Softmax → 下一个 token 概率分布

Transformer 模型的顶层封装类,它把之前写的所有组件(Encoder、Decoder、Embedding、Positional Encoding、ProjectionLayer)组合成一个完整的 Transformer 模型

Transformer 类实现

class Transformer(nn.Module):
    def __init__(self, encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.projection_layer = projection_layer

    def encode(self, src, src_mask):
        src = self.src_embed(src)
        src = self.src_pos(src)
        return self.encoder(src, src_mask)

    def decode(self, encoder_output, src_mask, tgt, tgt_mask):
        tgt = self.tgt_embed(tgt)
        tgt = self.tgt_pos(tgt)
        return self.decoder(tgt, encoder_output, src_mask, tgt_mask)

    def project(self, x):
        return self.projection_layer(x)

    def forward(self, src, tgt, src_mask, tgt_mask):
        enc_out = self.encode(src, src_mask)
        dec_out = self.decode(enc_out, src_mask, tgt, tgt_mask)
        return self.project(dec_out)
class Transformer(nn.Module):

继承 nn.Module,这是 PyTorch 中所有神经网络模型的基类。

    def __init__(self, encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.projection_layer = projection_layer

初始化函数 __init__

接收 7 个已经构建好的组件,并把它们保存为类的属性:

参数 含义 作用
encoder Encoder 堆叠(6层) 理解输入序列
decoder Decoder 堆叠(6层) 生成输出序列
src_embed 源语言(输入)的 Embedding 层 将 token 转为向量
tgt_embed 目标语言(输出)的 Embedding 层 将 token 转为向量
src_pos 源序列的位置编码 加入位置信息
tgt_pos 目标序列的位置编码 加入位置信息
projection_layer 输出投影层(Linear) 把隐状态转成词汇表 logits
    def encode(self, src, src_mask):
        src = self.src_embed(src)      # [batch, src_len] → [batch, src_len, d_model]
        src = self.src_pos(src)        # 加上位置编码
        return self.encoder(src, src_mask)   # 通过多层 Encoder

encode 方法:负责编码阶段(Encoder 前向传播)

  1. 输入的 token ID 序列 → Embedding
  2. 加上位置编码
  3. 送入 Encoder 多层处理
  4. 返回 Encoder 的最终输出(Memory / States)
    def decode(self, encoder_output, src_mask, tgt, tgt_mask):
        tgt = self.tgt_embed(tgt)           # Decoder 输入的 token 做 Embedding
        tgt = self.tgt_pos(tgt)             # 加上位置编码
        return self.decoder(tgt, encoder_output, src_mask, tgt_mask)

decode 方法:负责解码阶段(Decoder 前向传播)

tgt:当前已生成的 target 序列(训练时是 teacher forcing)
encoder_output:Encoder 输出的 Memory(K 和 V 来源)
内部会依次经过 Decoder 的每一层(Masked Self-Attention + Cross-Attention + FeedForward)

    def project(self, x):
        return self.projection_layer(x)

project 方法:把 Decoder 的输出 [batch, tgt_len, d_model] 映射到词汇表维度 [batch, tgt_len, vocab_size],得到 logits

    def forward(self, src, tgt, src_mask, tgt_mask):
        """端到端前向传播"""
        enc_out = self.encode(src, src_mask)
        dec_out = self.decode(enc_out, src_mask, tgt, tgt_mask)
        return self.project(dec_out)

forward 方法(最重要):

这是 PyTorch 模型被调用时(model(...))实际执行的函数。它把整个流程串联起来:

完整流程

  1. src → Encoder → enc_out
  2. tgt + enc_out → Decoder → dec_out
  3. dec_out → Projection → logits

完整代码

import torch
import torch.nn as nn
import math
import random
import copy
import time

# ====================== 基础组件 ======================

class InputEmbeddings(nn.Module):
    def __init__(self, d_model: int, vocab_size: int):
        super().__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)

    def forward(self, x):
        return self.embedding(x) * math.sqrt(self.d_model)


class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, seq_len: int, dropout: float):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(seq_len, d_model)
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.shape[1], :].requires_grad_(False)
        return self.dropout(x)


def scaled_dot_product_attention(query, key, value, mask=None, dropout=None):
    d_k = query.shape[-1]
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    
    if mask is not None:
        scores = scores.masked_fill(mask, float('-inf'))
    
    attn = torch.softmax(scores, dim=-1)
    if dropout is not None:
        attn = dropout(attn)
    
    return torch.matmul(attn, value), attn


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, dropout: float):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        self.attention_weights = None

    def forward(self, query, key, value, mask=None):
        batch_size = query.shape[0]

        Q = self.W_q(query)
        K = self.W_k(key)
        V = self.W_v(value)

        Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        output, self.attention_weights = scaled_dot_product_attention(Q, K, V, mask, self.dropout)

        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        return self.W_o(output)


class LayerNormalization(nn.Module):
    def __init__(self, features: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(features))
        self.beta = nn.Parameter(torch.zeros(features))

    def forward(self, x: torch.Tensor):
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True, unbiased=False)
        normalized = (x - mean) / torch.sqrt(std ** 2 + self.eps)
        return self.gamma * normalized + self.beta


class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float):
        super().__init__()
        self.linear_1 = nn.Linear(d_model, d_ff)
        self.linear_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.ReLU()

    def forward(self, x):
        return self.linear_2(self.dropout(self.activation(self.linear_1(x))))


class ResidualConnection(nn.Module):
    def __init__(self, features: int, dropout: float):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = LayerNormalization(features)

    def forward(self, x, sublayer):
        return self.norm(x + self.dropout(sublayer(x)))


# ====================== Encoder & Decoder ======================

class EncoderBlock(nn.Module):
    def __init__(self, features: int, self_attention_block: MultiHeadAttention, 
                 feed_forward_block: PositionwiseFeedForward, dropout: float):
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(2)])

    def forward(self, x, src_mask):
        x = self.residual_connections[0](x, lambda x_res: self.self_attention_block(x_res, x_res, x_res, src_mask))
        x = self.residual_connections[1](x, self.feed_forward_block)
        return x


class Encoder(nn.Module):
    def __init__(self, features: int, layers: nn.ModuleList):
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features)

    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)


class DecoderBlock(nn.Module):
    def __init__(self, features: int, self_attention_block: MultiHeadAttention,
                 cross_attention_block: MultiHeadAttention, feed_forward_block: PositionwiseFeedForward, dropout: float):
        super().__init__()
        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(3)])

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        # Masked Self-Attention
        x = self.residual_connections[0](x, lambda x_res: self.self_attention_block(x_res, x_res, x_res, tgt_mask))
        # Cross-Attention (Q from decoder, K/V from encoder)
        x = self.residual_connections[1](x, lambda x_res: 
            self.cross_attention_block(x_res, encoder_output, encoder_output, src_mask))
        # Feed Forward
        x = self.residual_connections[2](x, self.feed_forward_block)
        return x


class Decoder(nn.Module):
    def __init__(self, features: int, layers: nn.ModuleList):
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features)

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        return self.norm(x)


class ProjectionLayer(nn.Module):
    def __init__(self, d_model: int, vocab_size: int):
        super().__init__()
        self.proj = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        return self.proj(x)


class Transformer(nn.Module):
    def __init__(self, encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.projection_layer = projection_layer

    def encode(self, src, src_mask):
        src = self.src_embed(src)
        src = self.src_pos(src)
        return self.encoder(src, src_mask)

    def decode(self, encoder_output, src_mask, tgt, tgt_mask):
        tgt = self.tgt_embed(tgt)
        tgt = self.tgt_pos(tgt)
        return self.decoder(tgt, encoder_output, src_mask, tgt_mask)

    def project(self, x):
        return self.projection_layer(x)

    def forward(self, src, tgt, src_mask, tgt_mask):
        enc_out = self.encode(src, src_mask)
        dec_out = self.decode(enc_out, src_mask, tgt, tgt_mask)
        return self.project(dec_out)


# ====================== 测试完整 Transformer 模型 ======================
print("\n--- 测试完整 Transformer 模型 ---")

# 超参数配置
src_vocab_size = 10000
tgt_vocab_size = 11000
d_model = 512
num_layers = 6
num_heads = 8
d_ff = 2048
dropout = 0.1
max_seq_len = 100

# 创建组件
src_embed = InputEmbeddings(d_model, src_vocab_size)
tgt_embed = InputEmbeddings(d_model, tgt_vocab_size)
src_pos = PositionalEncoding(d_model, max_seq_len, dropout)
tgt_pos = PositionalEncoding(d_model, max_seq_len, dropout)

attention = MultiHeadAttention(d_model, num_heads, dropout)
feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)

# 构建 Encoder
encoder_blocks = nn.ModuleList([
    EncoderBlock(d_model, copy.deepcopy(attention), copy.deepcopy(feed_forward), dropout) 
    for _ in range(num_layers)
])
encoder = Encoder(d_model, encoder_blocks)

# 构建 Decoder
decoder_blocks = nn.ModuleList([
    DecoderBlock(d_model, copy.deepcopy(attention), copy.deepcopy(attention), 
                 copy.deepcopy(feed_forward), dropout) 
    for _ in range(num_layers)
])
decoder = Decoder(d_model, decoder_blocks)

projection = ProjectionLayer(d_model, tgt_vocab_size)

# 组装完整模型
transformer = Transformer(
    encoder=encoder,
    decoder=decoder,
    src_embed=src_embed,
    tgt_embed=tgt_embed,
    src_pos=src_pos,
    tgt_pos=tgt_pos,
    projection_layer=projection
)

# 参数初始化
for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

print(f"模型创建成功,参数量: {sum(p.numel() for p in transformer.parameters()):,}")

# ==================== 虚拟数据测试 ====================
batch_size, src_len, tgt_len = 2, 10, 12

dummy_src = torch.randint(1, src_vocab_size, (batch_size, src_len))
dummy_tgt = torch.randint(1, tgt_vocab_size, (batch_size, tgt_len))

# 掩码
src_mask = torch.zeros(batch_size, 1, 1, src_len, dtype=torch.bool)
look_ahead = torch.triu(torch.ones(tgt_len, tgt_len), diagonal=1).bool()
tgt_mask = look_ahead.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, tgt_len, tgt_len)

print(f"\n输入形状:")
print(f"  src: {dummy_src.shape}, tgt: {dummy_tgt.shape}")
print(f"  src_mask: {src_mask.shape}, tgt_mask: {tgt_mask.shape}")

transformer.eval()
with torch.no_grad():
    enc_out = transformer.encode(dummy_src, src_mask)
    print(f"\nEncoder 输出: {enc_out.shape}")   # [2, 10, 512]
    
    dec_out = transformer.decode(enc_out, src_mask, dummy_tgt, tgt_mask)
    print(f"Decoder 输出: {dec_out.shape}")   # [2, 12, 512]
    
    logits = transformer.project(dec_out)
    print(f"最终 logits: {logits.shape}")     # [2, 12, 11000]
    
    full_output = transformer(dummy_src, dummy_tgt, src_mask, tgt_mask)
    print(f"完整 forward: {full_output.shape}")
    
    assert torch.allclose(full_output, logits)
    print("\n完整 Transformer 测试全部通过!")

输出

--- 测试完整 Transformer 模型 ---
模型创建成功,参数量: 60,535,544

输入形状:
  src: torch.Size([2, 10]), tgt: torch.Size([2, 12])
  src_mask: torch.Size([2, 1, 1, 10]), tgt_mask: torch.Size([2, 1, 12, 12])

Encoder 输出: torch.Size([2, 10, 512])
Decoder 输出: torch.Size([2, 12, 512])
最终 logits: torch.Size([2, 12, 11000])
完整 forward: torch.Size([2, 12, 11000])

完整 Transformer 测试全部通过!

下一步预告:训练

第 8 部分 将实现完整的训练流程

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐