知识体系篇-数据标注与处理(04)模型测试与评估:模型鲁棒性测试
·
模型鲁棒性测试
专栏:人工智能训练师(三级)备考全攻略
模块:卷三·知识体系 — 第四部分·模型测试与评估
难度:⭐⭐⭐☆☆
考试权重:中高频(选择+简答)
一、什么是鲁棒性?
鲁棒性(Robustness)定义:
模型在面对输入扰动、噪声、分布偏移、对抗攻击等
不理想条件时,仍能维持稳定性能的能力。
直觉理解:
┌─────────────────────────────────────────────┐
│ "这个产品很好用" → 正面 ✅ │
│ "这个产品很好用!"→ 正面 ✅(加感叹号) │
│ "这个产晶很好用" → 正面 ✅(错别字) │
│ "这个产品很好用😊"→ 正面 ✅(Emoji) │
│ "zhege chanpin hen haoyong" → 正面 ✅ │
└─────────────────────────────────────────────┘
以上扰动对人类毫无影响,鲁棒的模型也应如此。
鲁棒性 vs 准确率
准确率(Accuracy):干净数据上的表现
鲁棒性(Robustness):扰动数据上的表现稳定性
两者关系:
高准确率 ≠ 高鲁棒性
例:对抗样本攻击可使准确率从99%瞬间降到<10%
而普通人根本看不出样本有何异常
二、鲁棒性测试的四大类型
┌──────────────────────────────────────────────────┐
│ 鲁棒性测试全景 │
├─────────────────┬────────────────────────────────┤
│ 自然扰动测试 │ 对抗攻击测试 │
│ ───────────── │ ───────────────────────── │
│ 拼写错误 │ FGSM(梯度方向攻击) │
│ 多余标点 │ PGD(迭代攻击) │
│ 大小写变化 │ C&W攻击 │
│ 噪声注入 │ TextFooler(文本对抗) │
├─────────────────┼────────────────────────────────┤
│ 分布偏移测试 │ 边界条件测试 │
│ ───────────── │ ───────────────────────── │
│ 时间域偏移 │ 超长/超短输入 │
│ 领域偏移 │ 空输入/None值 │
│ 设备/传感器变化 │ 特殊字符/Unicode │
└─────────────────┴────────────────────────────────┘
2.1 各类型详解
| 类型 | 定义 | 示例 | 理想结果 |
|---|---|---|---|
| 自然扰动 | 现实中自然发生的输入变化 | 打字错误、口语表达 | 分类不变 |
| 对抗攻击 | 刻意构造的"人眼难辨"的输入 | 图像像素微调、文本改词 | 检测并拒绝/保持正确 |
| 分布偏移 | 测试数据与训练数据来自不同分布 | 模型训练于正式文本,测试于网络用语 | 性能优雅降级 |
| 边界条件 | 极端输入值 | 空字符串、超长文本、全符号文本 | 不崩溃,给出合理输出 |
三、文本鲁棒性测试实战
import re
import random
import string
# =========================================
# 文本扰动函数库
# =========================================
class TextPerturbator:
"""文本扰动工具,用于鲁棒性测试"""
# 同音/形近字映射
TYPO_MAP = {
'的': '地', '地': '的', '得': '的',
'再': '在', '在': '再',
'做': '作', '作': '做',
'是': '识', '识': '是',
'已': '以', '以': '已',
}
def add_punctuation(self, text: str) -> str:
"""在随机位置插入多余标点"""
puncs = ['!', ',', '…', '~', '、']
pos = random.randint(len(text)//2, len(text))
return text[:pos] + random.choice(puncs) + text[pos:]
def add_whitespace(self, text: str) -> str:
"""在字符间插入随机空格"""
return ' '.join(list(text))
def typo(self, text: str) -> str:
"""同音/形近字替换"""
result = list(text)
for i, char in enumerate(result):
if char in self.TYPO_MAP and random.random() < 0.2:
result[i] = self.TYPO_MAP[char]
return ''.join(result)
def delete_char(self, text: str, ratio: float = 0.1) -> str:
"""随机删除字符"""
n_delete = max(1, int(len(text) * ratio))
indices = sorted(random.sample(range(len(text)), n_delete), reverse=True)
result = list(text)
for idx in indices:
result.pop(idx)
return ''.join(result)
def add_emoji(self, text: str) -> str:
"""在末尾添加Emoji"""
emojis = ['😊', '👍', '❌', '🔥', '💯']
return text + random.choice(emojis)
def repeat_chars(self, text: str) -> str:
"""随机重复某个字符"""
if not text:
return text
pos = random.randint(0, len(text) - 1)
return text[:pos] + text[pos] * random.randint(2, 4) + text[pos+1:]
def uppercase(self, text: str) -> str:
"""转为全大写(英文)"""
return text.upper()
# =========================================
# 鲁棒性测试框架
# =========================================
class RobustnessEvaluator:
"""
系统化评估模型对各类文本扰动的鲁棒性
"""
def __init__(self, model, perturbator=None):
self.model = model
self.p = perturbator or TextPerturbator()
def evaluate(self, test_samples: list, n_perturbations: int = 10):
"""
对每条测试样本生成多种扰动,评估预测一致性
Args:
test_samples: [{"text": "...", "label": "..."}]
n_perturbations: 每条样本生成多少种扰动
Returns:
robustness_report: dict
"""
perturbation_methods = {
"add_punctuation": self.p.add_punctuation,
"add_whitespace": self.p.add_whitespace,
"typo": self.p.typo,
"delete_char": self.p.delete_char,
"add_emoji": self.p.add_emoji,
"repeat_chars": self.p.repeat_chars,
}
results = {method: {"total": 0, "consistent": 0}
for method in perturbation_methods}
for sample in test_samples:
original_text = sample["text"]
original_pred = self.model.predict(original_text)
for method_name, method_fn in perturbation_methods.items():
for _ in range(n_perturbations):
perturbed_text = method_fn(original_text)
perturbed_pred = self.model.predict(perturbed_text)
results[method_name]["total"] += 1
if perturbed_pred == original_pred:
results[method_name]["consistent"] += 1
# 计算每类扰动的一致性率
report = {}
for method, counts in results.items():
consistency = counts["consistent"] / counts["total"]
report[method] = {
"consistency_rate": consistency,
"pass": consistency >= 0.90 # 阈值:90%一致
}
overall = sum(r["consistent"] for r in results.values()) / \
sum(r["total"] for r in results.values())
report["overall"] = {"consistency_rate": overall, "pass": overall >= 0.90}
return report
def print_report(self, report: dict):
"""打印鲁棒性评估报告"""
print("=" * 50)
print(" 模型鲁棒性测试报告")
print("=" * 50)
for method, result in report.items():
status = "✅ PASS" if result["pass"] else "❌ FAIL"
rate = result["consistency_rate"]
print(f"{method:<25} {rate:.1%} {status}")
print("=" * 50)
四、图像鲁棒性测试
import numpy as np
from PIL import Image, ImageFilter
import torch
import torchvision.transforms.functional as TF
# =========================================
# 图像扰动函数
# =========================================
def add_gaussian_noise(image: np.ndarray, std: float = 0.05) -> np.ndarray:
"""添加高斯噪声"""
noise = np.random.normal(0, std, image.shape)
return np.clip(image + noise, 0, 1)
def add_salt_pepper_noise(image: np.ndarray, prob: float = 0.02) -> np.ndarray:
"""添加椒盐噪声"""
output = image.copy()
# 椒(黑点)
mask = np.random.random(image.shape[:2]) < prob / 2
output[mask] = 0
# 盐(白点)
mask = np.random.random(image.shape[:2]) < prob / 2
output[mask] = 1
return output
def blur_image(image: np.ndarray, radius: float = 2.0) -> np.ndarray:
"""高斯模糊"""
pil_img = Image.fromarray((image * 255).astype(np.uint8))
blurred = pil_img.filter(ImageFilter.GaussianBlur(radius=radius))
return np.array(blurred) / 255.0
def adjust_brightness(image: np.ndarray, factor: float = 0.5) -> np.ndarray:
"""调整亮度(factor<1变暗,factor>1变亮)"""
return np.clip(image * factor, 0, 1)
def jpeg_compression(image: np.ndarray, quality: int = 20) -> np.ndarray:
"""JPEG压缩模拟"""
import io
pil_img = Image.fromarray((image * 255).astype(np.uint8))
buffer = io.BytesIO()
pil_img.save(buffer, format="JPEG", quality=quality)
buffer.seek(0)
compressed = Image.open(buffer)
return np.array(compressed) / 255.0
# =========================================
# 图像鲁棒性系统测试
# =========================================
def evaluate_image_robustness(model, test_loader, device="cpu"):
"""
评估模型对图像扰动的鲁棒性
返回各扰动类型下的准确率
"""
perturbations = {
"clean": lambda x: x,
"gaussian_noise_weak": lambda x: add_gaussian_noise(x, std=0.05),
"gaussian_noise_strong":lambda x: add_gaussian_noise(x, std=0.15),
"blur": lambda x: blur_image(x, radius=2.0),
"brightness_dark": lambda x: adjust_brightness(x, 0.5),
"brightness_bright": lambda x: adjust_brightness(x, 1.5),
"jpeg_compress": lambda x: jpeg_compression(x, quality=20),
}
results = {k: {"correct": 0, "total": 0} for k in perturbations}
model.eval()
with torch.no_grad():
for images, labels in test_loader:
for name, perturb_fn in perturbations.items():
# 应用扰动
perturbed = torch.stack([
torch.FloatTensor(perturb_fn(img.numpy()))
for img in images
])
outputs = model(perturbed.to(device))
preds = outputs.argmax(dim=1)
results[name]["correct"] += (preds.cpu() == labels).sum().item()
results[name]["total"] += len(labels)
# 打印报告
print(f"\n{'扰动类型':<25} {'准确率':>10} {'vs 干净数据':>10}")
print("-" * 48)
clean_acc = results["clean"]["correct"] / results["clean"]["total"]
for name, counts in results.items():
acc = counts["correct"] / counts["total"]
diff = acc - clean_acc
diff_str = f"{diff:+.1%}"
print(f"{name:<25} {acc:.1%} {diff_str:>10}")
return results
五、对抗攻击基础
5.1 FGSM 对抗攻击原理
FGSM(Fast Gradient Sign Method)
核心思想:沿梯度方向微小扰动,欺骗模型
扰动公式:
x_adv = x + ε × sign(∇ₓ J(θ, x, y))
其中:
x = 原始输入
ε = 扰动幅度(如0.01,人眼不可见)
J = 损失函数
∇ₓJ = 损失函数对输入的梯度
sign = 符号函数(正→1,负→-1)
直觉:
普通训练最小化 J(让模型预测正确)
FGSM 沿最大化 J 的方向微调输入
→ 模型分类从"猫"变"狗",人眼仍看是"猫"
import torch
import torch.nn as nn
def fgsm_attack(model, loss_fn, images, labels, epsilon=0.01):
"""
FGSM 对抗样本生成
Args:
epsilon: 扰动幅度(通常0.01~0.1)
Returns:
adv_images: 对抗样本
"""
images.requires_grad = True
# 正向传播
outputs = model(images)
loss = loss_fn(outputs, labels)
# 反向传播,计算梯度
model.zero_grad()
loss.backward()
# 沿梯度符号方向扰动
gradient_sign = images.grad.sign()
adv_images = images + epsilon * gradient_sign
# 裁剪到合法范围[0,1]
adv_images = torch.clamp(adv_images, 0, 1)
return adv_images.detach()
# 评估模型在对抗样本上的准确率
def evaluate_adversarial_robustness(model, test_loader, epsilon=0.01):
"""评估FGSM攻击下的准确率"""
loss_fn = nn.CrossEntropyLoss()
clean_correct = 0
adv_correct = 0
total = 0
for images, labels in test_loader:
# 干净样本准确率
outputs = model(images)
clean_correct += (outputs.argmax(dim=1) == labels).sum().item()
# 对抗样本准确率
adv_images = fgsm_attack(model, loss_fn, images, labels, epsilon)
adv_outputs = model(adv_images)
adv_correct += (adv_outputs.argmax(dim=1) == labels).sum().item()
total += len(labels)
print(f"干净样本准确率: {clean_correct/total:.2%}")
print(f"对抗样本准确率: {adv_correct/total:.2%}")
print(f"鲁棒性下降: {(clean_correct - adv_correct)/total:.2%}")
六、鲁棒性提升方法
| 方法 | 原理 | 效果 | 代价 |
|---|---|---|---|
| 对抗训练 | 在训练中混入对抗样本 | 最有效 | 训练时间×2~3 |
| 数据增强 | 各种扰动的训练数据 | 提升自然鲁棒性 | 较小 |
| 输入预处理 | 推理前去噪/规范化 | 对已知扰动有效 | 增加延迟 |
| 集成模型 | 多模型投票 | 抗对抗攻击 | 部署成本大 |
| 随机化推理 | 推理时加随机噪声 | 对梯度类攻击有效 | 结果有噪声 |
七、考试重点总结
7.1 核心概念
| 概念 | 关键点 |
|---|---|
| 鲁棒性 | 面对扰动/攻击维持稳定性能的能力 |
| 自然扰动 vs 对抗攻击 | 前者自然发生,后者人为刻意构造 |
| FGSM | 沿梯度符号方向微扰,欺骗模型 |
| 对抗训练 | 最有效的鲁棒性提升方法,训练中加入对抗样本 |
| 分布偏移 | 训练与测试数据来自不同分布,导致性能下降 |
7.2 高频选择题
Q: 以下哪种方法最能提升模型对对抗攻击的鲁棒性?
A: 对抗训练(在训练集中混入对抗样本)✅
Q: FGSM攻击的核心是?
A: 沿损失函数对输入的梯度符号方向添加微小扰动 ✅
Q: 鲁棒性测试中,"自然扰动"不包括?
A: FGSM生成的对抗样本(FGSM是人工构造的对抗攻击)✅
Q: 高准确率的模型一定有高鲁棒性吗?
A: 否。模型可能在干净数据上表现很好,但在扰动下大幅下降 ✅
八、思维导图
📌 备考贴士:鲁棒性测试重点记住四大类型(自然扰动/对抗攻击/分布偏移/边界条件)的区别,以及"对抗训练是最有效的提升方法"。FGSM 公式不需要手推,但要理解"沿梯度方向微扰"的核心思想。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐

所有评论(0)