基于 ConvNeXt-Tiny 与序数回归的糖尿病视网膜病变自动分级:从代码到论文的完整拆解
源码
# -*- coding: utf-8 -*-
"""
APTOS 2019 糖尿病视网膜病变自动分级训练脚本(强注释版)
核心任务:
输入一张眼底图像,预测糖尿病视网膜病变等级:
0:无 DR
1:轻度 NPDR
2:中度 NPDR
3:重度 NPDR
4:增殖性 DR
核心思想:
1. DR 分级不是普通五分类,而是天然有序等级:0 < 1 < 2 < 3 < 4
2. 因此本代码使用“序数回归”:
- 不直接预测 0/1/2/3/4
- 而是预测:
y > 0 ?
y > 1 ?
y > 2 ?
y > 3 ?
3. 模型结构:
ConvNeXt-Tiny backbone + Neck + 双头输出
- ordinal_head:4 维输出,用于序数回归
- class_head:5 维输出,用于普通分类辅助监督
4. 损失函数:
- ordinal BCE loss
- focal classification loss
5. 验证阶段:
- 不固定使用 0.5 阈值
- 在验证集上搜索最优 4 个阈值
6. 推理阶段:
- 使用 TTA:原图、水平翻转、垂直翻转、水平+垂直翻转
运行前目录要求:
当前脚本所在目录下需要有:
- train.csv
- test.csv
- train_images_320_crop_clahe/
- test_images_320_crop_clahe/
train.csv 至少包含:
- id_code
- diagnosis
test.csv 至少包含:
- id_code
"""
import os
import random
import warnings
from pathlib import Path
import numpy as np
import pandas as pd
from PIL import Image, ImageFile
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from torchvision import transforms, models
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix
# =========================================================
# 0. 基础设置
# =========================================================
# 忽略不影响主流程的警告,避免终端输出太乱
warnings.filterwarnings("ignore")
# 允许 PIL 读取部分损坏或截断的图像
# 医学图像数据集中偶尔可能存在异常图片,这样可以避免程序直接中断
ImageFile.LOAD_TRUNCATED_IMAGES = True
# =========================================================
# 1. 全局配置
# =========================================================
# 固定随机种子,尽量保证实验可复现
SEED = 42
# -------------------------
# 运行模式
# -------------------------
# False:只跑 1 个 fold,用于快速验证模型是否能正常训练
# True :正式跑 5-fold,用于论文/报告实验结果
USE_5FOLD = False
# 5 折交叉验证
N_SPLITS = 5
# 当 USE_5FOLD=False 时,只跑第几个 fold
# 0 表示第 1 折,1 表示第 2 折
QUICK_FOLD_INDEX = 0
# -------------------------
# 数据和训练超参数
# -------------------------
# 输入图像大小
# 眼底病灶通常较小,448 比 224 能保留更多细节
IMG_SIZE = 448
# batch size
# 医学图像较大,ConvNeXt 又比较吃显存,所以这里设为 8
BATCH_SIZE = 8
# 训练轮数
EPOCHS = 20
# APTOS 是 5 分类:0、1、2、3、4
NUM_CLASSES = 5
# 序数回归输出维度 = 类别数 - 1
# 5 个等级需要 4 个边界:
# y>0, y>1, y>2, y>3
NUM_ORDINAL_OUTPUTS = NUM_CLASSES - 1
# 学习率
LEARNING_RATE = 2e-4
# 权重衰减,抑制过拟合
WEIGHT_DECAY = 1e-4
# Windows 下 DataLoader 多进程容易出问题,设为 0 最稳
NUM_WORKERS = 0
# 如果有 GPU,则开启 pin_memory,加速 CPU 到 GPU 的数据拷贝
PIN_MEMORY = torch.cuda.is_available()
# 是否加载 ImageNet 预训练权重
# 小数据集强烈建议使用预训练
USE_PRETRAINED = True
# 是否在验证/推理阶段使用 TTA
USE_TTA = True
# -------------------------
# 损失函数相关配置
# -------------------------
# 序数回归损失权重
LAMBDA_ORD = 1.0
# 分类辅助损失权重
LAMBDA_CLS = 0.7
# 分类头是否使用 Focal Loss
USE_FOCAL_FOR_CLS = True
# Focal Loss 的 gamma
# gamma 越大,越强调困难样本
FOCAL_GAMMA = 2.0
# 4 个 ordinal head 的额外权重
# 越靠后的 head 对应越严重的等级,正样本越少,所以给更大权重
HEAD_WEIGHTS = [1.0, 1.1, 1.35, 1.6]
# ordinal 标签平滑
ORDINAL_LABEL_SMOOTHING = 0.02
# 普通分类标签平滑
CLASS_LABEL_SMOOTHING = 0.05
# 梯度裁剪阈值,防止梯度爆炸
GRAD_CLIP_NORM = 1.0
# 阈值粗搜范围
# 从 0.20 到 0.84,步长 0.02
COARSE_THRESHOLD_CANDIDATES = np.arange(0.20, 0.86, 0.02)
# 自动选择设备
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# =========================================================
# 2. 固定随机种子
# =========================================================
def set_seed(seed=42):
"""
固定常见随机源,使实验尽量可复现。
注意:
深度学习中完全复现很难,因为 GPU 某些算子本身存在非确定性。
但固定随机种子至少能减少实验波动。
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
set_seed(SEED)
print("当前设备:", DEVICE)
if torch.cuda.is_available():
print("GPU 名称:", torch.cuda.get_device_name(0))
# 输入尺寸固定时,cudnn.benchmark=True 可以自动寻找更快的卷积实现
torch.backends.cudnn.benchmark = True
# =========================================================
# 3. 数据路径
# =========================================================
# 当前脚本所在目录
BASE_DIR = Path(__file__).resolve().parent
# CSV 文件
TRAIN_CSV = BASE_DIR / "train.csv"
TEST_CSV = BASE_DIR / "test.csv"
# 预处理后的图像目录
# 这里默认你已经提前完成:
# 1. 黑边裁切
# 2. CLAHE 对比度增强
TRAIN_DIR = BASE_DIR / "train_images_320_crop_clahe"
TEST_DIR = BASE_DIR / "test_images_320_crop_clahe"
# 输出目录
OUTPUT_DIR = BASE_DIR / "output_convnext_hybrid_strong"
OUTPUT_DIR.mkdir(exist_ok=True)
# 路径检查:提前报错,避免训练到一半才发现文件缺失
assert TRAIN_CSV.exists(), f"找不到文件: {TRAIN_CSV}"
assert TEST_CSV.exists(), f"找不到文件: {TEST_CSV}"
assert TRAIN_DIR.exists(), f"找不到目录: {TRAIN_DIR}"
assert TEST_DIR.exists(), f"找不到目录: {TEST_DIR}"
print(f"项目目录: {BASE_DIR}")
print(f"训练 CSV: {TRAIN_CSV}")
print(f"测试 CSV: {TEST_CSV}")
print(f"训练图片目录: {TRAIN_DIR}")
print(f"测试图片目录: {TEST_DIR}")
print(f"输出目录: {OUTPUT_DIR}")
# =========================================================
# 4. 读取数据
# =========================================================
train_df = pd.read_csv(TRAIN_CSV)
test_df = pd.read_csv(TEST_CSV)
# 检查必要字段
assert "id_code" in train_df.columns, "train.csv 缺少 id_code 列"
assert "diagnosis" in train_df.columns, "train.csv 缺少 diagnosis 列"
assert "id_code" in test_df.columns, "test.csv 缺少 id_code 列"
# 根据 id_code 拼接图片路径
# 例如 id_code = abc123,则图片路径为 train_images_320_crop_clahe/abc123.png
train_df["image_path"] = train_df["id_code"].astype(str).apply(
lambda x: str(TRAIN_DIR / f"{x}.png")
)
test_df["image_path"] = test_df["id_code"].astype(str).apply(
lambda x: str(TEST_DIR / f"{x}.png")
)
# 只保留真实存在的图片
# 防止 CSV 中有记录但图片不存在导致训练时报错
train_df = train_df[train_df["image_path"].apply(os.path.exists)].reset_index(drop=True)
test_df = test_df[test_df["image_path"].apply(os.path.exists)].reset_index(drop=True)
print(f"\n有效训练样本数: {len(train_df)}")
print(f"有效测试样本数: {len(test_df)}")
print("\n训练集类别分布:")
print(train_df["diagnosis"].value_counts().sort_index())
# =========================================================
# 5. 图像增强
# =========================================================
# 训练集增强
train_transform = transforms.Compose([
# 先放大,再随机裁剪,增强模型对位置变化的鲁棒性
transforms.Resize((IMG_SIZE + 32, IMG_SIZE + 32)),
transforms.RandomResizedCrop(
IMG_SIZE,
scale=(0.88, 1.0),
ratio=(0.95, 1.05)
),
# 翻转增强
# 眼底病灶翻转后语义通常不变
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.2),
# 小角度旋转,模拟拍摄角度偏差
transforms.RandomRotation(degrees=12),
# 颜色扰动,模拟不同相机和光照条件
transforms.ColorJitter(
brightness=0.10,
contrast=0.10,
saturation=0.06,
hue=0.015
),
# PIL Image -> Tensor
transforms.ToTensor(),
# 使用 ImageNet 均值方差
# 因为 ConvNeXt-Tiny 使用 ImageNet 预训练权重
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])
# 验证集不做随机增强,只做确定性预处理
val_transform = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])
# =========================================================
# 6. 序数编码与阈值解码
# =========================================================
def label_to_ordinal(label, num_classes=5):
"""
普通标签 -> 序数标签。
原始标签:
0, 1, 2, 3, 4
序数标签:
label=0 -> [0, 0, 0, 0]
label=1 -> [1, 0, 0, 0]
label=2 -> [1, 1, 0, 0]
label=3 -> [1, 1, 1, 0]
label=4 -> [1, 1, 1, 1]
医学含义:
第 k 位表示是否 y > k。
"""
ordinal = np.zeros(num_classes - 1, dtype=np.float32)
ordinal[:label] = 1.0
return ordinal
def smooth_ordinal_targets(targets, smoothing=0.02):
"""
对 ordinal 标签做标签平滑。
原本:
1 -> 1.00
0 -> 0.00
平滑后:
1 -> 0.98
0 -> 0.02
作用:
1. 防止模型过度自信
2. 缓解少量标注噪声
3. 提升泛化能力
"""
if smoothing <= 0:
return targets
return targets * (1.0 - smoothing) + (1.0 - targets) * smoothing
def ordinal_logits_to_score_probs(logits):
"""
ordinal head 输出的是 logits,不是概率。
sigmoid 可以把任意实数映射到 0~1:
负数 -> 接近 0
0 -> 0.5
正数 -> 接近 1
输出概率含义:
prob[0] = P(y > 0)
prob[1] = P(y > 1)
prob[2] = P(y > 2)
prob[3] = P(y > 3)
"""
return torch.sigmoid(logits)
def ordinal_probs_to_class(probs, thresholds):
"""
4 个 ordinal 概率 -> 最终类别。
解码规则:
有几个概率超过对应阈值,最终类别就是几。
例:
probs = [0.91, 0.72, 0.33, 0.11]
thresholds = [0.5, 0.5, 0.5, 0.5]
前两个超过阈值,后两个没有超过。
所以预测类别 = 2。
"""
thresholds = np.array(thresholds).reshape(1, -1)
passed = (probs > thresholds).astype(np.int32)
return passed.sum(axis=1)
def local_candidates(center, low=0.20, high=0.85, radius=0.08, step=0.01):
"""
细粒度阈值搜索候选值生成。
先粗搜得到一个大概的最优阈值 center。
然后在 center 左右 radius 范围内,以 step 为步长继续细搜。
"""
start = max(low, center - radius)
end = min(high, center + radius)
vals = np.arange(start, end + 1e-9, step)
return np.round(vals, 4)
def search_best_thresholds(y_true, prob_preds, coarse_candidates):
"""
双重阈值搜索:
1. 粗搜:在较大范围内寻找较优阈值
2. 细搜:围绕粗搜结果做局部精修
为什么不直接用 0.5?
因为 APTOS 类别不平衡,固定 0.5 不一定适合所有等级边界。
优化目标:
1. Macro F1 最大
2. 如果 Macro F1 相同,则 Accuracy 更高者优先
单调性约束:
t1 <= t2 <= t3 <= t4
医学含义:
判断是否超过更严重等级,应当更谨慎。
"""
best_f1 = -1.0
best_acc = -1.0
best_thresholds = [0.5, 0.5, 0.5, 0.5]
# -------------------------
# 第一步:粗搜
# -------------------------
for t1 in coarse_candidates:
for t2 in coarse_candidates:
if t2 < t1:
continue
for t3 in coarse_candidates:
if t3 < t2:
continue
for t4 in coarse_candidates:
if t4 < t3:
continue
thresholds = [float(t1), float(t2), float(t3), float(t4)]
pred = ordinal_probs_to_class(prob_preds, thresholds)
macro_f1 = f1_score(y_true, pred, average="macro")
acc = accuracy_score(y_true, pred)
if macro_f1 > best_f1 or (macro_f1 == best_f1 and acc > best_acc):
best_f1 = macro_f1
best_acc = acc
best_thresholds = thresholds
# -------------------------
# 第二步:细搜
# -------------------------
c1 = local_candidates(best_thresholds[0], radius=0.08, step=0.01)
c2 = local_candidates(best_thresholds[1], radius=0.08, step=0.01)
c3 = local_candidates(best_thresholds[2], radius=0.08, step=0.01)
c4 = local_candidates(best_thresholds[3], radius=0.08, step=0.01)
for t1 in c1:
for t2 in c2:
if t2 < t1:
continue
for t3 in c3:
if t3 < t2:
continue
for t4 in c4:
if t4 < t3:
continue
thresholds = [float(t1), float(t2), float(t3), float(t4)]
pred = ordinal_probs_to_class(prob_preds, thresholds)
macro_f1 = f1_score(y_true, pred, average="macro")
acc = accuracy_score(y_true, pred)
if macro_f1 > best_f1 or (macro_f1 == best_f1 and acc > best_acc):
best_f1 = macro_f1
best_acc = acc
best_thresholds = thresholds
return best_thresholds, best_f1, best_acc
# =========================================================
# 7. 数据集类
# =========================================================
class AptosDataset(Dataset):
"""
自定义 APTOS 数据集。
Dataset 的职责:
1. 根据 index 找到对应样本
2. 读取图片
3. 应用图像增强
4. 返回模型训练需要的数据
训练/验证模式返回:
image, label, ordinal_target, id_code
测试模式返回:
image, id_code
"""
def __init__(self, df, transform=None, is_test=False):
self.df = df.reset_index(drop=True)
self.transform = transform
self.is_test = is_test
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
row = self.df.iloc[idx]
# 读取图片并转成 RGB 三通道
image = Image.open(row["image_path"]).convert("RGB")
# 应用 transform
if self.transform:
image = self.transform(image)
# 测试集没有标签
if self.is_test:
return image, row["id_code"]
# 普通标签
label = int(row["diagnosis"])
# 序数标签
ordinal_target = label_to_ordinal(label, NUM_CLASSES)
return (
image,
torch.tensor(label, dtype=torch.long),
torch.tensor(ordinal_target, dtype=torch.float32),
row["id_code"]
)
# =========================================================
# 8. 模型:ConvNeXt-Tiny + 双头输出
# =========================================================
class ConvNeXtTinyHybrid(nn.Module):
"""
ConvNeXt-Tiny 双头混合模型。
整体结构:
输入眼底图像
-> ConvNeXt-Tiny features
-> avgpool
-> LayerNorm + Flatten
-> Neck
-> ordinal_head + class_head
两个输出头:
1. ordinal_head:
输出 4 维,对应:
y>0, y>1, y>2, y>3
2. class_head:
输出 5 维,对应普通 5 分类:
0, 1, 2, 3, 4
最终预测主要使用 ordinal_head。
class_head 主要作为辅助监督,帮助模型更好地区分困难样本。
"""
def __init__(self, num_classes=5, num_ord=4, use_pretrained=True):
super().__init__()
# -------------------------
# 1. 加载 ConvNeXt-Tiny
# -------------------------
try:
if use_pretrained:
weights = models.ConvNeXt_Tiny_Weights.DEFAULT
backbone = models.convnext_tiny(weights=weights)
print("成功加载 ConvNeXt-Tiny 预训练权重。")
else:
backbone = models.convnext_tiny(weights=None)
print("使用随机初始化 ConvNeXt-Tiny。")
except Exception as e:
# 如果因为网络或 torchvision 版本问题无法加载预训练权重,则退化为随机初始化
print(f"加载 ConvNeXt-Tiny 预训练失败,改用随机初始化。原因: {e}")
backbone = models.convnext_tiny(weights=None)
# ConvNeXt 原始分类头大致为:
# classifier[0] -> LayerNorm2d
# classifier[1] -> Flatten
# classifier[2] -> Linear(768 -> 1000)
#
# 我们不要 1000 分类头,只保留前面的特征提取部分。
in_features = backbone.classifier[2].in_features
self.features = backbone.features
self.avgpool = backbone.avgpool
# 保留 LayerNorm2d + Flatten
self.norm_flatten = nn.Sequential(
backbone.classifier[0],
backbone.classifier[1],
)
# -------------------------
# 2. 自定义 Neck
# -------------------------
# 作用:
# 把 768 维 backbone 特征进一步变成 512 维任务特征。
self.neck = nn.Sequential(
nn.LayerNorm(in_features),
nn.Dropout(0.3),
nn.Linear(in_features, 512),
nn.GELU(),
nn.Dropout(0.25)
)
# -------------------------
# 3. 双头输出
# -------------------------
# 序数回归头:512 -> 4
self.ordinal_head = nn.Linear(512, num_ord)
# 普通分类辅助头:512 -> 5
self.class_head = nn.Linear(512, num_classes)
def forward(self, x):
"""
前向传播。
输入:
x: [B, 3, 448, 448]
输出:
ord_logits: [B, 4]
cls_logits: [B, 5]
"""
# ConvNeXt 特征提取
x = self.features(x)
# 全局平均池化
x = self.avgpool(x)
# LayerNorm + Flatten
feat = self.norm_flatten(x)
# Neck 特征精炼
feat = self.neck(feat)
# 两个输出头
ord_logits = self.ordinal_head(feat)
cls_logits = self.class_head(feat)
return ord_logits, cls_logits
# =========================================================
# 9. 损失函数
# =========================================================
def compute_pos_weights(train_labels, num_classes=5):
"""
计算每个 ordinal head 的正样本权重。
对第 k 个 ordinal head:
正样本:label > k
负样本:label <= k
pos_weight = 负样本数 / 正样本数
作用:
如果某个 head 正样本很少,就提高正样本损失权重,
防止模型忽视少数严重类别。
"""
train_labels = np.array(train_labels)
pos_weights = []
for k in range(num_classes - 1):
positives = np.sum(train_labels > k)
negatives = np.sum(train_labels <= k)
if positives == 0:
pw = 1.0
else:
pw = negatives / positives
pos_weights.append(pw)
return torch.tensor(pos_weights, dtype=torch.float32)
def focal_ce_loss(logits, targets, gamma=2.0, smoothing=0.0):
"""
Focal Cross Entropy Loss。
普通交叉熵的问题:
简单样本太多时,训练会被简单样本主导。
Focal Loss 的思想:
简单样本权重降低;
困难样本权重提高。
gamma:
控制困难样本聚焦强度。
"""
num_classes = logits.size(1)
if smoothing > 0:
# 构造标签平滑后的 soft label
with torch.no_grad():
true_dist = torch.zeros_like(logits)
true_dist.fill_(smoothing / (num_classes - 1))
true_dist.scatter_(1, targets.unsqueeze(1), 1.0 - smoothing)
log_probs = F.log_softmax(logits, dim=1)
probs = torch.exp(log_probs)
# Focal 权重
focal_weight = (1 - probs) ** gamma
loss = -(true_dist * focal_weight * log_probs).sum(dim=1).mean()
else:
ce = F.cross_entropy(logits, targets, reduction="none")
pt = torch.exp(-ce)
loss = ((1 - pt) ** gamma * ce).mean()
return loss
def hybrid_loss_fn(ord_logits, cls_logits, labels, ord_targets, pos_weights, head_weights):
"""
混合损失函数。
总损失:
total_loss = LAMBDA_ORD * ordinal_loss
+ LAMBDA_CLS * classification_loss
ordinal_loss:
BCEWithLogitsLoss
+ pos_weight
+ head_weight
+ ordinal label smoothing
classification_loss:
Focal CE
+ class label smoothing
"""
# 1. ordinal 标签平滑
ord_targets = smooth_ordinal_targets(ord_targets, ORDINAL_LABEL_SMOOTHING)
# 2. 每个 ordinal head 的 BCE loss
bce_per_head = F.binary_cross_entropy_with_logits(
ord_logits,
ord_targets,
pos_weight=pos_weights,
reduction="none"
)
# 3. head 权重
# head_weights 形状是 [4]
# view(1, -1) 变成 [1, 4],方便和 [B, 4] 广播相乘
ord_loss = (bce_per_head * head_weights.view(1, -1)).mean()
# 4. 分类辅助头 loss
if USE_FOCAL_FOR_CLS:
cls_loss = focal_ce_loss(
cls_logits,
labels,
gamma=FOCAL_GAMMA,
smoothing=CLASS_LABEL_SMOOTHING
)
else:
cls_loss = F.cross_entropy(
cls_logits,
labels,
label_smoothing=CLASS_LABEL_SMOOTHING
)
# 5. 总损失
total_loss = LAMBDA_ORD * ord_loss + LAMBDA_CLS * cls_loss
# detach 后的 ord_loss 和 cls_loss 仅用于日志记录,不参与梯度传播
return total_loss, ord_loss.detach(), cls_loss.detach()
# =========================================================
# 10. TTA:测试时增强
# =========================================================
def tta_forward(model, images):
"""
Test-Time Augmentation。
对同一批图像做 4 次推理:
1. 原图
2. 水平翻转
3. 垂直翻转
4. 水平 + 垂直翻转
然后对 logits 求平均。
为什么在 logits 层平均?
logits 保留了模型原始置信度信息,通常比概率平均更稳定。
"""
ord_list = []
cls_list = []
# images 的形状:[B, C, H, W]
# dims=[3]:宽度方向翻转,即水平翻转
# dims=[2]:高度方向翻转,即垂直翻转
image_versions = [
images,
torch.flip(images, dims=[3]),
torch.flip(images, dims=[2]),
torch.flip(images, dims=[2, 3]),
]
for img in image_versions:
ord_logits, cls_logits = model(img)
ord_list.append(ord_logits)
cls_list.append(cls_logits)
ord_mean = torch.mean(torch.stack(ord_list, dim=0), dim=0)
cls_mean = torch.mean(torch.stack(cls_list, dim=0), dim=0)
return ord_mean, cls_mean
# 是否开启混合精度
use_amp = torch.cuda.is_available()
# =========================================================
# 11. 训练一个 epoch
# =========================================================
def train_one_epoch(model, loader, optimizer, device, scaler, pos_weights, head_weights):
"""
训练一个 epoch。
流程:
1. model.train()
2. 遍历 DataLoader
3. 前向传播
4. 计算损失
5. 反向传播
6. 梯度裁剪
7. optimizer 更新参数
8. 统计训练 loss、accuracy、macro F1
"""
model.train()
running_loss = 0.0
all_true = []
all_prob = []
for images, labels, ord_targets, _ in loader:
images = images.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True)
ord_targets = ord_targets.to(device, non_blocking=True)
# 清空上一轮梯度
optimizer.zero_grad(set_to_none=True)
# 混合精度前向传播
with autocast(enabled=use_amp):
ord_logits, cls_logits = model(images)
loss, ord_loss, cls_loss = hybrid_loss_fn(
ord_logits,
cls_logits,
labels,
ord_targets,
pos_weights,
head_weights
)
# 反向传播
scaler.scale(loss).backward()
# 梯度裁剪前需要先 unscale
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_NORM)
# 更新参数
scaler.step(optimizer)
scaler.update()
running_loss += loss.item() * images.size(0)
# 训练集指标这里用固定 0.5 阈值粗略计算
probs = ordinal_logits_to_score_probs(ord_logits).detach().cpu().numpy()
all_prob.append(probs)
all_true.extend(labels.detach().cpu().numpy())
epoch_loss = running_loss / len(loader.dataset)
all_prob = np.concatenate(all_prob, axis=0)
all_true = np.array(all_true)
train_pred = ordinal_probs_to_class(
all_prob,
thresholds=[0.5, 0.5, 0.5, 0.5]
)
epoch_acc = accuracy_score(all_true, train_pred)
epoch_f1 = f1_score(all_true, train_pred, average="macro")
return epoch_loss, epoch_acc, epoch_f1
# =========================================================
# 12. 验证一个 epoch
# =========================================================
@torch.no_grad()
def validate_one_epoch(model, loader, device, pos_weights, head_weights, use_tta=True):
"""
验证一个 epoch。
注意:
验证阶段不更新参数,所以使用 @torch.no_grad()。
验证流程:
1. model.eval()
2. 前向传播,可选 TTA
3. 计算 validation loss
4. 收集所有样本 ordinal 概率
5. 在验证集上搜索最佳阈值
6. 返回最佳 acc、macro F1、预测结果、概率、阈值
"""
model.eval()
running_loss = 0.0
all_true = []
all_prob = []
all_ids = []
for images, labels, ord_targets, ids in loader:
images = images.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True)
ord_targets = ord_targets.to(device, non_blocking=True)
with autocast(enabled=use_amp):
if use_tta:
ord_logits, cls_logits = tta_forward(model, images)
else:
ord_logits, cls_logits = model(images)
loss, _, _ = hybrid_loss_fn(
ord_logits,
cls_logits,
labels,
ord_targets,
pos_weights,
head_weights
)
running_loss += loss.item() * images.size(0)
probs = ordinal_logits_to_score_probs(ord_logits).cpu().numpy()
all_prob.append(probs)
all_true.extend(labels.cpu().numpy())
all_ids.extend(ids)
epoch_loss = running_loss / len(loader.dataset)
all_prob = np.concatenate(all_prob, axis=0)
all_true = np.array(all_true)
# 验证集上搜索最佳阈值
best_thresholds, best_f1, best_acc = search_best_thresholds(
all_true,
all_prob,
COARSE_THRESHOLD_CANDIDATES
)
pred = ordinal_probs_to_class(all_prob, best_thresholds)
return (
epoch_loss,
best_acc,
best_f1,
all_true,
pred,
all_prob,
all_ids,
best_thresholds
)
# =========================================================
# 13. 完整训练一个 split/fold
# =========================================================
def run_one_split(train_data, val_data, tag="quick"):
"""
完整训练一个数据划分。
train_data:
当前 fold 的训练集 DataFrame
val_data:
当前 fold 的验证集 DataFrame
tag:
输出目录名称,例如:
quick_fold_1
fold_1
fold_2
"""
# 当前 fold 输出目录
fold_dir = OUTPUT_DIR / tag
fold_dir.mkdir(exist_ok=True)
best_model_path = fold_dir / "best_model.pth"
best_thr_path = fold_dir / "best_thresholds.npy"
log_path = fold_dir / "train_log.csv"
val_pred_path = fold_dir / "val_predictions.csv"
# -------------------------
# 1. Dataset
# -------------------------
train_ds = AptosDataset(
train_data,
transform=train_transform,
is_test=False
)
val_ds = AptosDataset(
val_data,
transform=val_transform,
is_test=False
)
# -------------------------
# 2. DataLoader
# -------------------------
train_loader = DataLoader(
train_ds,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=NUM_WORKERS,
pin_memory=PIN_MEMORY
)
val_loader = DataLoader(
val_ds,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=NUM_WORKERS,
pin_memory=PIN_MEMORY
)
# -------------------------
# 3. 模型
# -------------------------
model = ConvNeXtTinyHybrid(
num_classes=NUM_CLASSES,
num_ord=NUM_ORDINAL_OUTPUTS,
use_pretrained=USE_PRETRAINED
).to(DEVICE)
# -------------------------
# 4. 类别不平衡权重
# -------------------------
pos_weights = compute_pos_weights(
train_data["diagnosis"].values,
NUM_CLASSES
).to(DEVICE)
head_weights = torch.tensor(
HEAD_WEIGHTS,
dtype=torch.float32
).to(DEVICE)
print("Ordinal pos_weight:", pos_weights)
print("Head weights:", head_weights)
# -------------------------
# 5. 优化器
# -------------------------
optimizer = torch.optim.AdamW(
model.parameters(),
lr=LEARNING_RATE,
weight_decay=WEIGHT_DECAY
)
# -------------------------
# 6. 学习率调度器
# -------------------------
# 当 val macro F1 不提升时,自动降低学习率
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode="max",
factor=0.5,
patience=2
)
# -------------------------
# 7. 混合精度 scaler
# -------------------------
scaler = GradScaler(enabled=use_amp)
# -------------------------
# 8. 最优模型记录
# -------------------------
best_val_macro_f1 = -1.0
best_val_acc = -1.0
best_thresholds_global = [0.5, 0.5, 0.5, 0.5]
# 早停参数
patience = 8
patience_counter = 0
# 保存每轮训练日志
history = []
# =====================================================
# 开始训练
# =====================================================
for epoch in range(EPOCHS):
print(f"\n========== {tag} Epoch [{epoch + 1}/{EPOCHS}] ==========")
# -------------------------
# 训练
# -------------------------
train_loss, train_acc, train_f1 = train_one_epoch(
model,
train_loader,
optimizer,
DEVICE,
scaler,
pos_weights,
head_weights
)
# -------------------------
# 验证
# -------------------------
(
val_loss,
val_acc,
val_f1,
val_true,
val_pred,
val_prob,
val_ids,
best_thr_epoch
) = validate_one_epoch(
model,
val_loader,
DEVICE,
pos_weights,
head_weights,
use_tta=USE_TTA
)
# 根据 val macro F1 调整学习率
scheduler.step(val_f1)
print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Train Macro F1: {train_f1:.4f}")
print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} | Val Macro F1: {val_f1:.4f}")
print(f"Best Thresholds 当前 epoch: {best_thr_epoch}")
print(f"Learning Rate: {optimizer.param_groups[0]['lr']:.8f}")
# -------------------------
# 判断是否刷新最佳模型
# -------------------------
improved = False
if val_f1 > best_val_macro_f1:
improved = True
elif val_f1 == best_val_macro_f1 and val_acc > best_val_acc:
improved = True
if improved:
best_val_macro_f1 = val_f1
best_val_acc = val_acc
best_thresholds_global = best_thr_epoch
patience_counter = 0
# 保存最佳模型权重
torch.save(model.state_dict(), best_model_path)
# 保存最佳阈值
np.save(best_thr_path, np.array(best_thresholds_global))
print(f"保存最佳模型到: {best_model_path}")
print(f"保存最佳阈值到: {best_thr_path}")
else:
patience_counter += 1
print(f"EarlyStopping counter: {patience_counter}/{patience}")
# -------------------------
# 记录日志
# -------------------------
history.append({
"epoch": epoch + 1,
"train_loss": train_loss,
"train_acc": train_acc,
"train_macro_f1": train_f1,
"val_loss": val_loss,
"val_acc": val_acc,
"val_macro_f1": val_f1,
"lr": optimizer.param_groups[0]["lr"],
"thr1": best_thr_epoch[0],
"thr2": best_thr_epoch[1],
"thr3": best_thr_epoch[2],
"thr4": best_thr_epoch[3],
})
# -------------------------
# 早停
# -------------------------
if patience_counter >= patience:
print("触发早停,训练结束。")
break
# 保存训练日志
pd.DataFrame(history).to_csv(
log_path,
index=False,
encoding="utf-8-sig"
)
# =====================================================
# 加载最佳模型做最终验证
# =====================================================
print("\n加载最佳模型做最终验证...")
model.load_state_dict(
torch.load(best_model_path, map_location=DEVICE)
)
best_thresholds_global = np.load(best_thr_path).tolist()
(
val_loss,
val_acc,
val_f1,
val_true,
val_pred,
val_prob,
val_ids,
_
) = validate_one_epoch(
model,
val_loader,
DEVICE,
pos_weights,
head_weights,
use_tta=USE_TTA
)
# 用保存的最佳阈值重新解码
val_pred = ordinal_probs_to_class(
val_prob,
best_thresholds_global
)
val_acc = accuracy_score(val_true, val_pred)
val_f1 = f1_score(val_true, val_pred, average="macro")
val_weighted_f1 = f1_score(val_true, val_pred, average="weighted")
print("\n===== 最终验证结果 =====")
print(f"Accuracy : {val_acc:.4f}")
print(f"Macro F1 : {val_f1:.4f}")
print(f"Weighted F1 : {val_weighted_f1:.4f}")
print(f"Thresholds : {best_thresholds_global}")
print("\n分类报告:")
print(classification_report(val_true, val_pred, digits=4))
print("\n混淆矩阵:")
print(confusion_matrix(val_true, val_pred))
# =====================================================
# 保存验证集逐样本预测结果
# =====================================================
val_result = pd.DataFrame({
"id_code": val_ids,
"diagnosis": val_true,
"pred": val_pred,
"correct": (val_true == val_pred).astype(int),
"prob_gt_0": val_prob[:, 0],
"prob_gt_1": val_prob[:, 1],
"prob_gt_2": val_prob[:, 2],
"prob_gt_3": val_prob[:, 3],
})
val_result.to_csv(
val_pred_path,
index=False,
encoding="utf-8-sig"
)
print(f"\n验证集预测结果已保存到: {val_pred_path}")
return {
"acc": val_acc,
"macro_f1": val_f1,
"weighted_f1": val_weighted_f1,
"thresholds": best_thresholds_global,
"dir": str(fold_dir)
}
# =========================================================
# 14. 主入口
# =========================================================
if USE_5FOLD:
"""
正式 5-fold 模式。
适合:
1. 写论文
2. 写报告
3. 做正式实验结果
4. 计算均值和标准差
"""
print("\n===== 运行模式:5-Fold =====")
skf = StratifiedKFold(
n_splits=N_SPLITS,
shuffle=True,
random_state=SEED
)
all_results = []
for fold, (tr_idx, va_idx) in enumerate(
skf.split(train_df, train_df["diagnosis"]),
start=1
):
print(f"\n{'=' * 20} Fold {fold}/{N_SPLITS} {'=' * 20}")
tr_df = train_df.iloc[tr_idx].reset_index(drop=True)
va_df = train_df.iloc[va_idx].reset_index(drop=True)
print(f"训练集: {len(tr_df)} | 验证集: {len(va_df)}")
print("训练集类别分布:")
print(tr_df["diagnosis"].value_counts().sort_index())
print("验证集类别分布:")
print(va_df["diagnosis"].value_counts().sort_index())
result = run_one_split(
tr_df,
va_df,
tag=f"fold_{fold}"
)
result["fold"] = fold
all_results.append(result)
# 汇总所有 fold 结果
comp = pd.DataFrame(all_results)
summary_path = OUTPUT_DIR / "fold_summary.csv"
comp.to_csv(summary_path, index=False, encoding="utf-8-sig")
print("\n===== 5-Fold 汇总 =====")
print(comp)
print("\n平均结果:")
print(comp[["acc", "macro_f1", "weighted_f1"]].mean())
print("\n标准差:")
print(comp[["acc", "macro_f1", "weighted_f1"]].std())
print(f"\n5-Fold 汇总结果已保存到: {summary_path}")
else:
"""
Quick Validation 模式。
只跑一个 fold。
适合:
1. 快速检查代码能否跑通
2. 快速判断模型是否有效
3. 避免一上来就跑完整 5-fold 浪费时间
"""
print("\n===== 运行模式:1-Fold Quick Validation =====")
skf = StratifiedKFold(
n_splits=N_SPLITS,
shuffle=True,
random_state=SEED
)
splits = list(skf.split(train_df, train_df["diagnosis"]))
tr_idx, va_idx = splits[QUICK_FOLD_INDEX]
tr_df = train_df.iloc[tr_idx].reset_index(drop=True)
va_df = train_df.iloc[va_idx].reset_index(drop=True)
print(f"Quick Fold Index: {QUICK_FOLD_INDEX}")
print(f"训练集: {len(tr_df)} | 验证集: {len(va_df)}")
print("训练集类别分布:")
print(tr_df["diagnosis"].value_counts().sort_index())
print("验证集类别分布:")
print(va_df["diagnosis"].value_counts().sort_index())
result = run_one_split(
tr_df,
va_df,
tag=f"quick_fold_{QUICK_FOLD_INDEX + 1}"
)
print("\n===== Quick Validation Result =====")
print(result)
这篇文章适合两类人:
第一类,想看懂一个完整医学图像深度学习项目的人;
第二类,想把本科项目包装成“像样科研工作”的人。我们要讲的不是“调一个模型跑分类”这么简单,而是一个从任务本质出发,围绕医学分级、类别不平衡、阈值决策、推理鲁棒性系统设计出来的深度学习方案。
一、项目到底在做什么?
本项目研究的是 糖尿病视网膜病变自动分级,英文是 Diabetic Retinopathy Grading,简称 DR 分级。
输入是一张眼底图像,输出是病变等级:
| 等级 | 含义 |
|---|---|
| 0 | 无 DR |
| 1 | 轻度非增殖性 DR |
| 2 | 中度非增殖性 DR |
| 3 | 重度非增殖性 DR |
| 4 | 增殖性 DR |
从表面看,这是一个五分类任务。
但如果你真的把它当成普通五分类,这个项目就落入了“会跑模型但不懂任务”的层次。
因为 DR 分级不是五个互不相关的类别,而是一个严重程度逐级递增的医学分级体系。论文里也明确指出,DR 的五个等级具有天然序数结构,将 4 级误判为 3 级和将 4 级误判为 0 级,在临床风险上完全不是一回事。
这就是本项目最核心的出发点:
不要把医学分级问题粗暴地当成普通分类问题,而要让模型理解“等级之间的顺序关系”。
二、为什么普通 Softmax 五分类不够好?
普通五分类模型通常这样做:
输出 5 个分数 → softmax → 预测概率最高的类别
比如:
类别 0:0.10
类别 1:0.15
类别 2:0.20
类别 3:0.25
类别 4:0.30
模型预测类别 4。
这种方法的问题是:
Softmax 把 0、1、2、3、4 当成五个彼此独立的类别。
但是医学分级不是这样。
对于 DR 来说:
0 → 1 → 2 → 3 → 4
这是从轻到重的连续严重程度。
4 级比 3 级严重,3 级比 2 级严重,2 级比 1 级严重。
所以更合理的问题不是:
这张图属于 0、1、2、3、4 哪一类?
而是:
它是否超过 0 级?
它是否超过 1 级?
它是否超过 2 级?
它是否超过 3 级?
这就是序数回归。
三、序数回归:把五分类改造成四个二分类
在这份代码中,最关键的函数之一是:
def label_to_ordinal(label, num_classes=5):
ordinal = np.zeros(num_classes - 1, dtype=np.float32)
ordinal[:label] = 1.0
return ordinal
它的作用是把普通标签转换成序数标签。
| 原始标签 | 序数标签 |
|---|---|
| 0 | [0, 0, 0, 0] |
| 1 | [1, 0, 0, 0] |
| 2 | [1, 1, 0, 0] |
| 3 | [1, 1, 1, 0] |
| 4 | [1, 1, 1, 1] |
举个例子,如果一张图真实标签是 3,那么它的序数标签是:
[1, 1, 1, 0]
这句话的医学含义是:
超过 0 级:是
超过 1 级:是
超过 2 级:是
超过 3 级:否
所以它最终属于 3 级。
这种设计非常优雅。
它不是直接让模型记住“这是 3 类”,而是让模型学习“这张图的病变严重程度已经超过了哪些等级”。
论文中也把这个建模方式称为累积概率建模,即学习 P(y>0)、P(y>1)、P(y>2)、P(y>3)。
四、整体代码架构:不是脚本,是一条完整训练流水线
整份代码可以分成 14 个模块:
1. 导入依赖库
2. 设置随机种子和超参数
3. 配置数据路径
4. 读取 CSV 与图片路径
5. 定义训练集与验证集图像增强
6. 定义序数编码与阈值解码
7. 定义阈值搜索算法
8. 定义 AptosDataset 数据集类
9. 定义 ConvNeXt-Tiny 双头模型
10. 定义损失函数
11. 定义 TTA 推理策略
12. 定义训练一个 epoch
13. 定义验证一个 epoch
14. 定义 quick fold / 5-fold 主训练流程
你可以把它看成一个标准深度学习工程模板:
数据 → 增强 → 模型 → 损失 → 训练 → 验证 → 阈值搜索 → 保存结果
但它的高级之处在于,每一步都不是随便写的,而是围绕 DR 分级任务的特点设计的。
五、配置区:高手写代码,先把实验开关集中管理
代码一开始定义了大量超参数:
SEED = 42
USE_5FOLD = False
N_SPLITS = 5
QUICK_FOLD_INDEX = 0
IMG_SIZE = 448
BATCH_SIZE = 8
EPOCHS = 20
NUM_CLASSES = 5
NUM_ORDINAL_OUTPUTS = NUM_CLASSES - 1
LEARNING_RATE = 2e-4
WEIGHT_DECAY = 1e-4
NUM_WORKERS = 0
PIN_MEMORY = torch.cuda.is_available()
USE_PRETRAINED = True
USE_TTA = True
这是一种很好的工程习惯。
初学者常犯的错误是把参数散落在代码各处。
这样一旦要修改实验,就需要到处找。
而高手会把所有关键实验设置放在开头。这样代码具备可控性、可复现性和可调试性。
这里几个参数尤其重要:
1. IMG_SIZE = 448
模型输入图片大小是 448×448。
医学图像里的病灶通常比较细小,如果输入太小,微动脉瘤、渗出物、出血点等细节可能被压没。
2. BATCH_SIZE = 8
batch size 设为 8,说明训练受显存限制。
这也是为什么模型选择 ConvNeXt-Tiny,而不是更大的 ConvNeXt-Base 或 ConvNeXt-Large。
3. USE_5FOLD = False
默认只跑一个 fold。
这适合快速验证模型是否能跑通。
如果要正式写论文结果,就应该改成:
USE_5FOLD = True
否则论文里写 5 折交叉验证,代码却只跑了 1 折,答辩时容易被问住。
4. USE_PRETRAINED = True
使用 ImageNet 预训练权重。
这对医学图像小数据集非常重要。
从零训练一个深度网络,需要大量数据。
而医学数据通常很少,所以更常见的做法是:使用自然图像上的预训练模型作为视觉特征提取基础,再迁移到医学任务。
六、路径设计:代码默认读取预处理后的图像
代码中路径配置如下:
BASE_DIR = Path(__file__).resolve().parent
TRAIN_CSV = BASE_DIR / "train.csv"
TEST_CSV = BASE_DIR / "test.csv"
TRAIN_DIR = BASE_DIR / "train_images_320_crop_clahe"
TEST_DIR = BASE_DIR / "test_images_320_crop_clahe"
OUTPUT_DIR = BASE_DIR / "output_convnext_hybrid_strong"
OUTPUT_DIR.mkdir(exist_ok=True)
这说明项目目录应该类似:
APTOS_Project/
├── aptos_convnext_hybrid_strong.py
├── train.csv
├── test.csv
├── train_images_320_crop_clahe/
├── test_images_320_crop_clahe/
└── output_convnext_hybrid_strong/
注意这里的图片目录名:
train_images_320_crop_clahe
test_images_320_crop_clahe
说明代码读取的是已经经过:
黑边裁切 + CLAHE 对比度增强
的图片。
也就是说,这份训练代码不负责图像预处理,它默认预处理已经完成。
论文中对应的描述是:预处理阶段对原始眼底图像进行自适应黑边裁切和 CLAHE 对比度增强,从而去除无效黑色区域并强化病变特征。
七、数据读取:从 CSV 到真实图片路径
代码读取 CSV:
train_df = pd.read_csv(TRAIN_CSV)
test_df = pd.read_csv(TEST_CSV)
然后检查必要列:
assert "id_code" in train_df.columns
assert "diagnosis" in train_df.columns
assert "id_code" in test_df.columns
训练集必须有:
id_code:图片 ID
diagnosis:诊断标签
接着拼接图片路径:
train_df["image_path"] = train_df["id_code"].astype(str).apply(
lambda x: str(TRAIN_DIR / f"{x}.png")
)
例如:
id_code = 000c1434d8d7
就会变成:
train_images_320_crop_clahe/000c1434d8d7.png
然后代码过滤不存在的图片:
train_df = train_df[train_df["image_path"].apply(os.path.exists)].reset_index(drop=True)
这个细节很实用。
真实项目里经常会出现:
CSV 里有记录,但图片文件丢失
图片文件名和 CSV 不一致
预处理时部分图片失败
如果不提前过滤,训练到一半就会报错。
八、图像增强:让模型见过更多“变化后的世界”
训练集增强代码如下:
train_transform = transforms.Compose([
transforms.Resize((IMG_SIZE + 32, IMG_SIZE + 32)),
transforms.RandomResizedCrop(IMG_SIZE, scale=(0.88, 1.0), ratio=(0.95, 1.05)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.2),
transforms.RandomRotation(degrees=12),
transforms.ColorJitter(brightness=0.10, contrast=0.10, saturation=0.06, hue=0.015),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
这段代码做了几件事。
1. 先放大,再随机裁剪
transforms.Resize((IMG_SIZE + 32, IMG_SIZE + 32))
transforms.RandomResizedCrop(IMG_SIZE, scale=(0.88, 1.0), ratio=(0.95, 1.05))
先把图变成 480×480,再随机裁剪到 448×448。
这可以模拟眼底图像中的轻微位置变化。
2. 随机翻转
transforms.RandomHorizontalFlip(p=0.5)
transforms.RandomVerticalFlip(p=0.2)
眼底图像的病灶语义通常不因翻转而改变。
出血点翻转后仍然是出血点,渗出物翻转后仍然是渗出物。
3. 随机旋转
transforms.RandomRotation(degrees=12)
模拟拍摄角度轻微变化。
4. 色彩扰动
transforms.ColorJitter(...)
模拟不同相机、不同曝光、不同光照条件带来的图像差异。
5. 归一化
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
这里使用 ImageNet 的均值和方差,因为模型使用了 ImageNet 预训练权重。
验证集则只做:
val_transform = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(...)
])
验证集不做随机增强。
因为验证指标必须稳定,否则每次验证输入都变,指标会抖动。
九、AptosDataset:PyTorch 训练的数据入口
代码定义了自定义数据集:
class AptosDataset(Dataset):
def __init__(self, df, transform=None, is_test=False):
self.df = df.reset_index(drop=True)
self.transform = transform
self.is_test = is_test
一个 PyTorch Dataset 至少要实现两个方法:
__len__()
__getitem__()
1. __len__
def __len__(self):
return len(self.df)
告诉 PyTorch:
这个数据集有多少张图片。
2. __getitem__
def __getitem__(self, idx):
row = self.df.iloc[idx]
image = Image.open(row["image_path"]).convert("RGB")
根据索引读取一张图片,并转成 RGB 三通道。
然后应用图像增强:
if self.transform:
image = self.transform(image)
如果是测试集:
if self.is_test:
return image, row["id_code"]
如果是训练或验证集:
label = int(row["diagnosis"])
ordinal_target = label_to_ordinal(label, NUM_CLASSES)
return (
image,
torch.tensor(label, dtype=torch.long),
torch.tensor(ordinal_target, dtype=torch.float32),
row["id_code"]
)
也就是说,训练时每张图片返回四个东西:
image:图像张量
label:普通标签 0~4
ordinal_target:序数标签,例如 [1,1,0,0]
id_code:图片 ID
这一步非常关键。
因为模型训练时同时需要普通标签和序数标签。
十、模型设计:ConvNeXt-Tiny + 双头输出
模型类是:
class ConvNeXtTinyHybrid(nn.Module):
它的整体结构可以写成:
输入图像
↓
ConvNeXt-Tiny backbone
↓
Global Average Pooling
↓
LayerNorm + Flatten
↓
Neck
↓
ordinal_head + class_head
论文中也明确说明,模型以 ConvNeXt-Tiny 为骨干网络,输入眼底图像经过特征提取、全局平均池化、Neck 精炼后,分别送入序数回归头和分类辅助头。
十一、为什么选择 ConvNeXt-Tiny?
代码加载模型:
weights = models.ConvNeXt_Tiny_Weights.DEFAULT
backbone = models.convnext_tiny(weights=weights)
ConvNeXt-Tiny 的优势主要有三个:
1. 小 batch 更稳定
医学图像训练通常 batch size 很小。
这份代码中:
BATCH_SIZE = 8
传统 BatchNorm 依赖 batch 内统计量,batch 太小时不稳定。
ConvNeXt 使用 LayerNorm,更适合小 batch 场景。
2. 感受野更大
ConvNeXt 使用 7×7 Depthwise Convolution。
相比传统 3×3 卷积,它可以在单层中看到更大的空间区域。
眼底图像中的病变尺度跨度很大:
微动脉瘤:很小
出血斑:中等
大面积渗出:更大
整体病变分布:需要全局感知
所以多尺度感知很重要。
3. 精度和效率折中
ConvNeXt-Tiny 不像大模型那么吃显存,也比过小模型表达能力更强。
对于本科项目和 8GB 级别显存机器来说,它是一个现实可行的选择。
论文中也从小 batch 稳定性、多尺度病变建模、精度效率平衡三个方面解释了 ConvNeXt-Tiny 的适配性。
十二、拆掉原始分类头,换成自己的双头结构
ConvNeXt-Tiny 原本是 ImageNet 1000 分类模型。
它最后输出的是 1000 类。
但我们现在只需要 DR 五级分级,所以要拆掉原来的分类头。
代码中保留了主干特征:
self.features = backbone.features
self.avgpool = backbone.avgpool
self.norm_flatten = nn.Sequential(
backbone.classifier[0],
backbone.classifier[1],
)
然后定义自己的 Neck:
self.neck = nn.Sequential(
nn.LayerNorm(in_features),
nn.Dropout(0.3),
nn.Linear(in_features, 512),
nn.GELU(),
nn.Dropout(0.25)
)
Neck 的作用是:
对 ConvNeXt 输出的 768 维特征再做一次精炼,变成 512 维特征。
最后接两个输出头:
self.ordinal_head = nn.Linear(512, num_ord)
self.class_head = nn.Linear(512, num_classes)
其中:
ordinal_head:512 → 4
class_head:512 → 5
这就是模型的核心设计。
十三、双头结构:一个负责顺序,一个负责分类
1. 序数回归头
self.ordinal_head = nn.Linear(512, 4)
它输出 4 个 logits,分别对应:
是否超过 0 级
是否超过 1 级
是否超过 2 级
是否超过 3 级
这个头负责学习 DR 等级的递增关系。
2. 分类辅助头
self.class_head = nn.Linear(512, 5)
它输出 5 个 logits,对应普通五分类:
0、1、2、3、4
这个头不是最终决策的主角,而是训练时的辅助监督。
为什么需要它?
因为单独的序数头虽然能建模顺序关系,但在类别极度不平衡时,严重等级样本少,边界可能学不清楚。
分类头配合 Focal Loss,可以让模型更加关注困难样本和少数类。
这也是论文创新点之一:序数回归头负责建模等级结构,分类辅助头负责缓解类别不平衡下的困难样本学习问题。
十四、前向传播:一张图片如何变成两个输出?
模型的 forward 函数如下:
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
feat = self.norm_flatten(x)
feat = self.neck(feat)
ord_logits = self.ordinal_head(feat)
cls_logits = self.class_head(feat)
return ord_logits, cls_logits
假设输入 batch 是:
[B, 3, 448, 448]
其中 B 是 batch size。
经过 ConvNeXt 特征提取后,大致变成:
[B, 768, H, W]
经过全局平均池化:
[B, 768, 1, 1]
展平后:
[B, 768]
经过 Neck:
[B, 512]
最后得到:
ord_logits: [B, 4]
cls_logits: [B, 5]
一个输入,两套监督。
这就是“双头混合架构”。
十五、损失函数:这个项目真正“懂医学数据”的地方
模型有两个输出头,所以损失函数也由两部分组成:
total = LAMBDA_ORD * ord_loss + LAMBDA_CLS * cls_loss
其中:
LAMBDA_ORD = 1.0
LAMBDA_CLS = 0.7
也就是:
总损失 = 1.0 × 序数损失 + 0.7 × 分类损失
十六、序数损失:BCE + 正样本权重 + head 权重
序数头的 4 个输出,本质上是 4 个二分类问题:
y > 0 ?
y > 1 ?
y > 2 ?
y > 3 ?
所以使用二元交叉熵:
bce_per_head = F.binary_cross_entropy_with_logits(
ord_logits,
ord_targets,
pos_weight=pos_weights,
reduction="none"
)
这里有一个非常关键的参数:
pos_weight=pos_weights
它用来处理正负样本不平衡。
例如第 4 个序数头判断:
y > 3 ?
只有 4 类样本才是正样本。
而 4 类样本通常非常少。
如果不加权,模型可能直接学会:
所有样本都预测不是 y>3
这样在训练集上损失也可能不大,但临床上非常危险。
因此代码中通过:
def compute_pos_weights(train_labels, num_classes=5):
...
自动根据当前训练集类别分布计算每个序数头的正样本权重。
此外,代码还加入了手动 head 权重:
HEAD_WEIGHTS = [1.0, 1.1, 1.35, 1.6]
越严重的边界,权重越高。
这背后的思想是:
越严重的病变越不能漏诊,所以训练时应该给更高等级边界更大的关注。
十七、分类损失:Focal Loss 聚焦困难样本
分类头使用的是 Focal Loss:
def focal_ce_loss(logits, targets, gamma=2.0, smoothing=0.0):
Focal Loss 的核心思想可以用一句话解释:
简单样本少管一点,困难样本多管一点。
普通交叉熵会让大量简单样本主导训练。
在 APTOS 这种类别不平衡数据中,0 类样本非常多,模型很容易偏向预测 0。
Focal Loss 通过:
focal_weight = (1 - probs) ** gamma
降低高置信度简单样本的权重,提高难分类样本的相对影响。
这对医学影像尤其重要,因为少数类往往恰恰是临床上最危险、最需要识别的类别。
十八、标签平滑:别让模型过度自信
代码中对序数标签和分类标签都做了平滑:
ORDINAL_LABEL_SMOOTHING = 0.02
CLASS_LABEL_SMOOTHING = 0.05
序数标签平滑函数:
def smooth_ordinal_targets(targets, smoothing=0.02):
if smoothing <= 0:
return targets
return targets * (1.0 - smoothing) + (1.0 - targets) * smoothing
原本:
1 → 1.00
0 → 0.00
平滑后:
1 → 0.98
0 → 0.02
为什么要这样?
医学图像标注不是绝对无误的。
不同医生之间可能存在判读差异,某些边界病例也很难定义。
标签平滑可以防止模型对训练标签过度自信,提高泛化能力。
十九、阈值解码:从四个概率回到最终类别
序数头输出的 4 个 logits 会先经过 sigmoid:
def ordinal_logits_to_score_probs(logits):
return torch.sigmoid(logits)
得到 4 个概率:
P(y>0), P(y>1), P(y>2), P(y>3)
然后通过阈值解码:
def ordinal_probs_to_class(probs, thresholds):
thresholds = np.array(thresholds).reshape(1, -1)
passed = (probs > thresholds).astype(np.int32)
return passed.sum(axis=1)
举例:
probs = [0.91, 0.72, 0.33, 0.11]
thresholds = [0.50, 0.50, 0.50, 0.50]
比较结果:
0.91 > 0.50,是
0.72 > 0.50,是
0.33 > 0.50,否
0.11 > 0.50,否
所以通过了 2 个阈值,最终预测为:
类别 2
这个设计非常直观:
通过几个等级边界,就属于几级。
二十、为什么不能固定使用 0.5 阈值?
初学者最容易想当然:
thresholds = [0.5, 0.5, 0.5, 0.5]
但是在类别不平衡数据中,0.5 不一定合理。
比如严重病变样本很少,模型对严重边界输出的概率可能整体偏低。
如果死板用 0.5,就会导致少数严重类别召回率下降。
论文中也指出,固定 0.5 阈值在高度不平衡数据下可能造成非对称分类边界,对大类过于宽松,对少数类过于严苛。
所以代码设计了一个阈值搜索函数:
def search_best_thresholds(y_true, prob_preds, coarse_candidates):
目标是在验证集上找到最优阈值组合。
二十一、双重阈值搜索:先粗搜,再细搜
阈值搜索分两步。
第一步:粗粒度全局搜索
候选范围:
COARSE_THRESHOLD_CANDIDATES = np.arange(0.20, 0.86, 0.02)
也就是从 0.20 到 0.84,步长 0.02。
代码遍历四个阈值:
for t1 in coarse_candidates:
for t2 in coarse_candidates:
if t2 < t1:
continue
for t3 in coarse_candidates:
if t3 < t2:
continue
for t4 in coarse_candidates:
if t4 < t3:
continue
这里还有一个重要约束:
t1 <= t2 <= t3 <= t4
为什么要有这个约束?
因为判断越严重的等级,标准应该越严格。
判断是否超过 0 级:可以相对宽松
判断是否超过 3 级:应该更加谨慎
这叫单调性约束。
第二步:局部细搜索
粗搜找到大致最优阈值后,代码再围绕这个结果细搜:
def local_candidates(center, low=0.20, high=0.85, radius=0.08, step=0.01):
比如粗搜得到:
0.42
那么细搜就围绕:
0.34 ~ 0.50
以 0.01 为步长继续搜索。
这就是论文中的“双重阈值搜索”:
先做粗粒度全局探索,再做细粒度局部精调。
二十二、验证指标:为什么主要看 Macro F1?
代码中阈值搜索的目标是:
macro_f1 = f1_score(y_true, pred, average="macro")
acc = accuracy_score(y_true, pred)
并且优先选择 Macro F1 更高的阈值:
if macro_f1 > best_f1 or (macro_f1 == best_f1 and acc > best_acc):
为什么不是只看 accuracy?
因为类别不平衡。
假设 0 类占 70%,模型全部预测 0,也能有很高准确率。
但这样的模型完全没有临床价值。
Macro F1 会对每个类别一视同仁:
先分别计算 0、1、2、3、4 类的 F1
再做平均
因此它更能反映模型对少数类的识别能力。
医学影像分类任务中,Macro F1 往往比 Accuracy 更值得关注。
二十三、TTA:测试时增强,让模型在推理阶段更稳
代码定义了:
def tta_forward(model, images):
TTA 是 Test-Time Augmentation,即测试时增强。
训练时增强是为了让模型学得更鲁棒。
TTA 则是在验证或测试阶段,对同一张图做多个版本的预测,然后平均结果。
代码中使用四种版本:
images
torch.flip(images, dims=[3])
torch.flip(images, dims=[2])
torch.flip(images, dims=[2, 3])
对应:
| 版本 | 含义 |
|---|---|
| 原图 | 不变 |
dims=[3] |
左右翻转 |
dims=[2] |
上下翻转 |
dims=[2,3] |
上下左右翻转 |
图片张量形状是:
[B, C, H, W]
所以:
H 是高度方向
W 是宽度方向
最后代码把四次预测取平均:
ord_mean = torch.mean(torch.stack(ord_list, dim=0), dim=0)
cls_mean = torch.mean(torch.stack(cls_list, dim=0), dim=0)
为什么选择翻转,而不是亮度增强、对比度增强?
因为翻转通常不会改变病灶语义。
微动脉瘤翻转后还是微动脉瘤,出血点翻转后还是出血点。
但如果改变亮度或对比度,可能让细小病灶消失,反而引入噪声。
论文中也强调,几何翻转 TTA 不改变病变病理语义,可以在不重新训练的情况下提升推理鲁棒性。
二十四、训练一个 epoch:模型如何真正学习?
训练函数是:
def train_one_epoch(model, loader, optimizer, device, scaler, pos_weights, head_weights):
一个 epoch 的含义是:
模型把整个训练集完整看一遍。
训练流程如下:
1. model.train()
2. 遍历训练 DataLoader
3. 图片和标签送入 GPU
4. 清空梯度
5. 前向传播
6. 计算损失
7. 反向传播
8. 梯度裁剪
9. optimizer 更新参数
10. 记录 loss、accuracy、macro F1
核心代码:
with autocast(enabled=use_amp):
ord_logits, cls_logits = model(images)
loss, ord_loss, cls_loss = hybrid_loss_fn(
ord_logits, cls_logits, labels, ord_targets,
pos_weights, head_weights
)
这里用了 autocast,表示开启混合精度训练。
有 GPU 时,它可以加速训练并节省显存。
反向传播:
scaler.scale(loss).backward()
梯度裁剪:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_NORM)
参数更新:
scaler.step(optimizer)
scaler.update()
这一套是比较标准的 PyTorch AMP 训练写法。
二十五、验证一个 epoch:不更新参数,只评估模型
验证函数是:
@torch.no_grad()
def validate_one_epoch(...):
@torch.no_grad() 表示:
不计算梯度,不更新模型参数。
验证阶段流程:
1. model.eval()
2. 遍历验证集
3. 使用普通推理或 TTA 推理
4. 计算验证 loss
5. 收集所有样本的 ordinal 概率
6. 在验证集上搜索最佳阈值
7. 计算 accuracy 和 macro F1
关键代码:
if use_tta:
ord_logits, cls_logits = tta_forward(model, images)
else:
ord_logits, cls_logits = model(images)
然后搜索最佳阈值:
best_thresholds, best_f1, best_acc = search_best_thresholds(
all_true, all_prob, COARSE_THRESHOLD_CANDIDATES
)
所以验证阶段并不是简单看模型输出,而是:
模型输出概率 → 阈值搜索 → 最优解码 → 计算指标
这也解释了为什么论文中会强调“决策层”的设计。
二十六、run_one_split:完整训练一折
代码中的核心训练封装是:
def run_one_split(train_data, val_data, tag="quick"):
这个函数完成一整套训练:
1. 创建当前 fold 输出目录
2. 构建 Dataset
3. 构建 DataLoader
4. 创建模型
5. 计算 pos_weight 和 head_weight
6. 定义优化器 AdamW
7. 定义学习率调度器
8. 开始 epoch 训练
9. 每轮验证并搜索阈值
10. 保存最佳模型和最佳阈值
11. 触发早停
12. 保存训练日志和验证预测结果
其中优化器是:
optimizer = torch.optim.AdamW(
model.parameters(),
lr=LEARNING_RATE,
weight_decay=WEIGHT_DECAY
)
学习率调度器是:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="max", factor=0.5, patience=2
)
含义是:
如果验证 Macro F1 连续 2 轮没有提升,就把学习率减半。
早停设置:
patience = 8
如果连续 8 轮没有提升,就停止训练。
这是非常实用的训练控制策略,既节省时间,也降低过拟合风险。
二十七、模型保存:不仅保存权重,还保存阈值
当验证集 Macro F1 提升时,代码保存:
torch.save(model.state_dict(), best_model_path)
np.save(best_thr_path, np.array(best_thresholds_global))
也就是保存:
best_model.pth
best_thresholds.npy
这一点非常重要。
因为对于这个项目来说,模型权重不是全部。
最终预测还依赖阈值。
如果只保存模型,不保存阈值,那么推理结果可能无法复现。
这体现了一个很好的工程意识:
对于带后处理决策的模型,必须保存模型参数和决策参数。
二十八、验证预测结果:为错误分析留下证据
训练结束后,代码会保存验证集每张图的预测结果:
val_result = pd.DataFrame({
"id_code": val_ids,
"diagnosis": val_true,
"pred": val_pred,
"correct": (val_true == val_pred).astype(int),
"prob_gt_0": val_prob[:, 0],
"prob_gt_1": val_prob[:, 1],
"prob_gt_2": val_prob[:, 2],
"prob_gt_3": val_prob[:, 3],
})
这个 CSV 非常有价值。
它不仅告诉你模型对不对,还告诉你模型的四个序数概率是多少。
你可以进一步分析:
哪些样本错了?
错在相邻等级还是跨级错误?
3 类为什么容易被预测成 2?
4 类有没有被漏诊成 0?
哪些样本卡在阈值附近?
这正是论文中错误分析、混淆矩阵分析、阈值分析的基础。
二十九、quick fold 与 5-fold:开发模式和论文模式分开
代码最后有两个模式。
1. quick validation 模式
USE_5FOLD = False
只跑一个 fold。
适合:
检查代码能不能跑通
验证 loss 是否下降
快速判断模型是否有潜力
2. 正式 5-fold 模式
USE_5FOLD = True
跑完整五折交叉验证。
适合:
写论文
做正式实验
报告均值和标准差
这是一种很好的实验设计习惯。
开发阶段不要一上来就跑 5-fold。
先 quick fold 验证思路,确认有效后再跑完整实验。
三十、这份代码和论文如何对应?
| 论文模块 | 代码实现 |
|---|---|
| 黑边裁切 + CLAHE | 读取 train_images_320_crop_clahe |
| ConvNeXt-Tiny 骨干网络 | models.convnext_tiny() |
| Neck 模块 | LayerNorm + Dropout + Linear + GELU |
| 序数回归头 | ordinal_head = nn.Linear(512, 4) |
| 分类辅助头 | class_head = nn.Linear(512, 5) |
| 序数编码 | label_to_ordinal() |
| 累积概率解码 | ordinal_probs_to_class() |
| 双重阈值搜索 | search_best_thresholds() |
| 混合损失函数 | hybrid_loss_fn() |
| Focal Loss | focal_ce_loss() |
| TTA 推理 | tta_forward() |
| 5 折交叉验证 | StratifiedKFold |
| 实验日志保存 | train_log.csv |
| 错误分析数据 | val_predictions.csv |
代码不是孤立的工程脚本。
它是论文方法的具体落地。
三十一、这份项目真正高级在哪里?
很多本科项目的问题是:
用了一个模型
跑了一个数据集
调了几个参数
出了一个准确率
这样的项目很难打动老师。
而这个项目更像一个“围绕任务本质设计的系统方案”。
它的逻辑链条是:
DR 分级有天然顺序
→ 使用序数回归
APTOS 类别不平衡
→ 使用 pos_weight、head_weight、Focal Loss
固定 0.5 阈值不适合不平衡数据
→ 使用双重阈值搜索
部署图像存在变化
→ 使用 TTA 提高鲁棒性
医学任务需要可分析
→ 保存每张图的序数概率和预测结果
这就从“调参”升级成了“方法设计”。
三十二、答辩时可以这样介绍这个项目
你可以这样讲:
本项目针对糖尿病视网膜病变五级自动分级任务,考虑到 DR 分级具有天然的严重程度递增关系,我没有将其简单建模为普通五分类,而是采用序数回归思想,将五级分类转化为四个累积二分类问题,即分别判断病变程度是否超过 0、1、2、3 级。模型以 ImageNet 预训练的 ConvNeXt-Tiny 为骨干网络,经过 Neck 模块后分为两个输出头:序数回归头负责建模等级关系,分类辅助头配合 Focal Loss 强化困难样本和少数类学习。训练时使用加权 BCE、head 权重和标签平滑缓解类别不平衡;验证阶段不采用固定 0.5 阈值,而是在验证集上进行粗搜加细搜的双重阈值搜索,并加入单调性约束,优化 Macro F1。推理阶段使用四向翻转 TTA 提升模型鲁棒性。整体上,该方法从建模、损失、决策和推理四个层面针对医学分级任务进行了系统优化。
这段话就是你项目的精华。
三十三、老师最可能问的问题
问题 1:为什么不用普通五分类?
答:
因为 DR 等级具有天然序数关系。普通五分类只关心预测是否正确,不关心错得有多远。但医学上将 4 级预测成 3 级和预测成 0 级风险完全不同。序数回归可以让模型学习病变严重程度是否超过各级边界,更符合 DR 分级本质。
问题 2:为什么序数头输出 4 维?
答:
因为 5 个等级只需要 4 个边界,分别是 y>0、y>1、y>2、y>3。最终通过的边界数量就是预测等级。
问题 3:为什么还要分类辅助头?
答:
序数头擅长建模等级关系,但在类别极度不平衡时,高等级样本少,边界学习可能不足。分类辅助头使用 Focal Loss,可以强化困难样本和少数类学习,与序数头形成互补。
问题 4:最终预测用哪个头?
答:
当前代码最终预测主要使用序数头输出的四个累积概率,再结合搜索得到的阈值进行解码。分类头主要作为辅助监督参与训练,不直接作为最终决策输出。
问题 5:为什么要搜索阈值?
答:
固定 0.5 阈值不一定适合类别不平衡数据。不同等级边界的概率分布不同,通过验证集搜索阈值可以自适应调整决策边界,提升 Macro F1,尤其改善少数类表现。
问题 6:为什么 TTA 用翻转?
答:
因为翻转不会改变眼底病灶的病理语义,微动脉瘤、出血点、渗出物翻转后仍然是相同病变。但亮度、对比度等光度变化可能改变细小病灶的可见性,所以这里选择几何保持的翻转 TTA。
三十四、这个项目还可以怎么继续增强?
如果继续打磨,这个项目可以往以下方向升级:
1. 加入 QWK 指标
DR 分级常用 Quadratic Weighted Kappa。
它比 Accuracy 更适合有序分级任务。
因为它会惩罚远距离错误:
4 → 3:轻度惩罚
4 → 0:严重惩罚
这和医学风险更一致。
2. 做错误距离分析
不仅看预测错没错,还要看错了几级:
相邻等级错误:可接受
跨 2 级错误:较严重
跨 3~4 级错误:危险
论文中已经提到预测错误中大量集中在相邻等级之间,这类分析很适合医学分级任务。
3. 加入可解释性热力图
可以用 Grad-CAM 或 LayerCAM 看模型关注区域。
如果模型关注微动脉瘤、出血点、渗出物等病变区域,说明模型具有一定医学合理性。
4. 做多骨干对比
可以对比:
ResNet-50
DenseNet-121
EfficientNet-B3
ConvNeXt-Tiny
这样论文对比实验更完整。
5. 做消融实验
逐步移除模块:
无 ordinal,只用 softmax
无 class_head
无 focal loss
无阈值搜索
无 TTA
完整方法
这样才能证明每个模块真的有效。
论文中也设置了不同方法对比和消融实验,这是让项目从“工程实现”走向“科研论证”的关键。
三十五、总结:好代码不是能跑,而是每一行都有理由
这份代码真正值得学习的地方,不是它用了多少复杂技术,而是它的设计逻辑很完整:
任务本质:DR 是有序分级
模型设计:ConvNeXt-Tiny + 双头输出
训练目标:序数 BCE + Focal CE
类别不平衡:pos_weight + head_weight
决策优化:双重阈值搜索
推理增强:四向翻转 TTA
实验管理:quick fold + 5-fold
结果分析:保存日志、阈值、预测概率
一个普通项目只会说:
我用了深度学习做糖尿病视网膜病变分类。
而一个更成熟的项目会说:
我从 DR 分级的序数结构出发,重新设计了建模方式、损失函数、决策阈值和推理策略,使模型不仅能分类,而且能更符合医学分级任务的特点。
这就是“会跑代码”和“会做研究”的区别。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐


所有评论(0)