从零实现Transformer:第 8 部分 - 数据集构建与训练全流程(Dataset Construction and Full Training Workflow)

flyfish

文末有完整的源码

这里从数据生成、编码、掩码构造,到训练循环、损失计算、梯度更新,吃透Transformer的工程实现细节。

一个简单的任务 复制任务

序列复制任务的目标很简单:让模型原样输出输入的数字序列(如输入[SOS]17095[EOS],输出[SOS]17095[EOS])。

  1. 无需复杂语料,随机数字即可构建数据集;
  2. 验证Transformer的编码器-解码器架构、注意力机制;
  3. 屏蔽业务逻辑,专注数据处理、训练流程的代码。

一、数据处理模块:从原始序列到模型输入

1.1 词汇表与特殊标记定义

首先定义序列任务必备的特殊标记词汇表,这是文本/序列转数字的基础:

# 特殊标记
PAD_TOKEN = '[PAD]'    # 填充标记:统一序列长度
SOS_TOKEN = '[SOS]'    # 序列开始标记:解码器起始输入
EOS_TOKEN = '[EOS]'    # 序列结束标记:终止生成

# 词汇表:特殊标记 + 数字0-9
VOCAB = [PAD_TOKEN, SOS_TOKEN, EOS_TOKEN] + [str(i) for i in range(10)]
# token与id双向映射
token_to_id = {token: i for i, token in enumerate(VOCAB)}
id_to_token = {i: token for token, i in token_to_id.items()}
VOCAB_SIZE = len(VOCAB)
PAD_ID = token_to_id[PAD_TOKEN]
作用:

特殊标记是序列任务的标配:PAD解决批次长度不一致问题,SOS/EOS控制解码器的生成起止;
token-id映射:模型无法处理文本,必须将字符转为数字索引;
本任务词汇表仅13个token(3个特殊标记+10个数字),极简且易训练。

1.2 复制任务数据集生成

手动生成随机数字序列,构造「输入=输出」的复制任务数据:

def generate_copy_task_data(num_examples: int, min_len: int, max_len: int):
    data = []
    # 过滤特殊标记,仅用数字生成序列
    content_vocab = [token for token in VOCAB if token not in [PAD_TOKEN, SOS_TOKEN, EOS_TOKEN]]
    for _ in range(num_examples):
        # 随机生成序列长度
        seq_len = random.randint(min_len, max_len)
        # 随机选择数字组成序列
        sequence = [random.choice(content_vocab) for _ in range(seq_len)]
        # 输入/输出序列:首尾拼接SOS、EOS
        src = [SOS_TOKEN] + sequence + [EOS_TOKEN]
        tgt = [SOS_TOKEN] + sequence + [EOS_TOKEN]
        data.append({'src': src, 'tgt': tgt})
    return data
逻辑:
  1. 生成5~10位随机数字序列
  2. 输入序列src和目标序列tgt完全一致(复制任务);
  3. 所有序列首尾拼接SOSEOS,符合Transformer的序列格式要求。

1.3 序列编码与填充

模型接收的输入是数字索引,且批次内序列必须长度统一,因此需要编码+填充:

# 序列编码:token转id
def tokenize_sequence(sequence, token_to_id_map):
    return [token_to_id_map[token] for token in sequence]

# 序列填充:统一长度
def pad_sequence(sequence_ids, max_len, pad_id):
    padded_ids = sequence_ids + [pad_id] * (max_len - len(sequence_ids))
    return padded_ids[:max_len]
知识点:

编码:将['SOS','1','7','0','9','5','EOS']转为[1,4,8,11,10,6,2]
填充:用PAD标记将所有序列补全到MAX_PADDED_LEN(最大序列长度+2),保证批次训练时张量形状一致。

1.4 掩码(Mask)生成

掩码分为源序列掩码目标序列掩码,作用是屏蔽无效信息:

# 源序列掩码:屏蔽填充的PAD标记
def create_src_mask(src_ids, pad_id):
    return (src_ids == pad_id).unsqueeze(1).unsqueeze(2)

# 目标序列掩码:屏蔽PAD + 屏蔽未来token(前瞻掩码)
def create_tgt_mask(tgt_ids, pad_id):
    batch_size, tgt_seq_len = tgt_ids.shape
    # 1. 填充掩码:屏蔽PAD
    tgt_padding_mask = (tgt_ids == pad_id).unsqueeze(1).unsqueeze(2)
    # 2. 前瞻掩码:上三角矩阵,屏蔽未来位置
    look_ahead_mask = torch.triu(torch.ones(tgt_seq_len, tgt_seq_len), diagonal=1).bool()
    look_ahead_mask = look_ahead_mask.unsqueeze(0).unsqueeze(0)
    # 合并两种掩码
    return tgt_padding_mask | look_ahead_mask
掩码的作用:
  1. 源序列掩码(src_mask)
    告诉编码器哪些位置是PAD填充,无需计算注意力,避免无效计算。
  2. 目标序列掩码(tgt_mask)
    填充掩码:同src_mask,屏蔽PAD;
    前瞻掩码(Look-ahead Mask):用torch.triu生成上三角布尔矩阵,强制解码器只能看到当前及之前的token,杜绝偷看未来信息。

1.5 自定义Dataset与DataLoader

PyTorch中用Dataset+DataLoader实现批量数据加载,这是训练的标准范式:

class CopyTaskDataset(Dataset):
    def __init__(self, data, max_len, pad_id):
        self.data = data
        self.max_len = max_len
        self.pad_id = pad_id

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        # 源序列编码+填充
        src_ids = pad_sequence(tokenize_sequence(item['src'], token_to_id), self.max_len, self.pad_id)
        tgt_ids = tokenize_sequence(item['tgt'], token_to_id)
        
        # 解码器:输入与标签偏移
        decoder_input_ids = pad_sequence(tgt_ids[:-1], self.max_len, self.pad_id)
        label_ids = pad_sequence(tgt_ids[1:], self.max_len, self.pad_id)

        return {
            "src_ids": torch.tensor(src_ids, dtype=torch.long),
            "decoder_input_ids": torch.tensor(decoder_input_ids, dtype=torch.long),
            "label_ids": torch.tensor(label_ids, dtype=torch.long)
        }
设计:解码器输入与标签的偏移

解码器输入:目标序列去掉最后一个token(tgt_ids[:-1]);
标签:目标序列去掉第一个token(tgt_ids[1:])。

原理:模型根据前n个token预测第n+1个token
例如:目标序列[SOS,1,7,0,EOS]
解码器输入:[SOS,1,7,0]
标签:[1,7,0,EOS]

最后用DataLoader实现批量、打乱加载:

train_dataset = CopyTaskDataset(raw_data, MAX_PADDED_LEN, PAD_ID)
train_dataloader = DataLoader(train_dataset, batch_size=CONFIG["batch_size"], shuffle=True)

二、模型训练模块

数据准备完毕后,训练模块负责模型的前向传播、损失计算、梯度更新

2.1 超参数集中配置

工程化最佳实践:将所有超参数集中管理,方便调试:

CONFIG = {
    "num_examples": 1000,    # 数据集大小
    "min_seq_len": 5,        # 最小序列长度
    "max_seq_len": 10,       # 最大序列长度
    "batch_size": 64,        # 批次大小
    "d_model": 128,          # 模型特征维度
    "num_layers": 3,         # 编码器/解码器层数
    "num_heads": 4,          # 注意力头数
    "d_ff": 512,             # 前馈网络维度
    "dropout": 0.1,          #  dropout概率
    "lr": 1e-3,              # 学习率
    "epochs": 40,            # 训练轮数
}

超参数说明:
d_model:Transformer的特征维度,注意力头数必须能整除它;
num_heads:多头注意力的头数;

2.2 单轮训练函数:train_one_epoch

这是训练的核心逻辑,封装了前向传播、损失计算、反向传播、参数更新

def train_one_epoch(model, dataloader, loss_fn, optimizer, pad_id):
    # 切换为训练模式:启用Dropout、BatchNorm
    model.train()
    total_loss = 0
    start_time = time.time()

    for batch_idx, batch in enumerate(dataloader):
        # 1. 数据移至GPU/CPU
        src_ids = batch['src_ids'].to(device)
        decoder_input_ids = batch['decoder_input_ids'].to(device)
        label_ids = batch['label_ids'].to(device)

        # 2. 生成掩码
        src_mask = create_src_mask(src_ids, pad_id).to(device)
        tgt_mask = create_tgt_mask(decoder_input_ids, pad_id).to(device)

        # 3. 前向传播
        optimizer.zero_grad()  # 清空梯度
        logits = model(src_ids, decoder_input_ids, src_mask, tgt_mask)
        
        # 4. 损失计算:展平张量,忽略PAD标记
        loss = loss_fn(logits.reshape(-1, logits.size(-1)), label_ids.reshape(-1))
        
        # 5. 反向传播 + 参数更新
        loss.backward()   # 计算梯度
        optimizer.step()  # 更新权重

        total_loss += loss.item()

        # 打印训练进度
        if (batch_idx + 1) % 10 == 0:
            elapsed = time.time() - start_time
            print(f"  批次 {batch_idx+1}/{len(dataloader)} | 损失: {loss.item():.4f} | 耗时: {elapsed:.2f}s")
            start_time = time.time()

    return total_loss / len(dataloader)
训练逻辑:
  1. 模式切换model.train() 启用Dropout等训练层,与推理模式区分;
  2. 设备迁移:将张量移至GPU(加速训练),无GPU则自动用CPU;
  3. 梯度清空optimizer.zero_grad() 避免梯度累加;
  4. 前向传播:输入源序列、解码器输入、掩码,输出预测概率分布(logits);
  5. 损失计算
    CrossEntropyLoss
    ignore_index=PAD_ID忽略填充标记的损失,避免PAD干扰训练;
    展平张量:将[batch, seq_len, vocab_size]转为[batch*seq_len, vocab_size],适配损失函数输入;
  6. 反向传播loss.backward() 计算所有参数的梯度;
  7. 参数更新optimizer.step() 根据梯度更新模型权重。

2.3 主训练循环 + 模型保存

在主程序中循环调用单轮训练,完成模型收敛并保存权重:

if __name__ == "__main__":
    # 1. 生成数据集+加载数据
    raw_data = generate_copy_task_data(...)
    train_dataloader = DataLoader(...)

    # 2. 初始化模型、优化器、损失函数
    model = build_transformer(CONFIG).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG["lr"], betas=(0.9, 0.98), eps=1e-9)
    loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_ID)

    # 3. 多轮训练
    for epoch in range(1, CONFIG["epochs"] + 1):
        epoch_start = time.time()
        avg_loss = train_one_epoch(...)
        print(f"【第{epoch}轮】平均损失: {avg_loss:.4f}")

    # 4. 保存模型权重
    torch.save(model.state_dict(), "transformer_copy_task.pth")
优化器选择:

使用Adam优化器(Transformer官方标配),搭配betas=(0.9, 0.98)eps=1e-9,稳定训练过程。

2.4 推理验证:测试复制效果

训练完成后,用infer函数做贪心解码,验证模型是否能完美复制序列:

def infer(model, input_sequence, max_len=MAX_PADDED_LEN):
    model.eval()  # 切换推理模式,关闭Dropout
    # 编码输入 → 编码器输出 → 解码器逐词生成
    with torch.no_grad():  # 关闭梯度计算,加速推理
        # 贪心解码:每次预测概率最大的token,直到生成EOS
        ...

测试效果:

输入序列: ['[SOS]', '1', '7', '0', '9', '5', '[EOS]']
模型输出: ['[SOS]', '1', '7', '0', '9', '5', '[EOS]']

训练流程

┌───────────────────────────────────────────────────────────┐
│                      Epoch loop                           │
│   ┌───────────────────────────────────────────────────┐   │
│   │              Mini-batch loop                      │   │
│   │  1️  zero gradients      →  model.zero_grad()     │   │
│   │  2️  forward pass         →  outputs = model(inputs)│   │
│   │  3️  compute loss         →  loss = criterion(outputs, labels) │   │
│   │  4️  compute gradients    →  loss.backward()       │   │
│   │  5️  update weights       →  optimizer.step()    │   │
│   └───────────────────────────────────────────────────┘   │
└───────────────────────────────────────────────────────────┘

完整代码

import torch
import torch.nn as nn
import math
import random
import copy
import time
from torch.utils.data import Dataset, DataLoader

# ====================== 1. Transformer 基础组件 ======================
class InputEmbeddings(nn.Module):
    def __init__(self, d_model: int, vocab_size: int):
        super().__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)

    def forward(self, x):
        return self.embedding(x) * math.sqrt(self.d_model)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, seq_len: int, dropout: float):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(seq_len, d_model)
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.shape[1], :].requires_grad_(False)
        return self.dropout(x)

def scaled_dot_product_attention(query, key, value, mask=None, dropout=None):
    d_k = query.shape[-1]
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    
    if mask is not None:
        scores = scores.masked_fill(mask, float('-inf'))
    
    attn = torch.softmax(scores, dim=-1)
    if dropout is not None:
        attn = dropout(attn)
    
    return torch.matmul(attn, value), attn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, dropout: float):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        self.attention_weights = None

    def forward(self, query, key, value, mask=None):
        batch_size = query.shape[0]

        Q = self.W_q(query)
        K = self.W_k(key)
        V = self.W_v(value)

        Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        output, self.attention_weights = scaled_dot_product_attention(Q, K, V, mask, self.dropout)

        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        return self.W_o(output)

class LayerNormalization(nn.Module):
    def __init__(self, features: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(features))
        self.beta = nn.Parameter(torch.zeros(features))

    def forward(self, x: torch.Tensor):
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True, unbiased=False)
        normalized = (x - mean) / torch.sqrt(std ** 2 + self.eps)
        return self.gamma * normalized + self.beta

class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float):
        super().__init__()
        self.linear_1 = nn.Linear(d_model, d_ff)
        self.linear_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.ReLU()

    def forward(self, x):
        return self.linear_2(self.dropout(self.activation(self.linear_1(x))))

class ResidualConnection(nn.Module):
    def __init__(self, features: int, dropout: float):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = LayerNormalization(features)

    def forward(self, x, sublayer):
        return self.norm(x + self.dropout(sublayer(x)))

# ====================== 2. Encoder & Decoder 模块 ======================
class EncoderBlock(nn.Module):
    def __init__(self, features: int, self_attention_block: MultiHeadAttention, 
                 feed_forward_block: PositionwiseFeedForward, dropout: float):
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(2)])

    def forward(self, x, src_mask):
        x = self.residual_connections[0](x, lambda x_res: self.self_attention_block(x_res, x_res, x_res, src_mask))
        x = self.residual_connections[1](x, self.feed_forward_block)
        return x

class Encoder(nn.Module):
    def __init__(self, features: int, layers: nn.ModuleList):
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features)

    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

class DecoderBlock(nn.Module):
    def __init__(self, features: int, self_attention_block: MultiHeadAttention,
                 cross_attention_block: MultiHeadAttention, feed_forward_block: PositionwiseFeedForward, dropout: float):
        super().__init__()
        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(3)])

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        x = self.residual_connections[0](x, lambda x_res: self.self_attention_block(x_res, x_res, x_res, tgt_mask))
        x = self.residual_connections[1](x, lambda x_res: 
            self.cross_attention_block(x_res, encoder_output, encoder_output, src_mask))
        x = self.residual_connections[2](x, self.feed_forward_block)
        return x

class Decoder(nn.Module):
    def __init__(self, features: int, layers: nn.ModuleList):
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features)

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        return self.norm(x)

class ProjectionLayer(nn.Module):
    def __init__(self, d_model: int, vocab_size: int):
        super().__init__()
        self.proj = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        return self.proj(x)

class Transformer(nn.Module):
    def __init__(self, encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.projection_layer = projection_layer

    def encode(self, src, src_mask):
        src = self.src_embed(src)
        src = self.src_pos(src)
        return self.encoder(src, src_mask)

    def decode(self, encoder_output, src_mask, tgt, tgt_mask):
        tgt = self.tgt_embed(tgt)
        tgt = self.tgt_pos(tgt)
        return self.decoder(tgt, encoder_output, src_mask, tgt_mask)

    def project(self, x):
        return self.projection_layer(x)

    def forward(self, src, tgt, src_mask, tgt_mask):
        enc_out = self.encode(src, src_mask)
        dec_out = self.decode(enc_out, src_mask, tgt, tgt_mask)
        return self.project(dec_out)

# ====================== 3. 数据处理模块 ======================
# 特殊标记
PAD_TOKEN = '[PAD]'
SOS_TOKEN = '[SOS]'
EOS_TOKEN = '[EOS]'
# 词汇表:特殊标记 + 数字0-9
VOCAB = [PAD_TOKEN, SOS_TOKEN, EOS_TOKEN] + [str(i) for i in range(10)]
token_to_id = {token: i for i, token in enumerate(VOCAB)}
id_to_token = {i: token for token, i in token_to_id.items()}
VOCAB_SIZE = len(VOCAB)
PAD_ID = token_to_id[PAD_TOKEN]

# 生成复制任务数据集
def generate_copy_task_data(num_examples: int, min_len: int, max_len: int):
    data = []
    content_vocab = [token for token in VOCAB if token not in [PAD_TOKEN, SOS_TOKEN, EOS_TOKEN]]
    for _ in range(num_examples):
        seq_len = random.randint(min_len, max_len)
        sequence = [random.choice(content_vocab) for _ in range(seq_len)]
        src = [SOS_TOKEN] + sequence + [EOS_TOKEN]
        tgt = [SOS_TOKEN] + sequence + [EOS_TOKEN]
        data.append({'src': src, 'tgt': tgt})
    return data

# 序列编码与填充
def tokenize_sequence(sequence, token_to_id_map):
    return [token_to_id_map[token] for token in sequence]

def pad_sequence(sequence_ids, max_len, pad_id):
    padded_ids = sequence_ids + [pad_id] * (max_len - len(sequence_ids))
    return padded_ids[:max_len]

# 掩码生成
def create_src_mask(src_ids, pad_id):
    return (src_ids == pad_id).unsqueeze(1).unsqueeze(2)

def create_tgt_mask(tgt_ids, pad_id):
    batch_size, tgt_seq_len = tgt_ids.shape
    tgt_padding_mask = (tgt_ids == pad_id).unsqueeze(1).unsqueeze(2)
    look_ahead_mask = torch.triu(torch.ones(tgt_seq_len, tgt_seq_len, device=tgt_ids.device), diagonal=1).bool()
    look_ahead_mask = look_ahead_mask.unsqueeze(0).unsqueeze(0)
    return tgt_padding_mask | look_ahead_mask

# 自定义数据集
class CopyTaskDataset(Dataset):
    def __init__(self, data, max_len, pad_id):
        self.data = data
        self.max_len = max_len
        self.pad_id = pad_id

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        src_ids = pad_sequence(tokenize_sequence(item['src'], token_to_id), self.max_len, self.pad_id)
        tgt_ids = tokenize_sequence(item['tgt'], token_to_id)
        
        decoder_input_ids = pad_sequence(tgt_ids[:-1], self.max_len, self.pad_id)
        label_ids = pad_sequence(tgt_ids[1:], self.max_len, self.pad_id)

        return {
            "src_ids": torch.tensor(src_ids, dtype=torch.long),
            "decoder_input_ids": torch.tensor(decoder_input_ids, dtype=torch.long),
            "label_ids": torch.tensor(label_ids, dtype=torch.long)
        }

# ====================== 4. 超参数配置(集中管理) ======================
CONFIG = {
    "num_examples": 1000,    # 数据量
    "min_seq_len": 5,        # 最小序列长度
    "max_seq_len": 10,       # 最大序列长度
    "batch_size": 64,        # 批次大小
    "d_model": 128,          # 模型维度
    "num_layers": 3,         # 编码器/解码器层数
    "num_heads": 4,          # 注意力头数
    "d_ff": 512,             # 前馈网络维度
    "dropout": 0.1,          # Dropout概率
    "lr": 1e-3,              # 学习率
    "epochs": 40,            # 训练轮数
}
MAX_PADDED_LEN = CONFIG["max_seq_len"] + 2  # 含首尾标记的最大长度
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ====================== 5. 模型初始化 ======================
def build_transformer(config):
    # 嵌入层
    src_embed = InputEmbeddings(config["d_model"], VOCAB_SIZE)
    tgt_embed = InputEmbeddings(config["d_model"], VOCAB_SIZE)
    # 位置编码
    src_pos = PositionalEncoding(config["d_model"], MAX_PADDED_LEN, config["dropout"])
    tgt_pos = PositionalEncoding(config["d_model"], MAX_PADDED_LEN, config["dropout"])
    # 注意力与前馈网络
    attention = MultiHeadAttention(config["d_model"], config["num_heads"], config["dropout"])
    ff = PositionwiseFeedForward(config["d_model"], config["d_ff"], config["dropout"])
    # 编码器
    encoder_blocks = nn.ModuleList([
        EncoderBlock(config["d_model"], copy.deepcopy(attention), copy.deepcopy(ff), config["dropout"]) 
        for _ in range(config["num_layers"])
    ])
    encoder = Encoder(config["d_model"], encoder_blocks)
    # 解码器
    decoder_blocks = nn.ModuleList([
        DecoderBlock(config["d_model"], copy.deepcopy(attention), copy.deepcopy(attention), 
                     copy.deepcopy(ff), config["dropout"]) 
        for _ in range(config["num_layers"])
    ])
    decoder = Decoder(config["d_model"], decoder_blocks)
    # 投影层
    projection = ProjectionLayer(config["d_model"], VOCAB_SIZE)
    # 组装模型
    transformer = Transformer(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection)
    # 参数初始化
    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    return transformer

# ====================== 6. 训练函数 ======================
def train_one_epoch(model, dataloader, loss_fn, optimizer, pad_id):
    model.train()
    total_loss = 0
    start_time = time.time()

    for batch_idx, batch in enumerate(dataloader):
        # 数据移至设备
        src_ids = batch['src_ids'].to(device)
        decoder_input_ids = batch['decoder_input_ids'].to(device)
        label_ids = batch['label_ids'].to(device)

        # 生成掩码
        src_mask = create_src_mask(src_ids, pad_id).to(device)
        tgt_mask = create_tgt_mask(decoder_input_ids, pad_id).to(device)

        # 前向传播
        optimizer.zero_grad()
        logits = model(src_ids, decoder_input_ids, src_mask, tgt_mask)
        
        # 计算损失
        loss = loss_fn(logits.reshape(-1, logits.size(-1)), label_ids.reshape(-1))
        # 反向传播
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # 打印进度
        if (batch_idx + 1) % 10 == 0:
            elapsed = time.time() - start_time
            print(f"  批次 {batch_idx+1}/{len(dataloader)} | 损失: {loss.item():.4f} | 耗时: {elapsed:.2f}s")
            start_time = time.time()

    return total_loss / len(dataloader)

# ====================== 7. 推理验证(训练后测试复制效果) ======================
def infer(model, input_sequence, max_len=MAX_PADDED_LEN):
    model.eval()
    # 编码输入序列
    src_ids = pad_sequence(tokenize_sequence(input_sequence, token_to_id), max_len, PAD_ID)
    src_ids = torch.tensor([src_ids], dtype=torch.long).to(device)
    src_mask = create_src_mask(src_ids, PAD_ID).to(device)

    with torch.no_grad():
        enc_out = model.encode(src_ids, src_mask)
        # 初始化解码器输入:以[SOS]开头
        decoder_input = torch.tensor([[token_to_id[SOS_TOKEN]]], dtype=torch.long).to(device)
        
        for _ in range(max_len):
            tgt_mask = create_tgt_mask(decoder_input, PAD_ID).to(device)
            dec_out = model.decode(enc_out, src_mask, decoder_input, tgt_mask)
            pred_token = model.project(dec_out[:, -1, :]).argmax(-1).item()
            
            # 追加预测结果
            decoder_input = torch.cat([decoder_input, torch.tensor([[pred_token]]).to(device)], dim=1)
            # 遇到结束标记停止
            if pred_token == token_to_id[EOS_TOKEN]:
                break

    # 解码输出
    output_ids = decoder_input.squeeze(0).tolist()
    output_tokens = [id_to_token[idx] for idx in output_ids if idx != PAD_ID]
    return output_tokens

# ====================== 8. 主程序:数据加载 + 训练 + 测试 ======================
if __name__ == "__main__":
    # 1. 生成数据集
    print("===== 生成复制任务数据集 =====")
    raw_data = generate_copy_task_data(
        CONFIG["num_examples"],
        CONFIG["min_seq_len"],
        CONFIG["max_seq_len"]
    )
    # 数据加载器
    train_dataset = CopyTaskDataset(raw_data, MAX_PADDED_LEN, PAD_ID)
    train_dataloader = DataLoader(train_dataset, batch_size=CONFIG["batch_size"], shuffle=True)

    # 2. 初始化模型、优化器、损失函数
    model = build_transformer(CONFIG).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG["lr"], betas=(0.9, 0.98), eps=1e-9)
    loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_ID)

    # 3. 开始训练
    print(f"\n===== 开始训练 | 设备: {device} =====")
    for epoch in range(1, CONFIG["epochs"] + 1):
        epoch_start = time.time()
        avg_loss = train_one_epoch(model, train_dataloader, loss_fn, optimizer, PAD_ID)
        duration = time.time() - epoch_start
        print(f"\n【第{epoch}/{CONFIG['epochs']}轮】平均损失: {avg_loss:.4f} | 总耗时: {duration:.2f}s")
        print("-" * 60)

    # 4. 保存模型
    torch.save(model.state_dict(), "transformer_copy_task.pth")
    print("\n模型已保存为: transformer_copy_task.pth")

    # 5. 推理测试
    print("\n===== 推理测试(复制任务) =====")
    # test_seq = [SOS_TOKEN, '1', '2', '3', '4', '5', EOS_TOKEN]
    test_seq = [SOS_TOKEN, '1', '7', '0', '9', '5', EOS_TOKEN]
    print(f"输入序列: {test_seq}")
    pred_seq = infer(model, test_seq)
    print(f"模型输出: {pred_seq}")
Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐