【GAN 系列·第九篇】GAN 的评估指标:FID、IS 与 Precision-Recall,如何量化生成质量
【GAN 系列·第九篇】GAN 的评估指标:FID、IS 与 Precision-Recall,如何量化生成质量
作者:技术博主 | 更新时间:2026-05-24 | 阅读时长:约 22 分钟
系列:GAN 从入门到精通(共 12 篇)
环境:Python 3.12,PyTorch 2.x,NumPy,SciPy
标签:FIDInception ScoreGAN评估生成质量Precision-Recall感知质量多样性

🔥 本篇目标:训练好一个 GAN 后,怎么知道它好不好?不能只看生成的图像样例——需要定量指标。FID(Fréchet Inception Distance)是当前最广泛使用的 GAN 评估指标,几乎所有论文都用它。但 FID 有什么局限?IS(Inception Score)测量的又是什么?为什么说 FID 比 IS 更可靠?Precision 和 Recall 如何同时评估质量和多样性?本篇把这些指标的数学原理、实现细节、局限性全部讲清楚,让你读论文时不再被数字迷惑。
系列进度
| 篇次 | 主题 | 状态 |
|---|---|---|
| 第一篇~第八篇 | GAN基础→StyleGAN | ✅ |
| 第九篇(本篇) | GAN 的评估指标:FID、IS | — |
| 第十篇 | GAN vs VAE vs 扩散模型 | 即将发布 |
| 第十一篇 | 工业界应用:超分、合成、增强 | 即将发布 |
| 第十二篇 | 收官:GAN 的前沿与未来 | 即将发布 |
目录
- 一、为什么 GAN 的评估很难
- 二、Inception Score(IS):早期标准
- 三、FID:当前最主流的评估指标
- 四、FID 的实现细节与代码
- 五、Precision 和 Recall:质量与多样性的分离
- 六、其他补充指标
- 七、代码实战:完整评估 Pipeline
- 八、最佳实践:如何正确使用这些指标
一、为什么 GAN 的评估很难
import numpy as np
import torch
import scipy.linalg
import warnings
warnings.filterwarnings('ignore')
print("GAN 评估的根本挑战")
print()
print(" 监督学习的评估很简单:")
print(" 分类任务 → 准确率(预测是否与标签匹配)")
print(" 回归任务 → MSE(预测值与真实值的差距)")
print()
print(" GAN 评估为什么难?")
print()
challenges = [
("没有明确的真实标签",
"生成的图像没有'正确答案',无法算准确率",
"需要衡量生成分布 P_G 与真实分布 P_data 的接近程度"),
("像素级误差不适用",
"MSE/PSNR 对感知质量不敏感(模糊图像 MSE 低但感知差)",
"需要基于感知特征(Perceptual Features)的度量"),
("质量 vs 多样性的权衡",
"模式崩溃的 GAN 只生成少数几种图像,但每种都很逼真",
"需要同时衡量单张图像质量和分布多样性"),
("样本量依赖",
"评估需要大量生成样本(通常 50K+)才能稳定",
"少量样本的 FID 方差很大,不可靠"),
("Inception 模型的偏差",
"FID/IS 都依赖 InceptionV3,对 ImageNet 类别有偏好",
"用于非自然图像(医疗/卫星)时可靠性下降"),
]
for name, problem, solution in challenges:
print(f" ⚠️ [{name}]")
print(f" 问题:{problem}")
print(f" 应对:{solution}")
print()
print(" 现有评估指标的分类:")
print()
metric_categories = [
("基于统计距离", "FID、KID", "衡量生成分布与真实分布的距离"),
("基于分类器", "IS(Inception Score)","衡量图像的可识别性和多样性"),
("基于精确率/召回率","P&R、Density&Coverage","分离质量和多样性"),
("基于感知相似度", "LPIPS、SSIM", "衡量单张图像的感知质量"),
("基于人工评估", "人工 MOS、Chatbot Arena", "最接近真实体验,但成本高"),
]
print(f" {'类别':^16} {'代表指标':^22} {'衡量内容':^30}")
print(" " + "─" * 72)
for cat, metrics, what in metric_categories:
print(f" {cat:^16} {metrics:^22} {what:^30}")
二、Inception Score(IS):早期标准
print("\nInception Score(IS):第一个被广泛使用的 GAN 评估指标")
print()
print(" IS 的提出:Salimans et al., 2016(OpenAI)")
print()
print(" IS 的两个核心要求:")
print(" ① 每张生成图像应该看起来像某个具体类别(高质量)")
print(" → Inception 对生成图像的预测分布应该是尖锐的(低熵)")
print(" ② 整体生成的图像应该包含多种类别(高多样性)")
print(" → 对所有生成图像的边际预测分布应该均匀(高熵)")
print()
print(" 数学公式:")
print(" IS(G) = exp(E_{x~P_G}[KL(p(y|x) || p(y))])")
print()
print(" 其中:")
print(" p(y|x):InceptionV3 对生成图像 x 的条件类别预测")
print(" p(y) = E_{x~P_G}[p(y|x)]:所有生成图像的边际预测分布")
print()
print(" 展开 KL 散度:")
print(" KL(p(y|x) || p(y)) = E_y[log p(y|x) - log p(y)]")
print(" = H(p(y)) - H(p(y|x))")
print()
print(" IS ∝ exp(H(边际分布) - H(条件分布))")
print(" = exp(多样性高 - 单张清晰度高)")
print()
import numpy as np
def compute_is(p_yx: np.ndarray, splits: int = 10) -> tuple:
"""
计算 Inception Score
p_yx: (N, num_classes) 每张图像的 Inception 预测概率
splits: 将 N 个样本分成几份计算(减小方差)
returns: (mean_IS, std_IS)
"""
N = len(p_yx)
split_scores = []
for k in range(splits):
# 第 k 份的样本
part = p_yx[k * (N // splits): (k+1) * (N // splits)]
# 边际分布 p(y) = 均值
p_y = part.mean(axis=0, keepdims=True) # (1, num_classes)
# KL 散度:每张图像的 KL(p(y|x) || p(y))
kl = part * (np.log(part + 1e-10) - np.log(p_y + 1e-10))
kl = kl.sum(axis=1) # 对类别求和 → (N//splits,)
# 每份的 IS
split_scores.append(np.exp(kl.mean()))
return float(np.mean(split_scores)), float(np.std(split_scores))
# 数值演示
np.random.seed(42)
num_classes = 1000 # ImageNet 的类别数
N_samples = 5000
# 场景1:完美的生成器(每张图像都是某一类,且各类均匀)
# 每张图像都是 one-hot(完全清晰)
perfect_pyx = np.zeros((N_samples, num_classes))
for i in range(N_samples):
perfect_pyx[i, i % num_classes] = 1.0
# 场景2:模式崩溃(所有图像都是同一类)
collapsed_pyx = np.zeros((N_samples, num_classes))
collapsed_pyx[:, 0] = 1.0 # 全部预测为类别 0
# 场景3:模糊图像(所有预测都是均匀分布)
blurry_pyx = np.ones((N_samples, num_classes)) / num_classes
# 场景4:真实 GAN(混合,加噪声)
realistic_pyx = np.random.dirichlet(
np.ones(num_classes) * 0.1, size=N_samples
)
print(" IS 数值演示(num_classes=1000):")
print()
print(f" {'场景':^28} {'IS 均值':^12} {'IS 标准差':^14} {'理论上界':^12}")
print(" " + "─" * 70)
# 理论上界:exp(log(num_classes)) = num_classes
IS_max = num_classes
for name, pyx in [
("完美生成器(one-hot)", perfect_pyx),
("模式崩溃(只有1类)", collapsed_pyx),
("极度模糊(均匀分布)", blurry_pyx),
("真实 GAN(混合噪声)", realistic_pyx),
]:
mean_is, std_is = compute_is(pyx, splits=10)
print(f" {name:^28} {mean_is:^12.2f} {std_is:^14.4f} {IS_max:^12}")
print()
print(" IS 的问题(重要!):")
is_problems = [
("不比较真实数据",
"IS 只看生成图像,不和真实数据对比",
"→ 生成的假图像只要被 Inception 识别就能得高分"),
("依赖 ImageNet 类别",
"Inception 在 ImageNet 上训练,对其他域有偏差",
"→ 人脸生成的 IS 通常很低(不是 ImageNet 类别)"),
("模式内多样性不感知",
"生成的 1000 只猫(各种颜色)和 1 只猫(同样的猫)的 IS 可能相同",
"→ 不区分'类间多样性'和'类内多样性'"),
("高 IS ≠ 高质量",
"对抗攻击的图像对 Inception 分类清晰,但人眼看来是噪声",
"→ 可以'刷' IS 分数"),
]
for name, problem, result in is_problems:
print(f" ⚠️ [{name}]")
print(f" 问题:{problem}")
print(f" 后果:{result}")
print()
三、FID:当前最主流的评估指标
print("\nFID(Fréchet Inception Distance):当前标准")
print()
print(" FID 的提出:Heusel et al., 2017(TU Graz)")
print()
print(" 核心改进:同时考虑真实数据和生成数据的分布!")
print()
print(" 步骤:")
print(" ① 用 InceptionV3 提取真实图像的特征(pool3 层,2048维)")
print(" ② 用同一网络提取生成图像的特征")
print(" ③ 假设两组特征都服从多元高斯分布")
print(" ④ 计算两个高斯分布之间的 Fréchet 距离(= W₂ 距离)")
print()
print(" 数学公式:")
print(" FID = ||μ_r - μ_g||² + Tr(Σ_r + Σ_g - 2(Σ_r Σ_g)^{1/2})")
print()
print(" 其中:")
print(" μ_r, Σ_r:真实数据特征的均值和协方差")
print(" μ_g, Σ_g:生成数据特征的均值和协方差")
print(" Tr(...):矩阵的迹(对角线元素之和)")
print(" (Σ_r Σ_g)^{1/2}:矩阵平方根(用于 Fréchet 距离)")
print()
print(" FID 越小 = 两个分布越接近 = 生成质量越好")
print(" FID = 0 意味着 P_G 与 P_data 完全相同(不可能达到)")
print()
# FID 与 W₂ 距离的关系
print(" ⭐ FID 的统计意义:")
print(" Fréchet 距离 = W₂ 距离(Wasserstein-2 距离)的特殊情形")
print(" (在多元高斯假设下)")
print()
print(" 与 IS 的关键区别:")
print(" IS 只看生成图像(P_G 内部)")
print(" FID 同时看真实图像(P_data)和生成图像(P_G)")
print(" → FID 直接衡量两个分布的距离,更有意义")
print()
# 典型 FID 数值参考
print(" 各模型在 CIFAR-10 和 FFHQ 1024×1024 上的 FID:")
print()
fid_reference = [
("模型", "CIFAR-10 FID", "FFHQ 1024 FID"),
("─────", "─────────────", "──────────────"),
("真实数据(上界)", "~0(完美)", "~0(完美)"),
("DCGAN(2015)", "~37", "N/A"),
("WGAN-GP(2017)", "~29", "N/A"),
("ProGAN(2018)", "~8", "~8"),
("BigGAN(2018)", "~14*", "N/A"),
("StyleGAN(2019)", "N/A", "4.40"),
("StyleGAN v2(2020)", "2.32", "2.84"),
("StyleGAN v3(2021)", "N/A", "2.79"),
("扩散模型 DDPM(2020)", "3.17", "N/A"),
("扩散模型 ADM(2021)", "1.14", "N/A"),
]
print(f" {'模型':^28} {'CIFAR-10 FID':^16} {'FFHQ 1024 FID':^16}")
for row in fid_reference:
print(f" {row[0]:^28} {row[1]:^16} {row[2]:^16}")
print()
print(" 注意:FID 值越低越好,跨数据集不可直接比较")
四、FID 的实现细节与代码
import numpy as np
import scipy.linalg
import torch
print("\nFID 的完整实现")
print()
def compute_fid(mu_real: np.ndarray, sigma_real: np.ndarray,
mu_fake: np.ndarray, sigma_fake: np.ndarray,
eps: float = 1e-6) -> float:
"""
计算 FID 分数
mu_real, mu_fake: (d,) 特征均值向量
sigma_real, sigma_fake: (d, d) 特征协方差矩阵
"""
mu_diff = mu_real - mu_fake
mu_sq = float(mu_diff @ mu_diff) # ||μ_r - μ_g||²
# 计算矩阵乘积的平方根 (Σ_r · Σ_g)^{1/2}
# 使用数值稳定的 scipy 实现
cov_prod = sigma_real @ sigma_fake
# 矩阵平方根(可能有数值误差,需要处理复数)
sqrt_cov, _ = scipy.linalg.sqrtm(cov_prod, disp=False)
# 处理数值误差:若结果含极小虚部,取实部
if np.iscomplexobj(sqrt_cov):
if not np.allclose(np.imag(sqrt_cov), 0, atol=1e-3):
raise ValueError("Matrix square root has significant imaginary part")
sqrt_cov = np.real(sqrt_cov)
# 正则化(防止奇异矩阵)
offset = np.eye(sigma_real.shape[0]) * eps
if np.isnan(sqrt_cov).any():
sqrt_cov, _ = scipy.linalg.sqrtm(
(sigma_real + offset) @ (sigma_fake + offset), disp=False
)
sqrt_cov = np.real(sqrt_cov)
# FID = ||μ_r - μ_g||² + Tr(Σ_r + Σ_g - 2·sqrt(Σ_r·Σ_g))
trace_term = np.trace(sigma_real + sigma_fake - 2 * sqrt_cov)
return float(mu_sq + trace_term)
def extract_features(images: torch.Tensor,
inception_model,
batch_size: int = 32) -> np.ndarray:
"""
提取 InceptionV3 的 pool3 特征(2048 维)
images: (N, 3, 299, 299) 图像张量,值域 [-1, 1] 或 [0, 1]
"""
inception_model.eval()
features = []
with torch.no_grad():
for i in range(0, len(images), batch_size):
batch = images[i:i+batch_size]
# InceptionV3 需要 [-1, 1] 范围,resize 到 299×299
if batch.shape[-1] != 299:
batch = torch.nn.functional.interpolate(
batch, size=299, mode='bilinear', align_corners=False
)
feat = inception_model(batch) # 通常是 (B, 2048)
features.append(feat.cpu().numpy())
return np.concatenate(features, axis=0)
def compute_stats(features: np.ndarray) -> tuple:
"""计算特征的均值和协方差"""
mu = np.mean(features, axis=0)
sigma = np.cov(features, rowvar=False)
return mu, sigma
# 数值验证 FID 计算
print(" FID 计算验证:")
print()
np.random.seed(42)
d = 128 # 特征维度(简化,真实是 2048)
# 情况1:完全相同的分布 → FID = 0
mu1 = np.zeros(d)
sigma1 = np.eye(d)
fid_same = compute_fid(mu1, sigma1, mu1, sigma1)
print(f" 相同分布(FID 应为 0):{fid_same:.6f} ✓")
# 情况2:均值不同
mu2 = np.ones(d) * 0.1
fid_mu = compute_fid(mu1, sigma1, mu2, sigma1)
print(f" 均值偏移 0.1(FID 应 > 0):{fid_mu:.4f}")
# 情况3:协方差不同
sigma2 = np.eye(d) * 2
fid_cov = compute_fid(mu1, sigma1, mu1, sigma2)
print(f" 方差扩大 2×(FID 应 > 0):{fid_cov:.4f}")
# 情况4:均值和协方差都不同
fid_both = compute_fid(mu1, sigma1, mu2, sigma2)
print(f" 均值和方差都不同:{fid_both:.4f}")
print()
# FID 的关键影响因素
print(" FID 的关键影响因素:")
print()
influence_factors = [
("样本数量 N",
"N 越大,FID 估计越稳定",
"建议 N ≥ 50,000(论文通常用 50K)",
"N=1000 时方差可达 ±5,N=50K 时 ±0.1"),
("InceptionV3 版本",
"不同实现的 InceptionV3 输出维度不同",
"必须用完全相同的 Inception 实现才可比较",
"常见实现:pytorch-fid、clean-fid"),
("图像预处理",
"像素范围、resize 方法影响特征",
"必须与基准论文使用相同的预处理",
"clean-fid 提出了标准化的预处理流程"),
("真实数据集",
"用不同的真实数据集计算的 FID 不可直接比较",
"必须说明用哪个数据集作为真实分布",
"CIFAR-10 FID ≠ FFHQ FID,不可跨比"),
]
for name, what, how, note in influence_factors:
print(f" ⭐ [{name}]")
print(f" 说明:{what}")
print(f" 建议:{how}")
print(f" 注意:{note}")
print()
# 样本数量对 FID 稳定性的影响
print(" 样本数量对 FID 估计的影响(模拟):")
print()
def simulate_fid_variance(n_samples: int, d: int = 64,
n_trials: int = 20) -> tuple:
"""模拟不同样本量时 FID 的方差"""
np.random.seed(0)
mu_true = np.zeros(d)
sigma_true = np.eye(d)
# 生成分布有轻微偏移
mu_gen = np.ones(d) * 0.05
sigma_gen = np.eye(d) * 1.05
fids = []
for _ in range(n_trials):
# 从两个分布采样
real_feat = np.random.multivariate_normal(mu_true, sigma_true, n_samples)
fake_feat = np.random.multivariate_normal(mu_gen, sigma_gen, n_samples)
mu_r, sig_r = compute_stats(real_feat)
mu_g, sig_g = compute_stats(fake_feat)
fids.append(compute_fid(mu_r, sig_r, mu_g, sig_g))
return float(np.mean(fids)), float(np.std(fids))
print(f" {'样本数 N':^12} {'FID 均值':^12} {'FID 标准差':^14} {'稳定性':^12}")
print(" " + "─" * 54)
for N in [100, 500, 1000, 5000, 10000, 50000]:
mean_fid, std_fid = simulate_fid_variance(N, d=64)
stable = "✓ 稳定" if std_fid < 1 else "⚠️ 不稳定"
print(f" {N:^12,} {mean_fid:^12.4f} {std_fid:^14.4f} {stable:^12}")
print()
print(" 结论:N < 5000 时 FID 方差大,建议至少 N = 10,000,最好 50,000")
五、Precision 和 Recall:质量与多样性的分离
print("\nPrecision & Recall:同时衡量质量和多样性")
print()
print(" FID 的局限:把质量和多样性混在一起")
print(" FID 低可以是:① 高质量+高多样性(理想)")
print(" ② 中等质量+中等多样性(也能低)")
print()
print(" 需要一个能分离'质量'和'多样性'的指标!")
print()
print(" Kynkäänniemi et al., 2019 的 Precision & Recall:")
print()
print(" Precision:生成图像中有多少是高质量的?")
print(" = P(生成图像 ∈ 真实数据的支撑)")
print(" 高 Precision = 大多数生成图像都是真实可信的")
print()
print(" Recall:真实数据的多样性有多少被覆盖?")
print(" = P(真实图像 ∈ 生成数据的支撑)")
print(" 高 Recall = 真实数据的大部分模式都能被生成")
print()
print(" 极端情况:")
print()
extreme_cases = [
("完美生成", 1.0, 1.0, "P_G = P_data,完美匹配"),
("模式崩溃", 1.0, 0.1, "只生成一种模式,但那种很真实"),
("模糊生成", 0.1, 1.0, "能生成所有模式,但都很模糊"),
("随机噪声", 0.0, 0.0, "既不质量好,也不覆盖真实分布"),
]
print(f" {'场景':^16} {'Precision':^12} {'Recall':^12} {'说明':^28}")
print(" " + "─" * 72)
for name, prec, rec, desc in extreme_cases:
print(f" {name:^16} {prec:^12.1f} {rec:^12.1f} {desc:^28}")
print()
# 实现:基于 k-NN 的 Precision 和 Recall
import numpy as np
from scipy.spatial.distance import cdist
def compute_precision_recall(
real_features: np.ndarray,
fake_features: np.ndarray,
k: int = 3,
batch_size: int = 1000
) -> tuple:
"""
计算基于 k-NN 流形估计的 Precision 和 Recall
实现细节:
① 对真实特征,用 k-NN 估计流形(每个点的 k-NN 半径)
② Precision:有多少生成点落在真实流形内?
③ Recall:有多少真实点落在生成流形内?
"""
def knn_radii(features, k):
"""计算每个点到第 k 个最近邻的距离"""
n = len(features)
radii = np.zeros(n)
for i in range(0, n, batch_size):
batch = features[i:i+batch_size]
# 计算批次与所有点的距离
dists = cdist(batch, features, metric='euclidean')
# 排除自身(距离为 0 的那个)
dists.sort(axis=1)
radii[i:i+batch_size] = dists[:, k] # 第 k 个最近邻距离
return radii
# 计算真实和生成特征的 k-NN 半径
real_radii = knn_radii(real_features, k)
fake_radii = knn_radii(fake_features, k)
# Precision:生成点是否在某个真实点的 k-NN 球内
# 即:生成点 f 是否满足 ||f - r|| ≤ radius_r 对某个真实点 r
n_fake = len(fake_features)
n_real = len(real_features)
# 批量计算(避免内存溢出)
precision_count = 0
for i in range(0, n_fake, batch_size):
batch_fake = fake_features[i:i+batch_size]
# 计算这批假特征与所有真实特征的距离
dists = cdist(batch_fake, real_features) # (B, n_real)
# 对每个假样本,检查是否有真实样本在其半径内
in_real_ball = (dists <= real_radii[np.newaxis, :]).any(axis=1)
precision_count += in_real_ball.sum()
precision = precision_count / n_fake
# Recall:真实点是否在某个生成点的 k-NN 球内
recall_count = 0
for i in range(0, n_real, batch_size):
batch_real = real_features[i:i+batch_size]
dists = cdist(batch_real, fake_features)
in_fake_ball = (dists <= fake_radii[np.newaxis, :]).any(axis=1)
recall_count += in_fake_ball.sum()
recall = recall_count / n_real
return float(precision), float(recall)
# 数值演示
np.random.seed(42)
d = 32 # 低维演示
# 真实数据:两个高斯混合
n_real = 500
real_feat = np.vstack([
np.random.randn(n_real//2, d) + np.ones(d) * 2,
np.random.randn(n_real//2, d) - np.ones(d) * 2,
])
print(" Precision & Recall 数值演示:")
print(f" 真实数据:双峰高斯,n={n_real},d={d}")
print()
print(f" {'场景':^28} {'Precision':^12} {'Recall':^12} {'解读':^22}")
print(" " + "─" * 78)
# 各种场景的生成分布
n_fake = 500
scenarios_pr = [
("完美匹配",
np.vstack([
np.random.randn(n_fake//2, d) + np.ones(d)*2,
np.random.randn(n_fake//2, d) - np.ones(d)*2,
])),
("模式崩溃(只有一个峰)",
np.random.randn(n_fake, d) + np.ones(d)*2),
("低质量(均匀噪声)",
np.random.randn(n_fake, d) * 5),
("覆盖但质量差(分散)",
np.vstack([
np.random.randn(n_fake//2, d)*3 + np.ones(d)*2,
np.random.randn(n_fake//2, d)*3 - np.ones(d)*2,
])),
]
for name, fake_feat in scenarios_pr:
prec, rec = compute_precision_recall(real_feat, fake_feat, k=3)
if prec > 0.8 and rec > 0.8:
interp = "质量好+多样性好 ✓"
elif prec > 0.7 and rec < 0.5:
interp = "质量好但多样性差 ⚠️"
elif prec < 0.5 and rec > 0.7:
interp = "覆盖广但质量差 ⚠️"
else:
interp = "质量和多样性都差 ✗"
print(f" {name:^28} {prec:^12.4f} {rec:^12.4f} {interp:^22}")
print()
print(" Precision-Recall 的优势:")
pr_advantages = [
"能区分'高质量但模式崩溃'和'多样但模糊'这两种截然不同的失败模式",
"帮助诊断训练问题:Precision 低 → 改进判别器;Recall 低 → 减少模式崩溃",
"与信息检索领域的 P&R 语义一致,直觉上易理解",
"不依赖高斯假设(FID 依赖),对非高斯分布也有效",
]
for adv in pr_advantages:
print(f" ✅ {adv}")
六、其他补充指标
print("\n其他补充指标")
print()
other_metrics = [
{
"name": "KID(Kernel Inception Distance)",
"paper": "Bińkowski et al., 2018",
"what": "类似 FID,但用核最大均值差异(MMD)替代 Fréchet 距离",
"pros": "无偏估计(FID 在小样本下有偏),N=1000 就能稳定",
"cons": "计算比 FID 慢,不如 FID 广泛使用",
"when": "样本数量少(<5000)时比 FID 更可靠",
"formula":"KID = MMD²(f(real), f(fake))",
},
{
"name": "LPIPS(Learned Perceptual Image Patch Similarity)",
"paper": "Zhang et al., 2018",
"what": "用 VGG/AlexNet 特征计算感知相似度",
"pros": "衡量单张图像对的感知相似度,适合配对图像翻译",
"cons": "需要配对图像(真实+生成对),不能评估无配对生成",
"when": "Pix2Pix、超分辨率等配对任务的评估",
"formula":"LPIPS(x,y) = ||φ(x) - φ(y)||²(特征空间距离)",
},
{
"name": "SSIM(Structural Similarity Index)",
"paper": "Wang et al., 2004",
"what": "基于亮度、对比度、结构的感知相似度",
"pros": "快速,不需要深度网络",
"cons": "与人类感知相关性弱(不如 LPIPS)",
"when": "超分辨率、图像压缩的基准测试",
"formula":"SSIM ∈ [-1, 1],越高越好",
},
{
"name": "Density & Coverage",
"paper": "Naeem et al., 2020",
"what": "改进版 P&R,更鲁棒地处理离群点",
"pros": "对离群假样本更鲁棒,更好地反映分布覆盖",
"cons": "较新,不如 P&R 普及",
"when": "P&R 的替代,当有离群样本时更可靠",
"formula":"Density = P&R 的平滑版本",
},
]
for m in other_metrics:
print(f" ⭐ [{m['name']}]({m['paper']})")
print(f" 原理:{m['what']}")
print(f" 优势:{m['pros']}")
print(f" 局限:{m['cons']}")
print(f" 适用:{m['when']}")
print(f" 公式:{m['formula']}")
print()
# 指标综合对比
print(" 指标综合对比:")
print()
print(f" {'指标':^12} {'质量':^8} {'多样性':^8} {'无需配对':^10} {'标准化':^10} {'推荐度':^10}")
print(" " + "─" * 62)
metrics_summary = [
("IS", "✓", "部分", "✓", "中", "⭐⭐⭐"),
("FID", "✓", "✓", "✓", "高", "⭐⭐⭐⭐⭐"),
("KID", "✓", "✓", "✓", "中", "⭐⭐⭐⭐"),
("P&R", "分离","分离", "✓", "中", "⭐⭐⭐⭐"),
("LPIPS", "✓", "✗", "✗", "高", "⭐⭐⭐(配对)"),
("SSIM", "部分","✗", "✗", "高", "⭐⭐(配对)"),
]
for row in metrics_summary:
name, qual, div, unpaired, std, rec = row
print(f" {name:^12} {qual:^8} {div:^8} {unpaired:^10} {std:^10} {rec:^10}")
七、代码实战:完整评估 Pipeline
import numpy as np
import torch
import torch.nn as nn
import scipy.linalg
print("\n完整评估 Pipeline")
print()
class FIDCalculator:
"""
FID 计算器(不依赖真实 InceptionV3,用随机特征演示)
真实使用时,替换 extract_features 为 InceptionV3 提取
"""
def __init__(self, feature_dim: int = 2048):
self.feature_dim = feature_dim
self.real_features = None
self.real_mu = None
self.real_sigma = None
def compute_real_stats(self, real_images_or_features: np.ndarray):
"""
计算并缓存真实数据的统计量(只需算一次!)
real_images_or_features: (N, feature_dim) 特征或 (N, C, H, W) 图像
"""
if real_images_or_features.ndim == 4:
# 图像输入:需要先提取特征
features = self._extract_inception_features(
real_images_or_features
)
else:
features = real_images_or_features
self.real_features = features
self.real_mu = np.mean(features, axis=0)
self.real_sigma = np.cov(features, rowvar=False)
print(f" 真实数据统计量已计算({len(features)} 个样本,{features.shape[1]}维)")
def compute_fid(self, fake_features: np.ndarray) -> float:
"""计算 FID(真实统计量已预计算)"""
assert self.real_mu is not None, "先调用 compute_real_stats()"
mu_fake = np.mean(fake_features, axis=0)
sigma_fake = np.cov(fake_features, rowvar=False)
return compute_fid(self.real_mu, self.real_sigma,
mu_fake, sigma_fake)
def _extract_inception_features(self, images: np.ndarray) -> np.ndarray:
"""(占位)真实使用时替换为 InceptionV3 特征提取"""
# 真实代码:
# from torchvision.models import inception_v3
# model = inception_v3(pretrained=True, transform_input=False)
# ... 提取 pool3 层特征 ...
return np.random.randn(len(images), self.feature_dim) # 占位
class GANEvaluator:
"""GAN 综合评估器"""
def __init__(self, n_samples: int = 5000,
feature_dim: int = 128): # 演示用小维度
self.n_samples = n_samples
self.feature_dim = feature_dim
self.fid_calc = FIDCalculator(feature_dim)
def setup_real_stats(self, real_features: np.ndarray):
"""设置真实数据统计量(只做一次)"""
self.fid_calc.compute_real_stats(real_features)
def evaluate(self, generator, z_dim: int = 100,
compute_pr: bool = True) -> dict:
"""综合评估生成器"""
results = {}
# 生成特征(真实项目用 InceptionV3 提取)
print(f" 生成 {self.n_samples} 个样本并提取特征...")
with torch.no_grad():
z = torch.randn(self.n_samples, z_dim)
fake_feat = np.random.randn(
self.n_samples, self.feature_dim
) # 演示:随机特征
# FID
results['FID'] = self.fid_calc.compute_fid(fake_feat)
# Precision & Recall
if compute_pr and self.fid_calc.real_features is not None:
prec, rec = compute_precision_recall(
self.fid_calc.real_features[:1000],
fake_feat[:1000],
k=3
)
results['Precision'] = prec
results['Recall'] = rec
return results
def report(self, results: dict, model_name: str = "Model"):
"""打印评估报告"""
print(f"\n ═══ {model_name} 评估报告 ═══")
print()
if 'FID' in results:
fid = results['FID']
grade = "优秀" if fid < 5 else "良好" if fid < 15 else "一般" if fid < 30 else "差"
print(f" FID:{fid:.4f} [{grade}]")
if 'Precision' in results:
print(f" Precision:{results['Precision']:.4f} (生成质量)")
if 'Recall' in results:
print(f" Recall:{results['Recall']:.4f} (多样性覆盖)")
print()
# 使用演示
np.random.seed(42)
d = 128
# 初始化评估器
evaluator = GANEvaluator(n_samples=2000, feature_dim=d)
# 设置真实数据统计量(真实项目:从验证集提取 InceptionV3 特征)
real_feat = np.random.randn(2000, d) * 1.0 # 真实数据特征(模拟)
evaluator.setup_real_stats(real_feat)
# 模拟三个不同质量的生成器
models = {
"差的 GAN(高 FID)": real_feat + np.random.randn(2000, d) * 3.0,
"中等 GAN": real_feat + np.random.randn(2000, d) * 1.0,
"好的 GAN(低 FID)": real_feat + np.random.randn(2000, d) * 0.3,
"模式崩溃的 GAN": np.zeros((2000, d)) + np.random.randn(2000, d) * 0.1,
}
print(" 综合评估演示:")
for model_name, fake_features in models.items():
fid = evaluator.fid_calc.compute_fid(fake_features)
prec, rec = compute_precision_recall(
real_feat[:500], fake_features[:500], k=3
)
print(f" [{model_name}]")
print(f" FID={fid:.4f},Precision={prec:.4f},Recall={rec:.4f}")
print()
八、最佳实践:如何正确使用这些指标
print("\n最佳实践:正确使用 GAN 评估指标")
print()
best_practices = [
{
"topic": "样本数量",
"dos": [
"使用至少 10,000 个生成样本计算 FID(最好 50,000)",
"固定随机种子,报告结果的方差",
"多次计算取均值(FID 估计有方差)",
],
"donts": [
"不要用 1,000 个样本报告 FID(方差太大)",
"不要只报告一次计算的结果",
],
},
{
"topic": "预处理标准化",
"dos": [
"使用 clean-fid 库(标准化了 resize 和范围)",
"明确说明使用的 InceptionV3 实现版本",
"与对比方法使用完全相同的预处理",
],
"donts": [
"不要自己随意实现 FID(容易出现细节错误)",
"不要跨不同实现的 FID 比较(如 pytorch-fid vs tensorflow-fid)",
],
},
{
"topic": "指标组合",
"dos": [
"同时报告 FID 和 Precision/Recall",
"如果任务是配对翻译,补充 LPIPS",
"大规模实验时报告 IS 作为参考",
],
"donts": [
"不要只报告 IS(已知有严重局限)",
"不要把 FID 作为唯一评估标准",
],
},
{
"topic": "结果解读",
"dos": [
"FID 差异 < 1.0 时可能不显著(考虑置信区间)",
"同时展示样本图像(定量+定性结合)",
"报告 FID 时说明样本数 N",
],
"donts": [
"不要跨数据集比较 FID(CIFAR FID ≠ FFHQ FID)",
"不要只看 FID 忽略生成图像的视觉质量",
],
},
]
for bp in best_practices:
print(f" ── [{bp['topic']}] ──")
print(f" ✅ 应该:")
for do in bp['dos']:
print(f" · {do}")
print(f" ❌ 不要:")
for dont in bp['donts']:
print(f" · {dont}")
print()
# 标准评估 Pipeline 推荐
print(" 推荐的标准评估 Pipeline:")
print()
pipeline_code = '''
# 推荐:使用 clean-fid 库(最标准化的实现)
# pip install clean-fid
from cleanfid import fid
# 计算 FID(自动处理预处理)
fid_score = fid.compute_fid(
fdir1 = "path/to/real/images", # 真实图像目录
fdir2 = "path/to/generated/imgs", # 生成图像目录
mode = "clean", # 标准化预处理
num_workers = 4,
batch_size = 32,
)
print(f"FID: {fid_score:.4f}")
# 计算 KID(小样本时更稳定)
kid_score = fid.compute_kid(
fdir1 = "path/to/real/images",
fdir2 = "path/to/generated/imgs",
)
print(f"KID: {kid_score:.6f}")
# 生成样本的 Precision & Recall
from prdc import compute_prdc # pip install prdc
real_features = extract_features("path/to/real/images")
fake_features = extract_features("path/to/generated/imgs")
metrics = compute_prdc(
real_features = real_features,
fake_features = fake_features,
nearest_k = 5,
)
print(f"Precision: {metrics['precision']:.4f}")
print(f"Recall: {metrics['recall']:.4f}")
print(f"Density: {metrics['density']:.4f}")
print(f"Coverage: {metrics['coverage']:.4f}")
'''
print(pipeline_code)
# 最终总结
print("=" * 65)
print(" GAN 评估指标速查:")
print("=" * 65)
print()
summary = [
("FID", "最主流", "越低越好", "≥10K 样本,clean-fid"),
("IS", "仅参考", "越高越好", "不推荐作为主要指标"),
("KID", "补充", "越低越好", "小样本时更稳定"),
("Precision", "质量指标", "越高越好", "与 Recall 配合使用"),
("Recall", "多样性", "越高越好", "与 Precision 配合使用"),
("LPIPS", "配对任务", "越低越好", "Pix2Pix、超分辨率"),
]
print(f" {'指标':^12} {'使用频率':^10} {'方向':^10} {'注意事项':^22}")
print(" " + "─" * 58)
for metric, freq, direction, note in summary:
print(f" {metric:^12} {freq:^10} {direction:^10} {note:^22}")
总结
本篇三个核心收获:
① FID 是最可靠的主要指标,但需要正确使用
FID = ∥ μ r − μ g ∥ 2 + Tr ( Σ r + Σ g − 2 ( Σ r Σ g ) 1 / 2 ) \text{FID} = \|\mu_r - \mu_g\|^2 + \text{Tr}(\Sigma_r + \Sigma_g - 2(\Sigma_r\Sigma_g)^{1/2}) FID=∥μr−μg∥2+Tr(Σr+Σg−2(ΣrΣg)1/2)
至少 10K 样本,使用 clean-fid,同一数据集比较,报告置信区间。
② IS 有根本性局限,应只作参考
IS 不与真实数据对比,对 ImageNet 以外的域有严重偏差,高 IS ≠ 高生成质量。
③ Precision & Recall 能分离质量和多样性
| 指标 | 衡量 | 低的含义 |
|---|---|---|
| Precision | 生成质量 | 大量生成图像是虚假/模糊的 |
| Recall | 多样性覆盖 | 真实数据的某些模式没有被生成(模式崩溃) |
| FID | 两者综合 | 质量差或多样性差,无法区分 |
下一篇预告:GAN vs VAE vs 扩散模型——三大生成范式的完整对比。从信息论视角(第五篇信息论系列)看三者的目标函数,从实验数据看各自的最佳场景,帮你在实际项目中选择正确的生成模型。
💬 你在实际项目中用过 FID 来评估生成质量吗?有没有遇到 FID 低但生成图像不好看的情况? 欢迎评论区分享!
🙏 如果这篇帮到你,点赞 + 收藏,系列持续更新!
本文为原创技术分享。代码在 Python 3.12 + PyTorch + SciPy 下验证。最后更新:2026-05-24
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐

所有评论(0)