模型剪枝:让神经网络“瘦身”的艺术
在深度学习中,大模型往往意味着更好的效果,但同时也带来了更高的计算成本和部署难度。模型剪枝通过移除“冗余”参数,让模型在不显著损失精度的情况下实现“瘦身”,加速推理、减小体积。本文将详细介绍模型剪枝的原理、方法与实践。
一、为什么需要模型剪枝?
随着深度学习模型规模的不断膨胀,BERT-base 有 1.1 亿参数,GPT-3 有 1750 亿参数,这些大模型在取得优异效果的同时,也带来了三个实际问题:
| 问题 | 具体表现 | 影响 |
|---|---|---|
| 推理延迟高 | 一次前向传播可能耗时数百毫秒甚至数秒 | 不适合实时交互场景(如语音助手、在线搜索) |
| 内存/显存占用大 | 参数多,中间激活值大 | 难以部署在移动端、嵌入式设备或边缘服务器 |
| 能耗高 | 每次推理消耗大量电能 | 大规模部署成本高昂,不符合绿色计算趋势 |
模型剪枝正是为了解决这些问题而生:通过移除模型中“不重要”的参数或结构,在尽量不牺牲精度的前提下,获得一个更小、更快的模型。
我们可以把深度学习模型想象成一棵茂盛的大树,剪枝就是剪掉那些枯枝、弱枝,让树变得更加精干,而不影响主要枝干的功能。
二、模型剪枝的类型
1. 非结构化剪枝(Unstructured Pruning)
做法:移除单个权重(神经元之间的连接),不改变网络结构,但会形成稀疏矩阵(大量参数为零)。
优点:
-
灵活,可以精细地移除不重要的权重
-
精度保留通常较高
缺点:
-
稀疏矩阵在标准硬件(GPU/CPU)上不易获得加速,需要专门的稀疏计算库(如 NVIDIA cuSPARSE)才能发挥速度优势
示例:对全连接层的权重矩阵,将绝对值小于某个阈值的元素置为零。
2. 结构化剪枝(Structured Pruning)
做法:移除整个结构单元,例如:
-
卷积层的一个通道
-
全连接层的一个神经元
-
Transformer 的一个注意力头
-
整个前馈网络(FFN)层或维度
优点:
-
移除后模型结构依然规则,可以直接在现有硬件上获得加速,无需特殊库
-
部署方便
缺点:
-
粗粒度,可能精度损失稍大
示例:BERT 中,可以移除某些注意力头,或者减少 FFN 的中间维度(从 3072 减到 2048)。
三、剪枝的经典流程
典型的剪枝流程通常包括以下步骤:
-
训练(或加载)原始大模型:达到较高的精度基准。
-
评估参数重要性:根据某种准则(如权重大小、梯度信息等)对参数进行排序。
-
剪枝:移除排名靠后(即“不重要”)的部分。
-
微调:用原始训练数据对剩余参数进行少量训练,恢复因剪枝损失的精度。
-
迭代:有时会反复多次进行“剪枝-微调”循环,以取得更好的压缩效果。
四、如何判断哪些参数“不重要”?
剪枝的核心问题是:如何确定哪些参数是冗余的? 以下是几种常见的重要性指标:
| 方法 | 说明 | 适用场景 |
|---|---|---|
| 基于权重大小 | 认为绝对值越小的权重越不重要。最简单直接。 | 非结构化剪枝常用,但并非总是最优。 |
| 基于梯度 | 结合权重和梯度,如 weight * gradient 表示该参数对损失的影响程度。 |
需要梯度信息,通常与训练结合。 |
| 基于激活值 | 移除那些激活值接近零的神经元。 | 结构化剪枝常用。 |
| 基于泰勒展开 | 计算移除某个参数后损失函数的近似变化,选择变化最小的参数移除。 | 理论较严谨,但计算量大。 |
| 基于正则化 | 在训练中加入稀疏正则化(如 L1 正则),使参数自动趋向于零,然后直接剪掉。 | 训练时即可稀疏化。 |
五、模型剪枝与其他压缩技术的关系
在实际工程中,剪枝往往与量化和知识蒸馏结合使用,以达到极致的压缩效果:
| 技术 | 作用 | 与剪枝的关系 |
|---|---|---|
| 量化 | 降低数值精度(如 FP32→INT8) | 可以组合使用:先剪枝,再量化,进一步压缩体积、加速推理。 |
| 知识蒸馏 | 用大模型教小模型 | 可视为一种“隐式剪枝”(小模型结构天然精简),两者可互补。 |
| 低秩分解 | 将权重矩阵分解为小矩阵乘积 | 也属于结构化压缩,但侧重点不同。 |
常见组合路线:
-
先用蒸馏得到一个结构紧凑的学生模型(如 6 层 Transformer)。
-
再对这个学生模型进行剪枝(如移除一些注意力头)。
-
最后量化到 INT8 部署。
六、在大模型(LLM)中的应用
大语言模型(如 LLaMA、GPT)的参数量巨大,剪枝对它们的部署至关重要。但 LLM 的剪枝面临一些独特挑战:
-
模型巨大,剪枝成本高:需要高效的剪枝算法。
-
精度敏感:即使移除少量参数,也可能导致模型出现“灾难性遗忘”。
目前主流的 LLM 剪枝方法倾向于结构化剪枝,以获取实际加速效果:
-
注意力头剪枝:移除一些低重要性的注意力头。
-
FFN 维度剪枝:缩减前馈网络的中间维度。
-
层剪枝:直接移除整个 Transformer 层。
这些剪枝通常需要结合微调(甚至重新预训练)来恢复精度。
七、代码示例:PyTorch 非结构化剪枝
以下示例使用 PyTorch 内置的剪枝工具,对 BERT 模型的一个线性层进行 L1 非结构化剪枝,剪掉 30% 的权重。
python
import torch
import torch.nn.utils.prune as prune
from transformers import BertForSequenceClassification
# 加载一个预训练 BERT 模型
model = BertForSequenceClassification.from_pretrained("bert-base-uncased")
# 对第一层注意力输出的全连接层进行 L1 剪枝,剪掉 30% 的权重
prune.l1_unstructured(
model.bert.encoder.layer[0].attention.output.dense,
name="weight",
amount=0.3
)
# 查看剪枝后的权重(weight_mask 记录了哪些权重被保留)
print(model.bert.encoder.layer[0].attention.output.dense.weight)
# 永久移除剪枝掩码,真正减小模型体积
prune.remove(
model.bert.encoder.layer[0].attention.output.dense,
"weight"
)
# 现在权重矩阵已经变为稀疏张量
print(model.bert.encoder.layer[0].attention.output.dense.weight)
注意:非结构化剪枝后的稀疏矩阵在标准硬件上不一定能直接加速。如需实际加速,推荐使用结构化剪枝或配合稀疏推理库(如 DeepSparse、TensorRT)。
八、剪枝的优缺点总结
| 优点 | 缺点 |
|---|---|
| 显著减小模型体积(2-4倍) | 非结构化剪枝需要特殊硬件/库才能加速 |
| 推理速度提升(尤其是结构化剪枝) | 可能造成精度损失,需要微调恢复 |
| 降低内存带宽和功耗 | 剪枝策略选择复杂,需实验调优 |
| 可与其他压缩技术叠加使用 | 对大型模型(LLM)剪枝风险较高,需谨慎 |
九、学习建议与实践方向
如果你正在学习大模型应用开发,剪枝是模型部署环节的重要技能。你可以按以下步骤循序渐进:
-
基础实践:用 PyTorch 对一个小型分类模型(如 ResNet-18 或 BERT-base)进行非结构化剪枝,观察模型大小和精度的变化。
-
深入理解:尝试结构化剪枝(如移除 Transformer 的注意力头),并评估实际推理加速效果。
-
结合部署:将剪枝后的模型导出为 ONNX 格式,用 ONNX Runtime 或 TensorRT 测试 CPU/GPU 推理速度。
-
挑战 LLM:使用开源库(如
llm-pruner、ShearedLLaMA)尝试对 7B 模型进行层剪枝或头剪枝,体验大模型压缩。
模型剪枝是一项非常实用的工程技能,掌握它能让你在资源受限场景下更灵活地部署 AI 应用。如果在实践中遇到具体问题,欢迎留言交流!
参考资料:
-
Han, S., et al. (2015). Learning both Weights and Connections for Efficient Neural Networks.
-
Fang, G., et al. (2023). DepGraph: Towards Any Structural Pruning.
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐

所有评论(0)