【GAN 系列·第九篇】GAN 的评估指标:FID、IS 与 Precision-Recall,如何量化生成质量

作者:技术博主 | 更新时间:2026-05-24 | 阅读时长:约 22 分钟
系列:GAN 从入门到精通(共 12 篇)
环境:Python 3.12,PyTorch 2.x,NumPy,SciPy
标签FID Inception Score GAN评估 生成质量 Precision-Recall 感知质量 多样性


在这里插入图片描述

🔥 本篇目标:训练好一个 GAN 后,怎么知道它好不好?不能只看生成的图像样例——需要定量指标。FID(Fréchet Inception Distance)是当前最广泛使用的 GAN 评估指标,几乎所有论文都用它。但 FID 有什么局限?IS(Inception Score)测量的又是什么?为什么说 FID 比 IS 更可靠?Precision 和 Recall 如何同时评估质量和多样性?本篇把这些指标的数学原理、实现细节、局限性全部讲清楚,让你读论文时不再被数字迷惑。


系列进度

篇次 主题 状态
第一篇~第八篇 GAN基础→StyleGAN
第九篇(本篇) GAN 的评估指标:FID、IS
第十篇 GAN vs VAE vs 扩散模型 即将发布
第十一篇 工业界应用:超分、合成、增强 即将发布
第十二篇 收官:GAN 的前沿与未来 即将发布

目录


一、为什么 GAN 的评估很难

import numpy as np
import torch
import scipy.linalg
import warnings
warnings.filterwarnings('ignore')

print("GAN 评估的根本挑战")
print()
print("  监督学习的评估很简单:")
print("  分类任务 → 准确率(预测是否与标签匹配)")
print("  回归任务 → MSE(预测值与真实值的差距)")
print()
print("  GAN 评估为什么难?")
print()

challenges = [
    ("没有明确的真实标签",
     "生成的图像没有'正确答案',无法算准确率",
     "需要衡量生成分布 P_G 与真实分布 P_data 的接近程度"),
    ("像素级误差不适用",
     "MSE/PSNR 对感知质量不敏感(模糊图像 MSE 低但感知差)",
     "需要基于感知特征(Perceptual Features)的度量"),
    ("质量 vs 多样性的权衡",
     "模式崩溃的 GAN 只生成少数几种图像,但每种都很逼真",
     "需要同时衡量单张图像质量和分布多样性"),
    ("样本量依赖",
     "评估需要大量生成样本(通常 50K+)才能稳定",
     "少量样本的 FID 方差很大,不可靠"),
    ("Inception 模型的偏差",
     "FID/IS 都依赖 InceptionV3,对 ImageNet 类别有偏好",
     "用于非自然图像(医疗/卫星)时可靠性下降"),
]

for name, problem, solution in challenges:
    print(f"  ⚠️  [{name}]")
    print(f"     问题:{problem}")
    print(f"     应对:{solution}")
    print()

print("  现有评估指标的分类:")
print()
metric_categories = [
    ("基于统计距离",   "FID、KID",          "衡量生成分布与真实分布的距离"),
    ("基于分类器",     "IS(Inception Score)","衡量图像的可识别性和多样性"),
    ("基于精确率/召回率","P&R、Density&Coverage","分离质量和多样性"),
    ("基于感知相似度", "LPIPS、SSIM",        "衡量单张图像的感知质量"),
    ("基于人工评估",   "人工 MOS、Chatbot Arena", "最接近真实体验,但成本高"),
]

print(f"  {'类别':^16} {'代表指标':^22} {'衡量内容':^30}")
print("  " + "─" * 72)
for cat, metrics, what in metric_categories:
    print(f"  {cat:^16} {metrics:^22} {what:^30}")

二、Inception Score(IS):早期标准

print("\nInception Score(IS):第一个被广泛使用的 GAN 评估指标")
print()
print("  IS 的提出:Salimans et al., 2016(OpenAI)")
print()
print("  IS 的两个核心要求:")
print("  ① 每张生成图像应该看起来像某个具体类别(高质量)")
print("     → Inception 对生成图像的预测分布应该是尖锐的(低熵)")
print("  ② 整体生成的图像应该包含多种类别(高多样性)")
print("     → 对所有生成图像的边际预测分布应该均匀(高熵)")
print()
print("  数学公式:")
print("  IS(G) = exp(E_{x~P_G}[KL(p(y|x) || p(y))])")
print()
print("  其中:")
print("  p(y|x):InceptionV3 对生成图像 x 的条件类别预测")
print("  p(y)   = E_{x~P_G}[p(y|x)]:所有生成图像的边际预测分布")
print()
print("  展开 KL 散度:")
print("  KL(p(y|x) || p(y)) = E_y[log p(y|x) - log p(y)]")
print("  = H(p(y)) - H(p(y|x))")
print()
print("  IS ∝ exp(H(边际分布) - H(条件分布))")
print("      = exp(多样性高 - 单张清晰度高)")
print()

import numpy as np

def compute_is(p_yx: np.ndarray, splits: int = 10) -> tuple:
    """
    计算 Inception Score
    p_yx: (N, num_classes) 每张图像的 Inception 预测概率
    splits: 将 N 个样本分成几份计算(减小方差)
    returns: (mean_IS, std_IS)
    """
    N = len(p_yx)
    split_scores = []

    for k in range(splits):
        # 第 k 份的样本
        part = p_yx[k * (N // splits): (k+1) * (N // splits)]

        # 边际分布 p(y) = 均值
        p_y  = part.mean(axis=0, keepdims=True)   # (1, num_classes)

        # KL 散度:每张图像的 KL(p(y|x) || p(y))
        kl   = part * (np.log(part + 1e-10) - np.log(p_y + 1e-10))
        kl   = kl.sum(axis=1)   # 对类别求和 → (N//splits,)

        # 每份的 IS
        split_scores.append(np.exp(kl.mean()))

    return float(np.mean(split_scores)), float(np.std(split_scores))


# 数值演示
np.random.seed(42)
num_classes = 1000   # ImageNet 的类别数
N_samples   = 5000

# 场景1:完美的生成器(每张图像都是某一类,且各类均匀)
# 每张图像都是 one-hot(完全清晰)
perfect_pyx = np.zeros((N_samples, num_classes))
for i in range(N_samples):
    perfect_pyx[i, i % num_classes] = 1.0

# 场景2:模式崩溃(所有图像都是同一类)
collapsed_pyx = np.zeros((N_samples, num_classes))
collapsed_pyx[:, 0] = 1.0   # 全部预测为类别 0

# 场景3:模糊图像(所有预测都是均匀分布)
blurry_pyx = np.ones((N_samples, num_classes)) / num_classes

# 场景4:真实 GAN(混合,加噪声)
realistic_pyx = np.random.dirichlet(
    np.ones(num_classes) * 0.1, size=N_samples
)

print("  IS 数值演示(num_classes=1000):")
print()
print(f"  {'场景':^28} {'IS 均值':^12} {'IS 标准差':^14} {'理论上界':^12}")
print("  " + "─" * 70)

# 理论上界:exp(log(num_classes)) = num_classes
IS_max = num_classes
for name, pyx in [
    ("完美生成器(one-hot)",     perfect_pyx),
    ("模式崩溃(只有1类)",       collapsed_pyx),
    ("极度模糊(均匀分布)",       blurry_pyx),
    ("真实 GAN(混合噪声)",      realistic_pyx),
]:
    mean_is, std_is = compute_is(pyx, splits=10)
    print(f"  {name:^28} {mean_is:^12.2f} {std_is:^14.4f} {IS_max:^12}")

print()
print("  IS 的问题(重要!):")
is_problems = [
    ("不比较真实数据",
     "IS 只看生成图像,不和真实数据对比",
     "→ 生成的假图像只要被 Inception 识别就能得高分"),
    ("依赖 ImageNet 类别",
     "Inception 在 ImageNet 上训练,对其他域有偏差",
     "→ 人脸生成的 IS 通常很低(不是 ImageNet 类别)"),
    ("模式内多样性不感知",
     "生成的 1000 只猫(各种颜色)和 1 只猫(同样的猫)的 IS 可能相同",
     "→ 不区分'类间多样性'和'类内多样性'"),
    ("高 IS ≠ 高质量",
     "对抗攻击的图像对 Inception 分类清晰,但人眼看来是噪声",
     "→ 可以'刷' IS 分数"),
]
for name, problem, result in is_problems:
    print(f"  ⚠️  [{name}]")
    print(f"     问题:{problem}")
    print(f"     后果:{result}")
    print()

三、FID:当前最主流的评估指标

print("\nFID(Fréchet Inception Distance):当前标准")
print()
print("  FID 的提出:Heusel et al., 2017(TU Graz)")
print()
print("  核心改进:同时考虑真实数据和生成数据的分布!")
print()
print("  步骤:")
print("  ① 用 InceptionV3 提取真实图像的特征(pool3 层,2048维)")
print("  ② 用同一网络提取生成图像的特征")
print("  ③ 假设两组特征都服从多元高斯分布")
print("  ④ 计算两个高斯分布之间的 Fréchet 距离(= W₂ 距离)")
print()
print("  数学公式:")
print("  FID = ||μ_r - μ_g||² + Tr(Σ_r + Σ_g - 2(Σ_r Σ_g)^{1/2})")
print()
print("  其中:")
print("  μ_r, Σ_r:真实数据特征的均值和协方差")
print("  μ_g, Σ_g:生成数据特征的均值和协方差")
print("  Tr(...):矩阵的迹(对角线元素之和)")
print("  (Σ_r Σ_g)^{1/2}:矩阵平方根(用于 Fréchet 距离)")
print()
print("  FID 越小 = 两个分布越接近 = 生成质量越好")
print("  FID = 0 意味着 P_G 与 P_data 完全相同(不可能达到)")
print()

# FID 与 W₂ 距离的关系
print("  ⭐ FID 的统计意义:")
print("  Fréchet 距离 = W₂ 距离(Wasserstein-2 距离)的特殊情形")
print("  (在多元高斯假设下)")
print()
print("  与 IS 的关键区别:")
print("  IS 只看生成图像(P_G 内部)")
print("  FID 同时看真实图像(P_data)和生成图像(P_G)")
print("  → FID 直接衡量两个分布的距离,更有意义")
print()

# 典型 FID 数值参考
print("  各模型在 CIFAR-10 和 FFHQ 1024×1024 上的 FID:")
print()
fid_reference = [
    ("模型",                      "CIFAR-10 FID", "FFHQ 1024 FID"),
    ("─────",                     "─────────────", "──────────────"),
    ("真实数据(上界)",           "~0(完美)",   "~0(完美)"),
    ("DCGAN(2015)",              "~37",          "N/A"),
    ("WGAN-GP(2017)",            "~29",          "N/A"),
    ("ProGAN(2018)",             "~8",           "~8"),
    ("BigGAN(2018)",             "~14*",         "N/A"),
    ("StyleGAN(2019)",           "N/A",          "4.40"),
    ("StyleGAN v2(2020)",        "2.32",         "2.84"),
    ("StyleGAN v3(2021)",        "N/A",          "2.79"),
    ("扩散模型 DDPM(2020)",      "3.17",         "N/A"),
    ("扩散模型 ADM(2021)",       "1.14",         "N/A"),
]

print(f"  {'模型':^28} {'CIFAR-10 FID':^16} {'FFHQ 1024 FID':^16}")
for row in fid_reference:
    print(f"  {row[0]:^28} {row[1]:^16} {row[2]:^16}")

print()
print("  注意:FID 值越低越好,跨数据集不可直接比较")

四、FID 的实现细节与代码

import numpy as np
import scipy.linalg
import torch

print("\nFID 的完整实现")
print()

def compute_fid(mu_real: np.ndarray, sigma_real: np.ndarray,
                mu_fake: np.ndarray, sigma_fake: np.ndarray,
                eps: float = 1e-6) -> float:
    """
    计算 FID 分数
    mu_real, mu_fake:       (d,) 特征均值向量
    sigma_real, sigma_fake: (d, d) 特征协方差矩阵
    """
    mu_diff  = mu_real - mu_fake
    mu_sq    = float(mu_diff @ mu_diff)   # ||μ_r - μ_g||²

    # 计算矩阵乘积的平方根 (Σ_r · Σ_g)^{1/2}
    # 使用数值稳定的 scipy 实现
    cov_prod = sigma_real @ sigma_fake

    # 矩阵平方根(可能有数值误差,需要处理复数)
    sqrt_cov, _ = scipy.linalg.sqrtm(cov_prod, disp=False)

    # 处理数值误差:若结果含极小虚部,取实部
    if np.iscomplexobj(sqrt_cov):
        if not np.allclose(np.imag(sqrt_cov), 0, atol=1e-3):
            raise ValueError("Matrix square root has significant imaginary part")
        sqrt_cov = np.real(sqrt_cov)

    # 正则化(防止奇异矩阵)
    offset = np.eye(sigma_real.shape[0]) * eps
    if np.isnan(sqrt_cov).any():
        sqrt_cov, _ = scipy.linalg.sqrtm(
            (sigma_real + offset) @ (sigma_fake + offset), disp=False
        )
        sqrt_cov = np.real(sqrt_cov)

    # FID = ||μ_r - μ_g||² + Tr(Σ_r + Σ_g - 2·sqrt(Σ_r·Σ_g))
    trace_term = np.trace(sigma_real + sigma_fake - 2 * sqrt_cov)

    return float(mu_sq + trace_term)


def extract_features(images: torch.Tensor,
                     inception_model,
                     batch_size: int = 32) -> np.ndarray:
    """
    提取 InceptionV3 的 pool3 特征(2048 维)
    images: (N, 3, 299, 299) 图像张量,值域 [-1, 1] 或 [0, 1]
    """
    inception_model.eval()
    features = []

    with torch.no_grad():
        for i in range(0, len(images), batch_size):
            batch = images[i:i+batch_size]
            # InceptionV3 需要 [-1, 1] 范围,resize 到 299×299
            if batch.shape[-1] != 299:
                batch = torch.nn.functional.interpolate(
                    batch, size=299, mode='bilinear', align_corners=False
                )
            feat = inception_model(batch)   # 通常是 (B, 2048)
            features.append(feat.cpu().numpy())

    return np.concatenate(features, axis=0)


def compute_stats(features: np.ndarray) -> tuple:
    """计算特征的均值和协方差"""
    mu    = np.mean(features, axis=0)
    sigma = np.cov(features, rowvar=False)
    return mu, sigma


# 数值验证 FID 计算
print("  FID 计算验证:")
print()
np.random.seed(42)

d = 128   # 特征维度(简化,真实是 2048)

# 情况1:完全相同的分布 → FID = 0
mu1    = np.zeros(d)
sigma1 = np.eye(d)
fid_same = compute_fid(mu1, sigma1, mu1, sigma1)
print(f"  相同分布(FID 应为 0):{fid_same:.6f} ✓")

# 情况2:均值不同
mu2     = np.ones(d) * 0.1
fid_mu  = compute_fid(mu1, sigma1, mu2, sigma1)
print(f"  均值偏移 0.1(FID 应 > 0):{fid_mu:.4f}")

# 情况3:协方差不同
sigma2  = np.eye(d) * 2
fid_cov = compute_fid(mu1, sigma1, mu1, sigma2)
print(f"  方差扩大 2×(FID 应 > 0):{fid_cov:.4f}")

# 情况4:均值和协方差都不同
fid_both = compute_fid(mu1, sigma1, mu2, sigma2)
print(f"  均值和方差都不同:{fid_both:.4f}")
print()

# FID 的关键影响因素
print("  FID 的关键影响因素:")
print()

influence_factors = [
    ("样本数量 N",
     "N 越大,FID 估计越稳定",
     "建议 N ≥ 50,000(论文通常用 50K)",
     "N=1000 时方差可达 ±5,N=50K 时 ±0.1"),
    ("InceptionV3 版本",
     "不同实现的 InceptionV3 输出维度不同",
     "必须用完全相同的 Inception 实现才可比较",
     "常见实现:pytorch-fid、clean-fid"),
    ("图像预处理",
     "像素范围、resize 方法影响特征",
     "必须与基准论文使用相同的预处理",
     "clean-fid 提出了标准化的预处理流程"),
    ("真实数据集",
     "用不同的真实数据集计算的 FID 不可直接比较",
     "必须说明用哪个数据集作为真实分布",
     "CIFAR-10 FID ≠ FFHQ FID,不可跨比"),
]

for name, what, how, note in influence_factors:
    print(f"  ⭐ [{name}]")
    print(f"     说明:{what}")
    print(f"     建议:{how}")
    print(f"     注意:{note}")
    print()

# 样本数量对 FID 稳定性的影响
print("  样本数量对 FID 估计的影响(模拟):")
print()

def simulate_fid_variance(n_samples: int, d: int = 64,
                           n_trials: int = 20) -> tuple:
    """模拟不同样本量时 FID 的方差"""
    np.random.seed(0)
    mu_true    = np.zeros(d)
    sigma_true = np.eye(d)
    # 生成分布有轻微偏移
    mu_gen     = np.ones(d) * 0.05
    sigma_gen  = np.eye(d) * 1.05

    fids = []
    for _ in range(n_trials):
        # 从两个分布采样
        real_feat = np.random.multivariate_normal(mu_true, sigma_true, n_samples)
        fake_feat = np.random.multivariate_normal(mu_gen, sigma_gen, n_samples)

        mu_r, sig_r = compute_stats(real_feat)
        mu_g, sig_g = compute_stats(fake_feat)
        fids.append(compute_fid(mu_r, sig_r, mu_g, sig_g))

    return float(np.mean(fids)), float(np.std(fids))

print(f"  {'样本数 N':^12} {'FID 均值':^12} {'FID 标准差':^14} {'稳定性':^12}")
print("  " + "─" * 54)
for N in [100, 500, 1000, 5000, 10000, 50000]:
    mean_fid, std_fid = simulate_fid_variance(N, d=64)
    stable = "✓ 稳定" if std_fid < 1 else "⚠️ 不稳定"
    print(f"  {N:^12,} {mean_fid:^12.4f} {std_fid:^14.4f} {stable:^12}")

print()
print("  结论:N < 5000 时 FID 方差大,建议至少 N = 10,000,最好 50,000")

五、Precision 和 Recall:质量与多样性的分离

print("\nPrecision & Recall:同时衡量质量和多样性")
print()
print("  FID 的局限:把质量和多样性混在一起")
print("  FID 低可以是:① 高质量+高多样性(理想)")
print("               ② 中等质量+中等多样性(也能低)")
print()
print("  需要一个能分离'质量'和'多样性'的指标!")
print()
print("  Kynkäänniemi et al., 2019 的 Precision & Recall:")
print()
print("  Precision:生成图像中有多少是高质量的?")
print("  = P(生成图像 ∈ 真实数据的支撑)")
print("  高 Precision = 大多数生成图像都是真实可信的")
print()
print("  Recall:真实数据的多样性有多少被覆盖?")
print("  = P(真实图像 ∈ 生成数据的支撑)")
print("  高 Recall = 真实数据的大部分模式都能被生成")
print()
print("  极端情况:")
print()
extreme_cases = [
    ("完美生成",   1.0, 1.0,  "P_G = P_data,完美匹配"),
    ("模式崩溃",   1.0, 0.1,  "只生成一种模式,但那种很真实"),
    ("模糊生成",   0.1, 1.0,  "能生成所有模式,但都很模糊"),
    ("随机噪声",   0.0, 0.0,  "既不质量好,也不覆盖真实分布"),
]

print(f"  {'场景':^16} {'Precision':^12} {'Recall':^12} {'说明':^28}")
print("  " + "─" * 72)
for name, prec, rec, desc in extreme_cases:
    print(f"  {name:^16} {prec:^12.1f} {rec:^12.1f} {desc:^28}")

print()

# 实现:基于 k-NN 的 Precision 和 Recall
import numpy as np
from scipy.spatial.distance import cdist

def compute_precision_recall(
    real_features: np.ndarray,
    fake_features: np.ndarray,
    k: int = 3,
    batch_size: int = 1000
) -> tuple:
    """
    计算基于 k-NN 流形估计的 Precision 和 Recall

    实现细节:
    ① 对真实特征,用 k-NN 估计流形(每个点的 k-NN 半径)
    ② Precision:有多少生成点落在真实流形内?
    ③ Recall:有多少真实点落在生成流形内?
    """
    def knn_radii(features, k):
        """计算每个点到第 k 个最近邻的距离"""
        n = len(features)
        radii = np.zeros(n)

        for i in range(0, n, batch_size):
            batch = features[i:i+batch_size]
            # 计算批次与所有点的距离
            dists = cdist(batch, features, metric='euclidean')
            # 排除自身(距离为 0 的那个)
            dists.sort(axis=1)
            radii[i:i+batch_size] = dists[:, k]  # 第 k 个最近邻距离

        return radii

    # 计算真实和生成特征的 k-NN 半径
    real_radii = knn_radii(real_features, k)
    fake_radii = knn_radii(fake_features, k)

    # Precision:生成点是否在某个真实点的 k-NN 球内
    # 即:生成点 f 是否满足 ||f - r|| ≤ radius_r 对某个真实点 r
    n_fake = len(fake_features)
    n_real = len(real_features)

    # 批量计算(避免内存溢出)
    precision_count = 0
    for i in range(0, n_fake, batch_size):
        batch_fake = fake_features[i:i+batch_size]
        # 计算这批假特征与所有真实特征的距离
        dists      = cdist(batch_fake, real_features)   # (B, n_real)
        # 对每个假样本,检查是否有真实样本在其半径内
        in_real_ball = (dists <= real_radii[np.newaxis, :]).any(axis=1)
        precision_count += in_real_ball.sum()

    precision = precision_count / n_fake

    # Recall:真实点是否在某个生成点的 k-NN 球内
    recall_count = 0
    for i in range(0, n_real, batch_size):
        batch_real = real_features[i:i+batch_size]
        dists      = cdist(batch_real, fake_features)
        in_fake_ball = (dists <= fake_radii[np.newaxis, :]).any(axis=1)
        recall_count += in_fake_ball.sum()

    recall = recall_count / n_real

    return float(precision), float(recall)


# 数值演示
np.random.seed(42)
d = 32    # 低维演示

# 真实数据:两个高斯混合
n_real = 500
real_feat = np.vstack([
    np.random.randn(n_real//2, d) + np.ones(d) * 2,
    np.random.randn(n_real//2, d) - np.ones(d) * 2,
])

print("  Precision & Recall 数值演示:")
print(f"  真实数据:双峰高斯,n={n_real},d={d}")
print()
print(f"  {'场景':^28} {'Precision':^12} {'Recall':^12} {'解读':^22}")
print("  " + "─" * 78)

# 各种场景的生成分布
n_fake = 500
scenarios_pr = [
    ("完美匹配",
     np.vstack([
         np.random.randn(n_fake//2, d) + np.ones(d)*2,
         np.random.randn(n_fake//2, d) - np.ones(d)*2,
     ])),
    ("模式崩溃(只有一个峰)",
     np.random.randn(n_fake, d) + np.ones(d)*2),
    ("低质量(均匀噪声)",
     np.random.randn(n_fake, d) * 5),
    ("覆盖但质量差(分散)",
     np.vstack([
         np.random.randn(n_fake//2, d)*3 + np.ones(d)*2,
         np.random.randn(n_fake//2, d)*3 - np.ones(d)*2,
     ])),
]

for name, fake_feat in scenarios_pr:
    prec, rec = compute_precision_recall(real_feat, fake_feat, k=3)
    if prec > 0.8 and rec > 0.8:
        interp = "质量好+多样性好 ✓"
    elif prec > 0.7 and rec < 0.5:
        interp = "质量好但多样性差 ⚠️"
    elif prec < 0.5 and rec > 0.7:
        interp = "覆盖广但质量差 ⚠️"
    else:
        interp = "质量和多样性都差 ✗"
    print(f"  {name:^28} {prec:^12.4f} {rec:^12.4f} {interp:^22}")

print()
print("  Precision-Recall 的优势:")
pr_advantages = [
    "能区分'高质量但模式崩溃'和'多样但模糊'这两种截然不同的失败模式",
    "帮助诊断训练问题:Precision 低 → 改进判别器;Recall 低 → 减少模式崩溃",
    "与信息检索领域的 P&R 语义一致,直觉上易理解",
    "不依赖高斯假设(FID 依赖),对非高斯分布也有效",
]
for adv in pr_advantages:
    print(f"  ✅ {adv}")

六、其他补充指标

print("\n其他补充指标")
print()

other_metrics = [
    {
        "name":   "KID(Kernel Inception Distance)",
        "paper":  "Bińkowski et al., 2018",
        "what":   "类似 FID,但用核最大均值差异(MMD)替代 Fréchet 距离",
        "pros":   "无偏估计(FID 在小样本下有偏),N=1000 就能稳定",
        "cons":   "计算比 FID 慢,不如 FID 广泛使用",
        "when":   "样本数量少(<5000)时比 FID 更可靠",
        "formula":"KID = MMD²(f(real), f(fake))",
    },
    {
        "name":   "LPIPS(Learned Perceptual Image Patch Similarity)",
        "paper":  "Zhang et al., 2018",
        "what":   "用 VGG/AlexNet 特征计算感知相似度",
        "pros":   "衡量单张图像对的感知相似度,适合配对图像翻译",
        "cons":   "需要配对图像(真实+生成对),不能评估无配对生成",
        "when":   "Pix2Pix、超分辨率等配对任务的评估",
        "formula":"LPIPS(x,y) = ||φ(x) - φ(y)||²(特征空间距离)",
    },
    {
        "name":   "SSIM(Structural Similarity Index)",
        "paper":  "Wang et al., 2004",
        "what":   "基于亮度、对比度、结构的感知相似度",
        "pros":   "快速,不需要深度网络",
        "cons":   "与人类感知相关性弱(不如 LPIPS)",
        "when":   "超分辨率、图像压缩的基准测试",
        "formula":"SSIM ∈ [-1, 1],越高越好",
    },
    {
        "name":   "Density & Coverage",
        "paper":  "Naeem et al., 2020",
        "what":   "改进版 P&R,更鲁棒地处理离群点",
        "pros":   "对离群假样本更鲁棒,更好地反映分布覆盖",
        "cons":   "较新,不如 P&R 普及",
        "when":   "P&R 的替代,当有离群样本时更可靠",
        "formula":"Density = P&R 的平滑版本",
    },
]

for m in other_metrics:
    print(f"  ⭐ [{m['name']}]({m['paper']})")
    print(f"     原理:{m['what']}")
    print(f"     优势:{m['pros']}")
    print(f"     局限:{m['cons']}")
    print(f"     适用:{m['when']}")
    print(f"     公式:{m['formula']}")
    print()

# 指标综合对比
print("  指标综合对比:")
print()
print(f"  {'指标':^12} {'质量':^8} {'多样性':^8} {'无需配对':^10} {'标准化':^10} {'推荐度':^10}")
print("  " + "─" * 62)
metrics_summary = [
    ("IS",         "✓",  "部分",  "✓", "中",  "⭐⭐⭐"),
    ("FID",        "✓",  "✓",    "✓", "高",  "⭐⭐⭐⭐⭐"),
    ("KID",        "✓",  "✓",    "✓", "中",  "⭐⭐⭐⭐"),
    ("P&R",        "分离","分离",  "✓", "中",  "⭐⭐⭐⭐"),
    ("LPIPS",      "✓",  "✗",    "✗", "高",  "⭐⭐⭐(配对)"),
    ("SSIM",       "部分","✗",    "✗", "高",  "⭐⭐(配对)"),
]
for row in metrics_summary:
    name, qual, div, unpaired, std, rec = row
    print(f"  {name:^12} {qual:^8} {div:^8} {unpaired:^10} {std:^10} {rec:^10}")

七、代码实战:完整评估 Pipeline

import numpy as np
import torch
import torch.nn as nn
import scipy.linalg

print("\n完整评估 Pipeline")
print()

class FIDCalculator:
    """
    FID 计算器(不依赖真实 InceptionV3,用随机特征演示)
    真实使用时,替换 extract_features 为 InceptionV3 提取
    """

    def __init__(self, feature_dim: int = 2048):
        self.feature_dim = feature_dim
        self.real_features = None
        self.real_mu = None
        self.real_sigma = None

    def compute_real_stats(self, real_images_or_features: np.ndarray):
        """
        计算并缓存真实数据的统计量(只需算一次!)
        real_images_or_features: (N, feature_dim) 特征或 (N, C, H, W) 图像
        """
        if real_images_or_features.ndim == 4:
            # 图像输入:需要先提取特征
            features = self._extract_inception_features(
                real_images_or_features
            )
        else:
            features = real_images_or_features

        self.real_features = features
        self.real_mu       = np.mean(features, axis=0)
        self.real_sigma    = np.cov(features, rowvar=False)
        print(f"  真实数据统计量已计算({len(features)} 个样本,{features.shape[1]}维)")

    def compute_fid(self, fake_features: np.ndarray) -> float:
        """计算 FID(真实统计量已预计算)"""
        assert self.real_mu is not None, "先调用 compute_real_stats()"

        mu_fake    = np.mean(fake_features, axis=0)
        sigma_fake = np.cov(fake_features, rowvar=False)

        return compute_fid(self.real_mu, self.real_sigma,
                            mu_fake, sigma_fake)

    def _extract_inception_features(self, images: np.ndarray) -> np.ndarray:
        """(占位)真实使用时替换为 InceptionV3 特征提取"""
        # 真实代码:
        # from torchvision.models import inception_v3
        # model = inception_v3(pretrained=True, transform_input=False)
        # ... 提取 pool3 层特征 ...
        return np.random.randn(len(images), self.feature_dim)  # 占位


class GANEvaluator:
    """GAN 综合评估器"""

    def __init__(self, n_samples: int = 5000,
                 feature_dim: int = 128):  # 演示用小维度
        self.n_samples   = n_samples
        self.feature_dim = feature_dim
        self.fid_calc    = FIDCalculator(feature_dim)

    def setup_real_stats(self, real_features: np.ndarray):
        """设置真实数据统计量(只做一次)"""
        self.fid_calc.compute_real_stats(real_features)

    def evaluate(self, generator, z_dim: int = 100,
                 compute_pr: bool = True) -> dict:
        """综合评估生成器"""
        results = {}

        # 生成特征(真实项目用 InceptionV3 提取)
        print(f"  生成 {self.n_samples} 个样本并提取特征...")
        with torch.no_grad():
            z          = torch.randn(self.n_samples, z_dim)
            fake_feat  = np.random.randn(
                self.n_samples, self.feature_dim
            )  # 演示:随机特征

        # FID
        results['FID'] = self.fid_calc.compute_fid(fake_feat)

        # Precision & Recall
        if compute_pr and self.fid_calc.real_features is not None:
            prec, rec = compute_precision_recall(
                self.fid_calc.real_features[:1000],
                fake_feat[:1000],
                k=3
            )
            results['Precision'] = prec
            results['Recall']    = rec

        return results

    def report(self, results: dict, model_name: str = "Model"):
        """打印评估报告"""
        print(f"\n  ═══ {model_name} 评估报告 ═══")
        print()
        if 'FID' in results:
            fid   = results['FID']
            grade = "优秀" if fid < 5 else "良好" if fid < 15 else "一般" if fid < 30 else "差"
            print(f"  FID:{fid:.4f}  [{grade}]")
        if 'Precision' in results:
            print(f"  Precision:{results['Precision']:.4f}  (生成质量)")
        if 'Recall' in results:
            print(f"  Recall:{results['Recall']:.4f}  (多样性覆盖)")
        print()


# 使用演示
np.random.seed(42)
d = 128

# 初始化评估器
evaluator = GANEvaluator(n_samples=2000, feature_dim=d)

# 设置真实数据统计量(真实项目:从验证集提取 InceptionV3 特征)
real_feat = np.random.randn(2000, d) * 1.0   # 真实数据特征(模拟)
evaluator.setup_real_stats(real_feat)

# 模拟三个不同质量的生成器
models = {
    "差的 GAN(高 FID)":   real_feat + np.random.randn(2000, d) * 3.0,
    "中等 GAN":             real_feat + np.random.randn(2000, d) * 1.0,
    "好的 GAN(低 FID)":   real_feat + np.random.randn(2000, d) * 0.3,
    "模式崩溃的 GAN":       np.zeros((2000, d)) + np.random.randn(2000, d) * 0.1,
}

print("  综合评估演示:")
for model_name, fake_features in models.items():
    fid   = evaluator.fid_calc.compute_fid(fake_features)
    prec, rec = compute_precision_recall(
        real_feat[:500], fake_features[:500], k=3
    )
    print(f"  [{model_name}]")
    print(f"    FID={fid:.4f},Precision={prec:.4f},Recall={rec:.4f}")
    print()

八、最佳实践:如何正确使用这些指标

print("\n最佳实践:正确使用 GAN 评估指标")
print()

best_practices = [
    {
        "topic": "样本数量",
        "dos": [
            "使用至少 10,000 个生成样本计算 FID(最好 50,000)",
            "固定随机种子,报告结果的方差",
            "多次计算取均值(FID 估计有方差)",
        ],
        "donts": [
            "不要用 1,000 个样本报告 FID(方差太大)",
            "不要只报告一次计算的结果",
        ],
    },
    {
        "topic": "预处理标准化",
        "dos": [
            "使用 clean-fid 库(标准化了 resize 和范围)",
            "明确说明使用的 InceptionV3 实现版本",
            "与对比方法使用完全相同的预处理",
        ],
        "donts": [
            "不要自己随意实现 FID(容易出现细节错误)",
            "不要跨不同实现的 FID 比较(如 pytorch-fid vs tensorflow-fid)",
        ],
    },
    {
        "topic": "指标组合",
        "dos": [
            "同时报告 FID 和 Precision/Recall",
            "如果任务是配对翻译,补充 LPIPS",
            "大规模实验时报告 IS 作为参考",
        ],
        "donts": [
            "不要只报告 IS(已知有严重局限)",
            "不要把 FID 作为唯一评估标准",
        ],
    },
    {
        "topic": "结果解读",
        "dos": [
            "FID 差异 < 1.0 时可能不显著(考虑置信区间)",
            "同时展示样本图像(定量+定性结合)",
            "报告 FID 时说明样本数 N",
        ],
        "donts": [
            "不要跨数据集比较 FID(CIFAR FID ≠ FFHQ FID)",
            "不要只看 FID 忽略生成图像的视觉质量",
        ],
    },
]

for bp in best_practices:
    print(f"  ── [{bp['topic']}] ──")
    print(f"  ✅ 应该:")
    for do in bp['dos']:
        print(f"     · {do}")
    print(f"  ❌ 不要:")
    for dont in bp['donts']:
        print(f"     · {dont}")
    print()

# 标准评估 Pipeline 推荐
print("  推荐的标准评估 Pipeline:")
print()
pipeline_code = '''
# 推荐:使用 clean-fid 库(最标准化的实现)
# pip install clean-fid

from cleanfid import fid

# 计算 FID(自动处理预处理)
fid_score = fid.compute_fid(
    fdir1 = "path/to/real/images",    # 真实图像目录
    fdir2 = "path/to/generated/imgs", # 生成图像目录
    mode  = "clean",                   # 标准化预处理
    num_workers = 4,
    batch_size  = 32,
)
print(f"FID: {fid_score:.4f}")

# 计算 KID(小样本时更稳定)
kid_score = fid.compute_kid(
    fdir1 = "path/to/real/images",
    fdir2 = "path/to/generated/imgs",
)
print(f"KID: {kid_score:.6f}")

# 生成样本的 Precision & Recall
from prdc import compute_prdc   # pip install prdc
real_features = extract_features("path/to/real/images")
fake_features = extract_features("path/to/generated/imgs")

metrics = compute_prdc(
    real_features = real_features,
    fake_features = fake_features,
    nearest_k     = 5,
)
print(f"Precision: {metrics['precision']:.4f}")
print(f"Recall:    {metrics['recall']:.4f}")
print(f"Density:   {metrics['density']:.4f}")
print(f"Coverage:  {metrics['coverage']:.4f}")
'''
print(pipeline_code)

# 最终总结
print("=" * 65)
print("  GAN 评估指标速查:")
print("=" * 65)
print()
summary = [
    ("FID",       "最主流",   "越低越好",  "≥10K 样本,clean-fid"),
    ("IS",        "仅参考",   "越高越好",  "不推荐作为主要指标"),
    ("KID",       "补充",     "越低越好",  "小样本时更稳定"),
    ("Precision", "质量指标", "越高越好",  "与 Recall 配合使用"),
    ("Recall",    "多样性",   "越高越好",  "与 Precision 配合使用"),
    ("LPIPS",     "配对任务", "越低越好",  "Pix2Pix、超分辨率"),
]
print(f"  {'指标':^12} {'使用频率':^10} {'方向':^10} {'注意事项':^22}")
print("  " + "─" * 58)
for metric, freq, direction, note in summary:
    print(f"  {metric:^12} {freq:^10} {direction:^10} {note:^22}")

总结

本篇三个核心收获:

① FID 是最可靠的主要指标,但需要正确使用

FID = ∥ μ r − μ g ∥ 2 + Tr ( Σ r + Σ g − 2 ( Σ r Σ g ) 1 / 2 ) \text{FID} = \|\mu_r - \mu_g\|^2 + \text{Tr}(\Sigma_r + \Sigma_g - 2(\Sigma_r\Sigma_g)^{1/2}) FID=μrμg2+Tr(Σr+Σg2(ΣrΣg)1/2)

至少 10K 样本,使用 clean-fid,同一数据集比较,报告置信区间。

② IS 有根本性局限,应只作参考

IS 不与真实数据对比,对 ImageNet 以外的域有严重偏差,高 IS ≠ 高生成质量。

③ Precision & Recall 能分离质量和多样性

指标 衡量 低的含义
Precision 生成质量 大量生成图像是虚假/模糊的
Recall 多样性覆盖 真实数据的某些模式没有被生成(模式崩溃)
FID 两者综合 质量差或多样性差,无法区分

下一篇预告:GAN vs VAE vs 扩散模型——三大生成范式的完整对比。从信息论视角(第五篇信息论系列)看三者的目标函数,从实验数据看各自的最佳场景,帮你在实际项目中选择正确的生成模型。


💬 你在实际项目中用过 FID 来评估生成质量吗?有没有遇到 FID 低但生成图像不好看的情况? 欢迎评论区分享!

🙏 如果这篇帮到你,点赞 + 收藏,系列持续更新!


本文为原创技术分享。代码在 Python 3.12 + PyTorch + SciPy 下验证。最后更新:2026-05-24

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐