从零跑通一个医学预测模型流程:基于Python机器学习
临床研究里有一类问题反复出现:哪些患者入院后风险更高? 靠经验判断固然重要,但如果能让数据说话,风险评估会更客观、更有说服力。
这篇文章以 MIMIC-IV 数据库中的肺炎患者为例,用 Python 搭建一套完整的机器学习预测流程,目标是预测患者是否会发生院内死亡。整个过程分 12 个步骤,提供部分参考代码。
第一步:把需要的工具都备好
做任何数据分析之前,先把相应库导入。
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split, StratifiedKFold, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score
from xgboost import XGBClassifier
import shap
几个关键库的作用:
-
• pandas:数据处理的核心,用于表格数据加载
-
• scikit-learn:机器学习工具箱,模型训练、性能评估
-
• XGBoost:目前在结构化数据上表现最稳定的模型之一
-
• SHAP:让模型"说清楚"预测因子及其预测力度
第二步:读取数据
data = pd.read_csv("肺炎院内死亡.csv", encoding="GBK")
print(data.head()) # 查看前5行
print(data.info()) # 查看列名和数据类型
data.info() 可以初步探索数据,哪些列有缺失、哪些是数值型、哪些是字符串。
为方便大家学习 这里给大家整理了一份学习资料包 需要的同学 根据下图自取即可

第三步:清洗数据
原始数据库字段通常包含大量 ID 和时间信息,这些数据不能直接用于预测,可以删除:
-
• ID 列(如患者编号):纯标识符
-
• 时间列(如入院时间):不常用作为预测变量
# 批量匹配并删除无效列
cols_to_remove = data.columns[data.columns.str.contains(
r"hadm_id|stay_id|admittime|dischtime|icu_intime|icu_outtime",
case=False, regex=True
)]
data = data.drop(columns=cols_to_remove)
用正则表达式批量匹配,一行代码搞定所有不需要的预测变量。
第四步:处理缺失值
缺失值是医学数据的常态,处理策略分两层:
第一层——直接删掉缺失太多的列(超过20%缺失的变量通常删除,5~20%直接用插补):
missing_ratio = data.isna().mean()
cols_to_drop = missing_ratio[missing_ratio > 0.2].index # 缺失超过20%就删
data = data.drop(columns=cols_to_drop)
第二层——剩余缺失用 KNN 插补:
from sklearn.impute import KNNImputer
imputer = KNNImputer(n_neighbors=5)
data_imputed = imputer.fit_transform(data[continuous_vars])
KNN 插补的逻辑:找到 5 个特征最相近的患者,用他们的均值填补缺失,比直接填均值更合理,能保留患者之间的相似性结构。
机器学习中还有另一个常用方法是 MissForest(随机森林插补),特点是精度更高,速度慢,数据量大时间长
第五步:处理异常值——温莎化
医学数据里的"异常值"是真实存在的极端数值,比如的急性肾衰竭患者肌酐值通常极高。不能简单把这些极端值删除,最常用方法是进行极端值处理,即温莎处理。
温莎化(Winsorize)的做法:不删数据,而是把超出合理范围的值"截断"到边界:
def winsorize_series(series, lower_pct=0.01, upper_pct=0.99):
lower = series.quantile(lower_pct)
upper = series.quantile(upper_pct)
return series.clip(lower=lower, upper=upper)
超出第 1 百分位的值统一变成第 1 百分位,超出第 99 百分位的统一变成第 99 百分位。在尽可能保留样本量没情况下,极端值导致模型的“不稳”情况得以改善。
第六步:拆分数据集
按照 7:3 的比例把数据拆成训练集和测试集,分层拆分是关键:
train_data, test_data = train_test_split(
data,
test_size=0.3,
stratify=data["hosp_expire_flag"], # 保证两组死亡率相同
random_state=2025
)
分层拆分成训练集及测试集,是机器学习关键一部,测试集用于模型评估及SHAP分析结果展示
为方便大家学习 这里给大家整理了一份学习资料包 需要的同学 根据下图自取即可

第七步:检查训练集和测试集是否"可比"
拆完之后要验证一件事:训练集和测试集的患者基线特征是否相似?这在医学研究中是论文必须汇报的内容,通常用 Table 1 呈现。
from tableone import TableOne
train_data["group"] = "训练集"
test_data["group"] = "测试集"
total = pd.concat([train_data, test_data])
table = TableOne(
data=total,
columns=all_vars,
categorical=categorical_vars,
groupby="group",
pval=True
)
输出的 Table 1 中,如果所有 p 值都大于 0.05,说明两组基线无统计学差异,拆分是均衡的。
第八步:标准化连续变量
SVM、KNN、神经网络这类模型对特征的数值范围非常敏感。需要用标准化数据进行模型训练,主要目的是减少模型训练时间及模型“噪声”。
scaler = StandardScaler()
train_data[continuous_vars] = scaler.fit_transform(train_data[continuous_vars])
test_data[continuous_vars] = scaler.transform(test_data[continuous_vars]) # 注意:只用transform
第九步:特征筛选(用R语言实现效果更佳)
原始数据可能有几十个变量,并并非每个都对模型预测有帮助。这里用两种方法筛选:
LASSO 回归:通过惩罚不重要变量的系数,强迫它们趋近于零,自动淘汰弱变量。适合处理线性关系。
Boruta 算法:基于随机森林,通过与随机生成的"幽灵特征"比较,判断每个变量是否显著。能捕捉非线性关系。
两种方法都跑完后,取交集或综合判断,最终保留显著变量:
significant_vars = [
"age", "malignant_cancer", "hosp_los_days",
"heart_rate", "sofa", "spo2", "resp_rate",
"bun", "platelet", "temperature", "hemoglobin",
"wbc", "sbp_ni"
]
第十步:训练 9 个模型
用网格搜索 + 5 折交叉验证的方式,依次训练 9 个模型:Logistic 回归、决策树、随机森林、XGBoost、LightGBM、SVM、AdaBoost、KNN、神经网络。
以随机森林为例:
param_grid = {
'n_estimators': [50, 100, 200, 300],
'max_features': [2, 3, 4]
}
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
grid_search = GridSearchCV(
estimator=RandomForestClassifier(random_state=42),
param_grid=param_grid,
cv=skf,
scoring='accuracy',
n_jobs=-1 # 用所有CPU核心并行,加速搜索
)
grid_search.fit(X_train, y_train)
交叉验证的逻辑:把训练集分成 5 份,轮流用其中 1 份验证,另外 4 份训练,最后取 5 次验证结果的均值,避免偶然性干扰参数选择。
网格搜索:枚举所有参数组合,找出让验证准确率最高的那一组。
为方便大家学习 这里给大家整理了一份学习资料包 需要的同学 根据下图自取即可

第十一步:评估模型性能
评估分两轮:先在训练集上看,再在测试集上看,对比两组结果能判断模型有没有过拟合。
主要指标:
|
指标 |
含义 |
|---|---|
| AUC |
模型区分死亡/存活的整体能力,越接近 1 越好 |
| 灵敏度 |
真正会死亡的患者,有多少被模型识别出来 |
| 特异度 |
真正存活的患者,有多少没被误判为死亡 |
| 校准曲线 |
预测的概率是否与真实发生率吻合 |
| DCA 曲线 |
在不同决策阈值下,模型带来的临床净获益 |
AUC 是最常报告的指标,参考标准:0.7~0.8 中等,0.8~0.9 良好,0.9 以上优秀。
第十二步:用 SHAP 解释模型在"想什么"
模型做出预测后,还需要回答一个问题:它凭什么这么判断?
SHAP 基于博弈论中的 Shapley 值,量化每个特征对每次预测的贡献大小。
explainer = shap.KernelExplainer(
lambda X: model.predict_proba(X),
X_train
)
shap_values = explainer(X_test.iloc[:100, :])
shap.plots.beeswarm(shap_values) # 蜂群图:全局特征重要性
shap.plots.waterfall(shap_values[5]) # 瀑布图:第5个患者的单例解释
常用的 SHAP 图:
-
• 蜂群图:横轴是特征对预测的推动方向,每个点是一个患者,颜色代表该特征的实际值。可以看出哪些特征对预测最关键
-
• 瀑布图:单个患者的预测分解,从"平均基准"出发,逐步叠加各个特征的贡献,最终到达该患者的预测概率
-
• 依赖图:某个特征值与其 SHAP 贡献之间的关系,揭示非线性效应
流程全貌
原始数据
↓
① 导入与初步查看
↓
② 删除 ID/时间列,重命名字段
↓
③ 缺失值:删除缺失>20%的列 → KNN/MissForest 插补
↓
④ 异常值:Winsorize 温莎化
↓
⑤ 7:3 分层拆分 → 训练集 / 测试集
↓
⑥ TableOne 验证两组均衡性
↓
⑦ 连续变量标准化(仅训练集 fit)
↓
⑧ LASSO + Boruta 特征筛选 → 13 个变量
↓
⑨ GridSearchCV × 5折CV → 训练 9 个模型
↓
⑩ 训练集 + 测试集双重评估(AUC、ROC、校准曲线、DCA)
↓
⑪ SHAP 可解释性分析(全局 + 单例)
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐


所有评论(0)