决策树模型笔记

1. 基础知识

基本模型形式

决策树是一种常见的监督学习模型,既可以做分类,也可以做回归。它通过一系列“如果…那么…”的规则不断划分特征空间,最终在叶子节点给出预测结果。

对于分类任务,模型会根据样本特征逐层判断,例如:

  • 如果花瓣长度小于某个阈值,进入左子树
  • 否则进入右子树

最终到达某个叶子节点后,叶子节点中占比最高的类别就是预测类别。

核心目标

决策树的核心目标是:在每一次节点划分时,找到一个最优特征和最优切分点,让划分后的子节点尽可能“纯”。

分类任务中常见目标包括:

  • 让同一类别样本尽量落到同一个叶子节点
  • 降低节点的不确定性
  • 提升整体分类准确率

损失函数

决策树通常不直接写成统一的全局损失函数最小化问题,而是在每个节点上贪心地选择最优划分标准。

常见划分指标有:

  • 基尼指数(Gini Index)
  • 信息熵(Entropy)

以基尼指数为例:

Gini(D) = 1 - Σ(p_k)^2

其中 p_k 表示样本集合 D 中第 k 类样本所占比例。基尼指数越小,说明节点越纯。

参数求解

决策树的参数求解过程本质上是一个递归划分过程:

  1. 在当前节点中遍历候选特征
  2. 为每个特征尝试不同划分阈值
  3. 计算划分后的不纯度下降
  4. 选择收益最大的划分方式
  5. 递归生成左右子树,直到满足停止条件

常见停止条件包括:

  • 达到最大树深度
  • 节点样本数过少
  • 节点已经足够纯

应用示例(Python实现)

本项目使用 scikit-learn 中的 DecisionTreeClassifier 实现一个经典的鸢尾花三分类任务:

from sklearn.tree import DecisionTreeClassifier

model = DecisionTreeClassifier(max_depth=3, random_state=42)
model.fit(x_train, y_train)
y_pred = model.predict(x_test)

注意要点

  • 决策树容易过拟合,需要通过 max_depthmin_samples_split 等参数控制复杂度
  • 决策树对特征缩放不敏感,一般不强制要求标准化
  • 树结构可解释性强,适合教学演示和规则分析
  • 单棵树性能通常不如集成模型,但更容易理解

2. 代码实践

model.py

model.py 负责定义决策树模型、训练模型和预测接口。这里统一封装了:

  • build_model():创建模型
  • train_model():拟合训练数据
  • predict():执行预测
from sklearn.tree import DecisionTreeClassifier


def build_model(
    criterion: str = "gini",
    max_depth: int = 3,
    random_state: int = 42,
) -> DecisionTreeClassifier:
    """创建决策树分类模型。"""
    return DecisionTreeClassifier(
        criterion=criterion,
        max_depth=max_depth,
        random_state=random_state,
    )


def train_model(
    x_train,
    y_train,
    criterion: str = "gini",
    max_depth: int = 3,
    random_state: int = 42,
) -> DecisionTreeClassifier:
    """训练决策树分类模型。"""
    model = build_model(
        criterion=criterion,
        max_depth=max_depth,
        random_state=random_state,
    )
    model.fit(x_train, y_train)
    return model


def predict(model: DecisionTreeClassifier, x_test):
    """使用训练好的模型进行预测。"""
    return model.predict(x_test)

train.py

train.py 负责训练流程,包括:

  • 训练集和测试集划分
  • 调用 train_model() 完成训练

代码中使用了 stratify=y,保证分类任务中训练集和测试集的类别分布更加稳定。

from sklearn.model_selection import train_test_split

from model import train_model


def split_data(
    x,
    y,
    test_size: float = 0.2,
    random_state: int = 42,
):
    """划分训练集和测试集。"""
    return train_test_split(
        x,
        y,
        test_size=test_size,
        random_state=random_state,
        stratify=y,
    )


def run_train(
    x,
    y,
    test_size: float = 0.2,
    random_state: int = 42,
    criterion: str = "gini",
    max_depth: int = 3,
):
    """完成数据划分和模型训练。"""
    x_train, x_test, y_train, y_test = split_data(
        x,
        y,
        test_size=test_size,
        random_state=random_state,
    )
    model = train_model(
        x_train,
        y_train,
        criterion=criterion,
        max_depth=max_depth,
        random_state=random_state,
    )
    return model, x_train, x_test, y_train, y_test

eval.py

eval.py 负责评估模型效果,输出:

  • 准确率 accuracy
  • 混淆矩阵 confusion_matrix
  • 分类报告 classification_report

这些指标能帮助我们同时观察总体表现和各类别的精确率、召回率、F1 值。

from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

from model import predict


def evaluate_model(model, x_test, y_test) -> dict:
    """评估决策树分类模型效果。"""
    y_pred = predict(model, x_test)
    return {
        "accuracy": accuracy_score(y_test, y_pred),
        "confusion_matrix": confusion_matrix(y_test, y_pred),
        "classification_report": classification_report(y_test, y_pred),
    }

dataload.py

dataload.pysklearn.datasets 中加载鸢尾花数据集:

  • 特征 x:4 个花萼/花瓣数值特征
  • 标签 y:3 个类别标签
  • target_names:类别名称,用于可视化展示
import pandas as pd
from sklearn.datasets import load_iris


def load_data():
    """加载 sklearn 自带的 iris 分类数据集。"""
    dataset = load_iris()
    x = pd.DataFrame(dataset.data, columns=dataset.feature_names)
    y = pd.Series(dataset.target, name="target")
    return x, y, dataset.target_names

run.py

run.py 是项目入口,负责串联整个流程:

  1. 加载数据
  2. 训练模型
  3. 评估模型
  4. 保存可视化结果

可视化部分包含:

  • 决策树结构图
  • 混淆矩阵图
from pathlib import Path

import matplotlib

matplotlib.use("Agg")

import matplotlib.pyplot as plt
from sklearn.metrics import ConfusionMatrixDisplay
from sklearn.tree import plot_tree

from dataload import load_data
from eval import evaluate_model
from model import predict
from train import run_train


def save_plots(model, x_test, y_test, class_names) -> list[Path]:
    """保存决策树结构图和混淆矩阵图。"""
    current_dir = Path(__file__).resolve().parent
    output_dir = current_dir / "figure"
    output_dir.mkdir(exist_ok=True)

    tree_path = output_dir / "decision_tree_structure.png"
    cm_path = output_dir / "decision_tree_confusion_matrix.png"

    fig, ax = plt.subplots(figsize=(16, 10))
    plot_tree(
        model,
        feature_names=list(x_test.columns),
        class_names=list(class_names),
        filled=True,
        rounded=True,
        ax=ax,
    )
    fig.tight_layout()
    fig.savefig(tree_path, dpi=150, bbox_inches="tight")
    plt.close(fig)

    fig, ax = plt.subplots(figsize=(6, 5))
    ConfusionMatrixDisplay.from_predictions(
        y_test,
        predict(model, x_test),
        display_labels=class_names,
        cmap="Blues",
        ax=ax,
    )
    fig.tight_layout()
    fig.savefig(cm_path, dpi=150, bbox_inches="tight")
    plt.close(fig)

    return [tree_path, cm_path]


def main() -> None:
    x, y, class_names = load_data()
    model, x_train, x_test, y_train, y_test = run_train(x, y)
    metrics = evaluate_model(model, x_test, y_test)
    plot_paths = save_plots(model, x_test, y_test, class_names)

    print("Decision Tree Demo")
    print(f"Train size: {len(x_train)}, Test size: {len(x_test)}")
    print(f"Accuracy: {metrics['accuracy']:.4f}")
    print("Confusion Matrix:")
    print(metrics["confusion_matrix"])
    print("Classification Report:")
    print(metrics["classification_report"])
    print("Saved plots:")
    for plot_path in plot_paths:
        print(plot_path)


if __name__ == "__main__":
    main()

运行结果

运行 python run.py 后,终端会输出训练集/测试集大小、准确率、混淆矩阵和分类报告。

图片会保存在当前目录下的 figure/ 文件夹中,通常包括:

  • decision_tree_structure.png
  • decision_tree_confusion_matrix.png

如果分类结果接近满分,这是因为鸢尾花数据集本身比较经典且较容易划分,适合作为决策树入门 demo。
在这里插入图片描述
在这里插入图片描述

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐