HoRain云--人工智能:注意力机制与Transformer模型实战指南

🎬 HoRain 云小助手:个人主页
⛺️生活的理想,就是为了理想的生活!
⛳️ 推荐
前些天发现了一个超棒的服务器购买网站,性价比超高,大内存超划算!忍不住分享一下给大家。点击跳转到网站。
目录

人工智能:注意力机制与Transformer模型实战指南
注意力机制和Transformer模型是当今人工智能领域的核心技术,从自然语言处理到计算机视觉都发挥着关键作用。本文将为您提供从理论到实践的完整指南。
📊 核心概念对比
|
组件 |
核心功能 |
数学表达 |
实现关键 |
|---|---|---|---|
|
自注意力 |
序列内部关系建模 |
Attention(Q,K,V)=softmax(dkQKT)V |
缩放点积、Softmax归一化 |
|
多头注意力 |
多子空间并行计算 |
MultiHead=Concat(head1,...,headh)WO |
头拆分、独立计算、结果拼接 |
|
位置编码 |
注入序列位置信息 |
PE(pos,2i)=sin(pos/100002i/d) |
正弦余弦函数、可学习/固定 |
|
前馈网络 |
位置独立非线性变换 |
FFN(x)=max(0,xW1+b1)W2+b2 |
两层全连接、ReLU激活 |
|
残差连接 |
缓解梯度消失 |
LayerNorm(x+Sublayer(x)) |
层归一化、跳跃连接 |
🔍 注意力机制深度解析
1. 自注意力机制原理
自注意力机制允许模型在处理序列时,为每个位置计算一个加权和,该加权和是序列中所有位置表示的组合。其核心计算过程如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class ScaledDotProductAttention(nn.Module):
"""缩放点积注意力"""
def __init__(self, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, mask=None):
# query, key, value形状: (batch_size, seq_len, d_k)
d_k = query.size(-1)
# 计算注意力分数
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
# 应用掩码(解码器用)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# Softmax归一化
p_attn = F.softmax(scores, dim=-1)
p_attn = self.dropout(p_attn)
# 加权求和
return torch.matmul(p_attn, value), p_attn
2. 多头注意力机制实现
多头注意力将输入拆分为多个子空间并行计算注意力,最后拼接结果。这种设计能够捕捉不同层次的语义信息:
class MultiHeadAttention(nn.Module):
"""多头注意力机制"""
def __init__(self, d_model=512, num_heads=8, dropout=0.1):
super().__init__()
assert d_model % num_heads == 0, "d_model必须能被num_heads整除"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# 线性变换层
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.attention = ScaledDotProductAttention(dropout)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(d_model)
def split_heads(self, x):
"""将输入拆分为多头"""
batch_size, seq_len, d_model = x.size()
return x.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
def combine_heads(self, x):
"""合并多头输出"""
batch_size, _, seq_len, d_k = x.size()
return x.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# 残差连接
residual = query
# 线性变换并分头
query = self.split_heads(self.W_q(query))
key = self.split_heads(self.W_k(key))
value = self.split_heads(self.W_v(value))
# 计算注意力
x, attn_weights = self.attention(query, key, value, mask)
# 合并多头并线性变换
x = self.combine_heads(x)
x = self.W_o(x)
x = self.dropout(x)
# 残差连接和层归一化
x = self.layer_norm(x + residual)
return x, attn_weights
🏗️ Transformer架构完整实现
1. 位置编码模块
由于Transformer缺乏循环结构,需要通过位置编码注入序列的顺序信息。常用正弦和余弦函数生成位置编码:
class PositionalEncoding(nn.Module):
"""位置编码"""
def __init__(self, d_model, max_len=5000, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
# 创建位置编码矩阵
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term) # 偶数维度
pe[:, 1::2] = torch.cos(position * div_term) # 奇数维度
pe = pe.unsqueeze(0) # (1, max_len, d_model)
self.register_buffer('pe', pe)
def forward(self, x):
# x形状: (batch_size, seq_len, d_model)
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
2. 前馈网络模块
class PositionwiseFeedForward(nn.Module):
"""位置前馈网络"""
def __init__(self, d_model, d_ff=2048, dropout=0.1):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(d_model)
self.activation = nn.GELU() # 或使用ReLU
def forward(self, x):
residual = x
x = self.linear1(x)
x = self.activation(x)
x = self.dropout(x)
x = self.linear2(x)
x = self.dropout(x)
x = self.layer_norm(x + residual)
return x
3. 编码器层实现
class EncoderLayer(nn.Module):
"""Transformer编码器层"""
def __init__(self, d_model, num_heads, d_ff=2048, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# 多头自注意力子层
attn_output, attn_weights = self.self_attn(x, x, x, mask)
# 前馈网络子层
output = self.feed_forward(attn_output)
return output, attn_weights
4. 解码器层实现
class DecoderLayer(nn.Module):
"""Transformer解码器层"""
def __init__(self, d_model, num_heads, d_ff=2048, dropout=0.1):
super().__init__()
# 掩码多头自注意力
self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
# 编码器-解码器注意力
self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
# 掩码自注意力
self_attn_output, self_attn_weights = self.self_attn(
x, x, x, tgt_mask
)
# 编码器-解码器注意力
cross_attn_output, cross_attn_weights = self.cross_attn(
self_attn_output, encoder_output, encoder_output, src_mask
)
# 前馈网络
output = self.feed_forward(cross_attn_output)
return output, self_attn_weights, cross_attn_weights
5. 完整Transformer模型
class Transformer(nn.Module):
"""完整的Transformer模型"""
def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512,
num_heads=8, num_layers=6, d_ff=2048, max_len=5000, dropout=0.1):
super().__init__()
# 词嵌入
self.src_embedding = nn.Embedding(src_vocab_size, d_model)
self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
# 位置编码
self.positional_encoding = PositionalEncoding(d_model, max_len, dropout)
# 编码器
self.encoder_layers = nn.ModuleList([
EncoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
# 解码器
self.decoder_layers = nn.ModuleList([
DecoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
# 输出层
self.output_linear = nn.Linear(d_model, tgt_vocab_size)
self.dropout = nn.Dropout(dropout)
# 初始化参数
self._init_parameters()
def _init_parameters(self):
"""参数初始化"""
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def encode(self, src, src_mask):
"""编码器前向传播"""
# 词嵌入 + 位置编码
x = self.src_embedding(src)
x = self.positional_encoding(x)
encoder_attentions = []
for layer in self.encoder_layers:
x, attn_weights = layer(x, src_mask)
encoder_attentions.append(attn_weights)
return x, encoder_attentions
def decode(self, tgt, encoder_output, src_mask, tgt_mask):
"""解码器前向传播"""
# 词嵌入 + 位置编码
x = self.tgt_embedding(tgt)
x = self.positional_encoding(x)
decoder_self_attentions = []
decoder_cross_attentions = []
for layer in self.decoder_layers:
x, self_attn, cross_attn = layer(
x, encoder_output, src_mask, tgt_mask
)
decoder_self_attentions.append(self_attn)
decoder_cross_attentions.append(cross_attn)
return x, decoder_self_attentions, decoder_cross_attentions
def forward(self, src, tgt, src_mask=None, tgt_mask=None):
# 编码
encoder_output, encoder_attentions = self.encode(src, src_mask)
# 解码
decoder_output, decoder_self_attns, decoder_cross_attns = self.decode(
tgt, encoder_output, src_mask, tgt_mask
)
# 输出投影
output = self.output_linear(decoder_output)
return output, encoder_attentions, decoder_self_attns, decoder_cross_attns
def generate_mask(self, src, tgt, pad_idx=0):
"""生成掩码"""
# 源序列填充掩码
src_mask = (src != pad_idx).unsqueeze(1).unsqueeze(2)
# 目标序列填充掩码
tgt_pad_mask = (tgt != pad_idx).unsqueeze(1).unsqueeze(3)
# 目标序列未来掩码(防止看到未来信息)
seq_len = tgt.size(1)
tgt_sub_mask = torch.tril(torch.ones(
(seq_len, seq_len), device=tgt.device
)).bool()
tgt_mask = tgt_pad_mask & tgt_sub_mask.unsqueeze(0).unsqueeze(0)
return src_mask, tgt_mask
🚀 实战应用:机器翻译示例
1. 数据准备与预处理
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import sentencepiece as spm
class TranslationDataset(Dataset):
"""机器翻译数据集"""
def __init__(self, src_file, tgt_file, src_spm, tgt_spm, max_len=100):
self.src_sentences = self.load_file(src_file)
self.tgt_sentences = self.load_file(tgt_file)
self.src_spm = spm.SentencePieceProcessor(model_file=src_spm)
self.tgt_spm = spm.SentencePieceProcessor(model_file=tgt_spm)
self.max_len = max_len
self.pad_id = 0
self.bos_id = 1
self.eos_id = 2
def load_file(self, filepath):
with open(filepath, 'r', encoding='utf-8') as f:
return [line.strip() for line in f]
def __len__(self):
return len(self.src_sentences)
def __getitem__(self, idx):
src_tokens = self.src_spm.encode(self.src_sentences[idx])
tgt_tokens = self.tgt_spm.encode(self.tgt_sentences[idx])
# 添加特殊标记
src_tokens = [self.bos_id] + src_tokens[:self.max_len-2] + [self.eos_id]
tgt_tokens = [self.bos_id] + tgt_tokens[:self.max_len-2] + [self.eos_id]
return {
'src': torch.tensor(src_tokens, dtype=torch.long),
'tgt': torch.tensor(tgt_tokens, dtype=torch.long)
}
def collate_fn(batch):
"""批次处理函数"""
src_batch = [item['src'] for item in batch]
tgt_batch = [item['tgt'] for item in batch]
# 填充序列
src_padded = pad_sequence(src_batch, batch_first=True, padding_value=0)
tgt_padded = pad_sequence(tgt_batch, batch_first=True, padding_value=0)
return {
'src': src_padded,
'tgt': tgt_padded
}
2. 训练循环实现
class TransformerTrainer:
"""Transformer训练器"""
def __init__(self, model, device, learning_rate=0.0001, label_smoothing=0.1):
self.model = model.to(device)
self.device = device
self.criterion = nn.CrossEntropyLoss(
ignore_index=0, label_smoothing=label_smoothing
)
self.optimizer = torch.optim.Adam(
model.parameters(), lr=learning_rate, betas=(0.9, 0.98), eps=1e-9
)
self.scheduler = torch.optim.lr_scheduler.StepLR(
self.optimizer, step_size=1, gamma=0.95
)
def train_epoch(self, dataloader, epoch):
self.model.train()
total_loss = 0
total_tokens = 0
for batch_idx, batch in enumerate(dataloader):
src = batch['src'].to(self.device)
tgt = batch['tgt'].to(self.device)
# 准备输入输出
tgt_input = tgt[:, :-1]
tgt_output = tgt[:, 1:].contiguous().view(-1)
# 生成掩码
src_mask, tgt_mask = self.model.generate_mask(src, tgt_input)
# 前向传播
self.optimizer.zero_grad()
output, _, _, _ = self.model(src, tgt_input, src_mask, tgt_mask)
# 计算损失
output = output.view(-1, output.size(-1))
loss = self.criterion(output, tgt_output)
# 反向传播
loss.backward()
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step()
# 统计
total_loss += loss.item() * tgt_output.size(0)
total_tokens += tgt_output.size(0)
if batch_idx % 100 == 0:
print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}')
avg_loss = total_loss / total_tokens
return avg_loss
def evaluate(self, dataloader):
self.model.eval()
total_loss = 0
total_tokens = 0
with torch.no_grad():
for batch in dataloader:
src = batch['src'].to(self.device)
tgt = batch['tgt'].to(self.device)
tgt_input = tgt[:, :-1]
tgt_output = tgt[:, 1:].contiguous().view(-1)
src_mask, tgt_mask = self.model.generate_mask(src, tgt_input)
output, _, _, _ = self.model(src, tgt_input, src_mask, tgt_mask)
output = output.view(-1, output.size(-1))
loss = self.criterion(output, tgt_output)
total_loss += loss.item() * tgt_output.size(0)
total_tokens += tgt_output.size(0)
return total_loss / total_tokens
def train(self, train_loader, val_loader, num_epochs=10):
best_val_loss = float('inf')
for epoch in range(1, num_epochs + 1):
print(f'\nEpoch {epoch}/{num_epochs}')
print('-' * 50)
# 训练
train_loss = self.train_epoch(train_loader, epoch)
print(f'Train Loss: {train_loss:.4f}')
# 验证
val_loss = self.evaluate(val_loader)
print(f'Val Loss: {val_loss:.4f}')
# 学习率调整
self.scheduler.step()
# 保存最佳模型
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save(self.model.state_dict(), 'best_transformer.pth')
print(f'Best model saved with val loss: {val_loss:.4f}')
3. 推理与解码
class TransformerTranslator:
"""Transformer翻译器"""
def __init__(self, model, src_spm, tgt_spm, device, max_len=100):
self.model = model.to(device)
self.model.eval()
self.src_spm = src_spm
self.tgt_spm = tgt_spm
self.device = device
self.max_len = max_len
self.bos_id = 1
self.eos_id = 2
self.pad_id = 0
def greedy_decode(self, src_sentence):
"""贪心解码"""
# 编码源句子
src_tokens = self.src_spm.encode(src_sentence)
src_tokens = [self.bos_id] + src_tokens + [self.eos_id]
src = torch.tensor([src_tokens], device=self.device)
# 生成掩码
src_mask = (src != self.pad_id).unsqueeze(1).unsqueeze(2)
# 编码器前向传播
with torch.no_grad():
encoder_output, _ = self.model.encode(src, src_mask)
# 初始化目标序列
tgt_tokens = [self.bos_id]
for _ in range(self.max_len):
tgt = torch.tensor([tgt_tokens], device=self.device)
tgt_mask = self.model.generate_mask(src, tgt)[1]
# 解码器前向传播
with torch.no_grad():
output, _, _, _ = self.model.decode(
tgt, encoder_output, src_mask, tgt_mask
)
output = self.model.output_linear(output)
# 获取下一个token
next_token = output[0, -1].argmax().item()
tgt_tokens.append(next_token)
if next_token == self.eos_id:
break
# 解码为目标语言
translation = self.tgt_spm.decode(tgt_tokens[1:-1]) # 去掉BOS和EOS
return translation
def beam_search_decode(self, src_sentence, beam_size=5):
"""束搜索解码"""
src_tokens = self.src_spm.encode(src_sentence)
src_tokens = [self.bos_id] + src_tokens + [self.eos_id]
src = torch.tensor([src_tokens], device=self.device)
src_mask = (src != self.pad_id).unsqueeze(1).unsqueeze(2)
with torch.no_grad():
encoder_output, _ = self.model.encode(src, src_mask)
# 初始化束
beams = [([self.bos_id], 0.0)] # (tokens, score)
completed = []
for step in range(self.max_len):
new_beams = []
for tokens, score in beams:
if tokens[-1] == self.eos_id:
completed.append((tokens, score))
continue
tgt = torch.tensor([tokens], device=self.device)
tgt_mask = self.model.generate_mask(src, tgt)[1]
with torch.no_grad():
output, _, _, _ = self.model.decode(
tgt, encoder_output, src_mask, tgt_mask
)
logits = self.model.output_linear(output)[0, -1]
probs = torch.softmax(logits, dim=-1)
# 获取top-k候选
topk_probs, topk_indices = torch.topk(probs, beam_size)
for i in range(beam_size):
new_tokens = tokens + [topk_indices[i].item()]
new_score = score + torch.log(topk_probs[i]).item()
new_beams.append((new_tokens, new_score))
# 选择top beam_size个序列
new_beams.sort(key=lambda x: x[1] / len(x[0]), reverse=True)
beams = new_beams[:beam_size]
if all(tokens[-1] == self.eos_id for tokens, _ in beams):
break
# 合并已完成序列
all_candidates = beams + completed
all_candidates.sort(key=lambda x: x[1] / len(x[0]), reverse=True)
best_tokens = all_candidates[0][0]
translation = self.tgt_spm.decode(best_tokens[1:-1])
return translation
🖼️ Vision Transformer实战
1. ViT基础实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
class PatchEmbedding(nn.Module):
"""图像分块嵌入"""
def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2
self.projection = nn.Conv2d(
in_channels, embed_dim,
kernel_size=patch_size, stride=patch_size
)
self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
self.position_embedding = nn.Parameter(
torch.randn(1, self.num_patches + 1, embed_dim)
)
def forward(self, x):
# x形状: (B, C, H, W)
B = x.shape[0]
# 分块投影
x = self.projection(x) # (B, embed_dim, H/patch, W/patch)
x = x.flatten(2).transpose(1, 2) # (B, num_patches, embed_dim)
# 添加CLS token
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=B)
x = torch.cat([cls_tokens, x], dim=1) # (B, num_patches+1, embed_dim)
# 添加位置编码
x = x + self.position_embedding
return x
class VisionTransformer(nn.Module):
"""Vision Transformer"""
def __init__(self, img_size=224, patch_size=16, in_channels=3,
num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4.0, dropout=0.1):
super().__init__()
self.patch_embed = PatchEmbedding(
img_size, patch_size, in_channels, embed_dim
)
# Transformer编码器层
self.encoder_layers = nn.ModuleList([
EncoderLayer(embed_dim, num_heads, embed_dim * mlp_ratio, dropout)
for _ in range(depth)
])
self.layer_norm = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, num_classes)
self.dropout = nn.Dropout(dropout)
# 初始化
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.zeros_(m.bias)
nn.init.ones_(m.weight)
def forward(self, x):
# 分块嵌入
x = self.patch_embed(x)
# Transformer编码
for layer in self.encoder_layers:
x, _ = layer(x)
# 取CLS token
x = self.layer_norm(x)
cls_token = x[:, 0]
# 分类头
output = self.head(self.dropout(cls_token))
return output
2. Swin Transformer实现
class WindowAttention(nn.Module):
"""窗口注意力"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
# 相对位置偏置表
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
)
# 生成相对位置索引
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
coords_flatten = torch.flatten(coords, 1)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords[:, :, 0] += window_size[0] - 1
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = relative_coords.sum(-1)
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)
def forward(self, x, mask=None):
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
# 添加相对位置偏置
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)
].view(
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
-1
)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = F.softmax(attn, dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class SwinTransformerBlock(nn.Module):
"""Swin Transformer块"""
def __init__(self, dim, input_resolution, num_heads, window_size=7,
shift_size=0, mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
# 窗口注意力
self.attn = WindowAttention(
dim, window_size=(self.window_size, self.window_size),
num_heads=num_heads, qkv_bias=qkv_bias,
attn_drop=attn_drop, proj_drop=drop
)
# 前馈网络
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_hidden_dim),
nn.GELU(),
nn.Dropout(drop),
nn.Linear(mlp_hidden_dim, dim),
nn.Dropout(drop)
)
# 层归一化
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
# 移位窗口掩码
if self.shift_size > 0:
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1))
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size)
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# 移位窗口
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# 窗口划分
x_windows = window_partition(shifted_x, self.window_size)
x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
# 窗口注意力
attn_windows = self.attn(x_windows, mask=self.attn_mask)
# 窗口合并
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W)
# 反向移位
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
# 残差连接
x = shortcut + x
# 前馈网络
x = x + self.mlp(self.norm2(x))
return x
⚡ 性能优化技巧
1. 混合精度训练
from torch.cuda.amp import autocast, GradScaler
class AMPTrainer:
"""自动混合精度训练"""
def __init__(self, model, optimizer, device):
self.model = model
self.optimizer = optimizer
self.device = device
self.scaler = GradScaler()
def train_step(self, src, tgt):
src_mask, tgt_mask = self.model.generate_mask(src, tgt[:, :-1])
with autocast():
output, _, _, _ = self.model(
src, tgt[:, :-1], src_mask, tgt_mask
)
loss = self.criterion(
output.view(-1, output.size(-1)),
tgt[:, 1:].contiguous().view(-1)
)
self.optimizer.zero_grad()
self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
return loss.item()
2. 梯度检查点
from torch.utils.checkpoint import checkpoint
class CheckpointTransformer(nn.Module):
"""使用梯度检查点的Transformer"""
def __init__(self, num_layers=12, use_checkpoint=False):
super().__init__()
self.use_checkpoint = use_checkpoint
self.layers = nn.ModuleList([
EncoderLayer(d_model=512, num_heads=8)
for _ in range(num_layers)
])
def forward(self, x, mask=None):
for layer in self.layers:
if self.use_checkpoint and self.training:
x = checkpoint(layer, x, mask)
else:
x, _ = layer(x, mask)
return x
3. 分布式训练
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def setup_distributed():
"""设置分布式训练"""
dist.init_process_group(backend='nccl')
local_rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local_rank)
return local_rank
def create_ddp_model(model, local_rank):
"""创建DDP模型"""
model = model.cuda(local_rank)
model = DDP(model, device_ids=[local_rank])
return model
📈 实战项目:文本分类任务
1. 基于Transformer的文本分类器
class TransformerClassifier(nn.Module):
"""基于Transformer的文本分类器"""
def __init__(self, vocab_size, num_classes, d_model=512,
num_heads=8, num_layers=6, max_len=512, dropout=0.1):
super().__init__()
# 词嵌入
self.embedding = nn.Embedding(vocab_size, d_model)
self.position_encoding = PositionalEncoding(d_model, max_len, dropout)
# Transformer编码器
self.encoder_layers = nn.ModuleList([
EncoderLayer(d_model, num_heads, d_model*4, dropout)
for _ in range(num_layers)
])
# 分类头
self.layer_norm = nn.LayerNorm(d_model)
self.fc = nn.Linear(d_model, num_classes)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# 嵌入层
x = self.embedding(x)
x = self.position_encoding(x)
# Transformer编码
for layer in self.encoder_layers:
x, _ = layer(x, mask)
# 池化(取CLS token或平均池化)
x = self.layer_norm(x)
x = x.mean(dim=1) # 平均池化
# 分类
x = self.dropout(x)
output = self.fc(x)
return output
# 训练示例
def train_text_classifier():
# 数据准备
train_dataset = TextClassificationDataset(train_texts, train_labels)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
# 模型初始化
model = TransformerClassifier(
vocab_size=50000,
num_classes=10,
d_model=256,
num_heads=8,
num_layers=4
)
# 训练配置
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
# 训练循环
for epoch in range(10):
model.train()
total_loss = 0
correct = 0
total = 0
for batch in train_loader:
texts, labels = batch
mask = (texts != 0).unsqueeze(1).unsqueeze(2)
optimizer.zero_grad()
outputs = model(texts, mask)
loss = criterion(outputs, labels)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
scheduler.step()
accuracy = 100. * correct / total
print(f'Epoch {epoch}: Loss={total_loss/len(train_loader):.4f}, Acc={accuracy:.2f}%')
🎯 最佳实践总结
1. 模型选择指南
|
任务类型 |
推荐模型 |
关键配置 |
适用场景 |
|---|---|---|---|
|
机器翻译 |
标准Transformer |
6层编码/解码器,512维度,8头 |
中英、英法等语言对 |
|
文本分类 |
Transformer编码器 |
4-6层,256-512维度,多头注意力 |
情感分析、主题分类 |
|
图像分类 |
Vision Transformer |
12-24层,768-1024维度,16×16分块 |
ImageNet级别数据集 |
|
目标检测 |
Swin Transformer |
分层结构,移位窗口注意力 |
COCO、VOC数据集 |
|
序列生成 |
GPT风格解码器 |
仅解码器,因果注意力掩码 |
文本生成、代码补全 |
2. 超参数调优策略
# 超参数搜索配置
hyperparameter_configs = {
'small': {
'd_model': 256,
'num_heads': 4,
'num_layers': 4,
'd_ff': 1024,
'dropout':
❤️❤️❤️本人水平有限,如有纰漏,欢迎各位大佬评论批评指正!😄😄😄
💘💘💘如果觉得这篇文对你有帮助的话,也请给个点赞、收藏下吧,非常感谢!👍 👍 👍
🔥🔥🔥Stay Hungry Stay Foolish 道阻且长,行则将至,让我们一起加油吧!🌙🌙🌙
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐




所有评论(0)