模型剪枝与知识蒸馏:压缩大模型的两种路径与工程取舍

一、模型压缩的必要性:精度与效率的永恒博弈

大模型的推理成本与参数量成正比。一个 7B 参数的模型在 FP16 下需要 14GB 显存存储权重,推理时还需要额外的 KV Cache 和激活值空间。在边缘设备或低成本服务器上部署时,模型必须被压缩。两种主流压缩路径是剪枝(Pruning)和知识蒸馏(Knowledge Distillation)。

剪枝直接移除模型中的冗余参数,减少计算量和内存占用。知识蒸馏训练一个小模型(Student)来模仿大模型(Teacher)的输出分布,不改变模型结构但减少参数量。两者的适用场景不同:剪枝适合保留原模型架构的场景,蒸馏适合可以接受更小模型架构的场景。

二、剪枝与蒸馏的机制对比:结构压缩 vs 知识迁移

剪枝分为非结构化剪枝(将个别权重置零)和结构化剪枝(移除整个通道或注意力头)。非结构化剪枝的压缩率高但需要稀疏计算硬件支持,结构化剪枝的压缩率低但可以直接在标准 GPU 上加速。知识蒸馏的核心是让 Student 学习 Teacher 的软标签(Soft Labels)——Teacher 输出的概率分布比硬标签包含更多信息。

flowchart TB
    A[大模型 Teacher] --> B{压缩路径}

    B --> C[剪枝]
    B --> D[知识蒸馏]

    C --> C1[非结构化剪枝<br/>权重级稀疏]
    C --> C2[结构化剪枝<br/>通道/头级移除]

    C1 --> C3[优点: 压缩率高 90%+<br/>缺点: 需稀疏硬件]
    C2 --> C4[优点: 通用 GPU 加速<br/>缺点: 压缩率有限 50-70%]

    D --> D1[软标签蒸馏<br/>学习概率分布]
    D --> D2[特征蒸馏<br/>学习中间表示]
    D --> D3[关系蒸馏<br/>学习样本间关系]

    D1 --> D5[优点: 灵活选择 Student<br/>缺点: 需要训练资源]
    D2 --> D5
    D3 --> D5

    C3 --> E[部署: 稀疏推理引擎]
    C4 --> F[部署: 标准 ONNX Runtime]
    D5 --> G[部署: 标准推理引擎]

关键差异:剪枝是"做减法",保留原模型结构但移除部分参数;蒸馏是"做迁移",用小模型继承大模型的知识。两者可以组合——先剪枝再蒸馏,或先蒸馏再剪枝,但组合的收益不一定叠加。

三、生产级代码实现:结构化剪枝与知识蒸馏

3.1 幅度剪枝:基于权重绝对值的结构化剪枝

import torch
import torch.nn as nn
import numpy as np

class MagnitudePruner:
    """幅度剪枝器:移除绝对值最小的权重"""

    def __init__(self, model, pruning_ratio=0.5):
        self.model = model
        self.pruning_ratio = pruning_ratio
        self.masks = {}

    def compute_masks(self):
        """计算剪枝掩码"""
        for name, param in self.model.named_parameters():
            if "weight" not in name:
                continue

            # 计算每个输出通道的 L1 范数
            # 为什么用 L1 范数而非 L2:L1 范数对
            # 小权重更敏感,更适合识别"不重要"的通道;
            # L2 范数会被少数大权重主导
            if param.dim() >= 2:
                # 对卷积/线性层:按输出通道计算重要性
                importance = param.abs().sum(
                    dim=tuple(range(1, param.dim())))
            else:
                importance = param.abs()

            # 确定阈值:保留 top-k 通道
            k = int(len(importance) * (1 - self.pruning_ratio))
            if k <= 0:
                k = 1

            threshold = torch.topk(importance, k).values[-1]

            # 创建掩码
            if param.dim() >= 2:
                channel_mask = (importance >= threshold).float()
                # 扩展掩码到所有维度
                expand_shape = [-1] + [1] * (param.dim() - 1)
                mask = channel_mask.view(*expand_shape).expand_as(
                    param)
            else:
                mask = (importance >= threshold).float()

            self.masks[name] = mask

    def apply_masks(self):
        """应用剪枝掩码"""
        for name, param in self.model.named_parameters():
            if name in self.masks:
                # 用掩码将不重要的权重置零
                param.data.mul_(self.masks[name])

    def fine_tune(self, train_loader, epochs=5, lr=1e-4):
        """剪枝后微调恢复精度"""
        # 为什么剪枝后需要微调:直接剪枝会导致
        # 精度大幅下降,微调让剩余权重重新适应
        # 被移除通道的功能
        optimizer = torch.optim.AdamW(
            self.model.parameters(), lr=lr)

        for epoch in range(epochs):
            for batch in train_loader:
                optimizer.zero_grad()
                output = self.model(batch["input"])
                loss = nn.CrossEntropyLoss()(
                    output, batch["label"])
                loss.backward()
                optimizer.step()

                # 每步训练后重新应用掩码
                # 为什么每步都应用:梯度更新可能
                # 让被剪枝的权重变为非零,
                # 必须持续掩码才能维持稀疏结构
                self.apply_masks()

            print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

3.2 知识蒸馏:软标签与温度调节

class KnowledgeDistiller:
    """知识蒸馏训练器"""

    def __init__(self, teacher, student,
                 temperature=4.0, alpha=0.7):
        self.teacher = teacher
        self.student = student
        self.temperature = temperature
        # alpha: 蒸馏损失权重, 1-alpha: 硬标签损失权重
        # 为什么需要两个损失:纯蒸馏损失可能忽略
        # 真实标签的信息,混合损失兼顾两者
        self.alpha = alpha

        # Teacher 冻结参数
        for param in self.teacher.parameters():
            param.requires_grad = False
        self.teacher.eval()

    def distillation_loss(self, student_logits,
                          teacher_logits, labels):
        """计算蒸馏损失"""
        # 软标签损失:KL 散度
        # 为什么用 KL 散度而非 MSE:KL 散度衡量
        # 两个概率分布的差异,与交叉熵等价;
        # MSE 衡量 logits 的数值差异,不保证
        # 概率分布的语义一致性
        soft_targets = nn.functional.softmax(
            teacher_logits / self.temperature, dim=-1)
        soft_student = nn.functional.log_softmax(
            student_logits / self.temperature, dim=-1)

        # KL 散度 × T^2:补偿温度缩放导致的梯度缩小
        # 为什么乘 T^2:温度 T 使概率分布变平滑,
        # 梯度被缩小 T 倍;乘 T^2 恢复梯度量级
        kd_loss = nn.functional.kl_div(
            soft_student, soft_targets,
            reduction="batchmean"
        ) * (self.temperature ** 2)

        # 硬标签损失:标准交叉熵
        ce_loss = nn.functional.cross_entropy(
            student_logits, labels)

        # 加权组合
        total_loss = (
            self.alpha * kd_loss +
            (1 - self.alpha) * ce_loss
        )
        return total_loss

    def train_step(self, batch, optimizer):
        """单步蒸馏训练"""
        inputs = batch["input"]
        labels = batch["label"]

        # Teacher 推理(不计算梯度)
        with torch.no_grad():
            teacher_logits = self.teacher(inputs)

        # Student 推理
        student_logits = self.student(inputs)

        # 计算蒸馏损失
        loss = self.distillation_loss(
            student_logits, teacher_logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        return loss.item()

    def train(self, train_loader, val_loader,
              epochs=20, lr=3e-4):
        """完整蒸馏训练流程"""
        optimizer = torch.optim.AdamW(
            self.student.parameters(), lr=lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=epochs)

        best_val_acc = 0.0
        for epoch in range(epochs):
            self.student.train()
            total_loss = 0

            for batch in train_loader:
                loss = self.train_step(batch, optimizer)
                total_loss += loss

            scheduler.step()

            # 验证
            val_acc = self._evaluate(val_loader)
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save(self.student.state_dict(),
                    "best_student.pt")

            avg_loss = total_loss / len(train_loader)
            print(f"Epoch {epoch}: Loss={avg_loss:.4f}, "
                  f"Val Acc={val_acc:.4f}")

        return best_val_acc

    def _evaluate(self, val_loader):
        self.student.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for batch in val_loader:
                outputs = self.student(batch["input"])
                preds = outputs.argmax(dim=-1)
                correct += (preds == batch["label"]).sum().item()
                total += len(batch["label"])
        return correct / total

3.3 剪枝效果评估

def evaluate_pruning(original_model, pruned_model,
                     test_loader):
    """评估剪枝效果"""
    # 精度对比
    orig_acc = evaluate_accuracy(original_model, test_loader)
    pruned_acc = evaluate_accuracy(pruned_model, test_loader)

    # 参数量对比
    orig_params = sum(p.numel()
        for p in original_model.parameters())
    pruned_params = sum(p.numel()
        for p in pruned_model.parameters())
    # 非零参数量
    nonzero_params = sum(
        (p != 0).sum().item()
        for p in pruned_model.parameters())

    # 推理速度对比
    orig_latency = benchmark_latency(original_model)
    pruned_latency = benchmark_latency(pruned_model)

    print(f"原始模型: 参数={orig_params/1e6:.1f}M, "
          f"精度={orig_acc:.4f}, 延迟={orig_latency:.2f}ms")
    print(f"剪枝模型: 参数={nonzero_params/1e6:.1f}M, "
          f"精度={pruned_acc:.4f}, 延迟={pruned_latency:.2f}ms")
    print(f"压缩率: {1 - nonzero_params/orig_params:.2%}")
    print(f"精度损失: {orig_acc - pruned_acc:.4f}")
    print(f"加速比: {orig_latency/pruned_latency:.2f}x")

四、模型压缩的架构权衡:精度、加速比与部署复杂度

剪枝的精度恢复瓶颈:50% 剪枝率下,微调通常能恢复大部分精度;70% 以上剪枝率时,微调的精度恢复越来越困难,因为被移除的参数中包含了不可替代的信息。建议从 30% 剪枝率开始,逐步增加直到精度下降不可接受。

蒸馏的 Student 架构选择:Student 太小(如 2 层 Transformer)无法学习 Teacher 的复杂表示,精度损失大;Student 太大(如与 Teacher 同构)则压缩效果有限。经验法则:Student 参数量约为 Teacher 的 1/4 到 1/2,层数约为 Teacher 的 1/2。

剪枝与蒸馏的组合顺序:先剪枝再蒸馏,Teacher 是原始大模型,Student 是剪枝后的模型——蒸馏帮助恢复剪枝损失的精度。先蒸馏再剪枝,先得到一个中等大小的 Student,再对 Student 剪枝——最终模型更小但精度损失可能更大。建议优先尝试"先剪枝再蒸馏"的路径。

部署端的稀疏推理支持:非结构化剪枝的加速依赖稀疏矩阵运算,但大多数推理引擎(ONNX Runtime、TensorRT)对稀疏运算的优化有限。实际加速比可能远低于理论压缩率。结构化剪枝虽然压缩率低,但加速比更可预测。

五、总结

模型压缩的两种路径各有适用场景。剪枝适合需要保留原模型架构的场景,结构化剪枝的加速比更可预测;蒸馏适合可以接受更小模型架构的场景,灵活性更高。落地时建议先尝试知识蒸馏(实现更简单、风险更低),如果压缩比不够再叠加结构化剪枝。压缩后的模型必须在实际业务数据上验证精度,不能只看公开数据集的结果。温度参数和 alpha 权重是蒸馏效果的关键超参数,需要网格搜索确定。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐