从零实现Transformer:第 8 部分 - 数据集构建与训练全流程(Dataset Construction and Full Training Workflow)
从零实现Transformer:第 8 部分 - 数据集构建与训练全流程(Dataset Construction and Full Training Workflow)
flyfish
文末有完整的源码
这里从数据生成、编码、掩码构造,到训练循环、损失计算、梯度更新,吃透Transformer的工程实现细节。
一个简单的任务 复制任务
序列复制任务的目标很简单:让模型原样输出输入的数字序列(如输入[SOS]17095[EOS],输出[SOS]17095[EOS])。
- 无需复杂语料,随机数字即可构建数据集;
- 验证Transformer的编码器-解码器架构、注意力机制;
- 屏蔽业务逻辑,专注数据处理、训练流程的代码。
一、数据处理模块:从原始序列到模型输入
1.1 词汇表与特殊标记定义
首先定义序列任务必备的特殊标记和词汇表,这是文本/序列转数字的基础:
# 特殊标记
PAD_TOKEN = '[PAD]' # 填充标记:统一序列长度
SOS_TOKEN = '[SOS]' # 序列开始标记:解码器起始输入
EOS_TOKEN = '[EOS]' # 序列结束标记:终止生成
# 词汇表:特殊标记 + 数字0-9
VOCAB = [PAD_TOKEN, SOS_TOKEN, EOS_TOKEN] + [str(i) for i in range(10)]
# token与id双向映射
token_to_id = {token: i for i, token in enumerate(VOCAB)}
id_to_token = {i: token for token, i in token_to_id.items()}
VOCAB_SIZE = len(VOCAB)
PAD_ID = token_to_id[PAD_TOKEN]
作用:
特殊标记是序列任务的标配:PAD解决批次长度不一致问题,SOS/EOS控制解码器的生成起止;
token-id映射:模型无法处理文本,必须将字符转为数字索引;
本任务词汇表仅13个token(3个特殊标记+10个数字),极简且易训练。
1.2 复制任务数据集生成
手动生成随机数字序列,构造「输入=输出」的复制任务数据:
def generate_copy_task_data(num_examples: int, min_len: int, max_len: int):
data = []
# 过滤特殊标记,仅用数字生成序列
content_vocab = [token for token in VOCAB if token not in [PAD_TOKEN, SOS_TOKEN, EOS_TOKEN]]
for _ in range(num_examples):
# 随机生成序列长度
seq_len = random.randint(min_len, max_len)
# 随机选择数字组成序列
sequence = [random.choice(content_vocab) for _ in range(seq_len)]
# 输入/输出序列:首尾拼接SOS、EOS
src = [SOS_TOKEN] + sequence + [EOS_TOKEN]
tgt = [SOS_TOKEN] + sequence + [EOS_TOKEN]
data.append({'src': src, 'tgt': tgt})
return data
逻辑:
- 生成5~10位随机数字序列;
- 输入序列
src和目标序列tgt完全一致(复制任务); - 所有序列首尾拼接
SOS和EOS,符合Transformer的序列格式要求。
1.3 序列编码与填充
模型接收的输入是数字索引,且批次内序列必须长度统一,因此需要编码+填充:
# 序列编码:token转id
def tokenize_sequence(sequence, token_to_id_map):
return [token_to_id_map[token] for token in sequence]
# 序列填充:统一长度
def pad_sequence(sequence_ids, max_len, pad_id):
padded_ids = sequence_ids + [pad_id] * (max_len - len(sequence_ids))
return padded_ids[:max_len]
知识点:
编码:将['SOS','1','7','0','9','5','EOS']转为[1,4,8,11,10,6,2];
填充:用PAD标记将所有序列补全到MAX_PADDED_LEN(最大序列长度+2),保证批次训练时张量形状一致。
1.4 掩码(Mask)生成
掩码分为源序列掩码和目标序列掩码,作用是屏蔽无效信息:
# 源序列掩码:屏蔽填充的PAD标记
def create_src_mask(src_ids, pad_id):
return (src_ids == pad_id).unsqueeze(1).unsqueeze(2)
# 目标序列掩码:屏蔽PAD + 屏蔽未来token(前瞻掩码)
def create_tgt_mask(tgt_ids, pad_id):
batch_size, tgt_seq_len = tgt_ids.shape
# 1. 填充掩码:屏蔽PAD
tgt_padding_mask = (tgt_ids == pad_id).unsqueeze(1).unsqueeze(2)
# 2. 前瞻掩码:上三角矩阵,屏蔽未来位置
look_ahead_mask = torch.triu(torch.ones(tgt_seq_len, tgt_seq_len), diagonal=1).bool()
look_ahead_mask = look_ahead_mask.unsqueeze(0).unsqueeze(0)
# 合并两种掩码
return tgt_padding_mask | look_ahead_mask
掩码的作用:
- 源序列掩码(src_mask):
告诉编码器哪些位置是PAD填充,无需计算注意力,避免无效计算。 - 目标序列掩码(tgt_mask):
填充掩码:同src_mask,屏蔽PAD;
前瞻掩码(Look-ahead Mask):用torch.triu生成上三角布尔矩阵,强制解码器只能看到当前及之前的token,杜绝偷看未来信息。
1.5 自定义Dataset与DataLoader
PyTorch中用Dataset+DataLoader实现批量数据加载,这是训练的标准范式:
class CopyTaskDataset(Dataset):
def __init__(self, data, max_len, pad_id):
self.data = data
self.max_len = max_len
self.pad_id = pad_id
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
# 源序列编码+填充
src_ids = pad_sequence(tokenize_sequence(item['src'], token_to_id), self.max_len, self.pad_id)
tgt_ids = tokenize_sequence(item['tgt'], token_to_id)
# 解码器:输入与标签偏移
decoder_input_ids = pad_sequence(tgt_ids[:-1], self.max_len, self.pad_id)
label_ids = pad_sequence(tgt_ids[1:], self.max_len, self.pad_id)
return {
"src_ids": torch.tensor(src_ids, dtype=torch.long),
"decoder_input_ids": torch.tensor(decoder_input_ids, dtype=torch.long),
"label_ids": torch.tensor(label_ids, dtype=torch.long)
}
设计:解码器输入与标签的偏移
解码器输入:目标序列去掉最后一个token(tgt_ids[:-1]);
标签:目标序列去掉第一个token(tgt_ids[1:])。
原理:模型根据前n个token预测第n+1个token。
例如:目标序列[SOS,1,7,0,EOS]
解码器输入:[SOS,1,7,0]
标签:[1,7,0,EOS]
最后用DataLoader实现批量、打乱加载:
train_dataset = CopyTaskDataset(raw_data, MAX_PADDED_LEN, PAD_ID)
train_dataloader = DataLoader(train_dataset, batch_size=CONFIG["batch_size"], shuffle=True)
二、模型训练模块
数据准备完毕后,训练模块负责模型的前向传播、损失计算、梯度更新。
2.1 超参数集中配置
工程化最佳实践:将所有超参数集中管理,方便调试:
CONFIG = {
"num_examples": 1000, # 数据集大小
"min_seq_len": 5, # 最小序列长度
"max_seq_len": 10, # 最大序列长度
"batch_size": 64, # 批次大小
"d_model": 128, # 模型特征维度
"num_layers": 3, # 编码器/解码器层数
"num_heads": 4, # 注意力头数
"d_ff": 512, # 前馈网络维度
"dropout": 0.1, # dropout概率
"lr": 1e-3, # 学习率
"epochs": 40, # 训练轮数
}
超参数说明:d_model:Transformer的特征维度,注意力头数必须能整除它;num_heads:多头注意力的头数;
2.2 单轮训练函数:train_one_epoch
这是训练的核心逻辑,封装了前向传播、损失计算、反向传播、参数更新:
def train_one_epoch(model, dataloader, loss_fn, optimizer, pad_id):
# 切换为训练模式:启用Dropout、BatchNorm
model.train()
total_loss = 0
start_time = time.time()
for batch_idx, batch in enumerate(dataloader):
# 1. 数据移至GPU/CPU
src_ids = batch['src_ids'].to(device)
decoder_input_ids = batch['decoder_input_ids'].to(device)
label_ids = batch['label_ids'].to(device)
# 2. 生成掩码
src_mask = create_src_mask(src_ids, pad_id).to(device)
tgt_mask = create_tgt_mask(decoder_input_ids, pad_id).to(device)
# 3. 前向传播
optimizer.zero_grad() # 清空梯度
logits = model(src_ids, decoder_input_ids, src_mask, tgt_mask)
# 4. 损失计算:展平张量,忽略PAD标记
loss = loss_fn(logits.reshape(-1, logits.size(-1)), label_ids.reshape(-1))
# 5. 反向传播 + 参数更新
loss.backward() # 计算梯度
optimizer.step() # 更新权重
total_loss += loss.item()
# 打印训练进度
if (batch_idx + 1) % 10 == 0:
elapsed = time.time() - start_time
print(f" 批次 {batch_idx+1}/{len(dataloader)} | 损失: {loss.item():.4f} | 耗时: {elapsed:.2f}s")
start_time = time.time()
return total_loss / len(dataloader)
训练逻辑:
- 模式切换:
model.train()启用Dropout等训练层,与推理模式区分; - 设备迁移:将张量移至GPU(加速训练),无GPU则自动用CPU;
- 梯度清空:
optimizer.zero_grad()避免梯度累加; - 前向传播:输入源序列、解码器输入、掩码,输出预测概率分布(logits);
- 损失计算:
用CrossEntropyLoss;ignore_index=PAD_ID:忽略填充标记的损失,避免PAD干扰训练;
展平张量:将[batch, seq_len, vocab_size]转为[batch*seq_len, vocab_size],适配损失函数输入; - 反向传播:
loss.backward()计算所有参数的梯度; - 参数更新:
optimizer.step()根据梯度更新模型权重。
2.3 主训练循环 + 模型保存
在主程序中循环调用单轮训练,完成模型收敛并保存权重:
if __name__ == "__main__":
# 1. 生成数据集+加载数据
raw_data = generate_copy_task_data(...)
train_dataloader = DataLoader(...)
# 2. 初始化模型、优化器、损失函数
model = build_transformer(CONFIG).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG["lr"], betas=(0.9, 0.98), eps=1e-9)
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_ID)
# 3. 多轮训练
for epoch in range(1, CONFIG["epochs"] + 1):
epoch_start = time.time()
avg_loss = train_one_epoch(...)
print(f"【第{epoch}轮】平均损失: {avg_loss:.4f}")
# 4. 保存模型权重
torch.save(model.state_dict(), "transformer_copy_task.pth")
优化器选择:
使用Adam优化器(Transformer官方标配),搭配betas=(0.9, 0.98)和eps=1e-9,稳定训练过程。
2.4 推理验证:测试复制效果
训练完成后,用infer函数做贪心解码,验证模型是否能完美复制序列:
def infer(model, input_sequence, max_len=MAX_PADDED_LEN):
model.eval() # 切换推理模式,关闭Dropout
# 编码输入 → 编码器输出 → 解码器逐词生成
with torch.no_grad(): # 关闭梯度计算,加速推理
# 贪心解码:每次预测概率最大的token,直到生成EOS
...
测试效果:
输入序列: ['[SOS]', '1', '7', '0', '9', '5', '[EOS]']
模型输出: ['[SOS]', '1', '7', '0', '9', '5', '[EOS]']
训练流程
┌───────────────────────────────────────────────────────────┐
│ Epoch loop │
│ ┌───────────────────────────────────────────────────┐ │
│ │ Mini-batch loop │ │
│ │ 1️ zero gradients → model.zero_grad() │ │
│ │ 2️ forward pass → outputs = model(inputs)│ │
│ │ 3️ compute loss → loss = criterion(outputs, labels) │ │
│ │ 4️ compute gradients → loss.backward() │ │
│ │ 5️ update weights → optimizer.step() │ │
│ └───────────────────────────────────────────────────┘ │
└───────────────────────────────────────────────────────────┘
完整代码
import torch
import torch.nn as nn
import math
import random
import copy
import time
from torch.utils.data import Dataset, DataLoader
# ====================== 1. Transformer 基础组件 ======================
class InputEmbeddings(nn.Module):
def __init__(self, d_model: int, vocab_size: int):
super().__init__()
self.d_model = d_model
self.embedding = nn.Embedding(vocab_size, d_model)
def forward(self, x):
return self.embedding(x) * math.sqrt(self.d_model)
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, seq_len: int, dropout: float):
super().__init__()
self.dropout = nn.Dropout(dropout)
pe = torch.zeros(seq_len, d_model)
position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.shape[1], :].requires_grad_(False)
return self.dropout(x)
def scaled_dot_product_attention(query, key, value, mask=None, dropout=None):
d_k = query.shape[-1]
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask, float('-inf'))
attn = torch.softmax(scores, dim=-1)
if dropout is not None:
attn = dropout(attn)
return torch.matmul(attn, value), attn
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, num_heads: int, dropout: float):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
self.attention_weights = None
def forward(self, query, key, value, mask=None):
batch_size = query.shape[0]
Q = self.W_q(query)
K = self.W_k(key)
V = self.W_v(value)
Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
output, self.attention_weights = scaled_dot_product_attention(Q, K, V, mask, self.dropout)
output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
return self.W_o(output)
class LayerNormalization(nn.Module):
def __init__(self, features: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.gamma = nn.Parameter(torch.ones(features))
self.beta = nn.Parameter(torch.zeros(features))
def forward(self, x: torch.Tensor):
mean = x.mean(dim=-1, keepdim=True)
std = x.std(dim=-1, keepdim=True, unbiased=False)
normalized = (x - mean) / torch.sqrt(std ** 2 + self.eps)
return self.gamma * normalized + self.beta
class PositionwiseFeedForward(nn.Module):
def __init__(self, d_model: int, d_ff: int, dropout: float):
super().__init__()
self.linear_1 = nn.Linear(d_model, d_ff)
self.linear_2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
self.activation = nn.ReLU()
def forward(self, x):
return self.linear_2(self.dropout(self.activation(self.linear_1(x))))
class ResidualConnection(nn.Module):
def __init__(self, features: int, dropout: float):
super().__init__()
self.dropout = nn.Dropout(dropout)
self.norm = LayerNormalization(features)
def forward(self, x, sublayer):
return self.norm(x + self.dropout(sublayer(x)))
# ====================== 2. Encoder & Decoder 模块 ======================
class EncoderBlock(nn.Module):
def __init__(self, features: int, self_attention_block: MultiHeadAttention,
feed_forward_block: PositionwiseFeedForward, dropout: float):
super().__init__()
self.self_attention_block = self_attention_block
self.feed_forward_block = feed_forward_block
self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(2)])
def forward(self, x, src_mask):
x = self.residual_connections[0](x, lambda x_res: self.self_attention_block(x_res, x_res, x_res, src_mask))
x = self.residual_connections[1](x, self.feed_forward_block)
return x
class Encoder(nn.Module):
def __init__(self, features: int, layers: nn.ModuleList):
super().__init__()
self.layers = layers
self.norm = LayerNormalization(features)
def forward(self, x, mask):
for layer in self.layers:
x = layer(x, mask)
return self.norm(x)
class DecoderBlock(nn.Module):
def __init__(self, features: int, self_attention_block: MultiHeadAttention,
cross_attention_block: MultiHeadAttention, feed_forward_block: PositionwiseFeedForward, dropout: float):
super().__init__()
self.self_attention_block = self_attention_block
self.cross_attention_block = cross_attention_block
self.feed_forward_block = feed_forward_block
self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(3)])
def forward(self, x, encoder_output, src_mask, tgt_mask):
x = self.residual_connections[0](x, lambda x_res: self.self_attention_block(x_res, x_res, x_res, tgt_mask))
x = self.residual_connections[1](x, lambda x_res:
self.cross_attention_block(x_res, encoder_output, encoder_output, src_mask))
x = self.residual_connections[2](x, self.feed_forward_block)
return x
class Decoder(nn.Module):
def __init__(self, features: int, layers: nn.ModuleList):
super().__init__()
self.layers = layers
self.norm = LayerNormalization(features)
def forward(self, x, encoder_output, src_mask, tgt_mask):
for layer in self.layers:
x = layer(x, encoder_output, src_mask, tgt_mask)
return self.norm(x)
class ProjectionLayer(nn.Module):
def __init__(self, d_model: int, vocab_size: int):
super().__init__()
self.proj = nn.Linear(d_model, vocab_size)
def forward(self, x):
return self.proj(x)
class Transformer(nn.Module):
def __init__(self, encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.src_embed = src_embed
self.tgt_embed = tgt_embed
self.src_pos = src_pos
self.tgt_pos = tgt_pos
self.projection_layer = projection_layer
def encode(self, src, src_mask):
src = self.src_embed(src)
src = self.src_pos(src)
return self.encoder(src, src_mask)
def decode(self, encoder_output, src_mask, tgt, tgt_mask):
tgt = self.tgt_embed(tgt)
tgt = self.tgt_pos(tgt)
return self.decoder(tgt, encoder_output, src_mask, tgt_mask)
def project(self, x):
return self.projection_layer(x)
def forward(self, src, tgt, src_mask, tgt_mask):
enc_out = self.encode(src, src_mask)
dec_out = self.decode(enc_out, src_mask, tgt, tgt_mask)
return self.project(dec_out)
# ====================== 3. 数据处理模块 ======================
# 特殊标记
PAD_TOKEN = '[PAD]'
SOS_TOKEN = '[SOS]'
EOS_TOKEN = '[EOS]'
# 词汇表:特殊标记 + 数字0-9
VOCAB = [PAD_TOKEN, SOS_TOKEN, EOS_TOKEN] + [str(i) for i in range(10)]
token_to_id = {token: i for i, token in enumerate(VOCAB)}
id_to_token = {i: token for token, i in token_to_id.items()}
VOCAB_SIZE = len(VOCAB)
PAD_ID = token_to_id[PAD_TOKEN]
# 生成复制任务数据集
def generate_copy_task_data(num_examples: int, min_len: int, max_len: int):
data = []
content_vocab = [token for token in VOCAB if token not in [PAD_TOKEN, SOS_TOKEN, EOS_TOKEN]]
for _ in range(num_examples):
seq_len = random.randint(min_len, max_len)
sequence = [random.choice(content_vocab) for _ in range(seq_len)]
src = [SOS_TOKEN] + sequence + [EOS_TOKEN]
tgt = [SOS_TOKEN] + sequence + [EOS_TOKEN]
data.append({'src': src, 'tgt': tgt})
return data
# 序列编码与填充
def tokenize_sequence(sequence, token_to_id_map):
return [token_to_id_map[token] for token in sequence]
def pad_sequence(sequence_ids, max_len, pad_id):
padded_ids = sequence_ids + [pad_id] * (max_len - len(sequence_ids))
return padded_ids[:max_len]
# 掩码生成
def create_src_mask(src_ids, pad_id):
return (src_ids == pad_id).unsqueeze(1).unsqueeze(2)
def create_tgt_mask(tgt_ids, pad_id):
batch_size, tgt_seq_len = tgt_ids.shape
tgt_padding_mask = (tgt_ids == pad_id).unsqueeze(1).unsqueeze(2)
look_ahead_mask = torch.triu(torch.ones(tgt_seq_len, tgt_seq_len, device=tgt_ids.device), diagonal=1).bool()
look_ahead_mask = look_ahead_mask.unsqueeze(0).unsqueeze(0)
return tgt_padding_mask | look_ahead_mask
# 自定义数据集
class CopyTaskDataset(Dataset):
def __init__(self, data, max_len, pad_id):
self.data = data
self.max_len = max_len
self.pad_id = pad_id
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
src_ids = pad_sequence(tokenize_sequence(item['src'], token_to_id), self.max_len, self.pad_id)
tgt_ids = tokenize_sequence(item['tgt'], token_to_id)
decoder_input_ids = pad_sequence(tgt_ids[:-1], self.max_len, self.pad_id)
label_ids = pad_sequence(tgt_ids[1:], self.max_len, self.pad_id)
return {
"src_ids": torch.tensor(src_ids, dtype=torch.long),
"decoder_input_ids": torch.tensor(decoder_input_ids, dtype=torch.long),
"label_ids": torch.tensor(label_ids, dtype=torch.long)
}
# ====================== 4. 超参数配置(集中管理) ======================
CONFIG = {
"num_examples": 1000, # 数据量
"min_seq_len": 5, # 最小序列长度
"max_seq_len": 10, # 最大序列长度
"batch_size": 64, # 批次大小
"d_model": 128, # 模型维度
"num_layers": 3, # 编码器/解码器层数
"num_heads": 4, # 注意力头数
"d_ff": 512, # 前馈网络维度
"dropout": 0.1, # Dropout概率
"lr": 1e-3, # 学习率
"epochs": 40, # 训练轮数
}
MAX_PADDED_LEN = CONFIG["max_seq_len"] + 2 # 含首尾标记的最大长度
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ====================== 5. 模型初始化 ======================
def build_transformer(config):
# 嵌入层
src_embed = InputEmbeddings(config["d_model"], VOCAB_SIZE)
tgt_embed = InputEmbeddings(config["d_model"], VOCAB_SIZE)
# 位置编码
src_pos = PositionalEncoding(config["d_model"], MAX_PADDED_LEN, config["dropout"])
tgt_pos = PositionalEncoding(config["d_model"], MAX_PADDED_LEN, config["dropout"])
# 注意力与前馈网络
attention = MultiHeadAttention(config["d_model"], config["num_heads"], config["dropout"])
ff = PositionwiseFeedForward(config["d_model"], config["d_ff"], config["dropout"])
# 编码器
encoder_blocks = nn.ModuleList([
EncoderBlock(config["d_model"], copy.deepcopy(attention), copy.deepcopy(ff), config["dropout"])
for _ in range(config["num_layers"])
])
encoder = Encoder(config["d_model"], encoder_blocks)
# 解码器
decoder_blocks = nn.ModuleList([
DecoderBlock(config["d_model"], copy.deepcopy(attention), copy.deepcopy(attention),
copy.deepcopy(ff), config["dropout"])
for _ in range(config["num_layers"])
])
decoder = Decoder(config["d_model"], decoder_blocks)
# 投影层
projection = ProjectionLayer(config["d_model"], VOCAB_SIZE)
# 组装模型
transformer = Transformer(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection)
# 参数初始化
for p in transformer.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
return transformer
# ====================== 6. 训练函数 ======================
def train_one_epoch(model, dataloader, loss_fn, optimizer, pad_id):
model.train()
total_loss = 0
start_time = time.time()
for batch_idx, batch in enumerate(dataloader):
# 数据移至设备
src_ids = batch['src_ids'].to(device)
decoder_input_ids = batch['decoder_input_ids'].to(device)
label_ids = batch['label_ids'].to(device)
# 生成掩码
src_mask = create_src_mask(src_ids, pad_id).to(device)
tgt_mask = create_tgt_mask(decoder_input_ids, pad_id).to(device)
# 前向传播
optimizer.zero_grad()
logits = model(src_ids, decoder_input_ids, src_mask, tgt_mask)
# 计算损失
loss = loss_fn(logits.reshape(-1, logits.size(-1)), label_ids.reshape(-1))
# 反向传播
loss.backward()
optimizer.step()
total_loss += loss.item()
# 打印进度
if (batch_idx + 1) % 10 == 0:
elapsed = time.time() - start_time
print(f" 批次 {batch_idx+1}/{len(dataloader)} | 损失: {loss.item():.4f} | 耗时: {elapsed:.2f}s")
start_time = time.time()
return total_loss / len(dataloader)
# ====================== 7. 推理验证(训练后测试复制效果) ======================
def infer(model, input_sequence, max_len=MAX_PADDED_LEN):
model.eval()
# 编码输入序列
src_ids = pad_sequence(tokenize_sequence(input_sequence, token_to_id), max_len, PAD_ID)
src_ids = torch.tensor([src_ids], dtype=torch.long).to(device)
src_mask = create_src_mask(src_ids, PAD_ID).to(device)
with torch.no_grad():
enc_out = model.encode(src_ids, src_mask)
# 初始化解码器输入:以[SOS]开头
decoder_input = torch.tensor([[token_to_id[SOS_TOKEN]]], dtype=torch.long).to(device)
for _ in range(max_len):
tgt_mask = create_tgt_mask(decoder_input, PAD_ID).to(device)
dec_out = model.decode(enc_out, src_mask, decoder_input, tgt_mask)
pred_token = model.project(dec_out[:, -1, :]).argmax(-1).item()
# 追加预测结果
decoder_input = torch.cat([decoder_input, torch.tensor([[pred_token]]).to(device)], dim=1)
# 遇到结束标记停止
if pred_token == token_to_id[EOS_TOKEN]:
break
# 解码输出
output_ids = decoder_input.squeeze(0).tolist()
output_tokens = [id_to_token[idx] for idx in output_ids if idx != PAD_ID]
return output_tokens
# ====================== 8. 主程序:数据加载 + 训练 + 测试 ======================
if __name__ == "__main__":
# 1. 生成数据集
print("===== 生成复制任务数据集 =====")
raw_data = generate_copy_task_data(
CONFIG["num_examples"],
CONFIG["min_seq_len"],
CONFIG["max_seq_len"]
)
# 数据加载器
train_dataset = CopyTaskDataset(raw_data, MAX_PADDED_LEN, PAD_ID)
train_dataloader = DataLoader(train_dataset, batch_size=CONFIG["batch_size"], shuffle=True)
# 2. 初始化模型、优化器、损失函数
model = build_transformer(CONFIG).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG["lr"], betas=(0.9, 0.98), eps=1e-9)
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_ID)
# 3. 开始训练
print(f"\n===== 开始训练 | 设备: {device} =====")
for epoch in range(1, CONFIG["epochs"] + 1):
epoch_start = time.time()
avg_loss = train_one_epoch(model, train_dataloader, loss_fn, optimizer, PAD_ID)
duration = time.time() - epoch_start
print(f"\n【第{epoch}/{CONFIG['epochs']}轮】平均损失: {avg_loss:.4f} | 总耗时: {duration:.2f}s")
print("-" * 60)
# 4. 保存模型
torch.save(model.state_dict(), "transformer_copy_task.pth")
print("\n模型已保存为: transformer_copy_task.pth")
# 5. 推理测试
print("\n===== 推理测试(复制任务) =====")
# test_seq = [SOS_TOKEN, '1', '2', '3', '4', '5', EOS_TOKEN]
test_seq = [SOS_TOKEN, '1', '7', '0', '9', '5', EOS_TOKEN]
print(f"输入序列: {test_seq}")
pred_seq = infer(model, test_seq)
print(f"模型输出: {pred_seq}")
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐


所有评论(0)