在深度学习中,大模型往往意味着更好的效果,但同时也带来了更高的计算成本和部署难度。模型剪枝通过移除“冗余”参数,让模型在不显著损失精度的情况下实现“瘦身”,加速推理、减小体积。本文将详细介绍模型剪枝的原理、方法与实践。


一、为什么需要模型剪枝?

随着深度学习模型规模的不断膨胀,BERT-base 有 1.1 亿参数,GPT-3 有 1750 亿参数,这些大模型在取得优异效果的同时,也带来了三个实际问题:

问题 具体表现 影响
推理延迟高 一次前向传播可能耗时数百毫秒甚至数秒 不适合实时交互场景(如语音助手、在线搜索)
内存/显存占用大 参数多,中间激活值大 难以部署在移动端、嵌入式设备或边缘服务器
能耗高 每次推理消耗大量电能 大规模部署成本高昂,不符合绿色计算趋势

模型剪枝正是为了解决这些问题而生:通过移除模型中“不重要”的参数或结构,在尽量不牺牲精度的前提下,获得一个更小、更快的模型。

我们可以把深度学习模型想象成一棵茂盛的大树,剪枝就是剪掉那些枯枝、弱枝,让树变得更加精干,而不影响主要枝干的功能。


二、模型剪枝的类型

1. 非结构化剪枝(Unstructured Pruning)

做法:移除单个权重(神经元之间的连接),不改变网络结构,但会形成稀疏矩阵(大量参数为零)。

优点

  • 灵活,可以精细地移除不重要的权重

  • 精度保留通常较高

缺点

  • 稀疏矩阵在标准硬件(GPU/CPU)上不易获得加速,需要专门的稀疏计算库(如 NVIDIA cuSPARSE)才能发挥速度优势

示例:对全连接层的权重矩阵,将绝对值小于某个阈值的元素置为零。

2. 结构化剪枝(Structured Pruning)

做法:移除整个结构单元,例如:

  • 卷积层的一个通道

  • 全连接层的一个神经元

  • Transformer 的一个注意力头

  • 整个前馈网络(FFN)层或维度

优点

  • 移除后模型结构依然规则,可以直接在现有硬件上获得加速,无需特殊库

  • 部署方便

缺点

  • 粗粒度,可能精度损失稍大

示例:BERT 中,可以移除某些注意力头,或者减少 FFN 的中间维度(从 3072 减到 2048)。


三、剪枝的经典流程

典型的剪枝流程通常包括以下步骤:

  1. 训练(或加载)原始大模型:达到较高的精度基准。

  2. 评估参数重要性:根据某种准则(如权重大小、梯度信息等)对参数进行排序。

  3. 剪枝:移除排名靠后(即“不重要”)的部分。

  4. 微调:用原始训练数据对剩余参数进行少量训练,恢复因剪枝损失的精度。

  5. 迭代:有时会反复多次进行“剪枝-微调”循环,以取得更好的压缩效果。


四、如何判断哪些参数“不重要”?

剪枝的核心问题是:如何确定哪些参数是冗余的? 以下是几种常见的重要性指标:

方法 说明 适用场景
基于权重大小 认为绝对值越小的权重越不重要。最简单直接。 非结构化剪枝常用,但并非总是最优。
基于梯度 结合权重和梯度,如 weight * gradient 表示该参数对损失的影响程度。 需要梯度信息,通常与训练结合。
基于激活值 移除那些激活值接近零的神经元。 结构化剪枝常用。
基于泰勒展开 计算移除某个参数后损失函数的近似变化,选择变化最小的参数移除。 理论较严谨,但计算量大。
基于正则化 在训练中加入稀疏正则化(如 L1 正则),使参数自动趋向于零,然后直接剪掉。 训练时即可稀疏化。

五、模型剪枝与其他压缩技术的关系

在实际工程中,剪枝往往与量化知识蒸馏结合使用,以达到极致的压缩效果:

技术 作用 与剪枝的关系
量化 降低数值精度(如 FP32→INT8) 可以组合使用:先剪枝,再量化,进一步压缩体积、加速推理。
知识蒸馏 用大模型教小模型 可视为一种“隐式剪枝”(小模型结构天然精简),两者可互补。
低秩分解 将权重矩阵分解为小矩阵乘积 也属于结构化压缩,但侧重点不同。

常见组合路线

  1. 先用蒸馏得到一个结构紧凑的学生模型(如 6 层 Transformer)。

  2. 再对这个学生模型进行剪枝(如移除一些注意力头)。

  3. 最后量化到 INT8 部署。


六、在大模型(LLM)中的应用

大语言模型(如 LLaMA、GPT)的参数量巨大,剪枝对它们的部署至关重要。但 LLM 的剪枝面临一些独特挑战:

  • 模型巨大,剪枝成本高:需要高效的剪枝算法。

  • 精度敏感:即使移除少量参数,也可能导致模型出现“灾难性遗忘”。

目前主流的 LLM 剪枝方法倾向于结构化剪枝,以获取实际加速效果:

  • 注意力头剪枝:移除一些低重要性的注意力头。

  • FFN 维度剪枝:缩减前馈网络的中间维度。

  • 层剪枝:直接移除整个 Transformer 层。

这些剪枝通常需要结合微调(甚至重新预训练)来恢复精度。


七、代码示例:PyTorch 非结构化剪枝

以下示例使用 PyTorch 内置的剪枝工具,对 BERT 模型的一个线性层进行 L1 非结构化剪枝,剪掉 30% 的权重。

python

import torch
import torch.nn.utils.prune as prune
from transformers import BertForSequenceClassification

# 加载一个预训练 BERT 模型
model = BertForSequenceClassification.from_pretrained("bert-base-uncased")

# 对第一层注意力输出的全连接层进行 L1 剪枝,剪掉 30% 的权重
prune.l1_unstructured(
    model.bert.encoder.layer[0].attention.output.dense, 
    name="weight", 
    amount=0.3
)

# 查看剪枝后的权重(weight_mask 记录了哪些权重被保留)
print(model.bert.encoder.layer[0].attention.output.dense.weight)

# 永久移除剪枝掩码,真正减小模型体积
prune.remove(
    model.bert.encoder.layer[0].attention.output.dense, 
    "weight"
)

# 现在权重矩阵已经变为稀疏张量
print(model.bert.encoder.layer[0].attention.output.dense.weight)

注意:非结构化剪枝后的稀疏矩阵在标准硬件上不一定能直接加速。如需实际加速,推荐使用结构化剪枝或配合稀疏推理库(如 DeepSparse、TensorRT)。


八、剪枝的优缺点总结

优点 缺点
显著减小模型体积(2-4倍) 非结构化剪枝需要特殊硬件/库才能加速
推理速度提升(尤其是结构化剪枝) 可能造成精度损失,需要微调恢复
降低内存带宽和功耗 剪枝策略选择复杂,需实验调优
可与其他压缩技术叠加使用 对大型模型(LLM)剪枝风险较高,需谨慎

九、学习建议与实践方向

如果你正在学习大模型应用开发,剪枝是模型部署环节的重要技能。你可以按以下步骤循序渐进:

  1. 基础实践:用 PyTorch 对一个小型分类模型(如 ResNet-18 或 BERT-base)进行非结构化剪枝,观察模型大小和精度的变化。

  2. 深入理解:尝试结构化剪枝(如移除 Transformer 的注意力头),并评估实际推理加速效果。

  3. 结合部署:将剪枝后的模型导出为 ONNX 格式,用 ONNX Runtime 或 TensorRT 测试 CPU/GPU 推理速度。

  4. 挑战 LLM:使用开源库(如 llm-prunerShearedLLaMA)尝试对 7B 模型进行层剪枝或头剪枝,体验大模型压缩。

模型剪枝是一项非常实用的工程技能,掌握它能让你在资源受限场景下更灵活地部署 AI 应用。如果在实践中遇到具体问题,欢迎留言交流!


参考资料

  • PyTorch 剪枝官方教程

  • Han, S., et al. (2015). Learning both Weights and Connections for Efficient Neural Networks.

  • Fang, G., et al. (2023). DepGraph: Towards Any Structural Pruning.

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐