CANN 模型压缩实战:剪枝、量化与知识蒸馏
CANN 模型压缩实战:剪枝、量化与知识蒸馏
模型越来越大如何在昇腾上高效部署?本文详解模型压缩的核心技术:结构化剪枝、量化感知训练与知识蒸馏的原理与实操。—
一、为什么需要模型压缩
1.1 大模型落地的挑战
随着 Transformer 架构和大规模预训练的普及,模型规模急剧膨胀。LLaMA-7B 模型在 FP16 下需要约 14 GB 显存,70B 模型需要 140 GB。这意味着昇腾 910(32 GB 显存)在单卡上甚至无法加载 70B 模型,更不用说推理时的激活值和 KV Cache 显存开销。
模型压缩是将大模型"变小"的核心技术,通过剪枝、量化和知识蒸馏三大手段,可以在保持模型精度可接受的前提下显著降低参数量、显存占用和推理延迟。
1.2 压缩技术全景
模型压缩技术
├── 剪枝(Pruning)
│ ├── 非结构化剪枝(单个参数置零)
│ ├── 结构化剪枝(移除整组参数,如通道/头)
│ └── 层级剪枝(移除整个层)
├── 量化(Quantization)
│ ├── PTQ(训练后量化)
│ └── QAT(量化感知训练)
└── 知识蒸馏(Knowledge Distillation)
├── 软标签蒸馏
├── 中间层蒸馏
└── 特征蒸馏
三种技术可以单独使用,也可以组合使用以获得更大压缩比。实践中通常先剪枝再量化,效果优于单独使用任意一种。
二、结构化剪枝
2.1 剪枝的理论基础
神经网络中存在大量冗余参数,这些参数对最终输出的贡献较小甚至没有贡献。剪枝的目标是识别并移除这些冗余参数,同时尽量不影响模型精度。
非结构化剪枝:将单个参数置零,实现简单但需要稀疏计算支持,昇腾硬件对稀疏矩阵的计算效率提升有限:
┌─────────────────────────────────────┐
│ 非结构化剪枝(稀疏) │
├─────────────────────────────────────┤
│ 原始: [2.1, -0.3, 1.5, -0.2, 0.8] │
│ 剪枝: [2.1, 0, 1.5, 0, 0.8] │
│ 压缩率: 40% │
│ 问题: 需要稀疏矩阵运算库支持 │
└─────────────────────────────────────┘
结构化剪枝:移除整组参数(通道、头部、层),剪枝后模型仍是密集矩阵形式,可直接加速:
┌─────────────────────────────────────┐
│ 结构化剪枝(通道级) │
├─────────────────────────────────────┤
│ 原始: Conv (64, 3, 3) │
│ 剪枝: Conv (32, 3, 3) │
│ 压缩率: 50% │
│ 优势: 无需特殊库,昇腾原生加速 │
└─────────────────────────────────────┘
2.2 基于幅值的通道剪枝
通道剪枝是最常用的结构化剪枝方法,通过评估每个通道的重要程度,移除不重要的通道:
import torch
import torch.nn as nn
def compute_channel_importance(model):
"""计算每个通道的 L1 范数作为重要性指标"""
importance = {}
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d):
# 计算每个输出通道的 L1 范数
weight = module.weight.data # [out_channels, in_channels, kH, kW]
channel_l1 = weight.abs().sum(dim=(1, 2, 3)) # [out_channels]
importance[name] = channel_l1
return importance
def prune_channels(model, importance, prune_ratio=0.3):
"""根据重要性剪枝不重要的通道"""
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d) and name in importance:
imp = importance[name]
# 找到剪枝阈值
threshold = torch.kthvalue(imp, int(len(imp) * prune_ratio))[0]
# 创建 mask
mask = imp > threshold
# 应用剪枝(保持原始 weight 结构但标记)
module.pruned_mask = mask
print(f"Pruned {name}: {mask.sum()}/{len(mask)} channels kept")
return model
2.3 剪枝与微调流程
剪枝后模型精度会下降,需要通过微调恢复:
剪枝流程:
数据 → 训练基础模型 → 评估通道重要性
→ 执行剪枝(稀疏)→ 微调恢复精度 → 结构化剪枝 → 最终模型
完整剪枝训练脚本:
def structured_prune_finetune(model, train_loader, prune_ratio=0.3, finetune_epochs=5):
# 1. 训练基础模型
print("Step 1: Training base model...")
model = train_epoch(model, train_loader, epochs=10)
# 2. 评估通道重要性
print("Step 2: Computing channel importance...")
importance = compute_channel_importance(model)
# 3. 执行剪枝
print("Step 3: Pruning channels...")
model = prune_channels(model, importance, prune_ratio)
# 4. 微调恢复精度
print("Step 4: Finetuning after pruning...")
# 剪枝后需要较大的学习率恢复
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
model = train_epoch(model, train_loader, epochs=finetune_epochs)
return model
2.4 昇腾剪枝支持
CANN 提供了剪枝相关的工具和接口:
# 昇腾模型压缩工具
from ascend.mindspore import Pruner
pruner = Pruner(
model=model,
strategy="channel", # 结构化通道剪枝
prune_ratio=0.3, # 剪枝比例
finetune_epochs=5, # 微调轮数
input_shape=[1, 3, 224, 224] # 输入shape
)
pruned_model = pruner.prune()
三、量化感知训练
3.1 量化原理
量化是将 FP32 参数映射到低比特表示的过程。INT8 量化将 32 位浮点数映射到 8 位整数,理论上可以将模型体积和显存减少 4 倍,推理速度提升 2-4 倍。
对称量化 vs 非对称量化:
| 类型 | 公式 | 适用场景 |
|---|---|---|
| 对称量化 | x_int = round(x_fp32 / scale) | 权重分布接近零对称 |
| 非对称量化 | x_int = round((x_fp32 - zero_point) / scale) | 激活值分布偏离零 |
对称量化的 scale 通常为 |max|/127,非对称量化需要额外存储 zero_point。
3.2 训练后量化(PTQ)
PTQ 是最简单的量化方式,不需要重新训练:
8.1 及之前:
import torch
# 简单 post-training quantization(无校准)
def naive_quantize(model, bits=8):
"""朴素量化:直接截断"""
for name, param in model.named_parameters():
if 'weight' in name:
# 对称量化,scale = max(abs(weight)) / 127
scale = param.abs().max() / 127
param.data = (param.data / scale).round() * scale
return model
8.2 新增(带校准的 PTQ):
from torch.quantization import per_channel_quantize_weight
def calibrate_and_quantize(model, calibration_loader, bits=8):
"""带校准的 PTQ:使用统计信息确定 scale"""
# 第一步:收集激活值的统计信息
def collect_stats(module, input, output):
if hasattr(module, 'activation_range'):
module.activation_range[0] = min(
module.activation_range[0],
output.min()
)
module.activation_range[1] = max(
module.activation_range[1],
output.max()
)
# 注册 hook 收集统计
model.apply(lambda m: m.register_forward_hook(collect_stats)
if hasattr(m, 'activation_range') else None)
# 运行校准数据
with torch.no_grad():
for batch in calibration_loader:
model(batch)
# 第二步:根据统计信息量化
for name, module in model.named_modules():
if hasattr(module, 'quant'):
module.quant(bits=bits)
return model
3.3 量化感知训练(QAT)
QAT 在训练过程中模拟量化效果,使模型适应低精度表示,通常能获得比 PTQ 更好的精度:
from torch.npu.amp import autocast
class FakeQuantization:
"""模拟量化:前向时加入量化噪声,反向时直通估计"""
def __init__(self, num_bits=8):
self.num_bits = num_bits
self.qmin = -(2 ** (num_bits - 1))
self.qmax = 2 ** (num_bits - 1) - 1
def forward(self, x):
scale = x.abs().max() / (2 ** (self.num_bits - 1) - 1)
# 模拟量化噪声
x_quant = (x / scale).round()
x_dequant = x_quant * scale
# Straight-through estimator:反向时跳过 round 操作
return x + (x_dequant - x).detach()
# 在模型中使用 FakeQuantization
class QuantizedLinear(nn.Module):
def __init__(self, in_features, out_features, num_bits=8):
super().__init__()
self.linear = nn.Linear(in_features, out_features)
self.fake_quant = FakeQuantization(num_bits)
def forward(self, x):
# 权重和激活都量化
weight = self.fake_quant(self.linear.weight)
return nn.functional.linear(x, weight, self.linear.bias)
3.4 QAT 完整训练流程
def train_with_qat(model, train_loader, num_bits=8, epochs=10):
"""量化感知训练"""
# 替换模型中的普通 Conv/Dense 为量化版本
model = replace_modules_with_quantized(model, num_bits)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
scaler = torch.cuda.amp.GradScaler()
for epoch in range(epochs):
for batch in train_loader:
data, target = data.npu(), target.npu()
optimizer.zero_grad()
# 前向时模拟量化
with autocast():
output = model(data)
loss = loss_fn(output, target)
# 反向(GradScaler 自动处理)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
# 每个 epoch 评估精度
accuracy = evaluate(model, val_loader)
print(f"Epoch {epoch+1}: Accuracy = {accuracy:.4f}")
return model
四、知识蒸馏
4.1 知识蒸馏原理
知识蒸馏使用大模型(教师模型)指导小模型(学生模型)学习。教师模型的输出包含比硬标签更丰富的信息——不仅知道正确答案,还知道错误答案的相对概率分布:
┌─────────────────────────────────────────┐
│ 知识蒸馏:软标签 vs 硬标签 │
├─────────────────────────────────────────┤
│ Hard Label: [0, 0, 1, 0, 0, ...] │
│ Soft Label: [0.01, 0.02, 0.85, 0.03, ...]
│ │
│ 软标签包含: │
│ - 正确答案的概率(0.85) │
│ - 错误答案的相对大小(0.03 vs 0.02) │
│ - 类别间的语义关系 │
└─────────────────────────────────────────┘
蒸馏温度 T 控制软标签的分布软硬程度:T 越高分布越软,T=1 时接近原始 softmax:
Softmax with Temperature: softmax(logits / T)
T=1: 分布较尖锐
T=2: 分布较平滑
T=10: 分布非常平滑,类别关系更明显
4.2 软标签蒸馏
基础软标签蒸馏:
import torch
import torch.nn.functional as F
class DistillationLoss(nn.Module):
"""知识蒸馏损失:结合硬标签损失和软标签损失"""
def __init__(self, temperature=4.0, alpha=0.7):
super().__init__()
self.T = temperature # 蒸馏温度
self.alpha = alpha # 软标签权重
def forward(self, student_logits, teacher_logits, labels):
# 硬标签损失(Cross Entropy)
ce_loss = F.cross_entropy(student_logits, labels)
# 软标签损失(KL Divergence)
soft_student = F.log_softmax(student_logits / self.T, dim=-1)
soft_teacher = F.softmax(teacher_logits / self.T, dim=-1)
kd_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (self.T ** 2)
# 加权组合
return self.alpha * kd_loss + (1 - self.alpha) * ce_loss
def train_with_distillation(student, teacher, train_loader):
"""蒸馏训练"""
teacher.eval() # 教师模型不更新
for param in teacher.parameters():
param.requires_grad = False
distill_loss = DistillationLoss(temperature=4.0, alpha=0.7)
optimizer = torch.optim.AdamW(student.parameters(), lr=1e-3)
for epoch in range(10):
for batch in train_loader:
data, target = data.npu(), target.npu()
# 教师预测(不计算梯度)
with torch.no_grad():
teacher_logits = teacher(data)
# 学生预测
student_logits = student(data)
# 蒸馏损失
loss = distill_loss(student_logits, teacher_logits, target)
loss.backward()
optimizer.step()
optimizer.zero_grad()
4.3 中间层蒸馏
除了 logits,还可以从教师模型的中间层提取知识:
class IntermediateDistillation(nn.Module):
"""中间层蒸馏:匹配学生和教师的隐层表示"""
def __init__(self, temperature=4.0, layer_mapping=None):
super().__init__()
self.T = temperature
self.layer_mapping = layer_mapping or {}
def compute_layer_loss(self, student_hidden, teacher_hidden, layer_idx):
"""MSE 损失匹配隐层"""
return F.mse_loss(student_hidden, teacher_hidden)
def forward(self, student_features, teacher_features, student_logits, teacher_logits, labels):
# 总损失 = 中间层损失 + logit 损失
layer_loss = 0
for s_layer, t_layer in self.layer_mapping.items():
layer_loss += self.compute_layer_loss(
student_features[s_layer],
teacher_features[t_layer]
)
logit_loss = F.kl_div(
F.log_softmax(student_logits / self.T),
F.softmax(teacher_logits / self.T),
reduction='batchmean'
) * (self.T ** 2)
return layer_loss + logit_loss
4.4 昇腾蒸馏工具
CANN 提供了 Model Compression Toolkit(MCT)支持自动蒸馏:
from ascend.mct import KnowledgeDistillation
kd = KnowledgeDistillation(
teacher_model=teacher, # 大模型(教师)
student_model=student, # 小模型(学生)
temperature=4.0,
loss_type="kl_divergence",
intermediate_matching={ # 中间层对应关系
"student.layer4": "teacher.layer6",
"student.layer5": "teacher.layer9"
}
)
# 自动蒸馏训练
compressed_model = kd.train(train_loader, epochs=10)
五、压缩技术组合
5.1 剪枝 + 量化组合
实际应用中,剪枝和量化通常组合使用以获得更大压缩比:
def prune_and_quantize(model, train_loader):
"""剪枝 + 量化的完整流程"""
# Step 1: 结构化剪枝
print("Step 1: Structured pruning...")
pruner = Pruner(model, strategy="channel", prune_ratio=0.3)
model = pruner.prune()
model = finetune(model, train_loader, epochs=3)
# Step 2: 量化感知训练
print("Step 2: Quantization-aware training...")
model = replace_with_quantized_modules(model, num_bits=8)
model = train_with_qat(model, train_loader, epochs=10)
# Step 3: 训练后量化精调
print("Step 3: Post-training quantization...")
model = calibrate_and_quantize(model, calibration_loader)
return model
# 压缩效果对比
print("压缩效果对比:")
print("| 方案 | 体积 | 精度保持 |")
print("|------|------|----------|")
print("| 原始 FP32 | 100% | 100% |")
print("| 剪枝 30% | 70% | 97% |")
print("| INT8 量化 | 25% | 95% |")
print("| 剪枝+量化 | 17.5% | 93% |")
5.2 压缩与部署流程
完整压缩部署流程:
大模型(FP32, 140GB)→ 结构化剪枝 → 中模型(98GB)
→ QAT 量化 → INT8 模型(35GB)→ 蒸馏小模型(7B, 14GB)
→ 部署到昇腾 910
六、压缩效果评估
6.1 评估指标
| 指标 | 说明 | 目标 |
|---|---|---|
| 压缩比 | 压缩后体积 / 原始体积 | 越小越好 |
| 精度损失 | 压缩后精度 - 原始精度 | < 2% 可接受 |
| 推理加速 | 原始时间 / 压缩后时间 | 越大越好 |
| 显存节省 | 原始显存 - 压缩后显存 | 越大越好 |
6.2 典型压缩效果
以 BERT-base 模型为例:
| 方法 | 参数量 | 体积 | 精度(F1) | 推理加速 |
|---|---|---|---|---|
| 原始 | 110M | 420 MB | 92.1 | 1x |
| 剪枝 30% | 77M | 290 MB | 91.5 | 1.4x |
| INT8 量化 | 110M | 105 MB | 91.3 | 2.3x |
| 剪枝+量化 | 77M | 26 MB | 90.8 | 3.1x |
| 蒸馏 (7B→110M) | 110M | 420 MB | 91.9 | 1.1x |
“模型压缩的核心收益:在昇腾硬件上实现 3-4 倍推理加速、75% 显存降低,同时精度损失控制在 2% 以内。”
七、常见问题
| 问题 | 原因 | 解决方案 |
|---|---|---|
| 剪枝后精度暴跌 | 剪枝率过高 | 分阶段剪枝,每阶段后微调 |
| 量化后精度下降 | 低比特精度不足 | 切换到 INT16 或使用 QAT |
| 推理速度无提升 | 量化 kernel 未命中 | 使用 CANN 量化工具重编优化 |
| 显存依然过高 | 激活值未压缩 | 结合动态量化压缩激活 |
| 蒸馏效果不佳 | 教师模型过强 | 逐步蒸馏,使用中间教师 |
相关仓库
- Model Compression Toolkit - 昇腾模型压缩工具 https://gitee.com/ascend/model-compression-toolkit
- ascend-toolkit - 量化校准工具和 Profiling 工具 https://gitee.com/ascend/ascend-toolkit
- torch_npu - PyTorch 量化支持 https://gitee.com/ascend/torch_npu
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐
所有评论(0)