🎬 HoRain 云小助手个人主页

⛺️生活的理想,就是为了理想的生活!


⛳️ 推荐

前些天发现了一个超棒的服务器购买网站,性价比超高,大内存超划算!忍不住分享一下给大家。点击跳转到网站。

目录

⛳️ 推荐

人工智能:注意力机制与Transformer模型实战指南

📊 核心概念对比

🔍 注意力机制深度解析

1. 自注意力机制原理

2. 多头注意力机制实现

🏗️ Transformer架构完整实现

1. 位置编码模块

2. 前馈网络模块

3. 编码器层实现

4. 解码器层实现

5. 完整Transformer模型

🚀 实战应用:机器翻译示例

1. 数据准备与预处理

2. 训练循环实现

3. 推理与解码

🖼️ Vision Transformer实战

1. ViT基础实现

2. Swin Transformer实现

⚡ 性能优化技巧

1. 混合精度训练

2. 梯度检查点

3. 分布式训练

📈 实战项目:文本分类任务

1. 基于Transformer的文本分类器

🎯 最佳实践总结

1. 模型选择指南

2. 超参数调优策略


人工智能:注意力机制与Transformer模型实战指南

注意力机制和Transformer模型是当今人工智能领域的核心技术,从自然语言处理到计算机视觉都发挥着关键作用。本文将为您提供从理论到实践的完整指南。

📊 核心概念对比

组件

核心功能

数学表达

实现关键

自注意力

序列内部关系建模

Attention(Q,K,V)=softmax(dk​​QKT​)V

缩放点积、Softmax归一化

多头注意力

多子空间并行计算

MultiHead=Concat(head1​,...,headh​)WO

头拆分、独立计算、结果拼接

位置编码

注入序列位置信息

PE(pos,2i)​=sin(pos/100002i/d)

正弦余弦函数、可学习/固定

前馈网络

位置独立非线性变换

FFN(x)=max(0,xW1​+b1​)W2​+b2​

两层全连接、ReLU激活

残差连接

缓解梯度消失

LayerNorm(x+Sublayer(x))

层归一化、跳跃连接

🔍 注意力机制深度解析

1. 自注意力机制原理

自注意力机制允许模型在处理序列时,为每个位置计算一个加权和,该加权和是序列中所有位置表示的组合。其核心计算过程如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class ScaledDotProductAttention(nn.Module):
    """缩放点积注意力"""
    def __init__(self, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, query, key, value, mask=None):
        # query, key, value形状: (batch_size, seq_len, d_k)
        d_k = query.size(-1)
        
        # 计算注意力分数
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
        
        # 应用掩码(解码器用)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        # Softmax归一化
        p_attn = F.softmax(scores, dim=-1)
        p_attn = self.dropout(p_attn)
        
        # 加权求和
        return torch.matmul(p_attn, value), p_attn

2. 多头注意力机制实现

多头注意力将输入拆分为多个子空间并行计算注意力,最后拼接结果。这种设计能够捕捉不同层次的语义信息:

class MultiHeadAttention(nn.Module):
    """多头注意力机制"""
    def __init__(self, d_model=512, num_heads=8, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model必须能被num_heads整除"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # 线性变换层
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
        self.attention = ScaledDotProductAttention(dropout)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(d_model)
    
    def split_heads(self, x):
        """将输入拆分为多头"""
        batch_size, seq_len, d_model = x.size()
        return x.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
    
    def combine_heads(self, x):
        """合并多头输出"""
        batch_size, _, seq_len, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
    
    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        
        # 残差连接
        residual = query
        
        # 线性变换并分头
        query = self.split_heads(self.W_q(query))
        key = self.split_heads(self.W_k(key))
        value = self.split_heads(self.W_v(value))
        
        # 计算注意力
        x, attn_weights = self.attention(query, key, value, mask)
        
        # 合并多头并线性变换
        x = self.combine_heads(x)
        x = self.W_o(x)
        x = self.dropout(x)
        
        # 残差连接和层归一化
        x = self.layer_norm(x + residual)
        
        return x, attn_weights

🏗️ Transformer架构完整实现

1. 位置编码模块

由于Transformer缺乏循环结构,需要通过位置编码注入序列的顺序信息。常用正弦和余弦函数生成位置编码:

class PositionalEncoding(nn.Module):
    """位置编码"""
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        # 创建位置编码矩阵
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                           (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)  # 偶数维度
        pe[:, 1::2] = torch.cos(position * div_term)  # 奇数维度
        
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        # x形状: (batch_size, seq_len, d_model)
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

2. 前馈网络模块

class PositionwiseFeedForward(nn.Module):
    """位置前馈网络"""
    def __init__(self, d_model, d_ff=2048, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(d_model)
        self.activation = nn.GELU()  # 或使用ReLU
    
    def forward(self, x):
        residual = x
        x = self.linear1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.linear2(x)
        x = self.dropout(x)
        x = self.layer_norm(x + residual)
        return x

3. 编码器层实现

class EncoderLayer(nn.Module):
    """Transformer编码器层"""
    def __init__(self, d_model, num_heads, d_ff=2048, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        # 多头自注意力子层
        attn_output, attn_weights = self.self_attn(x, x, x, mask)
        
        # 前馈网络子层
        output = self.feed_forward(attn_output)
        
        return output, attn_weights

4. 解码器层实现

class DecoderLayer(nn.Module):
    """Transformer解码器层"""
    def __init__(self, d_model, num_heads, d_ff=2048, dropout=0.1):
        super().__init__()
        # 掩码多头自注意力
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        # 编码器-解码器注意力
        self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
    
    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        # 掩码自注意力
        self_attn_output, self_attn_weights = self.self_attn(
            x, x, x, tgt_mask
        )
        
        # 编码器-解码器注意力
        cross_attn_output, cross_attn_weights = self.cross_attn(
            self_attn_output, encoder_output, encoder_output, src_mask
        )
        
        # 前馈网络
        output = self.feed_forward(cross_attn_output)
        
        return output, self_attn_weights, cross_attn_weights

5. 完整Transformer模型

class Transformer(nn.Module):
    """完整的Transformer模型"""
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, 
                 num_heads=8, num_layers=6, d_ff=2048, max_len=5000, dropout=0.1):
        super().__init__()
        
        # 词嵌入
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        
        # 位置编码
        self.positional_encoding = PositionalEncoding(d_model, max_len, dropout)
        
        # 编码器
        self.encoder_layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout) 
            for _ in range(num_layers)
        ])
        
        # 解码器
        self.decoder_layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout) 
            for _ in range(num_layers)
        ])
        
        # 输出层
        self.output_linear = nn.Linear(d_model, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)
        
        # 初始化参数
        self._init_parameters()
    
    def _init_parameters(self):
        """参数初始化"""
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    
    def encode(self, src, src_mask):
        """编码器前向传播"""
        # 词嵌入 + 位置编码
        x = self.src_embedding(src)
        x = self.positional_encoding(x)
        
        encoder_attentions = []
        for layer in self.encoder_layers:
            x, attn_weights = layer(x, src_mask)
            encoder_attentions.append(attn_weights)
        
        return x, encoder_attentions
    
    def decode(self, tgt, encoder_output, src_mask, tgt_mask):
        """解码器前向传播"""
        # 词嵌入 + 位置编码
        x = self.tgt_embedding(tgt)
        x = self.positional_encoding(x)
        
        decoder_self_attentions = []
        decoder_cross_attentions = []
        
        for layer in self.decoder_layers:
            x, self_attn, cross_attn = layer(
                x, encoder_output, src_mask, tgt_mask
            )
            decoder_self_attentions.append(self_attn)
            decoder_cross_attentions.append(cross_attn)
        
        return x, decoder_self_attentions, decoder_cross_attentions
    
    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        # 编码
        encoder_output, encoder_attentions = self.encode(src, src_mask)
        
        # 解码
        decoder_output, decoder_self_attns, decoder_cross_attns = self.decode(
            tgt, encoder_output, src_mask, tgt_mask
        )
        
        # 输出投影
        output = self.output_linear(decoder_output)
        
        return output, encoder_attentions, decoder_self_attns, decoder_cross_attns
    
    def generate_mask(self, src, tgt, pad_idx=0):
        """生成掩码"""
        # 源序列填充掩码
        src_mask = (src != pad_idx).unsqueeze(1).unsqueeze(2)
        
        # 目标序列填充掩码
        tgt_pad_mask = (tgt != pad_idx).unsqueeze(1).unsqueeze(3)
        
        # 目标序列未来掩码(防止看到未来信息)
        seq_len = tgt.size(1)
        tgt_sub_mask = torch.tril(torch.ones(
            (seq_len, seq_len), device=tgt.device
        )).bool()
        
        tgt_mask = tgt_pad_mask & tgt_sub_mask.unsqueeze(0).unsqueeze(0)
        
        return src_mask, tgt_mask

🚀 实战应用:机器翻译示例

1. 数据准备与预处理

import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import sentencepiece as spm

class TranslationDataset(Dataset):
    """机器翻译数据集"""
    def __init__(self, src_file, tgt_file, src_spm, tgt_spm, max_len=100):
        self.src_sentences = self.load_file(src_file)
        self.tgt_sentences = self.load_file(tgt_file)
        self.src_spm = spm.SentencePieceProcessor(model_file=src_spm)
        self.tgt_spm = spm.SentencePieceProcessor(model_file=tgt_spm)
        self.max_len = max_len
        self.pad_id = 0
        self.bos_id = 1
        self.eos_id = 2
    
    def load_file(self, filepath):
        with open(filepath, 'r', encoding='utf-8') as f:
            return [line.strip() for line in f]
    
    def __len__(self):
        return len(self.src_sentences)
    
    def __getitem__(self, idx):
        src_tokens = self.src_spm.encode(self.src_sentences[idx])
        tgt_tokens = self.tgt_spm.encode(self.tgt_sentences[idx])
        
        # 添加特殊标记
        src_tokens = [self.bos_id] + src_tokens[:self.max_len-2] + [self.eos_id]
        tgt_tokens = [self.bos_id] + tgt_tokens[:self.max_len-2] + [self.eos_id]
        
        return {
            'src': torch.tensor(src_tokens, dtype=torch.long),
            'tgt': torch.tensor(tgt_tokens, dtype=torch.long)
        }

def collate_fn(batch):
    """批次处理函数"""
    src_batch = [item['src'] for item in batch]
    tgt_batch = [item['tgt'] for item in batch]
    
    # 填充序列
    src_padded = pad_sequence(src_batch, batch_first=True, padding_value=0)
    tgt_padded = pad_sequence(tgt_batch, batch_first=True, padding_value=0)
    
    return {
        'src': src_padded,
        'tgt': tgt_padded
    }

2. 训练循环实现

class TransformerTrainer:
    """Transformer训练器"""
    def __init__(self, model, device, learning_rate=0.0001, label_smoothing=0.1):
        self.model = model.to(device)
        self.device = device
        self.criterion = nn.CrossEntropyLoss(
            ignore_index=0, label_smoothing=label_smoothing
        )
        self.optimizer = torch.optim.Adam(
            model.parameters(), lr=learning_rate, betas=(0.9, 0.98), eps=1e-9
        )
        self.scheduler = torch.optim.lr_scheduler.StepLR(
            self.optimizer, step_size=1, gamma=0.95
        )
        
    def train_epoch(self, dataloader, epoch):
        self.model.train()
        total_loss = 0
        total_tokens = 0
        
        for batch_idx, batch in enumerate(dataloader):
            src = batch['src'].to(self.device)
            tgt = batch['tgt'].to(self.device)
            
            # 准备输入输出
            tgt_input = tgt[:, :-1]
            tgt_output = tgt[:, 1:].contiguous().view(-1)
            
            # 生成掩码
            src_mask, tgt_mask = self.model.generate_mask(src, tgt_input)
            
            # 前向传播
            self.optimizer.zero_grad()
            output, _, _, _ = self.model(src, tgt_input, src_mask, tgt_mask)
            
            # 计算损失
            output = output.view(-1, output.size(-1))
            loss = self.criterion(output, tgt_output)
            
            # 反向传播
            loss.backward()
            
            # 梯度裁剪
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            
            self.optimizer.step()
            
            # 统计
            total_loss += loss.item() * tgt_output.size(0)
            total_tokens += tgt_output.size(0)
            
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}')
        
        avg_loss = total_loss / total_tokens
        return avg_loss
    
    def evaluate(self, dataloader):
        self.model.eval()
        total_loss = 0
        total_tokens = 0
        
        with torch.no_grad():
            for batch in dataloader:
                src = batch['src'].to(self.device)
                tgt = batch['tgt'].to(self.device)
                
                tgt_input = tgt[:, :-1]
                tgt_output = tgt[:, 1:].contiguous().view(-1)
                
                src_mask, tgt_mask = self.model.generate_mask(src, tgt_input)
                
                output, _, _, _ = self.model(src, tgt_input, src_mask, tgt_mask)
                output = output.view(-1, output.size(-1))
                
                loss = self.criterion(output, tgt_output)
                total_loss += loss.item() * tgt_output.size(0)
                total_tokens += tgt_output.size(0)
        
        return total_loss / total_tokens
    
    def train(self, train_loader, val_loader, num_epochs=10):
        best_val_loss = float('inf')
        
        for epoch in range(1, num_epochs + 1):
            print(f'\nEpoch {epoch}/{num_epochs}')
            print('-' * 50)
            
            # 训练
            train_loss = self.train_epoch(train_loader, epoch)
            print(f'Train Loss: {train_loss:.4f}')
            
            # 验证
            val_loss = self.evaluate(val_loader)
            print(f'Val Loss: {val_loss:.4f}')
            
            # 学习率调整
            self.scheduler.step()
            
            # 保存最佳模型
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(self.model.state_dict(), 'best_transformer.pth')
                print(f'Best model saved with val loss: {val_loss:.4f}')

3. 推理与解码

class TransformerTranslator:
    """Transformer翻译器"""
    def __init__(self, model, src_spm, tgt_spm, device, max_len=100):
        self.model = model.to(device)
        self.model.eval()
        self.src_spm = src_spm
        self.tgt_spm = tgt_spm
        self.device = device
        self.max_len = max_len
        self.bos_id = 1
        self.eos_id = 2
        self.pad_id = 0
    
    def greedy_decode(self, src_sentence):
        """贪心解码"""
        # 编码源句子
        src_tokens = self.src_spm.encode(src_sentence)
        src_tokens = [self.bos_id] + src_tokens + [self.eos_id]
        src = torch.tensor([src_tokens], device=self.device)
        
        # 生成掩码
        src_mask = (src != self.pad_id).unsqueeze(1).unsqueeze(2)
        
        # 编码器前向传播
        with torch.no_grad():
            encoder_output, _ = self.model.encode(src, src_mask)
        
        # 初始化目标序列
        tgt_tokens = [self.bos_id]
        
        for _ in range(self.max_len):
            tgt = torch.tensor([tgt_tokens], device=self.device)
            tgt_mask = self.model.generate_mask(src, tgt)[1]
            
            # 解码器前向传播
            with torch.no_grad():
                output, _, _, _ = self.model.decode(
                    tgt, encoder_output, src_mask, tgt_mask
                )
                output = self.model.output_linear(output)
            
            # 获取下一个token
            next_token = output[0, -1].argmax().item()
            tgt_tokens.append(next_token)
            
            if next_token == self.eos_id:
                break
        
        # 解码为目标语言
        translation = self.tgt_spm.decode(tgt_tokens[1:-1])  # 去掉BOS和EOS
        return translation
    
    def beam_search_decode(self, src_sentence, beam_size=5):
        """束搜索解码"""
        src_tokens = self.src_spm.encode(src_sentence)
        src_tokens = [self.bos_id] + src_tokens + [self.eos_id]
        src = torch.tensor([src_tokens], device=self.device)
        src_mask = (src != self.pad_id).unsqueeze(1).unsqueeze(2)
        
        with torch.no_grad():
            encoder_output, _ = self.model.encode(src, src_mask)
        
        # 初始化束
        beams = [([self.bos_id], 0.0)]  # (tokens, score)
        completed = []
        
        for step in range(self.max_len):
            new_beams = []
            
            for tokens, score in beams:
                if tokens[-1] == self.eos_id:
                    completed.append((tokens, score))
                    continue
                
                tgt = torch.tensor([tokens], device=self.device)
                tgt_mask = self.model.generate_mask(src, tgt)[1]
                
                with torch.no_grad():
                    output, _, _, _ = self.model.decode(
                        tgt, encoder_output, src_mask, tgt_mask
                    )
                    logits = self.model.output_linear(output)[0, -1]
                    probs = torch.softmax(logits, dim=-1)
                
                # 获取top-k候选
                topk_probs, topk_indices = torch.topk(probs, beam_size)
                
                for i in range(beam_size):
                    new_tokens = tokens + [topk_indices[i].item()]
                    new_score = score + torch.log(topk_probs[i]).item()
                    new_beams.append((new_tokens, new_score))
            
            # 选择top beam_size个序列
            new_beams.sort(key=lambda x: x[1] / len(x[0]), reverse=True)
            beams = new_beams[:beam_size]
            
            if all(tokens[-1] == self.eos_id for tokens, _ in beams):
                break
        
        # 合并已完成序列
        all_candidates = beams + completed
        all_candidates.sort(key=lambda x: x[1] / len(x[0]), reverse=True)
        
        best_tokens = all_candidates[0][0]
        translation = self.tgt_spm.decode(best_tokens[1:-1])
        return translation

🖼️ Vision Transformer实战

1. ViT基础实现

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat

class PatchEmbedding(nn.Module):
    """图像分块嵌入"""
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        
        self.projection = nn.Conv2d(
            in_channels, embed_dim, 
            kernel_size=patch_size, stride=patch_size
        )
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.position_embedding = nn.Parameter(
            torch.randn(1, self.num_patches + 1, embed_dim)
        )
    
    def forward(self, x):
        # x形状: (B, C, H, W)
        B = x.shape[0]
        
        # 分块投影
        x = self.projection(x)  # (B, embed_dim, H/patch, W/patch)
        x = x.flatten(2).transpose(1, 2)  # (B, num_patches, embed_dim)
        
        # 添加CLS token
        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=B)
        x = torch.cat([cls_tokens, x], dim=1)  # (B, num_patches+1, embed_dim)
        
        # 添加位置编码
        x = x + self.position_embedding
        
        return x

class VisionTransformer(nn.Module):
    """Vision Transformer"""
    def __init__(self, img_size=224, patch_size=16, in_channels=3, 
                 num_classes=1000, embed_dim=768, depth=12, 
                 num_heads=12, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        
        self.patch_embed = PatchEmbedding(
            img_size, patch_size, in_channels, embed_dim
        )
        
        # Transformer编码器层
        self.encoder_layers = nn.ModuleList([
            EncoderLayer(embed_dim, num_heads, embed_dim * mlp_ratio, dropout)
            for _ in range(depth)
        ])
        
        self.layer_norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)
        self.dropout = nn.Dropout(dropout)
        
        # 初始化
        self.apply(self._init_weights)
    
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.zeros_(m.bias)
            nn.init.ones_(m.weight)
    
    def forward(self, x):
        # 分块嵌入
        x = self.patch_embed(x)
        
        # Transformer编码
        for layer in self.encoder_layers:
            x, _ = layer(x)
        
        # 取CLS token
        x = self.layer_norm(x)
        cls_token = x[:, 0]
        
        # 分类头
        output = self.head(self.dropout(cls_token))
        
        return output

2. Swin Transformer实现

class WindowAttention(nn.Module):
    """窗口注意力"""
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        
        # 相对位置偏置表
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
        )
        
        # 生成相对位置索引
        coords_h = torch.arange(window_size[0])
        coords_w = torch.arange(window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
        coords_flatten = torch.flatten(coords, 1)
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
        relative_coords[:, :, 0] += window_size[0] - 1
        relative_coords[:, :, 1] += window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)
        self.register_buffer("relative_position_index", relative_position_index)
        
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        
        nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)
    
    def forward(self, x, mask=None):
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))
        
        # 添加相对位置偏置
        relative_position_bias = self.relative_position_bias_table[
            self.relative_position_index.view(-1)
        ].view(
            self.window_size[0] * self.window_size[1],
            self.window_size[0] * self.window_size[1],
            -1
        )
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
        attn = attn + relative_position_bias.unsqueeze(0)
        
        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
        
        attn = F.softmax(attn, dim=-1)
        attn = self.attn_drop(attn)
        
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class SwinTransformerBlock(nn.Module):
    """Swin Transformer块"""
    def __init__(self, dim, input_resolution, num_heads, window_size=7, 
                 shift_size=0, mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        
        # 窗口注意力
        self.attn = WindowAttention(
            dim, window_size=(self.window_size, self.window_size),
            num_heads=num_heads, qkv_bias=qkv_bias,
            attn_drop=attn_drop, proj_drop=drop
        )
        
        # 前馈网络
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(mlp_hidden_dim, dim),
            nn.Dropout(drop)
        )
        
        # 层归一化
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        
        # 移位窗口掩码
        if self.shift_size > 0:
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))
            h_slices = (slice(0, -self.window_size),
                       slice(-self.window_size, -self.shift_size),
                       slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                       slice(-self.window_size, -self.shift_size),
                       slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1
            
            mask_windows = window_partition(img_mask, self.window_size)
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        else:
            attn_mask = None
        
        self.register_buffer("attn_mask", attn_mask)
    
    def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        
        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)
        
        # 移位窗口
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x
        
        # 窗口划分
        x_windows = window_partition(shifted_x, self.window_size)
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
        
        # 窗口注意力
        attn_windows = self.attn(x_windows, mask=self.attn_mask)
        
        # 窗口合并
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)
        
        # 反向移位
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
        
        x = x.view(B, H * W, C)
        
        # 残差连接
        x = shortcut + x
        
        # 前馈网络
        x = x + self.mlp(self.norm2(x))
        
        return x

⚡ 性能优化技巧

1. 混合精度训练

from torch.cuda.amp import autocast, GradScaler

class AMPTrainer:
    """自动混合精度训练"""
    def __init__(self, model, optimizer, device):
        self.model = model
        self.optimizer = optimizer
        self.device = device
        self.scaler = GradScaler()
    
    def train_step(self, src, tgt):
        src_mask, tgt_mask = self.model.generate_mask(src, tgt[:, :-1])
        
        with autocast():
            output, _, _, _ = self.model(
                src, tgt[:, :-1], src_mask, tgt_mask
            )
            loss = self.criterion(
                output.view(-1, output.size(-1)),
                tgt[:, 1:].contiguous().view(-1)
            )
        
        self.optimizer.zero_grad()
        self.scaler.scale(loss).backward()
        self.scaler.step(self.optimizer)
        self.scaler.update()
        
        return loss.item()

2. 梯度检查点

from torch.utils.checkpoint import checkpoint

class CheckpointTransformer(nn.Module):
    """使用梯度检查点的Transformer"""
    def __init__(self, num_layers=12, use_checkpoint=False):
        super().__init__()
        self.use_checkpoint = use_checkpoint
        self.layers = nn.ModuleList([
            EncoderLayer(d_model=512, num_heads=8) 
            for _ in range(num_layers)
        ])
    
    def forward(self, x, mask=None):
        for layer in self.layers:
            if self.use_checkpoint and self.training:
                x = checkpoint(layer, x, mask)
            else:
                x, _ = layer(x, mask)
        return x

3. 分布式训练

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def setup_distributed():
    """设置分布式训练"""
    dist.init_process_group(backend='nccl')
    local_rank = int(os.environ['LOCAL_RANK'])
    torch.cuda.set_device(local_rank)
    return local_rank

def create_ddp_model(model, local_rank):
    """创建DDP模型"""
    model = model.cuda(local_rank)
    model = DDP(model, device_ids=[local_rank])
    return model

📈 实战项目:文本分类任务

1. 基于Transformer的文本分类器

class TransformerClassifier(nn.Module):
    """基于Transformer的文本分类器"""
    def __init__(self, vocab_size, num_classes, d_model=512, 
                 num_heads=8, num_layers=6, max_len=512, dropout=0.1):
        super().__init__()
        
        # 词嵌入
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.position_encoding = PositionalEncoding(d_model, max_len, dropout)
        
        # Transformer编码器
        self.encoder_layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_model*4, dropout)
            for _ in range(num_layers)
        ])
        
        # 分类头
        self.layer_norm = nn.LayerNorm(d_model)
        self.fc = nn.Linear(d_model, num_classes)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        # 嵌入层
        x = self.embedding(x)
        x = self.position_encoding(x)
        
        # Transformer编码
        for layer in self.encoder_layers:
            x, _ = layer(x, mask)
        
        # 池化(取CLS token或平均池化)
        x = self.layer_norm(x)
        x = x.mean(dim=1)  # 平均池化
        
        # 分类
        x = self.dropout(x)
        output = self.fc(x)
        
        return output

# 训练示例
def train_text_classifier():
    # 数据准备
    train_dataset = TextClassificationDataset(train_texts, train_labels)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    
    # 模型初始化
    model = TransformerClassifier(
        vocab_size=50000,
        num_classes=10,
        d_model=256,
        num_heads=8,
        num_layers=4
    )
    
    # 训练配置
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
    
    # 训练循环
    for epoch in range(10):
        model.train()
        total_loss = 0
        correct = 0
        total = 0
        
        for batch in train_loader:
            texts, labels = batch
            mask = (texts != 0).unsqueeze(1).unsqueeze(2)
            
            optimizer.zero_grad()
            outputs = model(texts, mask)
            loss = criterion(outputs, labels)
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
        
        scheduler.step()
        accuracy = 100. * correct / total
        print(f'Epoch {epoch}: Loss={total_loss/len(train_loader):.4f}, Acc={accuracy:.2f}%')

🎯 最佳实践总结

1. 模型选择指南

任务类型

推荐模型

关键配置

适用场景

机器翻译

标准Transformer

6层编码/解码器,512维度,8头

中英、英法等语言对

文本分类

Transformer编码器

4-6层,256-512维度,多头注意力

情感分析、主题分类

图像分类

Vision Transformer

12-24层,768-1024维度,16×16分块

ImageNet级别数据集

目标检测

Swin Transformer

分层结构,移位窗口注意力

COCO、VOC数据集

序列生成

GPT风格解码器

仅解码器,因果注意力掩码

文本生成、代码补全

2. 超参数调优策略

# 超参数搜索配置
hyperparameter_configs = {
    'small': {
        'd_model': 256,
        'num_heads': 4,
        'num_layers': 4,
        'd_ff': 1024,
        'dropout':

❤️❤️❤️本人水平有限,如有纰漏,欢迎各位大佬评论批评指正!😄😄😄

💘💘💘如果觉得这篇文对你有帮助的话,也请给个点赞、收藏下吧,非常感谢!👍 👍 👍

🔥🔥🔥Stay Hungry Stay Foolish 道阻且长,行则将至,让我们一起加油吧!🌙🌙🌙

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐