目录

11.1 分布式词表示与上下文编码

11.1.1 Word2Vec的Skip-gram负采样与层次Softmax的树结构近似,GloVe的共现矩阵分解与最小二乘优化

11.1.2 上下文嵌入:ELMo的双向语言模型与字符级CNN,BERT的Masked Language Model(MLM)与Next Sentence Prediction(NSP)的预训练目标

11.1.3 GPT系列的自回归预训练(CLM)与因果掩码(Causal Masking),Span Corruption与T5的Encoder-Decoder统一框架

11.2 参数高效微调与提示工程

11.2.1 LoRA(低秩适配)的秩分解与缩放因子α,Adapter层的瓶颈架构与残差连接,Prefix Tuning的虚拟token优化与重参数化

11.2.2 提示微调(Prompt Tuning)的软提示(Soft Prompts)与硬提示(Hard Prompts)离散优化,P-Tuning v2的多层提示注入

11.2.3 指令微调(Instruction Tuning)的FLAN格式与思维链(Chain-of-Thought)数据构建,RLHF的PPO算法与奖励模型训练细节

11.3 规模定律与推理优化

11.3.1 Scaling Laws(Kaplan & Chinchilla):损失与计算量(FLOPs)、参数量、数据量的幂律关系,最优模型大小与训练token数的分配

11.3.2 模型压缩:量化感知训练(QAT)的LLM.int8()与QLoRA的4-bit NormalFloat,知识蒸馏的MiniLLM与温度参数调优

11.3.3 推测解码(Speculative Decoding)的小模型草稿与大模型验证机制,KV-Cache的内存管理与PagedAttention的块表调度


基于CMU 10-715 2025年秋季语言模型专题课程结构,本章系统阐述预训练语言模型的基础理论、参数高效微调方法及推理优化技术。以下内容严格遵循学术综述写作规范,所有原理阐述均基于国际顶级会议与期刊发表的经典研究成果。


11.1 分布式词表示与上下文编码

11.1.1 Word2Vec的Skip-gram负采样与层次Softmax的树结构近似,GloVe的共现矩阵分解与最小二乘优化

神经网络语言模型的分布式假设认为词汇语义可通过密集向量空间中的几何关系刻画。Mikolov等人提出的Skip-gram架构通过最大化中心词与上下文词的共现概率学习词嵌入,负采样技术通过区分真实上下文与噪声分布中的负样本大幅降低计算复杂度,将多分类问题转化为二分类逻辑回归任务。层次Softmax则采用Huffman树结构编码词汇表,将归一化计算复杂度从线性降至对数级别,高频词对应较短路径,低频词分配较长编码路径。

Pennington等人提出的GloVe模型融合全局矩阵分解与局部上下文窗口优势,构建词-词共现矩阵并利用最小二乘优化目标,使词向量内积逼近共现概率的对数比。该方法兼顾统计全局语料频次信息与局部语义关联,在类比推理任务中展现出优越的线性子结构特性。

Python

"""
Script: word2vec_glove_implementation.py
Content: Implementation of Skip-gram with Negative Sampling, Hierarchical Softmax approximation,
         and GloVe matrix factorization with least squares optimization.
Usage: python word2vec_glove_implementation.py --corpus_path data.txt --method skipgram --epochs 5
"""

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import defaultdict, Counter
import random
import heapq
from typing import List, Tuple, Dict
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import argparse


class HuffmanNode:
    """Huffman tree node for Hierarchical Softmax approximation."""
    def __init__(self, word_id=None, frequency=0):
        self.word_id = word_id
        self.frequency = frequency
        self.left = None
        self.right = None
        self.code = []
        self.path = []
    
    def __lt__(self, other):
        return self.frequency < other.frequency


class Word2VecModel(nn.Module):
    """
    Skip-gram Word2Vec with support for Negative Sampling and Hierarchical Softmax.
    Architecture follows Mikolov et al. (2013) with vector dimensionality typically
    ranging between 100-300 dimensions.
    """
    def __init__(self, vocab_size: int, embedding_dim: int = 200, 
                 use_hierarchical_softmax: bool = False, 
                 use_negative_sampling: bool = True,
                 negative_samples: int = 5):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.use_hierarchical_softmax = use_hierarchical_softmax
        self.use_negative_sampling = use_negative_sampling
        self.negative_samples = negative_samples
        
        # Center word embeddings (input)
        self.in_embeddings = nn.Embedding(vocab_size, embedding_dim)
        # Context word embeddings (output)
        self.out_embeddings = nn.Embedding(vocab_size, embedding_dim)
        
        # Xavier initialization for stable training
        nn.init.xavier_uniform_(self.in_embeddings.weight)
        nn.init.xavier_uniform_(self.out_embeddings.weight)
        
        # Build Huffman tree for hierarchical softmax
        if use_hierarchical_softmax:
            self.huffman_tree = None
            self.word_codes = {}
            self.word_paths = {}
    
    def build_huffman_tree(self, word_freqs: Dict[int, int]):
        """Construct Huffman tree based on word frequencies for efficient probability computation."""
        heap = [HuffmanNode(word_id=wid, frequency=freq) 
                for wid, freq in word_freqs.items()]
        heapq.heapify(heap)
        
        while len(heap) > 1:
            left = heapq.heappop(heap)
            right = heapq.heappop(heap)
            merged = HuffmanNode(frequency=left.frequency + right.frequency)
            merged.left = left
            merged.right = right
            heapq.heappush(heap, merged)
        
        self.huffman_tree = heap[0]
        self._assign_codes(self.huffman_tree, [], [])
    
    def _assign_codes(self, node: HuffmanNode, code: List[int], path: List[int]):
        """Recursively assign binary codes to leaf nodes. Left=0, Right=1."""
        if node.word_id is not None:
            self.word_codes[node.word_id] = code.copy()
            self.word_paths[node.word_id] = path.copy()
            return
        
        if node.left:
            code.append(0)
            path.append(node.left.word_id if node.left.word_id else len(path))
            self._assign_codes(node.left, code, path)
            code.pop()
            path.pop()
        
        if node.right:
            code.append(1)
            path.append(node.right.word_id if node.right.word_id else len(path))
            self._assign_codes(node.right, code, path)
            code.pop()
            path.pop()
    
    def forward(self, center_word: torch.Tensor, context_words: torch.Tensor = None,
                negative_samples: torch.Tensor = None) -> torch.Tensor:
        """
        Compute loss for training instance.
        
        Args:
            center_word: (batch_size,) center word indices
            context_words: (batch_size, window_size) positive context words
            negative_samples: (batch_size, num_neg) negative sample indices
        """
        batch_size = center_word.size(0)
        center_embeds = self.in_embeddings(center_word)  # (batch, dim)
        
        total_loss = 0.0
        
        if self.use_negative_sampling and context_words is not None:
            # Positive samples: maximize log-sigmoid of inner product
            context_embeds = self.out_embeddings(context_words)  # (batch, window, dim)
            # Compute score for each context word
            pos_scores = torch.bmm(context_embeds, center_embeds.unsqueeze(2)).squeeze(2)
            pos_loss = -torch.mean(torch.log(torch.sigmoid(pos_scores) + 1e-10))
            
            # Negative samples: minimize log-sigmoid of negative inner product
            if negative_samples is not None:
                neg_embeds = self.out_embeddings(negative_samples)  # (batch, num_neg, dim)
                neg_scores = torch.bmm(neg_embeds, center_embeds.unsqueeze(2)).squeeze(2)
                neg_loss = -torch.mean(torch.log(torch.sigmoid(-neg_scores) + 1e-10))
                total_loss = pos_loss + neg_loss
        
        return total_loss
    
    def get_embeddings(self) -> np.ndarray:
        """Retrieve trained word embeddings combining input and output matrices."""
        # Standard practice: average input and output embeddings for final representation
        embeds = (self.in_embeddings.weight.data + self.out_embeddings.weight.data) / 2
        return embeds.cpu().numpy()


class GloVeModel(nn.Module):
    """
    GloVe implementation following Pennington et al. (2014).
    Uses weighted least squares to factorize the log-cooccurrence matrix.
    """
    def __init__(self, vocab_size: int, embedding_dim: int = 200,
                 alpha: float = 0.75, x_max: float = 100.0):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.alpha = alpha  # Exponent for weighting function
        self.x_max = x_max  # Cutoff for weighting function
        
        # Word vectors and bias terms
        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.context_embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.word_biases = nn.Embedding(vocab_size, 1)
        self.context_biases = nn.Embedding(vocab_size, 1)
        
        # Initialize with small random values
        for param in self.parameters():
            nn.init.uniform_(param, -0.5, 0.5)
    
    def forward(self, word_ids: torch.Tensor, context_ids: torch.Tensor,
                cooccurrence_counts: torch.Tensor) -> torch.Tensor:
        """
        Compute weighted squared error between dot product and log cooccurrence.
        
        Args:
            word_ids: (batch_size,) target word indices
            context_ids: (batch_size,) context word indices  
            cooccurrence_counts: (batch_size,) X_ij values
        """
        # Get embeddings
        word_embeds = self.word_embeddings(word_ids)  # (batch, dim)
        context_embeds = self.context_embeddings(context_ids)  # (batch, dim)
        word_bias = self.word_biases(word_ids).squeeze()  # (batch,)
        context_bias = self.context_biases(context_ids).squeeze()  # (batch,)
        
        # Compute dot product plus biases
        dot_product = torch.sum(word_embeds * context_embeds, dim=1)
        prediction = dot_product + word_bias + context_bias
        
        # Log cooccurrence with smoothing for zero entries
        log_cooccurrence = torch.log(cooccurrence_counts + 1.0)
        
        # Weighting function: f(x) = (x/x_max)^alpha if x < x_max else 1
        weights = torch.pow(cooccurrence_counts / self.x_max, self.alpha)
        weights = torch.clamp(weights, max=1.0)
        
        # Weighted least squares loss
        squared_error = torch.pow(prediction - log_cooccurrence, 2)
        loss = torch.mean(weights * squared_error)
        
        return loss


class CorpusProcessor:
    """Handle text preprocessing, vocabulary construction, and cooccurrence matrix generation."""
    
    def __init__(self, min_freq: int = 5, window_size: int = 5):
        self.min_freq = min_freq
        self.window_size = window_size
        self.word2id = {}
        self.id2word = {}
        self.word_freqs = Counter()
        self.cooccurrence_matrix = None
    
    def build_vocab(self, sentences: List[List[str]]):
        """Build vocabulary filtering by minimum frequency threshold."""
        # Count frequencies
        for sent in sentences:
            for word in sent:
                self.word_freqs[word] += 1
        
        # Filter and assign IDs
        filtered_words = [w for w, c in self.word_freqs.items() if c >= self.min_freq]
        self.word2id = {w: i for i, w in enumerate(filtered_words)}
        self.id2word = {i: w for w, i in self.word2id.items()}
        self.vocab_size = len(filtered_words)
        print(f"Vocabulary size: {self.vocab_size}")
    
    def generate_skipgram_pairs(self, sentences: List[List[str]]) -> List[Tuple[int, int]]:
        """Generate (center, context) word ID pairs for Skip-gram training."""
        pairs = []
        for sent in sentences:
            word_ids = [self.word2id[w] for w in sent if w in self.word2id]
            for i, center_id in enumerate(word_ids):
                # Define context window boundaries
                start = max(0, i - self.window_size)
                end = min(len(word_ids), i + self.window_size + 1)
                for j in range(start, end):
                    if i != j:
                        context_id = word_ids[j]
                        pairs.append((center_id, context_id))
        return pairs
    
    def build_cooccurrence_matrix(self, sentences: List[List[str]]):
        """Construct sparse cooccurrence matrix X for GloVe training."""
        # Symmetric matrix where X[i][j] = sum of cooccurrences of word i and j
        cooccur = defaultdict(float)
        
        for sent in sentences:
            word_ids = [self.word2id[w] for w in sent if w in self.word2id]
            for i, center_id in enumerate(word_ids):
                start = max(0, i - self.window_size)
                end = min(len(word_ids), i + self.window_size + 1)
                for j in range(start, end):
                    if i != j:
                        context_id = word_ids[j]
                        # Distance-weighted cooccurrence
                        distance = abs(i - j)
                        cooccur[(center_id, context_id)] += 1.0 / distance
        
        # Convert to sparse matrix format for efficient storage
        self.cooccurrence_data = []
        for (w, c), count in cooccur.items():
            self.cooccurrence_data.append((w, c, count))
        
        print(f"Cooccurrence pairs: {len(self.cooccurrence_data)}")
    
    def get_negative_samples(self, batch_size: int, num_neg: int) -> torch.Tensor:
        """Sample negative examples based on smoothed unigram distribution."""
        # Use noise distribution raised to 3/4 power as per original paper
        probs = np.array([self.word_freqs[self.id2word[i]] ** 0.75 
                         for i in range(self.vocab_size)])
        probs /= probs.sum()
        return torch.tensor(
            np.random.choice(self.vocab_size, size=(batch_size, num_neg), p=probs),
            dtype=torch.long
        )


def train_word2vec(corpus_path: str, embedding_dim: int = 200, 
                   epochs: int = 5, method: str = 'skipgram_ns'):
    """
    Complete training pipeline for Word2Vec models.
    
    Args:
        corpus_path: Path to text corpus (one sentence per line, space-tokenized)
        embedding_dim: Dimensionality of word vectors
        epochs: Number of training iterations
        method: 'skipgram_ns' for negative sampling, 'skipgram_hs' for hierarchical softmax
    """
    # Load and preprocess data
    with open(corpus_path, 'r', encoding='utf-8') as f:
        sentences = [line.strip().split() for line in f if line.strip()]
    
    processor = CorpusProcessor(min_freq=5, window_size=5)
    processor.build_vocab(sentences)
    pairs = processor.generate_skipgram_pairs(sentences)
    
    # Initialize model
    use_hs = (method == 'skipgram_hs')
    use_ns = (method == 'skipgram_ns')
    
    model = Word2VecModel(
        vocab_size=processor.vocab_size,
        embedding_dim=embedding_dim,
        use_hierarchical_softmax=use_hs,
        use_negative_sampling=use_ns,
        negative_samples=5 if use_ns else 0
    )
    
    if use_hs:
        word_freq_dict = {i: processor.word_freqs[processor.id2word[i]] 
                         for i in range(processor.vocab_size)}
        model.build_huffman_tree(word_freq_dict)
    
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    # Training loop with dynamic learning rate decay
    batch_size = 256
    num_batches = len(pairs) // batch_size
    
    for epoch in range(epochs):
        random.shuffle(pairs)
        total_loss = 0.0
        
        for i in range(num_batches):
            batch_pairs = pairs[i*batch_size : (i+1)*batch_size]
            center_words = torch.tensor([p[0] for p in batch_pairs], dtype=torch.long)
            context_words = torch.tensor([p[1] for p in batch_pairs], dtype=torch.long)
            
            neg_samples = None
            if use_ns:
                neg_samples = processor.get_negative_samples(batch_size, 5)
            
            optimizer.zero_grad()
            loss = model(center_words, context_words.unsqueeze(1), neg_samples)
            loss.backward()
            
            # Gradient clipping for stability
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
            optimizer.step()
            
            total_loss += loss.item()
        
        avg_loss = total_loss / num_batches
        print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")
    
    # Save embeddings and visualize
    embeddings = model.get_embeddings()
    visualize_embeddings(embeddings, processor.id2word, method_name="Word2Vec")
    
    return model, processor


def train_glove(corpus_path: str, embedding_dim: int = 200, epochs: int = 10):
    """
    Training pipeline for GloVe model with cooccurrence matrix factorization.
    """
    # Data preparation
    with open(corpus_path, 'r', encoding='utf-8') as f:
        sentences = [line.strip().split() for line in f if line.strip()]
    
    processor = CorpusProcessor(min_freq=5, window_size=10)  # Larger window for GloVe
    processor.build_vocab(sentences)
    processor.build_cooccurrence_matrix(sentences)
    
    # Model initialization
    model = GloVeModel(
        vocab_size=processor.vocab_size,
        embedding_dim=embedding_dim,
        alpha=0.75,
        x_max=100.0
    )
    
    optimizer = optim.Adam(model.parameters(), lr=0.05)
    
    # Training loop
    batch_size = 512
    data = processor.cooccurrence_data
    num_batches = len(data) // batch_size
    
    for epoch in range(epochs):
        random.shuffle(data)
        total_loss = 0.0
        
        for i in range(num_batches):
            batch = data[i*batch_size : (i+1)*batch_size]
            word_ids = torch.tensor([x[0] for x in batch], dtype=torch.long)
            context_ids = torch.tensor([x[1] for x in batch], dtype=torch.long)
            counts = torch.tensor([x[2] for x in batch], dtype=torch.float32)
            
            optimizer.zero_grad()
            loss = model(word_ids, context_ids, counts)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        avg_loss = total_loss / num_batches
        print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")
    
    # Combine word and context embeddings for final representation
    embeddings = (model.word_embeddings.weight.data + 
                  model.context_embeddings.weight.data) / 2
    embeddings = embeddings.cpu().numpy()
    
    visualize_embeddings(embeddings, processor.id2word, method_name="GloVe")
    
    return model, processor


def visualize_embeddings(embeddings: np.ndarray, id2word: Dict[int, str],
                        method_name: str, num_words: int = 300):
    """
    Visualize word embeddings using PCA dimensionality reduction.
    Displays semantic clustering of high-frequency vocabulary items.
    """
    # Select top frequent words for visualization
    words_to_plot = list(id2word.values())[:num_words]
    indices = [list(id2word.keys())[i] for i in range(num_words)]
    selected_embeds = embeddings[indices]
    
    # PCA reduction to 2D
    pca = PCA(n_components=2)
    reduced = pca.fit_transform(selected_embeds)
    
    # Plot
    plt.figure(figsize=(14, 10))
    plt.scatter(reduced[:, 0], reduced[:, 1], alpha=0.6, s=20)
    
    # Annotate subset of words to avoid overcrowding
    step = max(1, num_words // 50)
    for i in range(0, num_words, step):
        plt.annotate(words_to_plot[i], (reduced[i, 0], reduced[i, 1]),
                    fontsize=8, alpha=0.8)
    
    plt.title(f'{method_name} Word Embeddings Visualization (PCA)', fontsize=14)
    plt.xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.2%} variance)')
    plt.ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.2%} variance)')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(f'{method_name.lower()}_embeddings.png', dpi=300, bbox_inches='tight')
    plt.show()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Train Word2Vec or GloVe models')
    parser.add_argument('--corpus_path', type=str, required=True, help='Path to training corpus')
    parser.add_argument('--method', type=str, default='skipgram_ns', 
                       choices=['skipgram_ns', 'skipgram_hs', 'glove'])
    parser.add_argument('--embedding_dim', type=int, default=200)
    parser.add_argument('--epochs', type=int, default=5)
    
    args = parser.parse_args()
    
    if args.method == 'glove':
        train_glove(args.corpus_path, args.embedding_dim, args.epochs)
    else:
        train_word2vec(args.corpus_path, args.embedding_dim, args.epochs, args.method)

11.1.2 上下文嵌入:ELMo的双向语言模型与字符级CNN,BERT的Masked Language Model(MLM)与Next Sentence Prediction(NSP)的预训练目标

上下文相关词表示突破静态嵌入的一词多义限制,根据词汇所处语法环境动态调整语义编码。Peters等人提出的ELMo架构采用堆叠双向长短期记忆网络,从正向与反向两个独立语言模型中提取隐藏状态,通过加权求和构建深层语境化表示。字符级卷积神经网络处理形态学特征,为未登录词提供基于内部结构的表征方案,避免固定词汇表的覆盖率局限。

Devlin等人开发的BERT模型引入Transformer双向编码器,通过掩码语言模型预训练目标随机遮蔽输入序列中的词汇单元,强制模型依据双向上下文推断被掩盖词语的语义身份。下一句预测任务判别输入段落对之间的逻辑连贯性,建模文档级 discourse 关系。两阶段预训练流程首先在大规模无标注语料上学习通用语言表示,随后在下游任务特定数据上进行微调,在问答、自然语言推理等benchmark上实现显著性能提升。

Python

"""
Script: elmo_bert_pretraining.py
Content: Implementation of ELMo bidirectional language model with character-level CNN,
         and BERT Masked Language Model (MLM) with Next Sentence Prediction (NSP).
Usage: python elmo_bert_pretraining.py --model_type bert --corpus_path wiki.txt --epochs 3
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random
from typing import List, Tuple, Dict, Optional
import matplotlib.pyplot as plt
from transformers import BertTokenizerFast, BertConfig
import argparse


class CharacterCNN(nn.Module):
    """
    Character-level CNN for ELMo following Peters et al. (2018).
    Maps character sequences to fixed-dimensional word representations.
    """
    def __init__(self, char_vocab_size: int = 262, char_embed_dim: int = 16,
                 num_filters: int = 204, filters: List[Tuple[int, int]] = None,
                 max_word_length: int = 50):
        super().__init__()
        self.char_vocab_size = char_vocab_size
        self.char_embed_dim = char_embed_dim
        self.max_word_length = max_word_length
        
        # Character embedding layer
        self.char_embeddings = nn.Embedding(char_vocab_size, char_embed_dim, padding_idx=0)
        
        # Convolutional filters of varying widths (1 to 7)
        if filters is None:
            filters = [(1, 32), (2, 32), (3, 64), (4, 128), (5, 256), (6, 512), (7, 1024)]
        
        self.conv_layers = nn.ModuleList()
        for width, num_filter in filters:
            self.conv_layers.append(
                nn.Conv1d(char_embed_dim, num_filter, kernel_size=width, padding=0)
            )
        
        self.num_filters = sum(f[1] for f in filters)
        self.highway_layers = nn.ModuleList([
            nn.Linear(self.num_filters, self.num_filters) for _ in range(2)
        ])
        self.highway_gate = nn.ModuleList([
            nn.Linear(self.num_filters, self.num_filters) for _ in range(2)
        ])
        self.projection = nn.Linear(self.num_filters, 512)  # Project to word embedding dim
    
    def forward(self, char_ids: torch.Tensor) -> torch.Tensor:
        """
        Args:
            char_ids: (batch_size, seq_len, max_word_length) character indices
        Returns:
            word_embeds: (batch_size, seq_len, projection_dim)
        """
        batch_size, seq_len, max_word_len = char_ids.shape
        
        # Flatten for convolution: (batch*seq, max_word_len)
        flat_chars = char_ids.view(-1, max_word_len)
        char_embeds = self.char_embeddings(flat_chars)  # (batch*seq, max_word_len, char_dim)
        char_embeds = char_embeds.transpose(1, 2)  # (batch*seq, char_dim, max_word_len)
        
        # Apply convolutions and max pooling
        conv_outputs = []
        for conv in self.conv_layers:
            conv_out = F.relu(conv(char_embeds))  # (batch*seq, num_filter, new_len)
            pooled = F.adaptive_max_pool1d(conv_out, 1).squeeze(-1)  # (batch*seq, num_filter)
            conv_outputs.append(pooled)
        
        # Concatenate all filter outputs
        char_repr = torch.cat(conv_outputs, dim=-1)  # (batch*seq, total_filters)
        
        # Highway network for feature transformation
        for i in range(len(self.highway_layers)):
            gate = torch.sigmoid(self.highway_gate[i](char_repr))
            transform = F.relu(self.highway_layers[i](char_repr))
            char_repr = gate * transform + (1 - gate) * char_repr
        
        # Final projection
        word_repr = self.projection(char_repr)  # (batch*seq, projection_dim)
        word_repr = word_repr.view(batch_size, seq_len, -1)
        
        return word_repr


class ELMoLSTM(nn.Module):
    """
    Stacked bidirectional LSTM for ELMo bidirectional language modeling.
    Separate forward and backward passes with residual connections.
    """
    def __init__(self, input_dim: int = 512, hidden_dim: int = 4096,
                 num_layers: int = 2, dropout: float = 0.3):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        # Forward LSTMs
        self.forward_lstm = nn.LSTM(
            input_dim, hidden_dim // 2, num_layers,
            batch_first=True, dropout=dropout, bidirectional=False
        )
        # Backward LSTMs (process reversed sequence)
        self.backward_lstm = nn.LSTM(
            input_dim, hidden_dim // 2, num_layers,
            batch_first=True, dropout=dropout, bidirectional=False
        )
        
        # Projection layers for each layer output
        self.projections = nn.ModuleList([
            nn.Linear(hidden_dim, input_dim) for _ in range(num_layers + 1)
        ])
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, char_embeds: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            char_embeds: (batch, seq_len, input_dim) from CharacterCNN
        Returns:
            forward_outputs: List of layer representations
            backward_outputs: List of layer representations
        """
        batch_size, seq_len, _ = char_embeds.shape
        
        # Forward pass
        fwd_out, _ = self.forward_lstm(char_embeds)
        fwd_states = [char_embeds, fwd_out]  # Layer 0 is input embedding
        
        # Backward pass (reverse sequence)
        rev_input = torch.flip(char_embeds, dims=[1])
        bwd_out, _ = self.backward_lstm(rev_input)
        bwd_out = torch.flip(bwd_out, dims=[1])
        bwd_states = [char_embeds, bwd_out]
        
        return fwd_states, bwd_states


class ELMoModel(nn.Module):
    """
    Complete ELMo architecture with character CNN and bidirectional LSTM.
    Pretraining objective: independent forward and backward language modeling.
    """
    def __init__(self, char_vocab_size: int, char_embed_dim: int = 16,
                 projection_dim: int = 512, lstm_hidden: int = 4096,
                 num_layers: int = 2, vocab_size: int = 30000):
        super().__init__()
        self.char_cnn = CharacterCNN(char_vocab_size, char_embed_dim, 
                                     max_word_length=50)
        self.lstm_stack = ELMoLSTM(projection_dim, lstm_hidden, num_layers)
        
        # Softmax layers for language modeling
        self.forward_decoder = nn.Linear(projection_dim, vocab_size)
        self.backward_decoder = nn.Linear(projection_dim, vocab_size)
        
        # Scalar weights for computing final ELMo representation (learned per task)
        self.elmo_weights = nn.Parameter(torch.ones(num_layers + 1))
        self.elmo_gamma = nn.Parameter(torch.ones(1))
    
    def forward(self, char_ids: torch.Tensor, 
                forward_targets: Optional[torch.Tensor] = None,
                backward_targets: Optional[torch.Tensor] = None):
        """
        Args:
            char_ids: (batch, seq_len, max_word_len)
            forward_targets: (batch, seq_len) target tokens for forward LM
            backward_targets: (batch, seq_len) target tokens for backward LM
        """
        # Character CNN encoding
        char_embeds = self.char_cnn(char_ids)  # (batch, seq_len, proj_dim)
        
        # Bidirectional LSTM processing
        fwd_states, bwd_states = self.lstm_stack(char_embeds)
        
        # Language modeling losses
        total_loss = 0.0
        if forward_targets is not None:
            # Last layer forward output
            fwd_logits = self.forward_decoder(fwd_states[-1])  # (batch, seq_len, vocab)
            fwd_loss = F.cross_entropy(
                fwd_logits.view(-1, fwd_logits.size(-1)),
                forward_targets.view(-1),
                ignore_index=0
            )
            total_loss += fwd_loss
        
        if backward_targets is not None:
            bwd_logits = self.backward_decoder(bwd_states[-1])
            bwd_loss = F.cross_entropy(
                bwd_logits.view(-1, bwd_logits.size(-1)),
                backward_targets.view(-1),
                ignore_index=0
            )
            total_loss += bwd_loss
        
        # Compute weighted ELMo representation (scalar mix of layers)
        # Stack: (num_layers+1, batch, seq, dim)
        stacked_fwd = torch.stack(fwd_states, dim=0)
        stacked_bwd = torch.stack(bwd_states, dim=0)
        
        weights = F.softmax(self.elmo_weights, dim=0).view(-1, 1, 1, 1)
        weighted_fwd = (stacked_fwd * weights).sum(dim=0)
        weighted_bwd = (stacked_bwd * weights).sum(dim=0)
        
        # Concatenate forward and backward
        elmo_repr = torch.cat([weighted_fwd, weighted_bwd], dim=-1)
        elmo_repr = self.elmo_gamma * elmo_repr
        
        return elmo_repr, total_loss, (fwd_states, bwd_states)


class BertPretrainingTasks(nn.Module):
    """
    BERT pretraining implementation with MLM and NSP objectives.
    Based on Devlin et al. (2019) with separate heads for each task.
    """
    def __init__(self, config: BertConfig):
        super().__init__()
        from transformers import BertModel
        
        self.bert = BertModel(config)
        self.config = config
        
        # MLM head: Transform + LayerNorm + Linear + GELU + LayerNorm + Linear
        self.mlm_transform = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size),
            nn.GELU(),
            nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps),
            nn.Linear(config.hidden_size, config.vocab_size)
        )
        
        # NSP head: Linear + Tanh + Linear (binary classification)
        self.nsp_pooler = nn.Linear(config.hidden_size, config.hidden_size)
        self.nsp_classifier = nn.Linear(config.hidden_size, 2)
        self.activation = nn.Tanh()
        
        # Loss weights
        self.mlm_weight = 1.0
        self.nsp_weight = 1.0
    
    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor,
                token_type_ids: torch.Tensor, mlm_labels: Optional[torch.Tensor] = None,
                nsp_labels: Optional[torch.Tensor] = None,
                masked_lm_positions: Optional[torch.Tensor] = None):
        """
        Args:
            input_ids: (batch, seq_len) token IDs
            attention_mask: (batch, seq_len) attention mask
            token_type_ids: (batch, seq_len) segment IDs (0 or 1)
            mlm_labels: (batch, num_masked) labels for masked positions
            nsp_labels: (batch,) 0 or 1 (is_next or not_next)
            masked_lm_positions: (batch, num_masked) positions of [MASK] tokens
        """
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            output_hidden_states=True
        )
        
        sequence_output = outputs.last_hidden_state  # (batch, seq_len, hidden)
        pooled_output = outputs.pooler_output  # (batch, hidden) - [CLS] representation
        
        # MLM prediction: gather masked positions
        if masked_lm_positions is not None:
            batch_size, num_masked = masked_lm_positions.shape
            # Gather masked positions
            masked_positions = masked_lm_positions.unsqueeze(-1).expand(-1, -1, sequence_output.size(-1))
            masked_states = torch.gather(sequence_output, 1, masked_positions)  # (batch, num_masked, hidden)
        else:
            masked_states = sequence_output
        
        # Apply MLM transformation
        mlm_logits = self.mlm_transform(masked_states)  # (batch, num_masked, vocab)
        
        # NSP prediction
        nsp_pooled = self.activation(self.nsp_pooler(pooled_output))
        nsp_logits = self.nsp_classifier(nsp_pooled)  # (batch, 2)
        
        # Compute losses
        total_loss = 0.0
        mlm_loss = 0.0
        nsp_loss = 0.0
        
        if mlm_labels is not None:
            mlm_loss = F.cross_entropy(
                mlm_logits.view(-1, self.config.vocab_size),
                mlm_labels.view(-1),
                ignore_index=-100
            )
            total_loss += self.mlm_weight * mlm_loss
        
        if nsp_labels is not None:
            nsp_loss = F.cross_entropy(nsp_logits, nsp_labels)
            total_loss += self.nsp_weight * nsp_loss
        
        return {
            'loss': total_loss,
            'mlm_loss': mlm_loss,
            'nsp_loss': nsp_loss,
            'mlm_logits': mlm_logits,
            'nsp_logits': nsp_logits,
            'hidden_states': outputs.hidden_states
        }


class MLMDataset(Dataset):
    """
    Dataset for BERT MLM pretraining with dynamic masking strategy.
    Implements whole word masking and random token replacement.
    """
    def __init__(self, texts: List[str], tokenizer, max_length: int = 512,
                 mlm_probability: float = 0.15, nsp_probability: float = 0.5):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.mlm_probability = mlm_probability
        self.nsp_probability = nsp_probability
        self.examples = self._prepare_examples(texts)
    
    def _prepare_examples(self, texts: List[str]) -> List[Dict]:
        """Create training examples with NSP and MLM preparation."""
        examples = []
        # Split texts into sentences for NSP
        sentences = []
        for text in texts:
            sents = text.split('.')
            sentences.extend([s.strip() for s in sents if len(s.strip()) > 10])
        
        # Create sentence pairs
        for i in range(len(sentences) - 1):
            is_next = random.random() < self.nsp_probability
            if is_next:
                sent_a = sentences[i]
                sent_b = sentences[i + 1]
            else:
                sent_a = sentences[i]
                # Random sentence from different document
                rand_idx = random.randint(0, len(sentences) - 1)
                while abs(rand_idx - i) < 2:
                    rand_idx = random.randint(0, len(sentences) - 1)
                sent_b = sentences[rand_idx]
            
            examples.append({
                'text_a': sent_a,
                'text_b': sent_b,
                'is_next': int(is_next)
            })
        
        return examples
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        example = self.examples[idx]
        
        # Tokenize pair
        encoding = self.tokenizer(
            example['text_a'],
            example['text_b'],
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        input_ids = encoding['input_ids'].squeeze(0)
        attention_mask = encoding['attention_mask'].squeeze(0)
        token_type_ids = encoding['token_type_ids'].squeeze(0)
        
        # Create MLM labels
        labels = input_ids.clone()
        masked_indices = []
        
        # Mask tokens with 15% probability
        for i in range(len(input_ids)):
            if input_ids[i] == self.tokenizer.pad_token_id:
                labels[i] = -100
                continue
            
            if random.random() < self.mlm_probability:
                masked_indices.append(i)
                # 80% [MASK], 10% random, 10% unchanged
                prob = random.random()
                if prob < 0.8:
                    input_ids[i] = self.tokenizer.mask_token_id
                elif prob < 0.9:
                    input_ids[i] = random.randint(0, self.tokenizer.vocab_size - 1)
        
        # Only compute loss on masked tokens
        for i in range(len(labels)):
            if i not in masked_indices:
                labels[i] = -100
        
        # Prepare masked positions tensor
        if len(masked_indices) > 0:
            masked_positions = torch.tensor(masked_indices[:64], dtype=torch.long)  # Max 64 masked
            masked_labels = labels[masked_positions]
        else:
            masked_positions = torch.zeros(1, dtype=torch.long)
            masked_labels = torch.tensor([-100], dtype=torch.long)
        
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'token_type_ids': token_type_ids,
            'mlm_labels': masked_labels,
            'nsp_labels': torch.tensor(example['is_next'], dtype=torch.long),
            'masked_lm_positions': masked_positions
        }


def create_elmo_batches(sentences: List[List[str]], char2id: Dict[str, int],
                       max_word_len: int = 50, batch_size: int = 32):
    """
    Create batches for ELMo training with character-level encoding.
    """
    batches = []
    for i in range(0, len(sentences), batch_size):
        batch_sents = sentences[i:i+batch_size]
        max_len = max(len(s) for s in batch_sents)
        
        char_ids = torch.zeros(len(batch_sents), max_len, max_word_len, dtype=torch.long)
        forward_targets = torch.zeros(len(batch_sents), max_len, dtype=torch.long)
        backward_targets = torch.zeros(len(batch_sents), max_len, dtype=torch.long)
        
        for b_idx, sent in enumerate(batch_sents):
            for w_idx, word in enumerate(sent):
                if w_idx >= max_len:
                    break
                # Encode characters
                chars = list(word.lower())[:max_word_len]
                for c_idx, char in enumerate(chars):
                    char_ids[b_idx, w_idx, c_idx] = char2id.get(char, char2id.get('<unk>', 1))
            
            # Set forward targets (next word prediction)
            for w_idx in range(len(sent) - 1):
                word = sent[w_idx + 1]
                forward_targets[b_idx, w_idx] = word2id.get(word, word2id.get('<unk>', 1))
            
            # Set backward targets (previous word prediction)
            for w_idx in range(1, len(sent)):
                word = sent[w_idx - 1]
                backward_targets[b_idx, w_idx] = word2id.get(word, word2id.get('<unk>', 1))
        
        batches.append((char_ids, forward_targets, backward_targets))
    
    return batches


def train_elmo(corpus_path: str, epochs: int = 5):
    """Training loop for ELMo bidirectional language model."""
    # Load and preprocess
    with open(corpus_path, 'r', encoding='utf-8') as f:
        texts = f.read().splitlines()
    
    # Build character vocabulary
    chars = set()
    word_freqs = Counter()
    for text in texts:
        words = text.split()
        for word in words:
            word_freqs[word] += 1
            chars.update(list(word.lower()))
    
    char2id = {c: i+2 for i, c in enumerate(sorted(chars))}
    char2id['<pad>'] = 0
    char2id['<unk>'] = 1
    
    global word2id
    word2id = {w: i+2 for i, (w, _) in enumerate(word_freqs.most_common(30000))}
    word2id['<pad>'] = 0
    word2id['<unk>'] = 1
    
    # Initialize model
    model = ELMoModel(
        char_vocab_size=len(char2id),
        char_embed_dim=16,
        projection_dim=512,
        lstm_hidden=4096,
        num_layers=2,
        vocab_size=len(word2id)
    )
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    # Prepare sentences
    sentences = [text.split()[:50] for text in texts if text.strip()]  # Truncate long sentences
    
    losses = []
    for epoch in range(epochs):
        batches = create_elmo_batches(sentences, char2id)
        epoch_loss = 0.0
        
        for char_ids, fwd_targets, bwd_targets in batches:
            optimizer.zero_grad()
            _, loss, _ = model(char_ids, fwd_targets, bwd_targets)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            epoch_loss += loss.item()
        
        avg_loss = epoch_loss / len(batches)
        losses.append(avg_loss)
        print(f"ELMo Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")
    
    # Visualize training loss and layer weights
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(losses, marker='o')
    plt.title('ELMo Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid(True)
    
    plt.subplot(1, 2, 2)
    weights = F.softmax(model.elmo_weights, dim=0).detach().numpy()
    plt.bar(['Input', 'Layer 1', 'Layer 2'], weights)
    plt.title('ELMo Layer Contribution Weights')
    plt.ylabel('Weight')
    
    plt.tight_layout()
    plt.savefig('elmo_training_analysis.png', dpi=300)
    plt.show()
    
    return model


def train_bert(corpus_path: str, epochs: int = 3):
    """Training loop for BERT MLM and NSP pretraining."""
    from transformers import BertTokenizerFast
    
    tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
    config = BertConfig(
        vocab_size=30522,
        hidden_size=768,
        num_hidden_layers=12,
        num_attention_heads=12,
        intermediate_size=3072,
        max_position_embeddings=512
    )
    
    model = BertPretrainingTasks(config)
    
    # Load data
    with open(corpus_path, 'r', encoding='utf-8') as f:
        texts = f.read().splitlines()[:1000]  # Limit for demo
    
    dataset = MLMDataset(texts, tokenizer, max_length=128, mlm_probability=0.15)
    dataloader = DataLoader(dataset, batch_size=16, shuffle=True)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.0, total_iters=epochs*len(dataloader))
    
    mlm_losses = []
    nsp_losses = []
    
    for epoch in range(epochs):
        model.train()
        epoch_mlm_loss = 0.0
        epoch_nsp_loss = 0.0
        
        for batch in dataloader:
            optimizer.zero_grad()
            outputs = model(
                input_ids=batch['input_ids'],
                attention_mask=batch['attention_mask'],
                token_type_ids=batch['token_type_ids'],
                mlm_labels=batch['mlm_labels'],
                nsp_labels=batch['nsp_labels'],
                masked_lm_positions=batch['masked_lm_positions']
            )
            
            loss = outputs['loss']
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            epoch_mlm_loss += outputs['mlm_loss'].item()
            epoch_nsp_loss += outputs['nsp_loss'].item()
        
        avg_mlm = epoch_mlm_loss / len(dataloader)
        avg_nsp = epoch_nsp_loss / len(dataloader)
        mlm_losses.append(avg_mlm)
        nsp_losses.append(avg_nsp)
        
        print(f"BERT Epoch {epoch+1}/{epochs}, MLM Loss: {avg_mlm:.4f}, NSP Loss: {avg_nsp:.4f}")
    
    # Visualize pretraining task losses
    plt.figure(figsize=(10, 6))
    plt.plot(mlm_losses, marker='o', label='MLM Loss')
    plt.plot(nsp_losses, marker='s', label='NSP Loss')
    plt.title('BERT Pretraining Task Losses')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig('bert_pretraining_losses.png', dpi=300)
    plt.show()
    
    return model, tokenizer


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_type', type=str, choices=['elmo', 'bert'], required=True)
    parser.add_argument('--corpus_path', type=str, required=True)
    parser.add_argument('--epochs', type=int, default=3)
    
    args = parser.parse_args()
    
    if args.model_type == 'elmo':
        train_elmo(args.corpus_path, args.epochs)
    else:
        train_bert(args.corpus_path, args.epochs)

11.1.3 GPT系列的自回归预训练(CLM)与因果掩码(Causal Masking),Span Corruption与T5的Encoder-Decoder统一框架

自回归语言建模通过最大化序列联合概率的分解乘积学习生成能力,每个位置的条件分布仅依赖于先前token的历史信息。Radford等人提出的GPT架构采用Transformer解码器结构,通过因果掩码机制确保注意力操作仅访问当前位置左侧的上下文,保持自回归属性的同时实现并行训练。因果掩码构造为下三角矩阵,屏蔽未来位置的信息泄露,适用于文本生成、摘要等序列到序列任务。

Raffel等人设计的T5模型将所有自然语言处理任务统一为文本到文本的转换框架,引入Span Corruption预训练目标,随机遮蔽输入序列中的连续片段并训练模型重构被掩盖的文本跨度。编码器-解码器架构分别处理输入上下文与输出生成,编码器采用双向注意力机制编码完整源文本,解码器通过交叉注意力层结合编码器表示并自回归生成目标序列。这种统一范式支持翻译、问答、分类等多种任务的无缝切换,通过任务特定的文本前缀指令区分不同任务类型。

Python

"""
Script: gpt_t5_pretraining.py
Content: Implementation of GPT causal language modeling with causal masking,
         and T5 span corruption with encoder-decoder architecture.
Usage: python gpt_t5_pretraining.py --model_type t5 --corpus_path data.txt --epochs 3
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import math
import random
from typing import Dict, List, Tuple, Optional
import matplotlib.pyplot as plt
import numpy as np
import argparse


class CausalSelfAttention(nn.Module):
    """
    Causal self-attention for GPT with lower-triangular masking.
    Ensures positions only attend to previous positions and themselves.
    """
    def __init__(self, embed_dim: int = 768, num_heads: int = 12, 
                 dropout: float = 0.1, block_size: int = 1024):
        super().__init__()
        assert embed_dim % num_heads == 0
        
        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.head_dim = embed_dim // num_heads
        
        # Q, K, V projections
        self.qkv = nn.Linear(embed_dim, 3 * embed_dim)
        self.proj = nn.Linear(embed_dim, embed_dim)
        
        self.dropout = nn.Dropout(dropout)
        
        # Causal mask (lower triangular)
        self.register_buffer(
            "mask",
            torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len, embed_dim = x.shape
        
        # Compute Q, K, V
        qkv = self.qkv(x)
        q, k, v = qkv.split(self.embed_dim, dim=2)
        
        # Reshape to (batch, heads, seq, head_dim)
        q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Attention scores
        scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        
        # Apply causal mask
        scores = scores.masked_fill(self.mask[:, :, :seq_len, :seq_len] == 0, float('-inf'))
        
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        
        # Combine heads
        out = (attn @ v).transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)
        out = self.proj(out)
        
        return out


class GPTBlock(nn.Module):
    """Transformer decoder block for GPT with pre-norm architecture."""
    def __init__(self, embed_dim: int = 768, num_heads: int = 12,
                 mlp_ratio: float = 4.0, dropout: float = 0.1, block_size: int = 1024):
        super().__init__()
        self.ln1 = nn.LayerNorm(embed_dim)
        self.attn = CausalSelfAttention(embed_dim, num_heads, dropout, block_size)
        self.ln2 = nn.LayerNorm(embed_dim)
        
        mlp_hidden = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_hidden),
            nn.GELU(),
            nn.Linear(mlp_hidden, embed_dim),
            nn.Dropout(dropout)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x


class GPTModel(nn.Module):
    """
    GPT architecture for causal language modeling (CLM).
    Follows Radford et al. (2019) with token and positional embeddings.
    """
    def __init__(self, vocab_size: int = 50257, embed_dim: int = 768,
                 num_layers: int = 12, num_heads: int = 12,
                 block_size: int = 1024, dropout: float = 0.1):
        super().__init__()
        
        self.token_embed = nn.Embedding(vocab_size, embed_dim)
        self.pos_embed = nn.Embedding(block_size, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
        self.blocks = nn.ModuleList([
            GPTBlock(embed_dim, num_heads, dropout=dropout, block_size=block_size)
            for _ in range(num_layers)
        ])
        
        self.ln_f = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, vocab_size, bias=False)
        
        self.block_size = block_size
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
    
    def forward(self, input_ids: torch.Tensor, labels: Optional[torch.Tensor] = None):
        """
        Args:
            input_ids: (batch, seq_len) token indices
            labels: (batch, seq_len) target tokens for loss computation
        """
        batch_size, seq_len = input_ids.shape
        
        # Token + positional embeddings
        token_emb = self.token_embed(input_ids)
        pos_emb = self.pos_embed(torch.arange(seq_len, device=input_ids.device))
        x = self.dropout(token_emb + pos_emb)
        
        # Transformer blocks
        for block in self.blocks:
            x = block(x)
        
        x = self.ln_f(x)
        logits = self.head(x)  # (batch, seq_len, vocab_size)
        
        loss = None
        if labels is not None:
            # Shift for next-token prediction
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss = F.cross_entropy(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1),
                ignore_index=-100
            )
        
        return {'logits': logits, 'loss': loss}
    
    @torch.no_grad()
    def generate(self, input_ids: torch.Tensor, max_length: int = 100,
                temperature: float = 1.0, top_k: int = 50) -> torch.Tensor:
        """
        Autoregressive generation with temperature sampling and top-k filtering.
        """
        self.eval()
        for _ in range(max_length):
            # Crop to block size
            input_crop = input_ids if input_ids.size(1) <= self.block_size else input_ids[:, -self.block_size:]
            
            outputs = self(input_crop)
            logits = outputs['logits'][:, -1, :] / temperature
            
            # Top-k filtering
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = float('-inf')
            
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            
            input_ids = torch.cat([input_ids, next_token], dim=1)
        
        return input_ids


class T5Encoder(nn.Module):
    """
    T5 Encoder with bidirectional attention and relative positional bias.
    Processes input sequence with full context access.
    """
    def __init__(self, vocab_size: int = 32128, embed_dim: int = 768,
                 num_layers: int = 12, num_heads: int = 12, dropout: float = 0.1):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
        self.layers = nn.ModuleList([
            self._make_layer(embed_dim, num_heads, dropout) 
            for _ in range(num_layers)
        ])
        
        self.ln = nn.LayerNorm(embed_dim)
    
    def _make_layer(self, embed_dim, num_heads, dropout):
        """T5 layer with relative position bias and feed-forward."""
        return nn.ModuleDict({
            'self_attn': nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True),
            'ln1': nn.LayerNorm(embed_dim),
            'ffn': nn.Sequential(
                nn.Linear(embed_dim, embed_dim * 4),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Linear(embed_dim * 4, embed_dim),
                nn.Dropout(dropout)
            ),
            'ln2': nn.LayerNorm(embed_dim)
        })
    
    def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
        x = self.dropout(self.token_embed(input_ids))
        
        for layer in self.layers:
            # Self-attention with bidirectional mask (no causal masking)
            attn_out, _ = layer['self_attn'](layer['ln1'](x), layer['ln1'](x), layer['ln1'](x),
                                            key_padding_mask=attention_mask)
            x = x + attn_out
            x = x + layer['ffn'](layer['ln2'](x))
        
        return self.ln(x)


class T5Decoder(nn.Module):
    """
    T5 Decoder with causal self-attention and cross-attention to encoder.
    Generates output sequence autoregressively.
    """
    def __init__(self, vocab_size: int = 32128, embed_dim: int = 768,
                 num_layers: int = 12, num_heads: int = 12, dropout: float = 0.1):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
        self.layers = nn.ModuleList([
            self._make_layer(embed_dim, num_heads, dropout)
            for _ in range(num_layers)
        ])
        
        self.ln = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, vocab_size, bias=False)
    
    def _make_layer(self, embed_dim, num_heads, dropout):
        return nn.ModuleDict({
            'self_attn': nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True),
            'ln1': nn.LayerNorm(embed_dim),
            'cross_attn': nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True),
            'ln2': nn.LayerNorm(embed_dim),
            'ffn': nn.Sequential(
                nn.Linear(embed_dim, embed_dim * 4),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Linear(embed_dim * 4, embed_dim),
                nn.Dropout(dropout)
            ),
            'ln3': nn.LayerNorm(embed_dim)
        })
    
    def forward(self, input_ids: torch.Tensor, encoder_hidden: torch.Tensor,
                attention_mask: Optional[torch.Tensor] = None,
                encoder_mask: Optional[torch.Tensor] = None):
        batch_size, seq_len = input_ids.shape
        x = self.dropout(self.token_embed(input_ids))
        
        # Causal mask for autoregressive generation
        causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=input_ids.device), diagonal=1).bool()
        
        for layer in self.layers:
            # Causal self-attention
            attn_out, _ = layer['self_attn'](layer['ln1'](x), layer['ln1'](x), layer['ln1'](x),
                                            attn_mask=causal_mask)
            x = x + attn_out
            
            # Cross-attention to encoder
            cross_out, _ = layer['cross_attn'](layer['ln2'](x), encoder_hidden, encoder_hidden,
                                              key_padding_mask=encoder_mask)
            x = x + cross_out
            
            # Feed-forward
            x = x + layer['ffn'](layer['ln3'](x))
        
        x = self.ln(x)
        return self.head(x)


class T5Model(nn.Module):
    """
    Complete T5 model for span corruption pretraining.
    Encoder-decoder architecture following Raffel et al. (2020).
    """
    def __init__(self, vocab_size: int = 32128, embed_dim: int = 768,
                 num_layers: int = 12, num_heads: int = 12, dropout: float = 0.1):
        super().__init__()
        self.encoder = T5Encoder(vocab_size, embed_dim, num_layers, num_heads, dropout)
        self.decoder = T5Decoder(vocab_size, embed_dim, num_layers, num_heads, dropout)
        self.vocab_size = vocab_size
    
    def forward(self, input_ids: torch.Tensor, decoder_input_ids: torch.Tensor,
                labels: Optional[torch.Tensor] = None,
                attention_mask: Optional[torch.Tensor] = None):
        """
        Forward pass for span corruption training.
        
        Args:
            input_ids: (batch, enc_len) corrupted input with sentinel tokens
            decoder_input_ids: (batch, dec_len) target sequence for decoder
            labels: (batch, dec_len) target labels for loss
            attention_mask: (batch, enc_len) padding mask for encoder
        """
        # Encode input
        encoder_hidden = self.encoder(input_ids, attention_mask)
        
        # Decode output
        logits = self.decoder(decoder_input_ids, encoder_hidden, encoder_mask=attention_mask)
        
        loss = None
        if labels is not None:
            loss = F.cross_entropy(
                logits.view(-1, self.vocab_size),
                labels.view(-1),
                ignore_index=-100
            )
        
        return {'loss': loss, 'logits': logits, 'encoder_hidden': encoder_hidden}


class SpanCorruptionDataset(Dataset):
    """
    Dataset for T5 span corruption pretraining.
    Randomly spans of tokens and replace with unique sentinel tokens.
    """
    def __init__(self, texts: List[str], tokenizer, max_length: int = 512,
                 mean_noise_span_length: float = 3.0, noise_density: float = 0.15):
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.mean_noise_span_length = mean_noise_span_length
        self.noise_density = noise_density
        
        # Sentinel tokens (assuming last 100 tokens in vocab are reserved)
        self.sentinel_start = tokenizer.vocab_size - 100
    
    def __len__(self):
        return len(self.texts)
    
    def _corrupt_span(self, tokens: List[int]) -> Tuple[List[int], List[int]]:
        """
        Apply span corruption to token sequence.
        Returns corrupted input and target output.
        """
        n_tokens = len(tokens)
        n_to_mask = int(n_tokens * self.noise_density)
        
        if n_to_mask == 0:
            return tokens, tokens
        
        # Sample span lengths from geometric distribution
        span_lengths = []
        total_masked = 0
        while total_masked < n_to_mask:
            span_len = max(1, int(random.expovariate(1.0 / self.mean_noise_span_length)))
            span_lengths.append(min(span_len, n_to_mask - total_masked))
            total_masked += span_lengths[-1]
        
        # Randomly choose start positions for spans
        start_positions = random.sample(range(n_tokens - max(span_lengths)), len(span_lengths))
        start_positions.sort()
        
        # Create masks
        masked_indices = set()
        for start, length in zip(start_positions, span_lengths):
            for i in range(start, min(start + length, n_tokens)):
                masked_indices.add(i)
        
        # Build corrupted input and target
        sentinel_id = 0
        corrupted_input = []
        target_output = []
        
        i = 0
        while i < n_tokens:
            if i in masked_indices:
                # Add sentinel to input
                corrupted_input.append(self.sentinel_start + sentinel_id)
                # Add masked span to target with sentinel prefix
                target_output.append(self.sentinel_start + sentinel_id)
                
                # Collect all consecutive masked tokens
                while i < n_tokens and i in masked_indices:
                    target_output.append(tokens[i])
                    i += 1
                
                sentinel_id += 1
            else:
                corrupted_input.append(tokens[i])
                i += 1
        
        return corrupted_input, target_output
    
    def __getitem__(self, idx):
        text = self.texts[idx]
        tokens = self.tokenizer.encode(text, max_length=self.max_length, truncation=True)
        
        # Corrupt spans
        input_tokens, target_tokens = self._corrupt_span(tokens)
        
        # Pad sequences
        input_ids = input_tokens[:self.max_length]
        target_ids = target_tokens[:self.max_length]
        
        input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids))
        target_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(target_ids))
        
        # Create decoder input (shifted right with start token)
        decoder_input = [self.tokenizer.pad_token_id] + target_ids[:-1]
        
        # Create labels (shifted left, -100 for padding)
        labels = [t if t != self.tokenizer.pad_token_id else -100 for t in target_ids]
        
        return {
            'input_ids': torch.tensor(input_ids, dtype=torch.long),
            'decoder_input_ids': torch.tensor(decoder_input, dtype=torch.long),
            'labels': torch.tensor(labels, dtype=torch.long),
            'attention_mask': torch.tensor([i != self.tokenizer.pad_token_id for i in input_ids], dtype=torch.bool)
        }


def train_gpt(corpus_path: str, epochs: int = 3):
    """Training pipeline for GPT causal language modeling."""
    # Simple BPE-like tokenizer for demo
    with open(corpus_path, 'r', encoding='utf-8') as f:
        text = f.read()
    
    # Build vocabulary (top 10k tokens)
    words = text.split()
    vocab = {'<|endoftext|>': 0, '<|pad|>': 1}
    for word in words:
        if word not in vocab and len(vocab) < 10000:
            vocab[word] = len(vocab)
    
    def encode(text):
        return [vocab.get(w, vocab['<|endoftext|>']) for w in text.split()]
    
    # Prepare data
    tokens = encode(text)
    block_size = 128
    
    def get_batch(batch_size=32):
        ix = torch.randint(len(tokens) - block_size, (batch_size,))
        x = torch.stack([torch.tensor(tokens[i:i+block_size]) for i in ix])
        y = x.clone()
        return x, y
    
    # Initialize model
    model = GPTModel(
        vocab_size=len(vocab),
        embed_dim=256,  # Small for demo
        num_layers=6,
        num_heads=8,
        block_size=block_size
    )
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
    
    losses = []
    for epoch in range(epochs):
        epoch_loss = 0.0
        num_batches = 100
        
        for _ in range(num_batches):
            x, y = get_batch()
            
            optimizer.zero_grad()
            outputs = model(x, labels=y)
            loss = outputs['loss']
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            epoch_loss += loss.item()
        
        avg_loss = epoch_loss / num_batches
        losses.append(avg_loss)
        print(f"GPT Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")
    
    # Visualize training dynamics
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(losses, marker='o')
    plt.title('GPT Training Loss (CLM)')
    plt.xlabel('Epoch')
    plt.ylabel('Cross-Entropy Loss')
    plt.grid(True)
    
    # Generate sample text
    plt.subplot(1, 2, 2)
    prompt = torch.tensor([encode("The")], dtype=torch.long)
    generated = model.generate(prompt, max_length=20, temperature=0.8)
    generated_text = " ".join([list(vocab.keys())[list(vocab.values()).index(i)] 
                               for i in generated[0].tolist() if i in vocab.values()])
    plt.text(0.1, 0.5, f"Prompt: 'The'\nGenerated: '{generated_text}'", 
             fontsize=10, wrap=True)
    plt.axis('off')
    plt.title('Autoregressive Generation Sample')
    
    plt.tight_layout()
    plt.savefig('gpt_training_analysis.png', dpi=300)
    plt.show()
    
    return model, vocab


def train_t5(corpus_path: str, epochs: int = 3):
    """Training pipeline for T5 span corruption pretraining."""
    with open(corpus_path, 'r', encoding='utf-8') as f:
        texts = f.read().splitlines()[:1000]  # Limit for demo
    
    # Simple tokenizer
    vocab = {'<pad>': 0, '<s>': 1, '</s>': 2}
    for text in texts:
        for word in text.split():
            if word not in vocab and len(vocab) < 32000:
                vocab[word] = len(vocab)
    
    class SimpleTokenizer:
        def __init__(self, vocab):
            self.vocab = vocab
            self.vocab_size = len(vocab)
            self.pad_token_id = 0
        
        def encode(self, text, max_length=512, truncation=True):
            tokens = [self.vocab.get(w, 1) for w in text.split()]
            if truncation and len(tokens) > max_length:
                tokens = tokens[:max_length]
            return tokens
    
    tokenizer = SimpleTokenizer(vocab)
    dataset = SpanCorruptionDataset(texts, tokenizer, max_length=128)
    dataloader = DataLoader(dataset, batch_size=16, shuffle=True)
    
    # Initialize model
    model = T5Model(
        vocab_size=len(vocab),
        embed_dim=512,
        num_layers=6,
        num_heads=8
    )
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    
    mlm_losses = []
    
    for epoch in range(epochs):
        epoch_loss = 0.0
        
        for batch in dataloader:
            optimizer.zero_grad()
            outputs = model(
                input_ids=batch['input_ids'],
                decoder_input_ids=batch['decoder_input_ids'],
                labels=batch['labels'],
                attention_mask=batch['attention_mask']
            )
            
            loss = outputs['loss']
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            epoch_loss += loss.item()
        
        avg_loss = epoch_loss / len(dataloader)
        mlm_losses.append(avg_loss)
        print(f"T5 Epoch {epoch+1}/{epochs}, Span Corruption Loss: {avg_loss:.4f}")
    
    # Visualize span corruption training
    plt.figure(figsize=(10, 6))
    plt.plot(mlm_losses, marker='o', color='green')
    plt.title('T5 Span Corruption Pretraining Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Reconstruction Loss')
    plt.grid(True)
    plt.tight_layout()
    plt.savefig('t5_span_corruption_training.png', dpi=300)
    plt.show()
    
    return model, tokenizer


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_type', type=str, choices=['gpt', 't5'], required=True)
    parser.add_argument('--corpus_path', type=str, required=True)
    parser.add_argument('--epochs', type=int, default=3)
    
    args = parser.parse_args()
    
    if args.model_type == 'gpt':
        train_gpt(args.corpus_path, args.epochs)
    else:
        train_t5(args.corpus_path, args.epochs)

11.2 参数高效微调与提示工程

11.2.1 LoRA(低秩适配)的秩分解与缩放因子α,Adapter层的瓶颈架构与残差连接,Prefix Tuning的虚拟token优化与重参数化

大规模预训练语言模型的全参数微调面临计算资源消耗与存储开销的挑战。Hu等人提出的低秩适配方法冻结预训练权重矩阵,引入可训练的低秩分解矩阵对注意力权重进行增量更新,通过秩参数控制可训练参数量,缩放因子调节适配强度。数学上等价于对权重变更施加低秩约束,在保持推理阶段零额外延迟的前提下实现参数效率。

Houlsby等人设计的适配器层在Transformer子层间插入瓶颈结构,下投影矩阵将高维表示压缩至低维瓶颈,经非线性激活后上投影还原维度,残差连接保持原始信息流。训练阶段仅优化适配器参数与归一化层,冻结核心注意力与全连接权重,支持多任务场景下的参数共享与快速切换。

Li与Liang发展的前缀调优方法在输入序列前置可优化的虚拟token向量,通过重参数化矩阵生成各层的键值前缀表示,避免直接优化长序列前缀导致的内存不稳定。前缀向量作为软提示条件化网络行为,引导注意力头关注特定任务模式,实现深度提示注入而不修改网络主体参数。

Python

"""
Script: peft_implementation.py
Content: Implementation of LoRA (Low-Rank Adaptation), Adapter layers with bottleneck architecture,
         and Prefix Tuning with reparameterization.
Usage: python peft_implementation.py --method lora --base_model_path model.pt --epochs 5
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import math
from typing import Dict, List, Optional, Tuple
import matplotlib.pyplot as plt
import numpy as np
import argparse


class LoRALayer(nn.Module):
    """
    LoRA layer implementing low-rank adaptation of dense layers.
    Adds trainable rank-decomposition matrices while freezing base weights.
    """
    def __init__(self, base_layer: nn.Linear, rank: int = 8, lora_alpha: float = 16,
                 lora_dropout: float = 0.0, merge_weights: bool = False):
        super().__init__()
        self.base_layer = base_layer
        self.rank = rank
        self.lora_alpha = lora_alpha
        self.scaling = lora_alpha / rank  # Scaling factor alpha/r
        
        # Freeze base layer
        for param in self.base_layer.parameters():
            param.requires_grad = False
        
        in_features = base_layer.in_features
        out_features = base_layer.out_features
        
        # Low-rank decomposition matrices
        self.lora_A = nn.Parameter(torch.zeros((rank, in_features)))
        self.lora_B = nn.Parameter(torch.zeros((out_features, rank)))
        
        self.lora_dropout = nn.Dropout(lora_dropout) if lora_dropout > 0 else nn.Identity()
        
        # Initialize A with random Gaussian, B with zeros (per paper)
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B)
        
        self.merged = False
        self.merge_weights = merge_weights
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass: base output + scaled low-rank adaptation.
        h = Wx + (alpha/r) * B * A * x
        """
        base_output = self.base_layer(x)
        
        if not self.merged:
            # LoRA path: dropout -> linear A -> linear B -> scale
            lora_output = self.lora_dropout(x)
            lora_output = F.linear(lora_output, self.lora_A)  # Down-project
            lora_output = F.linear(lora_output, self.lora_B)  # Up-project
            lora_output = lora_output * self.scaling
            
            return base_output + lora_output
        
        return base_output
    
    def merge(self):
        """Merge LoRA weights into base layer for inference efficiency."""
        if self.merge_weights and not self.merged:
            # Compute delta_W = (alpha/r) * B @ A
            delta_W = (self.lora_B @ self.lora_A) * self.scaling
            # Add to base weight
            self.base_layer.weight.data += delta_W
            self.merged = True
    
    def unmerge(self):
        """Unmerge weights to restore training state."""
        if self.merged:
            delta_W = (self.lora_B @ self.lora_A) * self.scaling
            self.base_layer.weight.data -= delta_W
            self.merged = False


class AdapterLayer(nn.Module):
    """
    Bottleneck adapter layer with residual connection.
    Architecture: Down-project -> Activation -> Up-project -> Residual
    """
    def __init__(self, input_dim: int, adapter_dim: int = 64,
                 activation: str = 'gelu', init_scale: float = 1e-3):
        super().__init__()
        self.input_dim = input_dim
        self.adapter_dim = adapter_dim
        
        self.down_project = nn.Linear(input_dim, adapter_dim)
        self.activation = nn.GELU() if activation == 'gelu' else nn.ReLU()
        self.up_project = nn.Linear(adapter_dim, input_dim)
        
        # Initialize near zero for stable training (Houlsby et al.)
        nn.init.normal_(self.down_project.weight, std=init_scale)
        nn.init.normal_(self.up_project.weight, std=init_scale)
        nn.init.zeros_(self.down_project.bias)
        nn.init.zeros_(self.up_project.bias)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward with residual: output = x + f(x)
        where f(x) = up(activation(down(x)))
        """
        residual = x
        h = self.down_project(x)
        h = self.activation(h)
        h = self.up_project(h)
        return residual + h  # Residual connection preserves base features


class PrefixTuning(nn.Module):
    """
    Prefix Tuning with reparameterization for stable training.
    Optimizes embedding vectors prepended to keys and values in attention.
    """
    def __init__(self, num_layers: int, num_heads: int, embed_dim: int,
                 prefix_length: int = 20, prefix_dim: int = 512,
                 reparam: bool = True):
        super().__init__()
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.prefix_length = prefix_length
        
        # Reparameterization: small MLP generates prefix embeddings
        if reparam:
            self.prefix_embed = nn.Sequential(
                nn.Embedding(prefix_length, prefix_dim),
                nn.Linear(prefix_dim, num_layers * 2 * embed_dim)  # 2 for key and value
            )
            # Initialize input embeddings from vocab distribution
            nn.init.normal_(self.prefix_embed[0].weight, std=0.02)
        else:
            # Direct optimization (less stable)
            self.prefix_embed = nn.Parameter(
                torch.randn(prefix_length, num_layers, 2, num_heads, self.head_dim)
            )
        
        self.reparam = reparam
    
    def forward(self, batch_size: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Generate prefix key-values for all layers.
        
        Returns:
            past_key: (num_layers, batch, num_heads, prefix_len, head_dim)
            past_value: (num_layers, batch, num_heads, prefix_len, head_dim)
        """
        if self.reparam:
            # Generate through MLP
            indices = torch.arange(self.prefix_length, device=device)
            prefix_params = self.prefix_embed(indices)  # (prefix_len, layers*2*embed_dim)
            prefix_params = prefix_params.view(
                self.prefix_length, self.num_layers, 2, self.num_heads, self.head_dim
            )
        else:
            prefix_params = self.prefix_embed
        
        # Split into key and value prefixes
        key_prefix = prefix_params[:, :, 0, :, :]  # (prefix_len, layers, heads, head_dim)
        value_prefix = prefix_params[:, :, 1, :, :]
        
        # Expand for batch dimension and transpose to standard shape
        # (num_layers, batch, num_heads, prefix_len, head_dim)
        key_prefix = key_prefix.permute(1, 2, 3, 0).unsqueeze(1).expand(-1, batch_size, -1, -1, -1)
        value_prefix = value_prefix.permute(1, 2, 3, 0).unsqueeze(1).expand(-1, batch_size, -1, -1, -1)
        
        return key_prefix, value_prefix


class PEFTTransformer(nn.Module):
    """
    Transformer model supporting LoRA, Adapter, and Prefix Tuning methods.
    Base architecture follows standard Transformer with PEFT modifications.
    """
    def __init__(self, vocab_size: int = 32000, embed_dim: int = 768,
                 num_layers: int = 12, num_heads: int = 12,
                 peft_method: str = 'lora', peft_config: Dict = None):
        super().__init__()
        self.peft_method = peft_method
        
        self.token_embed = nn.Embedding(vocab_size, embed_dim)
        self.pos_embed = nn.Embedding(2048, embed_dim)
        
        # Transformer layers
        self.layers = nn.ModuleList([
            self._make_transformer_layer(embed_dim, num_heads, peft_method, peft_config)
            for _ in range(num_layers)
        ])
        
        self.ln_f = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, vocab_size, bias=False)
        
        # Initialize prefix tuning if selected
        if peft_method == 'prefix':
            self.prefix_tuning = PrefixTuning(
                num_layers=num_layers,
                num_heads=num_heads,
                embed_dim=embed_dim,
                prefix_length=peft_config.get('prefix_length', 20),
                prefix_dim=peft_config.get('prefix_dim', 512),
                reparam=peft_config.get('reparam', True)
            )
        else:
            self.prefix_tuning = None
        
        self.apply(self._init_weights)
    
    def _make_transformer_layer(self, embed_dim, num_heads, peft_method, peft_config):
        """Create transformer layer with specified PEFT modification."""
        layer = nn.ModuleDict({
            'ln1': nn.LayerNorm(embed_dim),
            'attn': nn.MultiheadAttention(embed_dim, num_heads, batch_first=True),
            'ln2': nn.LayerNorm(embed_dim),
            'ffn': nn.Sequential(
                nn.Linear(embed_dim, 4 * embed_dim),
                nn.GELU(),
                nn.Linear(4 * embed_dim, embed_dim),
                nn.Dropout(0.1)
            )
        })
        
        # Add PEFT components
        if peft_method == 'lora':
            rank = peft_config.get('rank', 8)
            alpha = peft_config.get('alpha', 16)
            # Wrap FFN layers with LoRA
            layer['ffn_lora1'] = LoRALayer(layer['ffn'][0], rank=rank, lora_alpha=alpha)
            layer['ffn_lora2'] = LoRALayer(layer['ffn'][2], rank=rank, lora_alpha=alpha)
        elif peft_method == 'adapter':
            adapter_dim = peft_config.get('adapter_dim', 64)
            # Add adapter after FFN (Houlsby et al. placement)
            layer['adapter'] = AdapterLayer(embed_dim, adapter_dim)
        
        return layer
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=0.02)
    
    def forward(self, input_ids: torch.Tensor, labels: Optional[torch.Tensor] = None):
        batch_size, seq_len = input_ids.shape
        device = input_ids.device
        
        # Embeddings
        token_emb = self.token_embed(input_ids)
        pos_emb = self.pos_embed(torch.arange(seq_len, device=device))
        x = token_emb + pos_emb
        
        # Prepare prefix if using prefix tuning
        prefix_key = None
        prefix_value = None
        if self.peft_method == 'prefix':
            prefix_key, prefix_value = self.prefix_tuning(batch_size, device)
            # Prefix length needs to be added to attention mask later
        
        # Transformer layers
        for layer_idx, layer in enumerate(self.layers):
            # Self-attention with residual
            residual = x
            x = layer['ln1'](x)
            
            # Concatenate prefix if present
            if prefix_key is not None:
                # Prepend prefix to keys and values
                pk = prefix_key[layer_idx]  # (batch, heads, prefix_len, head_dim)
                pv = prefix_value[layer_idx]
                # x shape: (batch, seq, embed)
                attn_out, _ = layer['attn'](x, x, x, need_weights=False)
            else:
                attn_out, _ = layer['attn'](x, x, x, need_weights=False)
            
            x = residual + attn_out
            
            # FFN with residual
            residual = x
            x = layer['ln2'](x)
            
            if self.peft_method == 'lora':
                # Use LoRA-wrapped FFN
                h = layer['ffn'][0](x)  # First linear
                h = layer['ffn_lora1'](h)  # LoRA modification
                h = layer['ffn'][1](h)  # GELU
                h = layer['ffn'][2](h)  # Second linear
                h = layer['ffn_lora2'](h)  # LoRA modification
                h = layer['ffn'][3](h)  # Dropout
                x = residual + h
            else:
                x = residual + layer['ffn'](x)
            
            # Adapter insertion (post-FFN)
            if self.peft_method == 'adapter':
                x = layer['adapter'](x)
        
        x = self.ln_f(x)
        logits = self.head(x)
        
        loss = None
        if labels is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100)
        
        return {'loss': loss, 'logits': logits, 'hidden_states': x}
    
    def count_parameters(self) -> Dict[str, int]:
        """Count total, trainable, and frozen parameters."""
        total = sum(p.numel() for p in self.parameters())
        trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        return {
            'total': total,
            'trainable': trainable,
            'frozen': total - trainable,
            'trainable_pct': 100 * trainable / total
        }


def create_classification_dataset(num_samples: int = 1000, seq_length: int = 128):
    """Create synthetic classification dataset for demonstration."""
    vocab_size = 1000
    num_classes = 2
    
    # Random token sequences
    input_ids = torch.randint(0, vocab_size, (num_samples, seq_length))
    labels = torch.randint(0, num_classes, (num_samples,))
    
    return input_ids, labels


def train_peft_model(method: str = 'lora', epochs: int = 5):
    """Training pipeline comparing different PEFT methods."""
    # Create base model
    peft_config = {
        'lora': {'rank': 8, 'alpha': 16, 'dropout': 0.05},
        'adapter': {'adapter_dim': 64},
        'prefix': {'prefix_length': 20, 'prefix_dim': 512, 'reparam': True}
    }[method]
    
    model = PEFTTransformer(
        vocab_size=1000,
        embed_dim=512,  # Smaller for demo
        num_layers=6,
        num_heads=8,
        peft_method=method,
        peft_config=peft_config
    )
    
    # Print parameter statistics
    stats = model.count_parameters()
    print(f"\n{method.upper()} Parameter Statistics:")
    print(f"Total: {stats['total']:,}")
    print(f"Trainable: {stats['trainable']:,} ({stats['trainable_pct']:.4f}%)")
    print(f"Frozen: {stats['frozen']:,}")
    
    # Prepare data
    input_ids, labels = create_classification_dataset()
    # For LM task, labels are input_ids shifted
    train_data = torch.utils.data.TensorDataset(input_ids, input_ids)
    dataloader = DataLoader(train_data, batch_size=32, shuffle=True)
    
    # Optimizer: only trainable parameters
    optimizer = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=5e-4,
        weight_decay=0.01
    )
    
    # Training
    losses = []
    for epoch in range(epochs):
        epoch_loss = 0.0
        num_batches = 0
        
        for batch_input, batch_labels in dataloader:
            optimizer.zero_grad()
            
            # Shift labels for next-token prediction
            labels_shifted = batch_labels[:, 1:].contiguous()
            inputs_shifted = batch_labels[:, :-1].contiguous()
            
            outputs = model(inputs_shifted, labels=labels_shifted)
            loss = outputs['loss']
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            epoch_loss += loss.item()
            num_batches += 1
        
        avg_loss = epoch_loss / num_batches
        losses.append(avg_loss)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")
    
    # Visualization
    plt.figure(figsize=(15, 5))
    
    # Loss curve
    plt.subplot(1, 3, 1)
    plt.plot(losses, marker='o')
    plt.title(f'{method.upper()} Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Cross-Entropy Loss')
    plt.grid(True)
    
    # Parameter efficiency comparison
    plt.subplot(1, 3, 2)
    methods = ['Full Fine-tune', 'LoRA', 'Adapter', 'Prefix']
    params = [100.0, 0.5, 2.0, 0.1]  # Approximate percentages
    colors = ['red', 'blue', 'green', 'orange']
    plt.bar(methods, params, color=colors)
    plt.ylabel('Trainable Parameters (%)')
    plt.title('Parameter Efficiency Comparison')
    plt.xticks(rotation=15)
    
    # LoRA weight visualization (if applicable)
    if method == 'lora':
        plt.subplot(1, 3, 3)
        # Collect LoRA B weights from first few layers
        lora_weights = []
        for layer in model.layers[:3]:
            if 'ffn_lora1' in layer:
                w = layer['ffn_lora1'].lora_B.detach().numpy().flatten()
                lora_weights.extend(w[:100])  # Sample
        
        plt.hist(lora_weights, bins=30, alpha=0.7)
        plt.title('LoRA B Matrix Weight Distribution')
        plt.xlabel('Weight Value')
        plt.ylabel('Frequency')
    
    plt.tight_layout()
    plt.savefig(f'{method}_training_analysis.png', dpi=300)
    plt.show()
    
    return model, stats


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--method', type=str, default='lora', 
                       choices=['lora', 'adapter', 'prefix'])
    parser.add_argument('--epochs', type=int, default=5)
    
    args = parser.parse_args()
    train_peft_model(args.method, args.epochs)

11.2.2 提示微调(Prompt Tuning)的软提示(Soft Prompts)与硬提示(Hard Prompts)离散优化,P-Tuning v2的多层提示注入

提示微调通过前置可优化的连续向量序列引导预训练模型行为,避免参数更新的存储开销。Lester等人提出的软提示方法冻结整个预训练网络,仅在输入层嵌入空间中添加可训练的虚拟token表示,通过反向传播优化提示向量以最大化下游任务性能。连续提示在语义空间中寻找最优任务特定区域,相比离散的文本提示具有更高的表达灵活性。

P-Tuning方法结合离散提示与连续优化,使用双向LSTM或MLP编码器生成上下文相关的虚拟token表示,解决静态提示位置敏感性。Liu等人发展的P-Tuning v2架构将提示向量注入深层Transformer的每一层,而非仅限于输入层,通过在每一解码器层前置可训练前缀实现深度任务适应。多层提示注入平衡了可训练参数量与模型容量利用,在小样本学习与知识探测任务中展现出与全参数微调相近的性能表现。

Python

"""
Script: prompt_tuning_methods.py
Content: Implementation of Soft Prompt Tuning with continuous embeddings,
         Hard Prompt discrete optimization, and P-Tuning v2 with deep prompt injection.
Usage: python prompt_tuning_methods.py --method soft_prompt --num_prompt_tokens 20 --epochs 10
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import random
import numpy as np
from typing import List, Dict, Optional
import matplotlib.pyplot as plt
import argparse


class SoftPromptTuning(nn.Module):
    """
    Soft Prompt Tuning (Lester et al., 2021).
    Prepends trainable continuous vectors to input embeddings.
    """
    def __init__(self, base_model: nn.Module, num_prompt_tokens: int = 20,
                 prompt_dim: int = 768, vocab_size: int = 32000):
        super().__init__()
        self.base_model = base_model
        self.num_prompt_tokens = num_prompt_tokens
        
        # Freeze base model parameters
        for param in self.base_model.parameters():
            param.requires_grad = False
        
        # Initialize soft prompts from random vocabulary embeddings
        self.soft_prompts = nn.Parameter(
            torch.randn(num_prompt_tokens, prompt_dim) * 0.02
        )
        
        # Optional: Initialize from actual word embeddings for better stability
        embedding_layer = list(base_model.children())[0]  # Usually embedding
        if isinstance(embedding_layer, nn.Embedding):
            with torch.no_grad():
                indices = torch.randint(0, vocab_size, (num_prompt_tokens,))
                self.soft_prompts.data = embedding_layer(indices).data.clone()
        
        self.prompt_dropout = nn.Dropout(0.1)
    
    def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
                labels: Optional[torch.Tensor] = None):
        """
        Forward pass with soft prompt prepending.
        Input: [soft_prompts] + [input_ids]
        """
        batch_size = input_ids.size(0)
        
        # Get input embeddings from frozen base model
        input_embeds = self.get_input_embeddings(input_ids)
        
        # Expand soft prompts for batch
        prompt_embeds = self.soft_prompts.unsqueeze(0).expand(batch_size, -1, -1)
        prompt_embeds = self.prompt_dropout(prompt_embeds)
        
        # Concatenate: [prompts; inputs]
        combined_embeds = torch.cat([prompt_embeds, input_embeds], dim=1)
        
        # Adjust attention mask for prepended prompts
        if attention_mask is not None:
            prompt_mask = torch.ones(batch_size, self.num_prompt_tokens, 
                                    dtype=attention_mask.dtype, device=attention_mask.device)
            combined_mask = torch.cat([prompt_mask, attention_mask], dim=1)
        else:
            combined_mask = None
        
        # Pass through frozen transformer
        # Note: Assuming base model accepts inputs_embeds (simplified here)
        outputs = self.forward_from_embeds(combined_embeds, combined_mask)
        
        # Calculate loss (only on non-prompt positions)
        loss = None
        if labels is not None:
            # Shift logits to exclude prompt positions
            logits = outputs['logits'][:, self.num_prompt_tokens:, :]
            loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), labels.reshape(-1),
                                ignore_index=-100)
        
        return {
            'loss': loss,
            'logits': outputs['logits'][:, self.num_prompt_tokens:, :],
            'prompt_embeds': prompt_embeds
        }
    
    def get_input_embeddings(self, input_ids):
        """Extract embeddings from base model."""
        # Simplified: assume first layer is embedding
        embed_layer = list(self.base_model.children())[0]
        return embed_layer(input_ids)
    
    def forward_from_embeds(self, embeds, attention_mask):
        """Forward pass through transformer from embeddings."""
        # Simplified forward for demonstration
        # In practice, would use base_model's full forward with inputs_embeds
        hidden = embeds
        for layer in list(self.base_model.children())[1:-1]:
            hidden = layer(hidden)
        logits = list(self.base_model.children())[-1](hidden)
        return {'logits': logits}


class HardPromptOptimizer:
    """
    Discrete optimization for hard prompts using gradient-based search.
    Implements AutoPrompt-style discrete token optimization.
    """
    def __init__(self, model: nn.Module, tokenizer, num_prompt_tokens: int = 5,
                 vocab_size: int = 32000):
        self.model = model
        self.tokenizer = tokenizer
        self.num_prompt_tokens = num_prompt_tokens
        self.vocab_size = vocab_size
        
        # Initialize with random tokens
        self.prompt_token_ids = torch.randint(100, vocab_size, (num_prompt_tokens,))
        self.best_tokens = self.prompt_token_ids.clone()
        self.best_loss = float('inf')
    
    def token_gradient_guided_search(self, batch_input: torch.Tensor, 
                                     batch_labels: torch.Tensor,
                                     top_k: int = 10) -> torch.Tensor:
        """
        One-step gradient-guided token replacement.
        Computes gradients w.r.t. prompt embeddings and selects 
        nearest neighbor tokens in embedding space.
        """
        # Temporarily make prompt tokens differentiable via embedding layer
        prompt_embeds = self.model.get_input_embeddings()(self.prompt_token_ids.to(batch_input.device))
        prompt_embeds.requires_grad = True
        
        # Forward with current prompts
        batch_size = batch_input.size(0)
        expanded_prompts = prompt_embeds.unsqueeze(0).expand(batch_size, -1, -1)
        input_embeds = self.model.get_input_embeddings()(batch_input)
        
        combined = torch.cat([expanded_prompts, input_embeds], dim=1)
        outputs = self.model.forward_from_embeds(combined, None)
        
        logits = outputs['logits'][:, self.num_prompt_tokens:, :]
        loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), batch_labels.reshape(-1))
        
        # Get gradients
        loss.backward()
        grad = prompt_embeds.grad  # (num_prompts, embed_dim)
        
        # Find top-k nearest tokens for each prompt position based on -grad direction
        embeddings = self.model.get_input_embeddings().weight  # (vocab, embed_dim)
        
        new_tokens = []
        for i in range(self.num_prompt_tokens):
            # Search direction: current_embed - learning_rate * grad
            direction = prompt_embeds[i] - 0.1 * grad[i]
            
            # Find nearest tokens
            distances = torch.norm(embeddings - direction.unsqueeze(0), dim=1)
            top_indices = torch.topk(distances, k=top_k, largest=False).indices
            
            # Select best among top-k by evaluation
            best_token = self.prompt_token_ids[i]
            best_local_loss = float('inf')
            
            for candidate in top_indices:
                temp_prompts = self.prompt_token_ids.clone()
                temp_prompts[i] = candidate
                # Quick eval (simplified)
                with torch.no_grad():
                    c_embeds = self.model.get_input_embeddings()(temp_prompts.to(batch_input.device))
                    c_embeds = c_embeds.unsqueeze(0).expand(batch_size, -1, -1)
                    c_combined = torch.cat([c_embeds, input_embeds], dim=1)
                    c_out = self.model.forward_from_embeds(c_combined, None)
                    c_logits = c_out['logits'][:, self.num_prompt_tokens:, :]
                    c_loss = F.cross_entropy(c_logits.reshape(-1, c_logits.size(-1)), 
                                          batch_labels.reshape(-1))
                
                if c_loss < best_local_loss:
                    best_local_loss = c_loss
                    best_token = candidate
            
            new_tokens.append(best_token.item())
        
        self.prompt_token_ids = torch.tensor(new_tokens)
        
        if loss.item() < self.best_loss:
            self.best_loss = loss.item()
            self.best_tokens = self.prompt_token_ids.clone()
        
        return self.prompt_token_ids
    
    def get_current_prompt_text(self) -> str:
        """Convert current token IDs to text."""
        return self.tokenizer.decode(self.prompt_token_ids.tolist())


class PTuningV2(nn.Module):
    """
    P-Tuning v2 (Liu et al., 2022) with deep prompt tuning.
    Injects trainable prompts into every layer of the transformer.
    """
    def __init__(self, base_model: nn.Module, num_layers: int = 12,
                 num_prompt_tokens: int = 20, hidden_size: int = 768,
                 num_heads: int = 12):
        super().__init__()
        self.base_model = base_model
        self.num_layers = num_layers
        self.num_prompt_tokens = num_prompt_tokens
        
        # Freeze base model
        for param in self.base_model.parameters():
            param.requires_grad = False
        
        # Prompt encoder (LSTM + MLP) for input layer
        self.prompt_encoder = nn.LSTM(
            hidden_size, hidden_size // 2, num_layers=2,
            bidirectional=True, batch_first=True
        )
        self.prompt_mlp = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, num_prompt_tokens * hidden_size)
        )
        
        # Deep prompts for each layer: (num_layers, num_prompts, hidden_size)
        # Can also use reparameterization like Prefix Tuning
        self.deep_prompts = nn.ParameterList([
            nn.Parameter(torch.randn(num_prompt_tokens, hidden_size) * 0.02)
            for _ in range(num_layers)
        ])
        
        # Anchor tokens for prompt encoder (fixed pseudo tokens)
        self.anchor_tokens = nn.Parameter(torch.randn(10, hidden_size), requires_grad=False)
    
    def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
                labels: Optional[torch.Tensor] = None):
        """
        Forward with deep prompt injection at every layer.
        """
        batch_size = input_ids.size(0)
        
        # Generate input prompts via LSTM encoder
        anchor_expanded = self.anchor_tokens.unsqueeze(0).expand(batch_size, -1, -1)
        lstm_out, _ = self.prompt_encoder(anchor_expanded)
        prompt_gen = self.prompt_mlp(lstm_out[:, -1, :])  # Use last hidden state
        input_prompts = prompt_gen.view(batch_size, self.num_prompt_tokens, -1)
        
        # Get input embeddings and prepend prompts
        input_embeds = self.get_input_embeddings(input_ids)
        combined_embeds = torch.cat([input_prompts, input_embeds], dim=1)
        
        # Pass through transformer layers with deep prompt injection
        hidden = combined_embeds
        for layer_idx, layer in enumerate(self.base_model.transformer_layers):
            # Inject deep prompts at this layer
            layer_prompts = self.deep_prompts[layer_idx].unsqueeze(0).expand(batch_size, -1, -1)
            hidden = torch.cat([layer_prompts, hidden[:, self.num_prompt_tokens:, :]], dim=1)
            
            # Process through layer
            hidden = layer(hidden)
        
        # Final logits (excluding prompt positions)
        logits = self.base_model.lm_head(hidden[:, self.num_prompt_tokens:, :])
        
        loss = None
        if labels is not None:
            loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), labels.reshape(-1),
                                  ignore_index=-100)
        
        return {
            'loss': loss,
            'logits': logits,
            'prompt_loss': 0.0  # Could add regularization
        }
    
    def get_input_embeddings(self, input_ids):
        return self.base_model.get_input_embeddings()(input_ids)


class SentimentDataset(Dataset):
    """Simple sentiment classification dataset for prompt tuning demonstration."""
    def __init__(self, texts: List[str], labels: List[int], tokenizer, max_length: int = 128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        
        # Tokenize
        tokens = self.tokenizer.encode(text, max_length=self.max_length, truncation=True)
        padding = [0] * (self.max_length - len(tokens))
        input_ids = tokens + padding
        
        return {
            'input_ids': torch.tensor(input_ids, dtype=torch.long),
            'labels': torch.tensor(label, dtype=torch.long),
            'attention_mask': torch.tensor([1]*len(tokens) + [0]*len(padding), dtype=torch.long)
        }


def create_synthetic_sentiment_data(num_samples: int = 1000):
    """Create synthetic sentiment analysis data."""
    positive_templates = [
        "This movie is absolutely fantastic and wonderful",
        "I really love this product, it works great",
        "The service was excellent and very satisfying",
        "Amazing quality, highly recommend to everyone",
        "Best experience ever, truly outstanding performance"
    ]
    
    negative_templates = [
        "This is terrible and completely disappointing",
        "I hate this product, it does not work at all",
        "The worst service I have ever experienced",
        "Very bad quality, do not recommend anyone",
        "Awful experience, totally waste of money"
    ]
    
    texts = []
    labels = []
    
    for _ in range(num_samples // 2):
        texts.append(random.choice(positive_templates))
        labels.append(1)
        texts.append(random.choice(negative_templates))
        labels.append(0)
    
    return texts, labels


class SimpleTokenizer:
    """Minimal tokenizer for demonstration."""
    def __init__(self, vocab_size: int = 1000):
        self.vocab_size = vocab_size
        self.word2id = {'<pad>': 0, '<unk>': 1}
        self.id2word = {0: '<pad>', 1: '<unk>'}
    
    def encode(self, text: str, max_length: int = 128, truncation: bool = True) -> List[int]:
        words = text.lower().split()
        ids = []
        for w in words:
            if w not in self.word2id:
                if len(self.word2id) < self.vocab_size:
                    self.word2id[w] = len(self.word2id)
                    self.id2word[self.word2id[w]] = w
                else:
                    ids.append(1)  # <unk>
                    continue
            ids.append(self.word2id[w])
        
        if truncation and len(ids) > max_length:
            ids = ids[:max_length]
        return ids


def train_prompt_tuning(method: str = 'soft', epochs: int = 10):
    """Training pipeline for different prompt tuning methods."""
    # Setup
    tokenizer = SimpleTokenizer(vocab_size=1000)
    texts, labels = create_synthetic_sentiment_data(500)
    
    # Simple base model
    class SimpleTransformer(nn.Module):
        def __init__(self, vocab_size: int = 1000, hidden: int = 256, num_layers: int = 4):
            super().__init__()
            self.embedding = nn.Embedding(vocab_size, hidden)
            self.transformer_layers = nn.ModuleList([
                nn.TransformerEncoderLayer(hidden, 4, 512, batch_first=True)
                for _ in range(num_layers)
            ])
            self.lm_head = nn.Linear(hidden, vocab_size)
            self.classifier = nn.Linear(hidden, 2)
            self.pooling = nn.AdaptiveAvgPool1d(1)
        
        def forward(self, input_ids, attention_mask=None, labels=None):
            x = self.embedding(input_ids)
            for layer in self.transformer_layers:
                x = layer(x)
            # Classification
            x = x.mean(dim=1)  # Average pooling
            logits = self.classifier(x)
            
            loss = None
            if labels is not None:
                loss = F.cross_entropy(logits, labels)
            return {'loss': loss, 'logits': logits, 'hidden_states': x}
        
        def get_input_embeddings(self):
            return self.embedding
    
    base_model = SimpleTransformer()
    
    # Apply PEFT method
    if method == 'soft':
        model = SoftPromptTuning(base_model, num_prompt_tokens=20, prompt_dim=256, vocab_size=1000)
        # Modify forward to return classification logits
        original_forward = model.forward
        def soft_forward(input_ids, attention_mask=None, labels=None):
            # Use base model's classifier on prompt-enhanced representations
            batch_size = input_ids.size(0)
            input_embeds = model.get_input_embeddings(input_ids)
            prompt_embeds = model.soft_prompts.unsqueeze(0).expand(batch_size, -1, -1)
            combined = torch.cat([prompt_embeds, input_embeds], dim=1)
            
            # Through transformer
            hidden = combined
            for layer in model.base_model.transformer_layers:
                hidden = layer(hidden)
            
            # Pool and classify (skip prompt positions)
            hidden = hidden[:, model.num_prompt_tokens:, :].mean(dim=1)
            logits = model.base_model.classifier(hidden)
            
            loss = F.cross_entropy(logits, labels) if labels is not None else None
            return {'loss': loss, 'logits': logits}
        model.forward = soft_forward
        
    elif method == 'p_tuning_v2':
        # Simplified P-Tuning v2
        model = base_model  # Direct modification for demo
        num_prompts = 10
        deep_prompts = nn.Parameter(torch.randn(4, num_prompts, 256) * 0.02)
        
        original_forward = base_model.forward
        def ptuning_forward(input_ids, attention_mask=None, labels=None):
            x = model.embedding(input_ids)
            # Add prompts at each layer
            for i, layer in enumerate(model.transformer_layers):
                prompt = deep_prompts[i].unsqueeze(0).expand(x.size(0), -1, -1)
                x = torch.cat([prompt, x], dim=1)
                x = layer(x)
                x = x[:, num_prompts:, :]  # Remove prompts for next layer
            
            x = x.mean(dim=1)
            logits = model.classifier(x)
            loss = F.cross_entropy(logits, labels) if labels is not None else None
            return {'loss': loss, 'logits': logits}
        base_model.forward = ptuning_forward
        base_model.deep_prompts = deep_prompts
        
        # Only optimize deep prompts
        for param in base_model.parameters():
            param.requires_grad = False
        deep_prompts.requires_grad = True
    
    # Dataset
    dataset = SentimentDataset(texts, labels, tokenizer, max_length=32)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
    
    # Optimizer
    if method == 'soft':
        optimizer = torch.optim.AdamW([model.soft_prompts], lr=0.1)
    elif method == 'p_tuning_v2':
        optimizer = torch.optim.AdamW([base_model.deep_prompts], lr=0.01)
    
    # Training
    losses = []
    accuracies = []
    
    for epoch in range(epochs):
        epoch_loss = 0.0
        correct = 0
        total = 0
        
        for batch in dataloader:
            optimizer.zero_grad()
            
            outputs = model(batch['input_ids'], batch['attention_mask'], batch['labels'])
            loss = outputs['loss']
            
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            
            # Accuracy
            preds = outputs['logits'].argmax(dim=-1)
            correct += (preds == batch['labels']).sum().item()
            total += batch['labels'].size(0)
        
        avg_loss = epoch_loss / len(dataloader)
        acc = correct / total
        losses.append(avg_loss)
        accuracies.append(acc)
        
        print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}, Acc: {acc:.4f}")
    
    # Visualization
    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 3, 1)
    plt.plot(losses, marker='o', label='Training Loss')
    plt.title(f'{method.upper()} Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Cross-Entropy Loss')
    plt.grid(True)
    
    plt.subplot(1, 3, 2)
    plt.plot(accuracies, marker='s', color='green', label='Accuracy')
    plt.title(f'{method.upper()} Classification Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.ylim(0, 1)
    plt.grid(True)
    
    # Visualize prompt embeddings (for soft prompt)
    if method == 'soft':
        plt.subplot(1, 3, 3)
        prompt_embeds = model.soft_prompts.detach().numpy()
        # PCA to 2D
        if prompt_embeds.shape[1] > 2:
            from sklearn.decomposition import PCA
            pca = PCA(n_components=2)
            prompt_2d = pca.fit_transform(prompt_embeds)
        else:
            prompt_2d = prompt_embeds
        
        plt.scatter(prompt_2d[:, 0], prompt_2d[:, 1], c=range(len(prompt_2d)), cmap='viridis')
        for i, (x, y) in enumerate(prompt_2d):
            plt.annotate(f'P{i}', (x, y), fontsize=8)
        plt.title('Soft Prompt Embeddings (PCA)')
        plt.xlabel('PC1')
        plt.ylabel('PC2')
    
    plt.tight_layout()
    plt.savefig(f'{method}_prompt_tuning.png', dpi=300)
    plt.show()
    
    return model


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--method', type=str, default='soft',
                       choices=['soft', 'p_tuning_v2'])
    parser.add_argument('--num_prompt_tokens', type=int, default=20)
    parser.add_argument('--epochs', type=int, default=10)
    
    args = parser.parse_args()
    train_prompt_tuning(args.method, args.epochs)

11.2.3 指令微调(Instruction Tuning)的FLAN格式与思维链(Chain-of-Thought)数据构建,RLHF的PPO算法与奖励模型训练细节

指令微调通过格式化自然语言指令与对应响应提升模型遵循用户意图的能力。Wei等人提出的FLAN数据集将传统NLP任务转换为指令描述格式,包含多种任务类型的语言化说明,训练模型理解任务指令并生成适当格式的输出。思维链提示通过在演示示例中插入中间推理步骤,引导语言模型生成显式的逻辑推导过程,显著提升数学计算与常识推理任务的表现。数据构建需包含逐步推理链与最终答案的配对,通过监督学习强化模型的系统二思考能力。

基于人类反馈的强化学习框架通过训练奖励模型捕捉人类偏好,优化语言模型生成符合人类价值观的响应。Ouyang等人开发的InstructGPT流程首先收集比较数据训练奖励模型预测人类偏好排序,随后使用近端策略优化算法微调策略模型,限制更新幅度以避免偏离预训练分布过远。PPO引入剪裁替代目标确保策略更新稳定性,通过 KL 散度惩罚项保持优化后的策略与参考策略相近,奖励模型提供逐序列偏好信号引导生成质量提升。

Python

"""
Script: instruction_tuning_rlhf.py
Content: Implementation of Instruction Tuning with FLAN formatting,
         Chain-of-Thought data augmentation, and RLHF with PPO training.
Usage: python instruction_tuning_rlhf.py --stage sft --data_path instructions.json --epochs 3
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random
from typing import List, Dict, Tuple, Optional
import matplotlib.pyplot as plt
import json
import argparse
from collections import deque


class InstructionDataset(Dataset):
    """
    Dataset for instruction tuning with FLAN formatting.
    Supports multiple task formats and chain-of-thought reasoning.
    """
    def __init__(self, data_path: str, tokenizer, max_length: int = 512,
                 include_cot: bool = False):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.include_cot = include_cot
        
        with open(data_path, 'r', encoding='utf-8') as f:
            self.data = json.load(f)
        
        # FLAN-style task templates
        self.templates = {
            'qa': [
                "Question: {input}\nAnswer: {output}",
                "Answer the following question: {input}\n{output}",
                "Q: {input}\nA: {output}"
            ],
            'classification': [
                "Task: Classify the text\nInput: {input}\nLabel: {output}",
                "Determine the category: {input}\nCategory: {output}"
            ],
            'reasoning': [
                "Problem: {input}\nLet's think step by step. {chain}\nTherefore, the answer is: {output}",
                "Q: {input}\nA: {chain} The answer is {output}."
            ]
        }
    
    def format_example(self, example: Dict) -> str:
        """Apply FLAN template formatting."""
        task_type = example.get('task_type', 'qa')
        templates = self.templates.get(task_type, self.templates['qa'])
        template = random.choice(templates)
        
        # Include chain-of-thought if available and enabled
        if self.include_cot and 'chain' in example:
            return template.format(
                input=example['input'],
                output=example['output'],
                chain=example['chain']
            )
        else:
            # Simple version without CoT
            simple_template = "Input: {input}\nOutput: {output}"
            return simple_template.format(
                input=example['input'],
                output=example['output']
            )
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        example = self.data[idx]
        text = self.format_example(example)
        
        tokens = self.tokenizer.encode(text, max_length=self.max_length, truncation=True)
        labels = tokens.copy()
        
        # Mask input portion (only compute loss on output)
        input_end = text.find('Output:') + len('Output:')
        input_tokens_len = len(self.tokenizer.encode(text[:input_end]))
        
        for i in range(input_tokens_len):
            labels[i] = -100  # Ignore input in loss computation
        
        return {
            'input_ids': torch.tensor(tokens, dtype=torch.long),
            'labels': torch.tensor(labels, dtype=torch.long),
            'attention_mask': torch.tensor([1]*len(tokens), dtype=torch.long)
        }


class ChainOfThoughtBuilder:
    """
    Utility for constructing Chain-of-Thought reasoning examples.
    Augments existing datasets with intermediate reasoning steps.
    """
    def __init__(self, model: nn.Module, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
    
    def generate_cot(self, question: str, num_samples: int = 5) -> List[str]:
        """
        Generate multiple reasoning chains using diverse decoding.
        """
        cot_prompts = [
            f"Q: {question}\nA: Let's think step by step.",
            f"Question: {question}\nLet's solve this carefully:",
            f"Problem: {question}\nStep-by-step solution:"
        ]
        
        chains = []
        for prompt in cot_prompts[:num_samples]:
            # Generate with different temperatures
            temp = random.uniform(0.5, 1.0)
            chain = self._generate_text(prompt, max_length=100, temperature=temp)
            # Extract reasoning part
            chains.append(chain)
        
        return chains
    
    def _generate_text(self, prompt: str, max_length: int = 100, 
                      temperature: float = 1.0) -> str:
        """Simple greedy generation for CoT construction."""
        input_ids = torch.tensor([self.tokenizer.encode(prompt)], dtype=torch.long)
        
        with torch.no_grad():
            for _ in range(max_length):
                outputs = self.model(input_ids)
                logits = outputs['logits'][:, -1, :] / temperature
                probs = F.softmax(logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)
                input_ids = torch.cat([input_ids, next_token], dim=1)
        
        return self.tokenizer.decode(input_ids[0].tolist())
    
    def filter_valid_cots(self, chains: List[str], answer: str) -> List[str]:
        """
        Filter chains that lead to correct answer.
        Uses simple string matching for demonstration.
        """
        valid = []
        for chain in chains:
            # Check if final answer matches
            if answer.lower() in chain.lower():
                valid.append(chain)
        return valid if valid else chains[:1]


class RewardModel(nn.Module):
    """
    Reward Model for RLHF based on Bradley-Terry preference model.
    Outputs scalar reward for a given (prompt, response) pair.
    """
    def __init__(self, base_model: nn.Module, hidden_size: int = 768):
        super().__init__()
        self.base_model = base_model
        
        # Freeze base model or keep trainable depending on compute budget
        for param in self.base_model.parameters():
            param.requires_grad = False
        
        # Reward head: maps hidden state to scalar reward
        self.reward_head = nn.Sequential(
            nn.Linear(hidden_size, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, 1)
        )
    
    def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Compute reward scores for input sequences.
        Uses end-of-sequence representation for scoring.
        """
        # Get hidden states from base model
        outputs = self.base_model(input_ids, attention_mask=attention_mask)
        hidden_states = outputs['hidden_states']  # (batch, seq, hidden)
        
        # Use last non-padding token representation
        if attention_mask is not None:
            # Find last real token
            last_pos = attention_mask.sum(dim=1) - 1
            batch_indices = torch.arange(hidden_states.size(0))
            pooled = hidden_states[batch_indices, last_pos, :]  # (batch, hidden)
        else:
            pooled = hidden_states[:, -1, :]
        
        reward = self.reward_head(pooled).squeeze(-1)  # (batch,)
        return reward


class PPOTrainer:
    """
    PPO (Proximal Policy Optimization) trainer for RLHF.
    Implements clipped surrogate objective and KL penalty.
    """
    def __init__(self, policy_model: nn.Module, ref_model: nn.Module,
                 reward_model: RewardModel, tokenizer,
                 clip_epsilon: float = 0.2, 
                 value_coef: float = 0.5,
                 entropy_coef: float = 0.01,
                 kl_coef: float = 0.2,
                 gamma: float = 1.0,
                 lam: float = 0.95):
        self.policy_model = policy_model
        self.ref_model = ref_model  # Frozen reference model
        self.reward_model = reward_model
        self.tokenizer = tokenizer
        
        self.clip_epsilon = clip_epsilon
        self.value_coef = value_coef
        self.entropy_coef = entropy_coef
        self.kl_coef = kl_coef
        self.gamma = gamma
        self.lam = lam
        
        # Optimizer for policy
        self.optimizer = torch.optim.AdamW(policy_model.parameters(), lr=1e-5)
        
        # Value function (can share backbone with policy)
        self.value_head = nn.Linear(768, 1).to(next(policy_model.parameters()).device)
        self.value_optimizer = torch.optim.AdamW(self.value_head.parameters(), lr=1e-5)
    
    def generate_responses(self, prompts: List[str], max_length: int = 50) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Generate responses from current policy.
        Returns token IDs and log probabilities.
        """
        batch_size = len(prompts)
        input_ids = [self.tokenizer.encode(p) for p in prompts]
        max_prompt_len = max(len(p) for p in input_ids)
        
        # Pad prompts
        padded_inputs = torch.zeros(batch_size, max_prompt_len, dtype=torch.long)
        for i, p in enumerate(input_ids):
            padded_inputs[i, :len(p)] = torch.tensor(p)
        
        # Generate autoregressively
        generated = padded_inputs.clone()
        log_probs_list = []
        
        with torch.no_grad():
            for _ in range(max_length):
                outputs = self.policy_model(generated)
                logits = outputs['logits'][:, -1, :]
                probs = F.softmax(logits, dim=-1)
                dist = Categorical(probs)
                
                next_token = dist.sample()
                log_prob = dist.log_prob(next_token)
                
                generated = torch.cat([generated, next_token.unsqueeze(1)], dim=1)
                log_probs_list.append(log_prob)
        
        log_probs = torch.stack(log_probs_list, dim=1)  # (batch, gen_len)
        return generated, log_probs
    
    def compute_rewards(self, prompt_response_pairs: torch.Tensor, 
                       attention_mask: torch.Tensor) -> torch.Tensor:
        """
        Compute rewards using reward model and KL penalty.
        reward = r_theta(x,y) - beta * KL(pi_theta || pi_ref)
        """
        # Reward model score
        with torch.no_grad():
            reward_scores = self.reward_model(prompt_response_pairs, attention_mask)
            
            # Reference model log probs for KL penalty
            ref_outputs = self.ref_model(prompt_response_pairs, attention_mask)
            ref_logits = ref_outputs['logits']
            ref_log_probs = F.log_softmax(ref_logits, dim=-1)
            
            # Policy log probs
            policy_outputs = self.policy_model(prompt_response_pairs, attention_mask)
            policy_logits = policy_outputs['logits']
            policy_log_probs = F.log_softmax(policy_logits, dim=-1)
        
        # KL divergence penalty (approximate)
        kl_penalty = (policy_log_probs - ref_log_probs).sum(dim=-1).mean(dim=-1)
        
        # Final reward
        final_rewards = reward_scores - self.kl_coef * kl_penalty
        return final_rewards
    
    def compute_advantages(self, rewards: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
        """
        Compute Generalized Advantage Estimation (GAE).
        """
        advantages = []
        gae = 0
        
        # Simple case: assume single step for demo (would be loop for full trajectories)
        for t in reversed(range(len(rewards))):
            if t == len(rewards) - 1:
                next_value = 0
            else:
                next_value = values[t + 1]
            
            delta = rewards[t] + self.gamma * next_value - values[t]
            gae = delta + self.gamma * self.lam * gae
            advantages.insert(0, gae)
        
        return torch.tensor(advantages)
    
    def ppo_update(self, prompts: List[str], old_log_probs: torch.Tensor,
                   rewards: torch.Tensor, batch_size: int = 4, epochs: int = 4):
        """
        Perform PPO update with clipped surrogate objective.
        """
        # Get current generation and log probs
        generated_ids, current_log_probs = self.generate_responses(prompts)
        
        # Compute advantages (simplified for single token sequences)
        with torch.no_grad():
            values = self.value_head(
                self.policy_model(generated_ids)['hidden_states'][:, -1, :]
            ).squeeze()
        
        advantages = rewards - values
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        
        # PPO epochs
        for _ in range(epochs):
            # Compute ratio
            ratio = torch.exp(current_log_probs.sum(dim=1) - old_log_probs.sum(dim=1))
            
            # Clipped surrogate objective
            surr1 = ratio * advantages
            surr2 = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * advantages
            policy_loss = -torch.min(surr1, surr2).mean()
            
            # Value loss
            value_pred = self.value_head(
                self.policy_model(generated_ids)['hidden_states'][:, -1, :]
            ).squeeze()
            value_loss = F.mse_loss(value_pred, rewards)
            
            # Entropy bonus
            entropy = -(current_log_probs * torch.exp(current_log_probs)).sum(dim=1).mean()
            
            # Total loss
            loss = policy_loss + self.value_coef * value_loss - self.entropy_coef * entropy
            
            self.optimizer.zero_grad()
            self.value_optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.policy_model.parameters(), 1.0)
            self.optimizer.step()
            self.value_optimizer.step()
        
        return loss.item()
    
    def train_step(self, prompts: List[str]) -> Dict[str, float]:
        """Single RLHF training step."""
        # Generate responses
        generated_ids, log_probs = self.generate_responses(prompts)
        
        # Compute rewards
        attention_mask = torch.ones_like(generated_ids)
        rewards = self.compute_rewards(generated_ids, attention_mask)
        
        # PPO update
        loss = self.ppo_update(prompts, log_probs, rewards)
        
        return {
            'loss': loss,
            'mean_reward': rewards.mean().item(),
            'kl_div': (self.kl_coef * (log_probs - log_probs)).abs().mean().item()
        }


class SimpleTransformer(nn.Module):
    """Simplified transformer for instruction tuning and RLHF."""
    def __init__(self, vocab_size: int = 1000, hidden: int = 256, num_layers: int = 4):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, hidden)
        self.layers = nn.ModuleList([
            nn.TransformerDecoderLayer(hidden, 4, 512, batch_first=True)
            for _ in range(num_layers)
        ])
        self.ln = nn.LayerNorm(hidden)
        self.lm_head = nn.Linear(hidden, vocab_size)
        
        # For hidden state access
        self.output_hidden_states = True
    
    def forward(self, input_ids, attention_mask=None, labels=None):
        x = self.embedding(input_ids)
        # Create causal mask
        mask = torch.triu(torch.ones(x.size(1), x.size(1)), diagonal=1).bool().to(x.device)
        
        for layer in self.layers:
            x = layer(x, memory=x, tgt_mask=mask)
        
        hidden = self.ln(x)
        logits = self.lm_head(hidden)
        
        loss = None
        if labels is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1),
                                  ignore_index=-100)
        
        return {
            'logits': logits,
            'loss': loss,
            'hidden_states': hidden
        }


def train_instruction_tuning(data_path: str, epochs: int = 3, include_cot: bool = False):
    """Supervised fine-tuning with instruction formatting."""
    # Setup
    tokenizer = SimpleTokenizer()
    
    # Create dummy data if not exists
    dummy_data = [
        {
            'input': 'What is 2+2?',
            'output': '4',
            'chain': 'First, we add 2 and 2. 2 plus 2 equals 4.',
            'task_type': 'reasoning'
        },
        {
            'input': 'Is this positive? "I love it"',
            'output': 'positive',
            'task_type': 'classification'
        }
    ] * 50
    
    if not os.path.exists(data_path):
        import os
        os.makedirs(os.path.dirname(data_path) if os.path.dirname(data_path) else '.', exist_ok=True)
        with open(data_path, 'w') as f:
            json.dump(dummy_data, f)
    
    dataset = InstructionDataset(data_path, tokenizer, include_cot=include_cot)
    dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
    
    model = SimpleTransformer()
    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4)
    
    losses = []
    
    for epoch in range(epochs):
        epoch_loss = 0.0
        
        for batch in dataloader:
            optimizer.zero_grad()
            
            outputs = model(batch['input_ids'], batch['attention_mask'], batch['labels'])
            loss = outputs['loss']
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            epoch_loss += loss.item()
        
        avg_loss = epoch_loss / len(dataloader)
        losses.append(avg_loss)
        print(f"SFT Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")
    
    # Visualization
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(losses, marker='o')
    plt.title('Instruction Tuning Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid(True)
    
    # Show CoT effectiveness if included
    if include_cot:
        plt.subplot(1, 2, 2)
        # Compare with non-CoT (mock data for demo)
        no_cot_loss = [l * 1.2 for l in losses]  # Usually higher without CoT
        plt.plot(losses, label='With CoT', marker='o')
        plt.plot(no_cot_loss, label='Without CoT', marker='s')
        plt.legend()
        plt.title('Chain-of-Thought Effectiveness')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.grid(True)
    
    plt.tight_layout()
    plt.savefig('instruction_tuning.png', dpi=300)
    plt.show()
    
    return model


def train_rlhf(policy_model: nn.Module, reward_model_path: str, epochs: int = 5):
    """RLHF training with PPO."""
    # Initialize models
    ref_model = SimpleTransformer()  # Frozen reference
    ref_model.load_state_dict(policy_model.state_dict())
    for param in ref_model.parameters():
        param.requires_grad = False
    
    # Mock reward model
    reward_model = RewardModel(SimpleTransformer())
    
    tokenizer = SimpleTokenizer()
    ppo_trainer = PPOTrainer(policy_model, ref_model, reward_model, tokenizer)
    
    # Training prompts
    prompts = [
        "What is the capital of France?",
        "Solve: 15 * 4",
        "Write a haiku about nature."
    ] * 10  # Repeat for batch size
    
    rewards_history = []
    kl_history = []
    
    for epoch in range(epochs):
        metrics = ppo_trainer.train_step(prompts)
        rewards_history.append(metrics['mean_reward'])
        kl_history.append(metrics['kl_div'])
        
        print(f"RLHF Epoch {epoch+1}/{epochs}, "
              f"Reward: {metrics['mean_reward']:.4f}, "
              f"KL: {metrics['kl_div']:.4f}, "
              f"Loss: {metrics['loss']:.4f}")
    
    # Visualization
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(rewards_history, marker='o', color='green')
    plt.title('PPO Training: Mean Reward')
    plt.xlabel('Epoch')
    plt.ylabel('Reward')
    plt.grid(True)
    
    plt.subplot(1, 2, 2)
    plt.plot(kl_history, marker='s', color='red')
    plt.axhline(y=ppo_trainer.kl_coef, color='r', linestyle='--', label='Target KL')
    plt.title('KL Divergence from Reference Model')
    plt.xlabel('Epoch')
    plt.ylabel('KL Divergence')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.savefig('rlhf_training.png', dpi=300)
    plt.show()
    
    return policy_model


if __name__ == "__main__":
    import os
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--stage', type=str, choices=['sft', 'rlhf'], required=True)
    parser.add_argument('--data_path', type=str, default='data/instructions.json')
    parser.add_argument('--epochs', type=int, default=3)
    parser.add_argument('--include_cot', action='store_true', help='Include Chain-of-Thought')
    
    args = parser.parse_args()
    
    if args.stage == 'sft':
        model = train_instruction_tuning(args.data_path, args.epochs, args.include_cot)
    else:
        # First do SFT, then RLHF
        model = train_instruction_tuning(args.data_path, args.epochs, False)
        model = train_rlhf(model, 'reward_model.pt', args.epochs)

11.3 规模定律与推理优化

11.3.1 Scaling Laws(Kaplan & Chinchilla):损失与计算量(FLOPs)、参数量、数据量的幂律关系,最优模型大小与训练token数的分配

神经语言模型的性能遵循可预测的幂律缩放关系,损失随计算量、参数规模与数据体量的增长呈平滑下降。Kaplan等人的研究表明,在充足计算预算下,模型性能主要依赖于参数量与训练token数的乘积,两者之间存在最优分配比例,增加参数量通常优于增加训练步数。Hoffmann等人通过Chinchilla研究重新校准了缩放定律,指出在给定计算预算下,模型规模与训练数据应当同步扩展,先前研究倾向于训练不足,最优计算效率要求较小的模型配合更多的训练token。

幂律关系预测损失与计算量的幂函数成反比,允许在有限资源下预测大规模模型的预期性能,指导预训练集群的计算分配与模型架构设计。最优模型大小的确定需要平衡参数量增长的表达能力与数据需求的标注成本,Scaling Laws为训练万亿参数模型的工程决策提供理论依据,避免资源浪费在次优配置方案上。

Python

"""
Script: scaling_laws_analysis.py
Content: Implementation and visualization of Scaling Laws (Kaplan et al. & Chinchilla).
         Power-law relationships between loss, compute (FLOPs), parameters, and data tokens.
Usage: python scaling_laws_analysis.py --compute_budget 1e21 --analysis_type optimal_model
"""

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import minimize_scalar
import argparse


class ScalingLaws:
    """
    Implementation of Neural Scaling Laws based on Kaplan et al. (2020) 
    and Hoffmann et al. (2022) (Chinchilla).
    
    Models the power-law relationship: L(N, D) = A/N^alpha + B/D^beta + L_inf
    where N is parameters, D is dataset size in tokens, and L is loss.
    """
    
    def __init__(self, law_type: str = 'chinchilla'):
        """
        Args:
            law_type: 'kaplan' or 'chinchilla' for different scaling coefficients
        """
        self.law_type = law_type
        
        # Constants from papers (fitted to empirical results)
        if law_type == 'kaplan':
            # Kaplan et al. coefficients (OpenAI)
            self.alpha = 0.076  # Parameter scaling exponent
            self.beta = 0.095   # Data scaling exponent
            self.A = 1.0        # Parameter coefficient
            self.B = 1.0        # Data coefficient
            self.L_inf = 1.7    # Irreducible entropy (GPT-3 setting)
            
            # Compute scaling: L(C) = (C_c / C) ^ 0.050
            self.compute_coeff = 1.0
            self.compute_exponent = 0.050
            
        else:  # chinchilla
            # Hoffmann et al. coefficients (DeepMind)
            self.alpha = 0.34   # Modified parameter scaling
            self.beta = 0.28    # Modified data scaling  
            self.A = 406.4
            self.B = 410.7
            self.L_inf = 1.69   # Estimated irreducible loss for English text
            
            # Chinchilla optimal: N_opt ∝ C^0.50, D_opt ∝ C^0.50
            self.param_exponent = 0.50
            self.data_exponent = 0.50
    
    def loss_vs_params(self, N: np.ndarray, D: float = np.inf) -> np.ndarray:
        """
        Compute loss as function of parameters N (fixed data size D).
        L(N) = A/N^alpha + B/D^beta + L_inf
        """
        param_term = self.A / (N ** self.alpha)
        data_term = 0 if np.isinf(D) else self.B / (D ** self.beta)
        return param_term + data_term + self.L_inf
    
    def loss_vs_data(self, D: np.ndarray, N: float = np.inf) -> np.ndarray:
        """
        Compute loss as function of data size D (fixed parameters N).
        """
        param_term = 0 if np.isinf(N) else self.A / (N ** self.alpha)
        data_term = self.B / (D ** self.beta)
        return param_term + data_term + self.L_inf
    
    def loss_vs_compute(self, C: np.ndarray) -> np.ndarray:
        """
        Compute loss as function of total compute C (in FLOPs).
        Assuming optimal allocation between N and D.
        
        For Kaplan: L(C) = (C_c / C) ^ 0.050 + L_inf
        For Chinchilla: L(C) with optimal N(C), D(C)
        """
        if self.law_type == 'kaplan':
            return self.compute_coeff * (C ** -self.compute_exponent) + self.L_inf
        else:
            # Chinchilla: under optimal allocation, loss follows specific form
            # L(C) = L_inf + (a / C^0.5) approximately
            return self.L_inf + 1000 * (C ** -0.5)  # Simplified form
    
    def optimal_allocation(self, C: float) -> Tuple[float, float]:
        """
        Given compute budget C (FLOPs), return optimal model size N and tokens D.
        
        Kaplan: N_opt ∝ C^0.73, D_opt ∝ C^0.27 (favor model size)
        Chinchilla: N_opt ∝ C^0.50, D_opt ∝ C^0.50 (equal split)
        """
        if self.law_type == 'kaplan':
            # Kaplan suggests 73% of budget to parameters, 27% to data
            N_opt = 0.1 * (C ** 0.73)  # Constants approximated
            D_opt = 5.4 * (C ** 0.27) * 1e12  # Convert to tokens
        else:
            # Chinchilla: equal exponents 0.5
            # FLOPs ≈ 6 * N * D (for decoder-only transformers)
            # Solving: N_opt = (C / (6 * D_opt)) and D_opt = k * C^0.5
            N_opt = 0.3 * (C ** 0.5)
            D_opt = (C / (6 * N_opt))
        
        return N_opt, D_opt
    
    def isoflop_profiles(self, C_values: List[float]) -> Dict:
        """
        Generate isoflop curves: for fixed compute C, vary N and compute D = C/(6*N).
        Returns dict mapping C to lists of (N, D, L) tuples.
        """
        profiles = {}
        
        for C in C_values:
            N_range = np.logspace(7, 12, 50)  # 10M to 1T parameters
            D_vals = C / (6 * N_range)  # From FLOPs = 6 * N * D approximation
            
            # Mask valid D (positive, not infinite)
            valid_mask = D_vals > 0
            N_valid = N_range[valid_mask]
            D_valid = D_vals[valid_mask]
            
            # Compute loss for each (N, D) pair
            losses = []
            for n, d in zip(N_valid, D_valid):
                param_term = self.A / (n ** self.alpha)
                data_term = self.B / (d ** self.beta)
                losses.append(param_term + data_term + self.L_inf)
            
            profiles[C] = {
                'N': N_valid,
                'D': D_valid,
                'losses': np.array(losses),
                'optimal_N': N_valid[np.argmin(losses)],
                'optimal_D': D_valid[np.argmin(losses)]
            }
        
        return profiles


def plot_scaling_laws_comparison():
    """Visualize Kaplan vs Chinchilla scaling laws."""
    kaplan = ScalingLaws('kaplan')
    chinchilla = ScalingLaws('chinchilla')
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    # 1. Loss vs Parameters (fixed data)
    N_range = np.logspace(8, 12, 100)  # 100M to 1T
    ax = axes[0, 0]
    ax.loglog(N_range, kaplan.loss_vs_params(N_range, D=1e12), 
             label='Kaplan (D=1T fixed)', linewidth=2)
    ax.loglog(N_range, chinchilla.loss_vs_params(N_range, D=1e12), 
             label='Chinchilla (D=1T fixed)', linewidth=2)
    ax.set_xlabel('Parameters N')
    ax.set_ylabel('Loss L(N)')
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.set_title('Loss vs Model Size (Fixed Data)')
    
    # 2. Loss vs Data (fixed parameters)
    D_range = np.logspace(9, 13, 100)  # 1B to 10T tokens
    ax = axes[0, 1]
    ax.loglog(D_range, kaplan.loss_vs_data(D_range, N=1e10), 
             label='Kaplan (N=10B fixed)', linewidth=2)
    ax.loglog(D_range, chinchilla.loss_vs_data(D_range, N=1e10), 
             label='Chinchilla (N=10B fixed)', linewidth=2)
    ax.set_xlabel('Tokens D')
    ax.set_ylabel('Loss L(D)')
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.set_title('Loss vs Data Size (Fixed Model)')
    
    # 3. Loss vs Compute (optimal allocation)
    C_range = np.logspace(18, 24, 100)  # 1e18 to 1e24 FLOPs
    ax = axes[0, 2]
    ax.loglog(C_range, kaplan.loss_vs_compute(C_range), 
             label='Kaplan Optimal', linewidth=2)
    ax.loglog(C_range, chinchilla.loss_vs_compute(C_range), 
             label='Chinchilla Optimal', linewidth=2)
    ax.set_xlabel('Compute C (FLOPs)')
    ax.set_ylabel('Loss L(C)')
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.set_title('Loss vs Compute (Optimal Allocation)')
    
    # 4. Optimal Model Size vs Compute
    C_values = np.logspace(18, 24, 50)
    ax = axes[1, 0]
    kaplan_N = [kaplan.optimal_allocation(c)[0] for c in C_values]
    chinch_N = [chinchilla.optimal_allocation(c)[0] for c in C_values]
    ax.loglog(C_values, kaplan_N, label='Kaplan N_opt', linewidth=2)
    ax.loglog(C_values, chinch_N, label='Chinchilla N_opt', linewidth=2)
    ax.set_xlabel('Compute C (FLOPs)')
    ax.set_ylabel('Optimal Parameters N')
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.set_title('Optimal Model Size Scaling')
    
    # 5. Optimal Data vs Compute
    ax = axes[1, 1]
    kaplan_D = [kaplan.optimal_allocation(c)[1] for c in C_values]
    chinch_D = [chinchilla.optimal_allocation(c)[1] for c in C_values]
    ax.loglog(C_values, kaplan_D, label='Kaplan D_opt', linewidth=2)
    ax.loglog(C_values, chinch_D, label='Chinchilla D_opt', linewidth=2)
    ax.set_xlabel('Compute C (FLOPs)')
    ax.set_ylabel('Optimal Tokens D')
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.set_title('Optimal Data Size Scaling')
    
    # 6. IsoFLOP curves (Chinchilla only, as it's more accurate)
    ax = axes[1, 2]
    C_vals = [1e18, 1e20, 1e22, 1e24]
    profiles = chinchilla.isoflop_profiles(C_vals)
    colors = plt.cm.viridis(np.linspace(0, 1, len(C_vals)))
    
    for C, color in zip(C_vals, colors):
        prof = profiles[C]
        # Plot loss vs parameters for fixed compute
        ax.semilogy(prof['N'] / 1e9, prof['losses'], color=color, 
                   label=f'C=1e{int(np.log10(C))}', linewidth=2)
        # Mark minimum
        min_idx = np.argmin(prof['losses'])
        ax.scatter(prof['N'][min_idx]/1e9, prof['losses'][min_idx], 
                  color=color, s=100, zorder=5, marker='*')
    
    ax.set_xlabel('Parameters N (Billions)')
    ax.set_ylabel('Loss')
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.set_title('IsoFLOP Curves (Chinchilla)\nStars mark optimal points')
    
    plt.tight_layout()
    plt.savefig('scaling_laws_comparison.png', dpi=300, bbox_inches='tight')
    plt.show()


def compute_optimal_training_config(target_loss: float = 2.0, 
                                   law_type: str = 'chinchilla') -> Dict:
    """
    Calculate required compute, model size, and data for target loss.
    """
    law = ScalingLaws(law_type)
    
    # Invert loss formula to estimate required compute
    # L(C) ≈ L_inf + a*C^(-b)
    # => C ≈ ((L - L_inf) / a)^(-1/b)
    
    if law_type == 'kaplan':
        a, b = 1.0, 0.050
        req_compute = ((target_loss - law.L_inf) / a) ** (-1/b)
    else:
        a, b = 1000, 0.5
        req_compute = ((target_loss - law.L_inf) / a) ** (-1/b)
    
    N_opt, D_opt = law.optimal_allocation(req_compute)
    
    return {
        'target_loss': target_loss,
        'required_compute_FLOPs': req_compute,
        'optimal_parameters': N_opt,
        'optimal_tokens': D_opt,
        'training_budget_days': req_compute / (1e20 * 86400),  # Assuming 1000 TPUs at 100TFLOP/s each
        'law_type': law_type
    }


def plot_training_efficiency_frontier():
    """Visualize Pareto frontier of training efficiency."""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
    
    # Left: Compute efficiency
    compute_range = np.logspace(18, 24, 100)
    kaplan = ScalingLaws('kaplan')
    chinchilla = ScalingLaws('chinchilla')
    
    losses_kaplan = [kaplan.loss_vs_compute(c) for c in compute_range]
    losses_chinch = [chinchilla.loss_vs_compute(c) for c in compute_range]
    
    ax1.loglog(compute_range, losses_kaplan, label='Kaplan Scaling', linewidth=2)
    ax1.loglog(compute_range, losses_chinch, label='Chinchilla Scaling', linewidth=2)
    ax1.set_xlabel('Training Compute (FLOPs)')
    ax1.set_ylabel('Final Loss')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    ax1.set_title('Training Efficiency Frontier')
    
    # Right: Data vs Parameters for fixed compute levels
    C_levels = [1e19, 1e21, 1e23]
    colors = ['blue', 'green', 'red']
    
    for C, color in zip(C_levels, colors):
        # Param range for this compute level
        N_range = np.logspace(8, 11, 50)
        D_vals = C / (6 * N_range)
        
        # Only plot valid region
        valid = D_vals > 1e9
        N_valid = N_range[valid]
        D_valid = D_vals[valid]
        
        ax2.loglog(N_valid / 1e9, D_valid / 1e9, color=color, 
                  label=f'C=1e{int(np.log10(C))}', linewidth=2)
        
        # Mark Chinchilla optimal
        N_opt, D_opt = chinchilla.optimal_allocation(C)
        ax2.scatter(N_opt / 1e9, D_opt / 1e9, color=color, s=200, marker='*', zorder=5)
    
    ax2.set_xlabel('Model Size (B parameters)')
    ax2.set_ylabel('Training Tokens (B tokens)')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    ax2.set_title('Optimal Training Configurations\nStars = Chinchilla optimal')
    
    plt.tight_layout()
    plt.savefig('training_efficiency.png', dpi=300)
    plt.show()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Analyze Neural Scaling Laws')
    parser.add_argument('--analysis_type', type=str, default='all',
                       choices=['comparison', 'optimal_config', 'efficiency', 'all'])
    parser.add_argument('--target_loss', type=float, default=2.0,
                       help='Target loss for optimal config calculation')
    parser.add_argument('--law_type', type=str, default='chinchilla',
                       choices=['kaplan', 'chinchilla'])
    
    args = parser.parse_args()
    
    if args.analysis_type in ['comparison', 'all']:
        plot_scaling_laws_comparison()
    
    if args.analysis_type in ['optimal_config', 'all']:
        config = compute_optimal_training_config(args.target_loss, args.law_type)
        print("\nOptimal Training Configuration:")
        for k, v in config.items():
            print(f"  {k}: {v:.2e}" if isinstance(v, float) else f"  {k}: {v}")
    
    if args.analysis_type in ['efficiency', 'all']:
        plot_training_efficiency_frontier()

11.3.2 模型压缩:量化感知训练(QAT)的LLM.int8()与QLoRA的4-bit NormalFloat,知识蒸馏的MiniLLM与温度参数调优

大语言模型部署受限于显存容量与计算带宽,量化技术通过降低参数精度压缩模型体积。Dettmers等人提出的LLM.int8()方法识别离群特征维度并分离处理,将矩阵乘法分解为8位精度与16位精度的混合计算,保持模型质量的同时实现两倍显存压缩。QLoRA扩展此技术至微调场景,采用4-bit NormalFloat量化存储基础权重,通过分位数量化方案保留正态分布参数的统计特性,结合分页优化器管理内存峰值,实现单卡微调65B参数模型。

知识蒸馏将大教师模型的知识迁移至紧凑学生网络,通过软目标分布训练学生模仿教师的输出概率。温度缩放调节softmax输出的平滑度,高温使分布更均匀传递类别间相似性信息,低温接近硬标签。MiniLLM等蒸馏方案针对生成任务优化,引入反向KL散度最小化防止学生过度估计低概率区域,结合词级与序列级目标函数训练高效小模型,在推理速度与存储效率间取得平衡。

Python

"""
Script: model_compression_quantization_distillation.py
Content: Implementation of 8-bit quantization (LLM.int8()), 4-bit QLoRA with NormalFloat,
         and Knowledge Distillation with temperature scaling and MiniLLM objectives.
Usage: python model_compression_quantization_distillation.py --method int8 --model_path model.pt
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict, List, Tuple, Optional
import matplotlib.pyplot as plt
import argparse


class Int8Quantization:
    """
    LLM.int8() implementation following Dettmers et al. (2022).
    Handles outlier features via mixed-precision decomposition.
    """
    def __init__(self, threshold: float = 6.0):
        self.threshold = threshold  # Outlier threshold
    
    def quantize_matrix(self, weight: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Quantize weight matrix to INT8 with outlier protection.
        
        Returns:
            int8_weights: Quantized weights (int8)
            scales: Per-channel scaling factors (float16)
            outliers: Indices of outlier columns
        """
        # Identify outlier dimensions (columns with values exceeding threshold)
        outlier_mask = weight.abs() > self.threshold
        
        # Separate outliers
        outliers = weight.clone()
        outliers[~outlier_mask] = 0  # Zero non-outliers
        
        # Remaining values for quantization
        weight_int8_part = weight.clone()
        weight_int8_part[outlier_mask] = 0  # Zero outliers
        
        # Per-channel (column-wise) quantization
        # Find max absolute value per column for scaling
        abs_max = weight_int8_part.abs().max(dim=0)[0]
        abs_max[abs_max == 0] = 1.0  # Avoid division by zero
        
        # Scale to [-127, 127] range (keeping one value for zero)
        scales = abs_max / 127.0
        int8_weights = torch.round(weight_int8_part / scales.unsqueeze(0)).to(torch.int8)
        
        return int8_weights, scales, outliers
    
    def dequantize_matmul(self, int8_weights: torch.Tensor, scales: torch.Tensor,
                         outliers: torch.Tensor, input_activations: torch.Tensor) -> torch.Tensor:
        """
        Perform mixed-precision matrix multiplication.
        INT8 part via quantized ops, outliers in FP16.
        """
        # Dequantize INT8 weights
        dequantized = int8_weights.float() * scales.unsqueeze(0)
        
        # Add back outliers in FP16
        full_weights = dequantized + outliers
        
        # Matrix multiplication
        output = torch.matmul(input_activations, full_weights.t())
        
        return output


class Linear8bit(nn.Module):
    """Linear layer with INT8 weights and FP16 outliers."""
    def __init__(self, in_features: int, out_features: int, bias: bool = True):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        # Store original weights for reference
        self.register_buffer('weight_fp16', torch.randn(out_features, in_features, dtype=torch.float16))
        if bias:
            self.register_buffer('bias_fp16', torch.zeros(out_features, dtype=torch.float16))
        else:
            self.bias_fp16 = None
        
        self.quantizer = Int8Quantization(threshold=6.0)
        self.quantized = False
        
        # Quantized storage
        self.register_buffer('int8_weight', torch.zeros(out_features, in_features, dtype=torch.int8))
        self.register_buffer('scales', torch.ones(out_features, dtype=torch.float16))
        self.register_buffer('outliers', torch.zeros(out_features, in_features, dtype=torch.float16))
    
    def quantize(self):
        """Convert FP16 weights to INT8 representation."""
        int8_w, scales, outliers = self.quantizer.quantize_matrix(self.weight_fp16)
        self.int8_weight = int8_w
        self.scales = scales.half()
        self.outliers = outliers.half()
        self.quantized = True
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if not self.quantized:
            output = F.linear(x, self.weight_fp16, self.bias_fp16)
        else:
            # Mixed-precision matmul
            output = self.quantizer.dequantize_matmul(
                self.int8_weight, self.scales, self.outliers, x
            )
            if self.bias_fp16 is not None:
                output = output + self.bias_fp16
        
        return output


class NormalFloat4:
    """
    4-bit NormalFloat quantization for QLoRA.
    Optimized for normally distributed weights.
    """
    def __init__(self, num_bits: int = 4, block_size: int = 64):
        self.num_bits = num_bits
        self.num_levels = 2 ** num_bits  # 16 levels for 4-bit
        self.block_size = block_size
        
        # Create NormalFloat quantization levels
        # Based on quantiles of standard normal distribution
        from scipy.stats import norm
        # Split into equal probability bins and use boundaries
        quantiles = norm.ppf(np.linspace(0, 1, self.num_levels + 1)[1:-1])
        
        # Scale to typical weight range (-1, 1) with some margin
        self.quantization_levels = torch.tensor(quantiles / quantiles.max(), dtype=torch.float32)
        
        # Add extremes
        self.min_val = -1.0
        self.max_val = 1.0
    
    def quantize_block(self, weight_block: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Quantize a block of weights to 4-bit.
        Returns quantized indices, absmax for block scaling, and codebook.
        """
        # Compute block-wise absmax for dynamic scaling
        absmax = weight_block.abs().max()
        
        # Normalize to [-1, 1]
        normalized = weight_block / absmax
        
        # Find nearest quantization level for each weight
        # Expand dimensions for broadcasting: (levels, num_weights) vs (num_weights,)
        expanded_levels = self.quantization_levels.to(weight_block.device).unsqueeze(1)
        expanded_weights = normalized.flatten().unsqueeze(0)
        
        distances = torch.abs(expanded_levels - expanded_weights)  # (levels, num_weights)
        indices = distances.argmin(dim=0)  # Closest level for each weight
        
        # Pack into uint8 (2 values per byte)
        return indices, absmax, self.quantization_levels
    
    def dequantize_block(self, indices: torch.Tensor, absmax: torch.Tensor,
                        codebook: torch.Tensor) -> torch.Tensor:
        """Dequantize 4-bit indices back to FP16."""
        # Lookup values
        values = codebook[indices]
        # Rescale
        return values * absmax


class QLoRALinear(nn.Module):
    """
    QLoRA linear layer: 4-bit quantized base weights + LoRA adapters.
    Following Dettmers et al. (2023).
    """
    def __init__(self, in_features: int, out_features: int, 
                 r: int = 64, lora_alpha: int = 16, 
                 quantize_base: bool = True):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.r = r
        self.lora_alpha = lora_alpha
        self.scaling = lora_alpha / r
        self.quantize_base = quantize_base
        
        if quantize_base:
            # 4-bit storage for base weights
            self.nf4 = NormalFloat4()
            self.register_buffer('quantized_weight', torch.zeros(
                (out_features, in_features // 2), dtype=torch.uint8  # Packed 4-bit
            ))
            self.register_buffer('quantization_absmax', 
                               torch.zeros((out_features, (in_features + self.nf4.block_size - 1) // self.nf4.block_size)))
            self.register_buffer('codebook', self.nf4.quantization_levels)
        else:
            self.register_buffer('base_weight', torch.zeros(out_features, in_features, dtype=torch.float16))
        
        # LoRA adapters (trainable)
        self.lora_A = nn.Parameter(torch.zeros(r, in_features))
        self.lora_B = nn.Parameter(torch.zeros(out_features, r))
        
        # Initialize LoRA layers
        nn.init.kaiming_uniform_(self.lora_A, a=np.sqrt(5))
        nn.init.zeros_(self.lora_B)
    
    def quantize_weight(self, fp16_weight: torch.Tensor):
        """Convert FP16 base weight to 4-bit NormalFloat."""
        if not self.quantize_base:
            self.base_weight = fp16_weight.half()
            return
        
        # Block-wise quantization
        num_blocks = (self.in_features + self.nf4.block_size - 1) // self.nf4.block_size
        indices_list = []
        absmax_list = []
        
        for i in range(num_blocks):
            start = i * self.nf4.block_size
            end = min(start + self.nf4.block_size, self.in_features)
            block = fp16_weight[:, start:end]
            
            indices, absmax, _ = self.nf4.quantize_block(block)
            indices_list.append(indices)
            absmax_list.append(absmax)
            
            # Pack indices (simplified - actual implementation packs 2 per byte)
            if i == 0:
                packed = indices[:len(indices)//2].to(torch.uint8) * 16 + indices[len(indices)//2:].to(torch.uint8)
                self.quantized_weight[:, start//2:end//2] = packed
        
        self.quantization_absmax = torch.stack(absmax_list, dim=1)
    
    def get_dequantized_weight(self) -> torch.Tensor:
        """Get FP16 version of base weight for computation."""
        if not self.quantize_base:
            return self.base_weight
        
        # Dequantize (simplified implementation)
        # Real implementation would unpack uint8 and lookup codebook
        blocks = []
        for i in range(self.quantization_absmax.size(1)):
            start = i * self.nf4.block_size
            end = min(start + self.nf4.block_size, self.in_features)
            # Simulated dequantization
            fake_dequant = torch.randn(self.out_features, end - start, device=self.quantized_weight.device) * 0.1
            blocks.append(fake_dequant * self.quantization_absmax[:, i:i+1])
        
        return torch.cat(blocks, dim=1).half()
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Dequantize base weight for forward pass (or use custom CUDA kernel)
        base = self.get_dequantized_weight()
        
        # Base output (no grad for frozen base)
        with torch.no_grad():
            base_output = F.linear(x, base)
        
        # LoRA path
        lora_output = F.linear(F.linear(x, self.lora_A), self.lora_B) * self.scaling
        
        return base_output + lora_output


class KnowledgeDistillation:
    """
    Knowledge distillation with temperature scaling and MiniLLM objectives.
    """
    def __init__(self, teacher_model: nn.Module, student_model: nn.Module,
                 temperature: float = 4.0, alpha: float = 0.5):
        self.teacher = teacher_model
        self.student = student_model
        self.T = temperature  # Softmax temperature
        self.alpha = alpha    # Weight for distillation vs hard label loss
        
        # Freeze teacher
        for param in self.teacher.parameters():
            param.requires_grad = False
        self.teacher.eval()
    
    def forward_kl_loss(self, student_logits: torch.Tensor, 
                       teacher_logits: torch.Tensor) -> torch.Tensor:
        """
        Standard forward KL divergence: KL(teacher || student).
        Tends to over-estimate uncertainty.
        """
        # Soften with temperature
        student_probs = F.log_softmax(student_logits / self.T, dim=-1)
        teacher_probs = F.softmax(teacher_logits / self.T, dim=-1)
        
        # KL divergence
        kl = F.kl_div(student_probs, teacher_probs, reduction='batchmean') * (self.T ** 2)
        return kl
    
    def reverse_kl_loss(self, student_logits: torch.Tensor,
                       teacher_logits: torch.Tensor) -> torch.Tensor:
        """
        Reverse KL: KL(student || teacher).
        MiniLLM objective - prevents over-estimation of low-prob regions.
        Better for generative tasks.
        """
        student_probs = F.softmax(student_logits / self.T, dim=-1)
        teacher_log_probs = F.log_softmax(teacher_logits / self.T, dim=-1)
        
        # Reverse KL: sum(p_student * (log p_student - log p_teacher))
        kl = (student_probs * (torch.log(student_probs + 1e-10) - teacher_log_probs)).sum(dim=-1).mean()
        kl = kl * (self.T ** 2)
        return kl
    
    def mini_llm_loss(self, student_logits: torch.Tensor,
                     teacher_logits: torch.Tensor,
                     input_ids: torch.Tensor,
                     attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        MiniLLM-style loss combining word-level and sequence-level objectives.
        Uses reverse KL to prevent mode collapse in generation.
        """
        # Word-level reverse KL
        word_kl = self.reverse_kl_loss(student_logits, teacher_logits)
        
        # Sequence-level: consider full sequence probability (simplified as sum here)
        if attention_mask is not None:
            # Weighted by sequence position importance
            weights = attention_mask.float() / attention_mask.sum(dim=1, keepdim=True)
            seq_kl = (word_kl * weights).sum(dim=1).mean()
        else:
            seq_kl = word_kl
        
        return seq_kl
    
    def compute_loss(self, input_ids: torch.Tensor, 
                    labels: torch.Tensor,
                    attention_mask: Optional[torch.Tensor] = None,
                    use_reverse_kl: bool = True) -> Dict[str, torch.Tensor]:
        """
        Combined distillation loss with hard labels.
        """
        # Teacher forward (no grad)
        with torch.no_grad():
            teacher_outputs = self.teacher(input_ids, attention_mask)
            teacher_logits = teacher_outputs['logits']
        
        # Student forward
        student_outputs = self.student(input_ids, attention_mask)
        student_logits = student_outputs['logits']
        
        # Distillation loss
        if use_reverse_kl:
            distill_loss = self.mini_llm_loss(student_logits, teacher_logits, 
                                             input_ids, attention_mask)
        else:
            distill_loss = self.forward_kl_loss(student_logits, teacher_logits)
        
        # Hard label loss (standard cross-entropy)
        hard_loss = F.cross_entropy(
            student_logits.view(-1, student_logits.size(-1)),
            labels.view(-1),
            ignore_index=-100
        )
        
        # Combined
        total_loss = self.alpha * distill_loss + (1 - self.alpha) * hard_loss
        
        return {
            'total_loss': total_loss,
            'distill_loss': distill_loss,
            'hard_loss': hard_loss,
            'student_logits': student_logits
        }


class DistillationTrainer:
    """Trainer for knowledge distillation with various objectives."""
    
    def __init__(self, teacher: nn.Module, student: nn.Module, 
                 temperature: float = 4.0, lr: float = 5e-5):
        self.distiller = KnowledgeDistillation(teacher, student, temperature)
        self.optimizer = torch.optim.AdamW(student.parameters(), lr=lr)
        self.temperature_schedule = lambda epoch: max(1.0, temperature - epoch)
    
    def train_step(self, batch: Dict, use_reverse_kl: bool = True) -> Dict[str, float]:
        self.optimizer.zero_grad()
        
        # Dynamic temperature annealing could be applied here
        losses = self.distiller.compute_loss(
            batch['input_ids'],
            batch['labels'],
            batch.get('attention_mask'),
            use_reverse_kl
        )
        
        losses['total_loss'].backward()
        torch.nn.utils.clip_grad_norm_(self.distiller.student.parameters(), 1.0)
        self.optimizer.step()
        
        return {
            'total': losses['total_loss'].item(),
            'distill': losses['distill_loss'].item(),
            'hard': losses['hard_loss'].item()
        }


def visualize_quantization_effects():
    """Compare FP16 vs INT8 vs NF4 quantization distributions."""
    # Generate synthetic weights (normal distribution, typical for LLMs)
    torch.manual_seed(42)
    weights = torch.randn(1000, 4096) * 0.02  # Standard transformer scale
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # Original distribution
    ax = axes[0, 0]
    ax.hist(weights.flatten().numpy(), bins=100, alpha=0.7, color='blue')
    ax.set_title('Original FP16 Weight Distribution')
    ax.set_xlabel('Weight Value')
    ax.set_ylabel('Frequency')
    ax.grid(True, alpha=0.3)
    
    # INT8 quantization
    int8_quantizer = Int8Quantization(threshold=6.0)
    int8_w, scales, outliers = int8_quantizer.quantize_matrix(weights[:100, :100])  # Sample
    
    ax = axes[0, 1]
    dequant = int8_w.float() * scales.unsqueeze(0)
    ax.hist(dequant.flatten().numpy(), bins=100, alpha=0.7, color='green', label='INT8 Dequantized')
    ax.hist(weights[:100, :100].flatten().numpy(), bins=100, alpha=0.5, color='blue', label='Original')
    ax.set_title('INT8 Quantization (with outliers)')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # NF4 quantization
    nf4 = NormalFloat4()
    ax = axes[1, 0]
    ax.hist(nf4.quantization_levels.numpy(), bins=16, color='red', alpha=0.7, edgecolor='black')
    ax.set_title('NF4 Quantization Levels (16 values)')
    ax.set_xlabel('Quantized Value')
    ax.set_ylabel('Count')
    ax.grid(True, alpha=0.3)
    
    # Error comparison
    ax = axes[1, 1]
    
    # Simulated quantization errors
    fp16_vals = np.linspace(-0.1, 0.1, 1000)
    int8_errors = np.abs(np.random.randn(1000) * 0.001)  # Simulated
    nf4_errors = np.abs(np.random.randn(1000) * 0.002)   # NF4 slightly higher error but 4x compression
    
    ax.semilogy(fp16_vals, int8_errors, label='INT8 Error', alpha=0.7)
    ax.semilogy(fp16_vals, nf4_errors, label='NF4 Error', alpha=0.7)
    ax.set_xlabel('Weight Value')
    ax.set_ylabel('Quantization Error (log scale)')
    ax.set_title('Quantization Error Comparison')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('quantization_comparison.png', dpi=300)
    plt.show()


def visualize_distillation_temperature():
    """Visualize effect of temperature on probability distributions."""
    # Synthetic logits
    logits = torch.tensor([2.0, 1.0, 0.1, 0.0, -0.5])
    
    temperatures = [0.5, 1.0, 2.0, 4.0]
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    axes = axes.flatten()
    
    for idx, T in enumerate(temperatures):
        probs = F.softmax(logits / T, dim=-1)
        
        ax = axes[idx]
        bars = ax.bar(range(len(logits)), probs.numpy(), color=plt.cm.viridis(np.linspace(0, 1, len(logits))))
        ax.set_title(f'Temperature = {T}\nEntropy = {-torch.sum(probs * torch.log(probs + 1e-10)):.3f}')
        ax.set_xlabel('Token ID')
        ax.set_ylabel('Probability')
        ax.set_ylim(0, 1)
        ax.grid(True, alpha=0.3, axis='y')
        
        # Add value labels on bars
        for bar, prob in zip(bars, probs):
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height,
                   f'{prob:.3f}', ha='center', va='bottom', fontsize=8)
    
    plt.tight_layout()
    plt.savefig('temperature_scaling.png', dpi=300)
    plt.show()


def visualize_kl_divergence_comparison():
    """Compare Forward KL vs Reverse KL in distillation."""
    # Teacher distribution (sharp, peaked)
    teacher_logits = torch.tensor([5.0, 1.0, 0.5, 0.2])
    teacher_probs = F.softmax(teacher_logits, dim=-1)
    
    # Student distributions to compare
    student_modes = [
        torch.tensor([4.5, 1.2, 0.6, 0.3]),  # Close to teacher
        torch.tensor([3.0, 2.5, 2.0, 1.5]),  # Flat, high entropy
        torch.tensor([5.5, 0.5, 0.2, 0.1]),  # Sharper than teacher
    ]
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
    
    labels = ['Mode Matching', 'High Entropy', 'Sharp Peak']
    colors = ['blue', 'green', 'red']
    
    forward_kls = []
    reverse_kls = []
    
    for student_logits, label, color in zip(student_modes, labels, colors):
        student_probs = F.softmax(student_logits, dim=-1)
        
        # Forward KL: KL(teacher || student)
        f_kl = torch.sum(teacher_probs * (torch.log(teacher_probs + 1e-10) - torch.log(student_probs + 1e-10)))
        forward_kls.append(f_kl.item())
        
        # Reverse KL: KL(student || teacher)
        r_kl = torch.sum(student_probs * (torch.log(student_probs + 1e-10) - torch.log(teacher_probs + 1e-10)))
        reverse_kls.append(r_kl.item())
        
        # Plot distributions
        x = np.arange(len(teacher_logits))
        ax1.plot(x, teacher_probs.numpy(), 'o--', label=f'Teacher', alpha=0.5, linewidth=2)
        ax1.plot(x, student_probs.numpy(), 's-', label=f'Student: {label}', color=color, alpha=0.7)
    
    ax1.set_title('Teacher vs Student Distributions')
    ax1.set_xlabel('Token Index')
    ax1.set_ylabel('Probability')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # KL comparison bar chart
    x_pos = np.arange(len(labels))
    width = 0.35
    
    ax2.bar(x_pos - width/2, forward_kls, width, label='Forward KL (Teacher||Student)', color='skyblue', edgecolor='black')
    ax2.bar(x_pos + width/2, reverse_kls, width, label='Reverse KL (Student||Teacher)', color='lightcoral', edgecolor='black')
    
    ax2.set_ylabel('KL Divergence')
    ax2.set_title('Forward vs Reverse KL Divergence')
    ax2.set_xticks(x_pos)
    ax2.set_xticklabels(labels, rotation=15)
    ax2.legend()
    ax2.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.savefig('kl_divergence_comparison.png', dpi=300)
    plt.show()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--visualization', type=str, default='all',
                       choices=['quantization', 'temperature', 'kl', 'all'])
    parser.add_argument('--demo_training', action='store_true',
                       help='Run demonstration training loop')
    
    args = parser.parse_args()
    
    if args.visualization in ['quantization', 'all']:
        visualize_quantization_effects()
    
    if args.visualization in ['temperature', 'all']:
        visualize_distillation_temperature()
    
    if args.visualization in ['kl', 'all']:
        visualize_kl_divergence_comparison()
    
    if args.demo_training:
        print("\nRunning demonstration training...")
        # Create dummy models
        class DummyModel(nn.Module):
            def __init__(self, vocab=1000):
                super().__init__()
                self.embed = nn.Embedding(vocab, 256)
                self.transformer = nn.TransformerEncoderLayer(256, 4, 512, batch_first=True)
                self.head = nn.Linear(256, vocab)
            
            def forward(self, input_ids, attention_mask=None):
                x = self.embed(input_ids)
                x = self.transformer(x)
                logits = self.head(x)
                return {'logits': logits, 'hidden_states': x}
        
        teacher = DummyModel()
        student = DummyModel()
        
        trainer = DistillationTrainer(teacher, student, temperature=4.0)
        
        # Dummy batch
        batch = {
            'input_ids': torch.randint(0, 1000, (4, 32)),
            'labels': torch.randint(0, 1000, (4, 32)),
            'attention_mask': torch.ones(4, 32)
        }
        
        for i in range(10):
            losses = trainer.train_step(batch, use_reverse_kl=True)
            print(f"Step {i+1}: Total={losses['total']:.4f}, "
                  f"Distill={losses['distill']:.4f}, Hard={losses['hard']:.4f}")

11.3.3 推测解码(Speculative Decoding)的小模型草稿与大模型验证机制,KV-Cache的内存管理与PagedAttention的块表调度

自回归生成受限于串行解码特性,推测解码通过小型草稿模型快速生成候选序列,再由大模型并行验证加速推理。Leviathan等人提出的方法利用大模型对小模型输出的接受率,在保持输出分布等价性的前提下实现二至三倍的解码加速,拒绝采样确保最终输出与直接采样大模型一致。草稿模型与目标模型的能力匹配度决定加速比,小型Transformer或n-gram模型担任草稿生成角色。

键值缓存存储先前计算的注意力键值张量避免重复计算,PagedAttention借鉴操作系统虚拟内存管理,将缓存分割为固定大小的块并通过块表动态映射逻辑位置到物理存储。Kwon等人的vLLM系统实现动态内存分配,消除传统缓存的内存碎片与过度预留问题,支持连续批处理下的高效内存共享,使高吞吐服务成为可能。块表调度允许请求间共享提示前缀,通过引用计数机制实现复制写入语义,进一步提升多用户场景下的GPU利用率。

Python

"""
Script: inference_optimization_speculative_kv.py
Content: Implementation of Speculative Decoding with draft model and target model verification,
         KV-Cache management, and PagedAttention block table scheduling.
Usage: python inference_optimization_speculative_kv.py --method speculative --prompt "The future of AI"
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Dict, Optional, Set
import numpy as np
import matplotlib.pyplot as plt
import argparse
from dataclasses import dataclass
import time


class DraftModel(nn.Module):
    """
    Small draft model for Speculative Decoding.
    Typically 10-100x smaller than target model.
    """
    def __init__(self, vocab_size: int = 50000, embed_dim: int = 256, 
                 num_layers: int = 4, num_heads: int = 4):
        super().__init__()
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.layers = nn.ModuleList([
            nn.TransformerDecoderLayer(embed_dim, num_heads, 512, batch_first=True)
            for _ in range(num_layers)
        ])
        self.ln = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, vocab_size)
    
    def forward(self, input_ids: torch.Tensor, kv_cache: Optional[List[torch.Tensor]] = None):
        x = self.embedding(input_ids)
        
        # Create causal mask
        seq_len = input_ids.size(1)
        mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool().to(input_ids.device)
        
        new_kv_cache = []
        for i, layer in enumerate(self.layers):
            if kv_cache is not None and i < len(kv_cache):
                # Concatenate with cached keys/values
                prev_kv = kv_cache[i]  # (2, batch, heads, seq, head_dim)
                x = layer(x, memory=x, tgt_mask=mask)
            else:
                x = layer(x, memory=x, tgt_mask=mask)
            
            # Store KV (simplified - actual implementation stores K,V separately)
            new_kv_cache.append(x[:, -1:, :].detach())  # Only last position
        
        logits = self.head(self.ln(x))
        return logits, new_kv_cache


class TargetModel(nn.Module):
    """
    Large target model for Speculative Decoding verification.
    """
    def __init__(self, vocab_size: int = 50000, embed_dim: int = 4096,
                 num_layers: int = 32, num_heads: int = 32):
        super().__init__()
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.layers = nn.ModuleList([
            nn.TransformerDecoderLayer(embed_dim, num_heads, 16384, batch_first=True)
            for _ in range(num_layers)
        ])
        self.ln = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, vocab_size)
    
    def forward(self, input_ids: torch.Tensor, kv_cache: Optional[List] = None,
                return_hidden: bool = False):
        x = self.embedding(input_ids)
        seq_len = input_ids.size(1)
        mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool().to(input_ids.device)
        
        for i, layer in enumerate(self.layers):
            x = layer(x, memory=x, tgt_mask=mask)
        
        hidden = self.ln(x)
        logits = self.head(hidden)
        
        if return_hidden:
            return logits, hidden
        return logits


class SpeculativeDecoder:
    """
    Speculative Decoding implementation following Leviathan et al. (2022).
    Draft model generates K tokens, target model verifies in parallel.
    """
    def __init__(self, draft_model: DraftModel, target_model: TargetModel,
                 gamma: int = 4, device: str = 'cuda'):
        self.draft_model = draft_model.to(device)
        self.target_model = target_model.to(device)
        self.gamma = gamma  # Number of draft tokens to generate
        self.device = device
        
        self.draft_model.eval()
        self.target_model.eval()
    
    @torch.no_grad()
    def generate(self, prompt_ids: torch.Tensor, max_length: int = 100,
                temperature: float = 1.0, top_k: Optional[int] = None) -> Tuple[torch.Tensor, Dict]:
        """
        Generate sequence using speculative decoding.
        
        Returns:
            generated_ids: Final sequence
            stats: Dict with acceptance rate, speedup, etc.
        """
        generated = prompt_ids.clone()
        num_draft_tokens = 0
        num_accepted_tokens = 0
        target_calls = 0
        
        while generated.size(1) < max_length:
            # Step 1: Draft model generates gamma tokens autoregressively
            draft_tokens = []
            draft_kv = None
            
            current = generated
            for _ in range(self.gamma):
                logits, draft_kv = self.draft_model(current, draft_kv)
                probs = F.softmax(logits[:, -1, :] / temperature, dim=-1)
                
                if top_k is not None:
                    v, _ = torch.topk(probs, min(top_k, probs.size(-1)))
                    probs[probs < v[:, [-1]]] = 0
                    probs = probs / probs.sum(dim=-1, keepdim=True)
                
                next_token = torch.multinomial(probs, num_samples=1)
                draft_tokens.append(next_token)
                current = torch.cat([current, next_token], dim=1)
                num_draft_tokens += 1
            
            draft_tokens = torch.cat(draft_tokens, dim=1)  # (batch, gamma)
            draft_sequence = torch.cat([generated, draft_tokens], dim=1)
            
            # Step 2: Target model evaluates draft tokens in parallel
            target_logits = self.target_model(draft_sequence)
            target_probs = F.softmax(target_logits / temperature, dim=-1)
            target_calls += 1
            
            # Step 3: Verification with rejection sampling
            accepted_count = 0
            for i in range(self.gamma):
                pos = generated.size(1) + i
                draft_token = draft_tokens[:, i]
                
                # Get probability distributions
                q = target_probs[:, pos, :]  # Target distribution
                p = F.softmax(
                    self.draft_model(draft_sequence[:, :pos])[0][:, -1, :] / temperature, 
                    dim=-1
                )  # Draft distribution at this position
                
                # Acceptance probability: min(1, q(x)/p(x))
                accept_prob = torch.min(
                    torch.ones(1).to(self.device),
                    q[0, draft_token[0]] / (p[0, draft_token[0]] + 1e-10)
                )
                
                if torch.rand(1).to(self.device) < accept_prob:
                    # Accept token
                    generated = torch.cat([generated, draft_token], dim=1)
                    accepted_count += 1
                    num_accepted_tokens += 1
                else:
                    # Reject: sample from rescaled distribution (q - p)+
                    adjusted_probs = torch.clamp(q - p, min=0)
                    adjusted_probs = adjusted_probs / adjusted_probs.sum(dim=-1, keepdim=True)
                    new_token = torch.multinomial(adjusted_probs, num_samples=1)
                    generated = torch.cat([generated, new_token], dim=1)
                    break
            
            # If all accepted, sample one additional from target
            if accepted_count == self.gamma:
                next_token = torch.multinomial(target_probs[:, -1, :], num_samples=1)
                generated = torch.cat([generated, next_token], dim=1)
        
        stats = {
            'num_draft_tokens': num_draft_tokens,
            'num_accepted_tokens': num_accepted_tokens,
            'acceptance_rate': num_accepted_tokens / num_draft_tokens if num_draft_tokens > 0 else 0,
            'target_model_calls': target_calls,
            'effective_speedup': num_draft_tokens / target_calls if target_calls > 0 else 1.0
        }
        
        return generated, stats


class Block:
    """Memory block for PagedAttention."""
    def __init__(self, block_id: int, block_size: int, num_layers: int, 
                 num_heads: int, head_dim: int):
        self.block_id = block_id
        self.block_size = block_size
        self.ref_count = 0  # Reference counting for sharing
        
        # Storage: (num_layers, 2, num_heads, block_size, head_dim)
        # 2 for Key and Value
        self.kv_cache = torch.zeros(num_layers, 2, num_heads, block_size, head_dim)
        self.occupied = 0  # Number of positions used in this block
    
    def allocate(self, count: int = 1) -> bool:
        """Allocate space in block. Returns success."""
        if self.occupied + count > self.block_size:
            return False
        self.occupied += count
        self.ref_count += 1
        return True


class PagedAttentionKVCache:
    """
    PagedAttention KV-Cache implementation following vLLM (Kwon et al., 2023).
    Manages KV cache as fixed-size blocks with block table mapping.
    """
    def __init__(self, num_blocks: int = 1000, block_size: int = 16,
                 num_layers: int = 32, num_heads: int = 32, 
                 head_dim: int = 128, device: str = 'cuda'):
        self.num_blocks = num_blocks
        self.block_size = block_size
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.device = device
        
        # Pre-allocate all blocks
        self.blocks = [
            Block(i, block_size, num_layers, num_heads, head_dim).to(device)
            for i in range(num_blocks)
        ]
        
        # Track free blocks
        self.free_blocks = set(range(num_blocks))
        
        # Request-specific mapping: request_id -> list of block_ids
        self.block_tables: Dict[int, List[int]] = {}
    
    def allocate(self, request_id: int, num_tokens: int) -> List[int]:
        """
        Allocate blocks for new request.
        Returns list of allocated block IDs.
        """
        num_blocks_needed = (num_tokens + self.block_size - 1) // self.block_size
        allocated = []
        
        for _ in range(num_blocks_needed):
            if not self.free_blocks:
                raise MemoryError("Out of memory: no free blocks available")
            
            block_id = self.free_blocks.pop()
            self.blocks[block_id].ref_count = 1
            allocated.append(block_id)
        
        self.block_tables[request_id] = allocated
        return allocated
    
    def append_token(self, request_id: int) -> Optional[int]:
        """
        Allocate space for one additional token.
        May allocate new block if current is full.
        """
        if request_id not in self.block_tables:
            return None
        
        block_table = self.block_tables[request_id]
        last_block_id = block_table[-1]
        last_block = self.blocks[last_block_id]
        
        if last_block.occupied < self.block_size:
            last_block.occupied += 1
            return last_block_id
        else:
            # Need new block
            if not self.free_blocks:
                return None
            
            new_block_id = self.free_blocks.pop()
            self.blocks[new_block_id].ref_count = 1
            self.blocks[new_block_id].occupied = 1
            self.block_tables[request_id].append(new_block_id)
            return new_block_id
    
    def get_kv_cache(self, request_id: int, layer_id: int) -> torch.Tensor:
        """
        Retrieve KV cache for specific request and layer.
        Gathers from all blocks in block table.
        """
        if request_id not in self.block_tables:
            return torch.zeros(0)
        
        block_ids = self.block_tables[request_id]
        
        # Gather KV from all blocks
        kvs = []
        for bid in block_ids:
            block = self.blocks[bid]
            # (2, heads, occupied, head_dim)
            layer_kv = block.kv_cache[layer_id, :, :, :block.occupied, :]
            kvs.append(layer_kv)
        
        # Concatenate along sequence dimension
        return torch.cat(kvs, dim=2)
    
    def fork(self, parent_id: int, child_id: int):
        """
        Copy-on-write fork: share parent blocks with child.
        Increments reference counts.
        """
        if parent_id not in self.block_tables:
            return
        
        parent_blocks = self.block_tables[parent_id]
        self.block_tables[child_id] = parent_blocks.copy()
        
        for bid in parent_blocks:
            self.blocks[bid].ref_count += 1
    
    def free(self, request_id: int):
        """Free all blocks associated with request."""
        if request_id not in self.block_tables:
            return
        
        for bid in self.block_tables[request_id]:
            self.blocks[bid].ref_count -= 1
            if self.blocks[bid].ref_count == 0:
                self.blocks[bid].occupied = 0
                self.free_blocks.add(bid)
        
        del self.block_tables[request_id]


class OptimizedTransformer(nn.Module):
    """
    Transformer with PagedAttention KV-cache support.
    """
    def __init__(self, vocab_size: int = 50000, embed_dim: int = 4096,
                 num_layers: int = 32, num_heads: int = 32,
                 block_size: int = 16, device: str = 'cuda'):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        
        # Initialize PagedAttention cache
        self.kv_cache = PagedAttentionKVCache(
            num_blocks=1000,
            block_size=block_size,
            num_layers=num_layers,
            num_heads=num_heads,
            head_dim=self.head_dim,
            device=device
        )
        
        self.layers = nn.ModuleList([
            nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
            for _ in range(num_layers)
        ])
        self.ln = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, vocab_size)
    
    def forward(self, input_ids: torch.Tensor, request_id: int,
                is_prompt: bool = False):
        """
        Forward pass with PagedAttention KV-cache.
        
        Args:
            input_ids: (batch, seq_len)
            request_id: Unique identifier for this sequence
            is_prompt: True if processing prompt (prefill), False if decoding
        """
        batch_size, seq_len = input_ids.shape
        x = self.embedding(input_ids)
        
        if is_prompt:
            # Prefill phase: allocate new blocks
            self.kv_cache.allocate(request_id, seq_len)
        
        # Retrieve cached KV for attention
        for layer_id, layer in enumerate(self.layers):
            # Get cached KV for this layer and request
            cached_kv = self.kv_cache.get_kv_cache(request_id, layer_id)
            
            # Compute attention with cached context
            if cached_kv.numel() > 0:
                # Split into K and V
                cached_k, cached_v = cached_kv[0], cached_kv[1]  # (heads, cached_len, head_dim)
                cached_k = cached_k.unsqueeze(0).expand(batch_size, -1, -1, -1).transpose(1, 2)
                cached_v = cached_v.unsqueeze(0).expand(batch_size, -1, -1, -1).transpose(1, 2)
                
                # Current Q, K, V
                q = x.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
                k = torch.cat([cached_k, q], dim=2)  # Append to cache
                v = torch.cat([cached_v, x.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)], dim=2)
                
                # Store updated KV back to cache (simplified)
                # Real implementation writes to specific block positions
            else:
                # No cache yet
                k = x.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
                v = k.clone()
            
            # Attention
            scores = torch.matmul(k, k.transpose(-2, -1)) / np.sqrt(self.head_dim)
            attn = F.softmax(scores, dim=-1)
            out = torch.matmul(attn, v)
            out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
            
            x = x + out  # Residual (simplified, no FFN for brevity)
        
        if not is_prompt:
            # Decoding phase: append token to cache
            self.kv_cache.append_token(request_id)
        
        logits = self.head(self.ln(x))
        return logits


def benchmark_speculative_decoding():
    """Benchmark speculative decoding vs standard autoregressive generation."""
    vocab_size = 1000
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Create models
    draft = DraftModel(vocab_size, embed_dim=256, num_layers=4)
    target = TargetModel(vocab_size, embed_dim=1024, num_layers=12)  # Smaller for demo
    
    decoder = SpeculativeDecoder(draft, target, gamma=4, device=device)
    
    # Test prompts
    prompt = torch.randint(0, vocab_size, (1, 10)).to(device)
    
    # Standard generation (no speculation)
    def standard_generate(model, prompt, max_len):
        generated = prompt.clone()
        start = time.time()
        with torch.no_grad():
            for _ in range(max_len):
                logits = model(generated)
                next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
                generated = torch.cat([generated, next_token], dim=1)
        return generated, time.time() - start
    
    # Benchmark standard
    standard_seq, standard_time = standard_generate(target, prompt, 50)
    
    # Benchmark speculative
    start = time.time()
    spec_seq, stats = decoder.generate(prompt, max_length=60)
    spec_time = time.time() - start
    
    print(f"\nSpeculative Decoding Benchmark:")
    print(f"  Draft tokens generated: {stats['num_draft_tokens']}")
    print(f"  Accepted tokens: {stats['num_accepted_tokens']}")
    print(f"  Acceptance rate: {stats['acceptance_rate']:.2%}")
    print(f"  Target model calls: {stats['target_model_calls']}")
    print(f"  Theoretical speedup: {stats['effective_speedup']:.2f}x")
    print(f"  Actual speedup: {standard_time / spec_time:.2f}x")
    
    # Visualization
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
    
    # Acceptance rate vs position
    positions = np.arange(stats['num_draft_tokens'])
    # Simulated acceptance probability (decreases with position)
    accept_probs = np.exp(-0.1 * positions)
    
    ax1.plot(positions, accept_probs, marker='o', label='Acceptance Probability')
    ax1.axhline(y=0.5, color='r', linestyle='--', label='50% threshold')
    ax1.set_xlabel('Position in Draft Sequence')
    ax1.set_ylabel('Acceptance Probability')
    ax1.set_title('Draft Token Acceptance Rate Decay')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Speedup comparison
    gamma_values = [1, 2, 4, 6, 8, 10]
    theoretical_speedups = []
    
    for gamma in gamma_values:
        # Theoretical speedup: gamma / (1 + (1 - alpha^gamma)/(1-alpha) * beta)
        # Simplified approximation
        alpha = 0.6  # Acceptance rate
        speedup = gamma / (1 + (1 - alpha**gamma) / (1 - alpha) * 0.1)
        theoretical_speedups.append(speedup)
    
    ax2.plot(gamma_values, theoretical_speedups, marker='o', label='Theoretical Speedup')
    ax2.axhline(y=1.0, color='r', linestyle='--', label='Baseline (no speculation)')
    ax2.set_xlabel('Gamma (Draft Tokens)')
    ax2.set_ylabel('Speedup Factor')
    ax2.set_title('Speculative Decoding Speedup vs Draft Length')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('speculative_decoding_benchmark.png', dpi=300)
    plt.show()


def visualize_paged_attention():
    """Visualize PagedAttention memory management."""
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6))
    
    # Block allocation visualization
    num_blocks = 20
    block_size = 4
    
    # Simulate requests
    requests = {
        'Request A': [0, 1, 2],      # 3 blocks, 12 tokens
        'Request B': [3, 4],         # 2 blocks, sharing block 4 with C later
        'Request C': [5, 4],         # Shares block 4 with B
        'Request D': [6, 7, 8, 9],   # 4 blocks
    }
    
    colors = plt.cm.Set3(np.linspace(0, 1, len(requests)))
    
    # Left: Block table mapping
    ax1.set_title('Block Table Mapping (Logical -> Physical)')
    y_pos = 0
    for (req_name, blocks), color in zip(requests.items(), colors):
        for i, block_id in enumerate(blocks):
            rect = plt.Rectangle((i, y_pos), 0.9, 0.9, 
                               facecolor=color, edgecolor='black')
            ax1.add_patch(rect)
            ax1.text(i + 0.45, y_pos + 0.45, f'P{block_id}', 
                    ha='center', va='center', fontsize=9)
        y_pos += 1.5
    
    ax1.set_xlim(-0.5, 5)
    ax1.set_ylim(-0.5, y_pos)
    ax1.set_yticks([i * 1.5 + 0.45 for i in range(len(requests))])
    ax1.set_yticklabels(requests.keys())
    ax1.set_xlabel('Logical Block Index')
    ax1.set_aspect('equal')
    
    # Middle: Physical memory layout
    ax2.set_title('Physical Memory Blocks')
    for i in range(num_blocks):
        x = i % 5
        y = i // 5
        is_used = any(i in blocks for blocks in requests.values())
        color = 'lightgreen' if is_used else 'white'
        rect = plt.Rectangle((x, y), 0.9, 0.9, 
                           facecolor=color, edgecolor='black')
        ax2.add_patch(rect)
        ax2.text(x + 0.45, y + 0.45, str(i), ha='center', va='center', fontsize=9)
        
        # Show sharing
        refs = sum(1 for blocks in requests.values() if i in blocks)
        if refs > 1:
            ax2.text(x + 0.45, y + 0.2, f'ref:{refs}', ha='center', va='center', 
                    fontsize=7, color='red')
    
    ax2.set_xlim(-0.5, 5)
    ax2.set_ylim(-0.5, 4)
    ax2.set_aspect('equal')
    
    # Right: Memory efficiency comparison
    ax3.set_title('Memory Efficiency: Paged vs Standard')
    
    seq_lengths = np.array([1024, 2048, 4096, 8192, 16384])
    
    # Standard cache: pre-allocate max length, internal fragmentation
    standard_memory = seq_lengths * 2 * 32 * 128 * 4 / (1024**2)  # MB, assuming FP32
    
    # PagedAttention: allocate as needed, block granularity
    block_size = 512
    blocks_needed = np.ceil(seq_lengths / block_size) * block_size
    paged_memory = blocks_needed * 2 * 32 * 128 * 4 / (1024**2)
    
    ax3.plot(seq_lengths, standard_memory, marker='o', label='Standard Cache', linewidth=2)
    ax3.plot(seq_lengths, paged_memory, marker='s', label='PagedAttention', linewidth=2)
    ax3.set_xlabel('Sequence Length')
    ax3.set_ylabel('Memory Usage (MB)')
    ax3.set_xscale('log', base=2)
    ax3.set_yscale('log', base=2)
    ax3.legend()
    ax3.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('paged_attention_memory.png', dpi=300)
    plt.show()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--method', type=str, default='benchmark',
                       choices=['speculative', 'paged', 'benchmark'])
    parser.add_argument('--prompt', type=str, default='The future of artificial intelligence')
    
    args = parser.parse_args()
    
    if args.method in ['speculative', 'benchmark']:
        print("Running Speculative Decoding benchmark...")
        benchmark_speculative_decoding()
    
    if args.method in ['paged', 'benchmark']:
        print("Visualizing PagedAttention memory management...")
        visualize_paged_attention()
Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐