1. 引言:为什么要重视混淆矩阵?

在机器学习和深度学习项目中,模型评估是至关重要的一环。混淆矩阵(Confusion Matrix)作为分类任务中最直观、最全面的性能展示工具,不仅能呈现模型的预测结果分布,还能衍生出精确率、召回率、F1分数等核心指标。然而,很多开发者仅停留在调用sklearn.metrics.confusion_matrix的层面,缺乏对可视化细节和指标深入计算的理解。

本文将带你从零开始,用Python绘制一张专业级、可定制、功能强大的混淆矩阵图。代码支持任意大小的矩阵,内置精确率/召回率/F1计算,并提供宏平均、微平均、加权平均等总体统计,同时包含进度条可视化、归一化选项、表格化输出等高级特性。最终成品可直接用于论文、技术报告或商业演示。


2. 混淆矩阵基础回顾

对于一个二分类问题,混淆矩阵是一个2×2表格:

预测为正 预测为负
实际为正 TP FN
实际为负 FP TN
  • TP:真正例,正确预测为正的样本数

  • FP:假正例,错误预测为正的样本数

  • FN:假反例,错误预测为负的样本数

  • TN:真反例,正确预测为负的样本数

基于此,可计算核心指标:

  • 精确率 (Precision) = TP / (TP + FP) —— 预测为正的样本中有多少是真正的正类

  • 召回率 (Recall) = TP / (TP + FN) —— 实际为正的样本中有多少被正确预测

  • F1分数 = 2 × (Precision × Recall) / (Precision + Recall) —— 精确率和召回率的调和平均

  • 准确率 (Accuracy) = (TP + TN) / 总样本数

对于多分类问题,通常采用宏平均(每个类别指标简单平均)、微平均(全局统计计算)、加权平均(按各类样本数加权)来汇总性能。


3. 基础绘制:7x7混淆矩阵

我们先从一个简单的7x7混淆矩阵绘制开始,使用Matplotlib和Seaborn。

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# 示例7x7混淆矩阵(7个类别)
cm = np.array([
    [85, 2, 1, 0, 0, 1, 3],
    [1, 92, 0, 0, 2, 0, 0],
    [0, 1, 78, 4, 0, 0, 0],
    [0, 0, 3, 88, 1, 0, 0],
    [1, 0, 0, 2, 82, 0, 0],
    [2, 1, 0, 0, 0, 79, 1],
    [1, 0, 0, 0, 0, 2, 91]
])

class_names = ['Action', 'Adventure', 'Comedy', 'Drama', 'Horror', 'Romance', 'Sci-Fi']

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=class_names, yticklabels=class_names)
plt.title('Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.show()

这段代码能生成一个带数值标签的热力图,但缺少指标信息,且样式较为朴素。下面我们将逐步增强。


4. 进阶美化:集成指标计算与专业配色

我们希望右侧直接显示每个类别的精确率和召回率,并采用统一的蓝色主题(文字黑色,确保打印清晰)。

4.1 指标计算函数

def calculate_metrics(cm):
    n = cm.shape[0]
    precision, recall, f1, support = [], [], [], []
    total_tp = total_fp = total_fn = 0
    for i in range(n):
        tp = cm[i, i]
        fp = cm[:, i].sum() - tp
        fn = cm[i, :].sum() - tp
        sup = cm[i, :].sum()
        p = tp / (tp + fp) if (tp + fp) > 0 else np.nan
        r = tp / (tp + fn) if (tp + fn) > 0 else np.nan
        f = 2 * p * r / (p + r) if (p + r) > 0 else np.nan
        precision.append(p)
        recall.append(r)
        f1.append(f)
        support.append(sup)
        total_tp += tp
        total_fp += fp
        total_fn += fn
    accuracy = total_tp / cm.sum() if cm.sum() > 0 else np.nan
    return precision, recall, f1, support, accuracy, total_tp, total_fp, total_fn

4.2 右侧指标文本面板(早期版本)

早期版本中,我们使用ax.text在右侧垂直排列指标,并添加进度条(每个▇代表5%)。

# 生成指标文本
metrics_text = []
for cls, p, r in zip(class_names, precision, recall):
    p_str = f"{p:.2%}" if not np.isnan(p) else "N/A"
    r_str = f"{r:.2%}" if not np.isnan(r) else "N/A"
    p_bars = "▇" * int(np.floor(p*20)) if not np.isnan(p) else ""
    r_bars = "▇" * int(np.floor(r*20)) if not np.isnan(r) else ""
    metrics_text.append(
        f"{cls}\nPrecision: {p_str} {p_bars}\nRecall:    {r_str} {r_bars}\n{'-'*30}"
    )
# 在ax1上绘制...

但这种方式在类别较多时易导致文本溢出或对齐问题。因此我们最终采用表格化布局


5. 高级功能:表格化多指标展示

最终版代码:混淆矩阵热力图,两个表格(类别指标表 + 总体统计表)。支持以下特性:

  • 任意大小的混淆矩阵(n×n),自动适配。

  • 三类平均:宏平均、微平均、加权平均。

  • 进度条:每个指标后附可视化进度条(可选)。

  • 归一化:可选择显示行归一化百分比。

  • 字体兼容:使用sans-serifmonospace,避免字体缺失。

  • 异常处理:除零或空值显示“N/A”。

  • 保存图片:支持指定路径保存高清图。


6. 使用示例及效果展示

6.1 示例代码

if __name__ == "__main__":
    # 7x7 混淆矩阵
    cm = np.array([
        [85, 2, 1, 0, 0, 1, 3],
        [1, 92, 0, 0, 2, 0, 0],
        [0, 1, 78, 4, 0, 0, 0],
        [0, 0, 3, 88, 1, 0, 0],
        [1, 0, 0, 2, 82, 0, 0],
        [2, 1, 0, 0, 0, 79, 1],
        [1, 0, 0, 0, 0, 2, 91]
    ])

    class_names = ['Action', 'Adventure', 'Comedy', 'Drama', 'Horror', 'Romance', 'Sci-Fi']

    advanced_confusion_matrix(
        cm=cm,
        classes=class_names,
        normalize=False,
        show_bars=True,
        cmap='Blues',
        title='Movie Genre Classification Performance',
        figsize=(24, 12),
        dpi=100,
        save_path='confusion_matrix.png'
    )

6.2 输出效果描述

运行上述代码后,你将看到一幅结构清晰、信息丰富的图表:

  • 左侧:蓝色渐变热力图,每个单元格显示原始计数,对角线深色表示正确分类。

  • 右上:类别指标表,每行包含类名、精确率(带进度条)、召回率(带进度条)、F1分数(带进度条)、支持度(该类真实样本数)。进度条由“▇”组成,每个▇代表5%,直观反映指标高低。

  • 右下:总体统计表,包含准确率、宏平均、加权平均、微平均的精确率/召回率/F1。空白列用于对齐。

  • 底部:进度条图例说明。

所有文字均为黑色,表格边框为浅灰色,标题和表头有浅蓝色背景,整体风格专业、清晰。


7. 参数详解与自定义指南

参数名 类型 默认值 说明
cm numpy.ndarray 必填 混淆矩阵,必须是方阵
classes list 必填 类别名称列表,长度需等于矩阵阶数
normalize bool False 是否对混淆矩阵按行归一化(显示百分比)
show_bars bool True 是否在指标后显示进度条
cmap str/cmap 'Blues' 热力图颜色映射,可选 'Blues','Greens','Oranges','Reds' 等
title str 'Enhanced Confusion Matrix' 图表主标题
figsize tuple (22,12) 画布尺寸,可根据类别数调整
dpi int 100 输出图像分辨率
save_path str or None None 若提供路径,则保存图像到该文件

自定义建议

  • 对于类别较多的矩阵(>10),可适当增大figsize,减小表格字体(代码中已有自适应,但可手动调整table.set_fontsize())。

  • 如需修改颜色主题,可将cmap更换为喜欢的配色,如'Reds''Purples'

  • 若不想显示进度条,设置show_bars=False即可。


8. 常见问题及解决方案

Q1: 运行时出现 Font family 'Arial' not found.

原因:系统中未安装Arial字体。
解决:代码中已使用sans-serifmonospace通用字体族,无需指定具体字体。若仍有警告,可在代码开头添加:

plt.rcParams['font.family'] = 'sans-serif'

Q2: 表格中的进度条显示不完整或过短

原因:进度条长度计算为int(p*20),即每5%一个块。如果指标值较低(如12%),将显示2个块,符合设计。
调整:如需改变粒度,修改代码中的乘数(如p*10为10%一个块)。

Q3: 混淆矩阵中有全零行或全零列,导致除零错误

解决:代码已使用np.nan处理,并最终显示“N/A”,不会崩溃。

Q4: 图像保存后文字模糊

解决:增加dpi参数,如dpi=300,同时确保figsize足够大。


9. 总结与展望

本文从零开始,逐步构建了一个功能全面、样式专业、健壮可靠的混淆矩阵可视化工具。它不仅支持7x7矩阵,也能无缝扩展到任意大小的分类任务。通过集成精确率、召回率、F1、支持度、宏/微/加权平均等指标,并以表格化+进度条的形式清晰展示,帮助开发者快速评估模型性能。

未来可扩展方向:

  • 支持多输出(multi-output)混淆矩阵。

  • 添加交互式功能(如鼠标悬停显示详细信息)。

  • 集成到机器学习框架(如TensorBoard)中。

希望本文能成为你模型评估工具箱中的得力助手。如果你有任何改进建议或遇到问题,欢迎在评论区留言交流!


版权声明:本文为CSDN博主原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐