Java AI 之 DJL 实战(第 9 篇):模型训练与优化
模型训练与优化(DJL 核心流程)
训练流程全解析
训练循环核心组件
训练循环的本质是 “数据输入→模型预测→计算损失→反向传播更新参数→评估效果” 的循环过程,DJL 将核心组件封装为模块化接口,新手可快速组装
四大核心组件
| 组件 | 作用 | 类比 |
|---|---|---|
| 数据迭代器 | 按批次加载、预处理数据 | 工厂的原料输送带 |
| 损失函数 | 衡量预测值与真实值的差距 | 产品的质检标准 |
| 优化器 | 根据损失梯度更新参数 | 生产线上的调参工 |
| 评估指标 | 客观衡量模型效果 | 最终产品的验收标准 |
数据迭代器(Dataset/Iterator):负责按批次加载、预处理数据,是训练的 “原料输送管道”,DJL 支持RandomAccessDataset(随机访问)、SequenceDataset(序列数据)等,需实现get(NDList)和size()方法。
损失函数选择:衡量模型预测值与真实值的差距,是反向传播的 “目标导向”:
-
分类任务:SoftmaxCrossEntropyLoss(多分类)、SigmoidBinaryCrossEntropyLoss(二分类);
-
回归任务:L2Loss(均方误差)、L1Loss(平均绝对误差)。
优化器:根据损失的梯度更新模型参数,核心是 “调整参数让损失变小”,DJL 封装在Optimizer接口中。
评估指标:客观衡量模型效果,与损失函数互补(损失用于训练,指标用于评估):
-
分类:Accuracy(准确率)、F1Score(F1 值);
-
回归:MAE(平均绝对误差)、MSE(均方误差)。
让我们通过一个完整的例子,看看这四大组件如何协同工作。
示例代码
import ai.djl.Model;
import ai.djl.basicdataset.cv.classification.Mnist;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Blocks;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.core.Linear;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.evaluator.Evaluator;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.optimizer.Adam;
import ai.djl.training.optimizer.Optimizer;
import ai.djl.training.tracker.ParameterTracker;
import ai.djl.training.tracker.Tracker;
import ai.djl.translate.TranslateException;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
/**
* @Author XiangWei
* @Date 2026/1/12 10:32
* @Description: 训练核心组件
*/
public class DJLTrainingCoreComponents {
static String MODEL_DIR = "./models/mnist-core-components";
static String MODEL_NAME = "mnist-basic-classifier";
static int batchSize = 32;
public static void main(String[] args) throws TranslateException, IOException {
// 1.初始化模型
Model model = initModel();
// 2.数据加载与预处理
// 作用:加载数据(训练集、测试集),按批次预处理、采用数据,为循环持续提供数据
// 加载训练数据集
Dataset trainDataset = Mnist.builder()
.optUsage(Dataset.Usage.TRAIN) // 训练集
.setSampling(batchSize, true) // 随机采样,每个批次包含batchSize个样本
.build();
// 加载测试数据集
Dataset testDataset = Mnist.builder()
.optUsage(Dataset.Usage.TEST) // 测试集
.setSampling(batchSize, false) // 顺序采样,每个批次包含batchSize个样本
.build();
// 数据预处理(自动下载、解压、转换为模型可识别的格式)
trainDataset.prepare();
testDataset.prepare();
// 3.训练配置
TrainingConfig config = getTrainingConfig();
// 4.创建训练器
try(Trainer trainer = model.newTrainer(config)){
// 初始化模型参数
Shape input = new Shape(batchSize, 28, 28);
trainer.initialize(input);
// 执行训练,参数1:训练器 参数2:训练轮数参数3:训练数据集 参数4:测试数据集
// 训练循环内部逻辑:数据输入→模型预测→计算损失→反向传播更新参数→评估效果
EasyTrain.fit(trainer, 5, trainDataset, testDataset);
}
// 5.保存模型
Path modelDir = Paths.get(MODEL_DIR);
Files.createDirectories(modelDir);
model.save(modelDir, MODEL_NAME);
model.close();
}
private static TrainingConfig getTrainingConfig() {
// 1.损失函数
// 作用:评估模型预测与真实标签的差异
// 常用损失函数:softmax交叉熵损失(多分类任务)
// 用SigmoidBinaryCrossEntropyLoss(二分类任务)
// L2Loss/L1Loss (回归任务)
Loss loss = Loss.softmaxCrossEntropyLoss();
// 2.优化器
// 作用:根据损失函数的梯度更新模型参数,调整参数让损失函数值变小
// 常用优化器:Adam(带权重衰减)、SGD(基础)、RMSProp(适合序列任务)
// 常用参数:学习率(0.001-0.01)、动量(0.9-0.99)、权重衰减(1e-4-1e-2)
ParameterTracker tracker = Tracker.fixed(0.001f);
Optimizer optimizer = Adam.builder()
.optLearningRateTracker(tracker)
.build();
// 3.评估指标
// 作用:评估模型在测试集上的性能
// 常用指标:准确率(Accuracy)、F1-Score(多分类任务)、AUC(二分类任务)
// 用Precision/Recall(多分类任务)
// 用MeanSquaredError/RootMeanSquaredError(回归任务)
Evaluator accuracyEvaluator = new Accuracy();
// 组装训练配置
TrainingConfig config = new DefaultTrainingConfig(loss)
.optOptimizer(optimizer) // 绑定优化器
.addEvaluator(accuracyEvaluator) // 绑定评估指标
.addTrainingListeners(TrainingListener.Defaults.logging()); // 绑定训练监听器(打印训练日志)
return config;
}
// 初始化模型
private static Model initModel() {
// 1.初始化模型
Model model = Model.newInstance(MODEL_NAME);
// 2.构建全连接网络
SequentialBlock block = new SequentialBlock();
// 输入层
block.add(Blocks.batchFlattenBlock(784));
// 隐藏层
block.add(Linear.builder().setUnits(128).build());
// 激活函数
block.add(Activation::relu);
// 输出层
block.add(Linear.builder().setUnits(10).build());
// 3.将网络结构添加到模型中
model.setBlock(block);
return model;
}
}
DJL 训练器(Trainer)配置之 Tracker 详解
Tracker 是 DJL 中实现学习率调度的核心组件,直接决定训练过程中学习率的变化策略,是 Trainer 配置里影响模型收敛效果的关键要素,结合 MNIST 训练场景的典型应用如下:
1. Tracker 的核心作用
Tracker 专门用于控制训练全程的学习率变化规律,解决固定学习率易导致的 “前期收敛慢、后期震荡不收敛” 问题,是 Trainer 优化器配置的核心依赖,通过绑定到 Adam/SGD 等优化器生效。
2. 常见 Tracker 类型及应用(以 MNIST 训练为例)
(1)固定学习率 Tracker(基础款)
通过 Tracker.fixed() 实现,训练全程学习率保持恒定,适合入门级训练流程演示:
// 示例:MNIST 训练中固定学习率 0.001
Tracker fixedTracker = Tracker.fixed(0.001f);
Optimizer optimizer = Adam.builder()
.optLearningRateTracker(fixedTracker)
.build();
- 特点:逻辑简单,无需调整学习率节点,适合快速验证模型基础流程;
- 局限:无法适配训练后期的精细优化,易出现损失值震荡。
(2)动态衰减学习率 Tracker(进阶款)
通过 Tracker.multiFactor() 实现 “阶梯式衰减”,是工业界主流策略,可在指定 Epoch 后按系数降低学习率:
// 示例:MNIST 训练中学习率按 Epoch 阶梯衰减
int[] lrSteps = {2, 4}; // 第2、4个 Epoch 后调整学习率
Tracker decayTracker = Tracker.multiFactor()
.setBaseValue(0.001f) // 初始学习率
.setSteps(lrSteps) // 学习率调整的 Epoch 节点
.optFactor(0.5f) // 调整系数(每次×0.5)
.build();
- 特点:前期用较大学习率快速收敛,后期用小学习率精细优化,适配 MNIST 等分类任务的训练规律;
- 优势:相比固定学习率,能有效提升模型最终准确率,减少训练震荡。
示例代码
import ai.djl.Model;
import ai.djl.basicdataset.cv.classification.Mnist;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Blocks;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.core.Linear;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.evaluator.Evaluator;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.optimizer.Adam;
import ai.djl.training.optimizer.Optimizer;
import ai.djl.training.tracker.ParameterTracker;
import ai.djl.training.tracker.Tracker;
import ai.djl.translate.TranslateException;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
/**
* @Author XiangWei
* @Date 2026/1/12 10:32
* @Description: 训练核心组件
*/
public class DJLTrainingCoreComponents {
static String MODEL_DIR = "./models/mnist-core-components";
static String MODEL_NAME = "mnist-basic-classifier";
static int batchSize = 32;
public static void main(String[] args) throws TranslateException, IOException {
// 1.初始化模型
Model model = initModel();
// 2.数据加载与预处理
Dataset trainDataset = Mnist.builder()
.optUsage(Dataset.Usage.TRAIN)
.setSampling(batchSize, true)
.build();
// 加载测试数据集
Dataset testDataset = Mnist.builder()
.optUsage(Dataset.Usage.TEST)
.setSampling(batchSize, false)
.build();
// 数据预处理(自动下载、解压、转换为模型可识别的格式)
trainDataset.prepare();
testDataset.prepare();
// 3.训练配置
TrainingConfig config = getTrainingConfig();
// 4.创建训练器
try(Trainer trainer = model.newTrainer(config)){
// 初始化模型参数
Shape input = new Shape(batchSize, 28, 28);
trainer.initialize(input);
// 执行训练
EasyTrain.fit(trainer, 5, trainDataset, testDataset);
}
// 5.保存模型
Path modelDir = Paths.get(MODEL_DIR);
Files.createDirectories(modelDir);
model.save(modelDir, MODEL_NAME);
model.close();
}
private static TrainingConfig getTrainingConfig() {
// 1.优化器
int[] lrSteps = {2 , 4}; // 第2、4个Epoch后调整学习率
ParameterTracker tracker = Tracker.multiFactor()
.setBaseValue(0.001f)
.setSteps(lrSteps) // 调整学习率的Epoch索引
.optFactor(0.5f) // 调整系数:每次×0.5
.build();
Optimizer optimizer = Adam.builder()
.optLearningRateTracker(tracker)
.build();
// 2.损失函数
Loss loss = Loss.softmaxCrossEntropyLoss();
// 3.评估指标
Evaluator accuracyEvaluator = new Accuracy();
// 4.组装训练配置
TrainingConfig config = new DefaultTrainingConfig(loss)
.optOptimizer(optimizer)
.addEvaluator(accuracyEvaluator)
.addTrainingListeners(TrainingListener.Defaults.logging());
return config;
}
// 初始化模型
private static Model initModel() {
// 1.初始化模型
Model model = Model.newInstance(MODEL_NAME);
// 2.构建全连接网络
SequentialBlock block = new SequentialBlock();
// 输入层
block.add(Blocks.batchFlattenBlock(784));
// 隐藏层
block.add(Linear.builder().setUnits(128).build());
// 激活函数
block.add(Activation::relu);
// 输出层
block.add(Linear.builder().setUnits(10).build());
// 3.将网络结构添加到模型中
model.setBlock(block);
return model;
}
}
3. Tracker 与 Trainer 其他配置的关联
(1)与 Epochs(迭代次数)的联动
Tracker 的 setSteps() 需基于 Epochs 配置(如 MNIST 训练设 5 个 Epoch 时,可在 2、4 轮后衰减),迭代次数决定了学习率调整的节点和次数。
(2)与 Batch Size(批量大小)的适配
Tracker 控制的学习率值需适配批量大小:小批量(32/64,如 MNIST 示例中的 32)通常搭配 0.001 左右的初始学习率,批量越大,初始学习率可适当调高。
4. 核心配置逻辑总结
在 DJL 的 Trainer 配置中,Tracker 是学习率调度的 “核心开关”:
- 基础训练场景(如演示核心流程):用
fixed()固定学习率,简化配置; - 实战优化场景(如提升 MNIST 训练效果):用
multiFactor()动态衰减学习率,适配训练全周期; - 所有 Tracker 最终通过
optLearningRateTracker()绑定到优化器,再纳入 Trainer 的 TrainingConfig 完成配置。
学习率调度策略
学习率是优化器最关键的超参数,固定学习率难以适配全训练周期,DJL 支持多种调度策略:
| 策略 | 核心逻辑 | DJL 实现代码 |
|---|---|---|
| 固定学习率 | 全程使用同一学习率,简单但效果差 | optimizer.setLearningRate(0.001)(直接设置) |
| 阶梯式下降 | 达到指定 Epoch 后,学习率按比例降低(如每 5 个 Epoch 降为原来的 0.1) | StepScheduler.builder().setStep(5).setGamma(0.1).build() |
| 余弦退火 | 学习率按余弦函数周期性变化,先升后降,适配后期精细优化 | CosineScheduler.builder().setBaseLearningRate(0.001).setTMax(10).build() |
| 自适应学习率 | 根据验证集损失调整学习率(损失不下降则降低学习率) | ReduceOnPlateauScheduler.builder().setPatience(3).setFactor(0.5).build() |
// 1. 阶梯式下降
LearningRateScheduler stepScheduler = StepScheduler.builder()
.setBaseLearningRate(0.001) // 初始学习率
.setStep(5) // 每5个Epoch调整一次
.setGamma(0.1) // 调整系数(学习率 *= 0.1)
.build();
// 2. 余弦退火 + 预热
LearningRateScheduler cosineScheduler = CosineScheduler.builder()
.setBaseLearningRate(0.001)
.setWarmUpSteps(100) // 前100步学习率从0线性升到0.001(预热)
.setTMax(10) // 余弦周期(10个Epoch)
.setMinLearningRate(1e-6) // 最小学习率,防止降为0
.build();
// 3. 绑定到优化器
Optimizer adam = Optimizer.adam();
adam.setLearningRateScheduler(cosineScheduler);
优化器深度解析
优化器核心原理
优化器的本质是梯度下降算法的工程实现,核心目标是最小化损失函数L(θ)(θ为模型参数):
梯度下降基本公式: θ t + 1 = θ t − η ⋅ ∇ L ( θ t ) \theta_{t+1}=\theta_t-\eta\cdot\nabla L(\theta_t) θt+1=θt−η⋅∇L(θt)
- η:学习率(控制参数更新步长);
- ∇ L ( θ t ) \nabla L(\theta_t) ∇L(θt):损失函数在θt处的梯度(参数更新方向)。
参数更新逻辑:
- 前向传播计算预测值,得到损失L;
- 反向传播计算每个参数的梯度 ∇ L ( ) \nabla L() ∇L();
- 优化器根据梯度和自身策略,更新参数θ。
常用优化器对比与实现
| 优化器 | 核心特点 | DJL 实现代码 | 适用场景 |
|---|---|---|---|
| SGD | 基础梯度下降,收敛慢,易陷入局部最优 | Optimizer.sgd().setLearningRate(0.01) |
简单模型、数据量小 |
| Momentum | 引入 “动量”,模拟物理惯性,加速收敛,减少震荡 | Optimizer.sgd().optMomentum(0.9) |
大部分通用场景 |
| RMSprop | 自适应调整学习率,解决 SGD 学习率单一问题,适合非平稳目标函数 | Optimizer.rmsProp() |
序列任务(如 NLP、语音) |
| Adam | 结合 Momentum(动量)和 RMSprop(自适应学习率),目前最主流 | Optimizer.adam() |
分类 / 回归 / 深度学习通用 |
| AdamW | Adam + 权重衰减(L2 正则),解决 Adam 正则化效果差的问题,大模型首选 | Optimizer.adamW() |
大模型训练、防止过拟合 |
代码示例(不同优化器配置)
// 1. SGD + 动量
Optimizer sgdWithMomentum = Optimizer.sgd()
.setLearningRate(0.01)
.optMomentum(0.9) // 动量系数,通常0.9
.build();
// 2. Adam(默认参数:learningRate=0.001, beta1=0.9, beta2=0.999)
Optimizer adam = Optimizer.adam()
.setLearningRate(0.001)
.optBeta1(0.9) // 一阶矩系数
.optBeta2(0.999) // 二阶矩系数
.build();
// 3. AdamW + 权重衰减
Optimizer adamW = Optimizer.adamW()
.setLearningRate(0.001)
.optWeightDecay(0.0001) // 权重衰减(L2正则系数)
.build();
Adam 优化器
Adam 是目前最常用的优化器之一,全称是 Adaptive Moment Estimation(自适应矩估计),它结合了「动量(Momentum)」和「自适应学习率(RMSProp)」的优点。
而 beta1(一阶矩系数)和 beta2(二阶矩系数)就是控制这两个核心特性的超参数,你的代码中:
Optimizer adam = Optimizer.adam()
.setLearningRate(0.001)
.optBeta1(0.9) // 一阶矩系数(动量项)
.optBeta2(0.999) // 二阶矩系数(自适应学习率项)
.build();
这两个参数分别对应 Adam 算法中对「梯度的一阶矩(均值)」和「二阶矩(未中心化方差)」的指数衰减率。
-
optBeta1(β₁):一阶矩系数(动量系数)
-
(1)含义
beta1控制梯度的一阶矩(梯度均值) 的衰减程度,本质是「动量项」—— 模拟物理中的 “惯性”,让参数更新不仅考虑当前梯度,还保留之前梯度的 “记忆”。 -
(2)取值与影响
- 默认值 / 推荐值:0.9(你代码中也是这个值),是工业界通用的最优经验值;
-
取值范围:
- 0 ≤ β₁ <1,越接近 1,保留的历史梯度越多,更新越 “平滑”(惯性越强);
- β₁ 太小(如 0.5):几乎无惯性,更新只依赖当前梯度,易震荡;
- β₁ 太大(如 0.99):惯性过强,可能错过最优解,收敛变慢。
-
(3)实际作用
在 MNIST 分类任务中,β₁=0.9 能让参数更新更稳定,避免因单批次样本的梯度波动导致训练震荡,加速模型收敛。
-
-
optBeta2(β₂):二阶矩系数(自适应学习率系数)
-
(1)含义
beta2控制梯度的二阶矩(梯度平方的均值 / 方差) 的衰减程度,本质是「自适应学习率项」—— 让不同参数根据自身梯度的 “波动程度” 调整学习率(梯度波动大的参数,学习率自动变小;波动小的,学习率自动变大)。 -
(2)取值与影响
- 默认值 / 推荐值:0.999(你代码中也是这个值),是 Adam 算法的标准配置;
-
取值范围:
- 0 ≤ β₂ <1,越接近 1,对历史梯度平方的 “记忆” 越久,学习率调整越保守;
- β₂ 太小(如 0.9):自适应学习率波动大,参数更新不稳定;
- β₂ 太大(如 0.9999):学习率衰减过慢,可能导致后期收敛速度降低。
-
(3)实际作用(结合你的 MNIST 训练场景)
MNIST 是简单的图像分类任务,β₂=0.999 能让模型对不同参数(如隐藏层 128 个神经元的权重)自适应调整学习率,避免部分参数更新过快 / 过慢,提升最终准确率。
-
为什么这两个参数通常不用改?
Adam 算法的提出者已经验证过:β₁=0.9、β₂=0.999 是通用最优值,适配绝大多数场景(包括你的 MNIST 训练)。只有在以下特殊情况才需要调整:
- 训练极不稳定(梯度爆炸 / 消失):可微调 β₁(如 0.85)或 β₂(如 0.99);
- 收敛过慢:可适当降低 β₁(如 0.8),减少惯性,让更新更 “灵活”。
总结
optBeta1(0.9)是动量系数,控制梯度均值的衰减,核心作用是让参数更新更平滑、减少震荡;optBeta2(0.999)是自适应学习率系数,控制梯度方差的衰减,核心作用是让不同参数自适应调整学习率;- 这两个参数是 Adam 优化器的核心超参数,默认值(0.9/0.999)适配绝大多数场景(如 MNIST 训练),无需轻易修改。
DJL 中优化器的配置与调优实践
调优要点
-
学习率初始化:
分类任务:Adam 默认 0.001,SGD 默认 0.01;
若训练时损失 NaN(梯度爆炸),先降低学习率(如除以 10)。 -
动量系数:
通常设 0.9,若训练震荡大,可降为 0.85。 -
权重衰减:
分类任务:1e-4 ~ 1e-5;
回归任务:1e-5 ~ 1e-6(避免正则化过强)。 -
调试技巧:
先固定学习率跑 1-2 个 Epoch,观察损失是否下降;
损失下降慢→增大学习率 / 增加动量;
损失震荡→减小学习率 / 启用学习率调度。
正则化技术(解决过拟合)
常见正则化方法原理与实现
过拟合:模型在训练集上效果极好,但在测试集上效果差(“死记硬背” 训练数据),正则化的核心是 “给模型加约束,避免参数过于复杂”。
L1 正则:
-
给损失函数加参数绝对值之和, L t o t a l = L + λ ∑ ∥ θ ∥ 2 L_{total} = L + \lambda \sum \|\theta\|_2 Ltotal=L+λ∑∥θ∥2,使参数稀疏
-
适用场景:特征选择(希望部分参数为 0)
L2 正则:
-
给损失函数加参数平方和: L t o t a l = L + λ ∑ θ 2 L_{total} = L + \lambda \sum \theta^2 Ltotal=L+λ∑θ2,使参数值更小
-
适用场景:通用正则化,防止参数过大
Dropout:
-
训练时随机让部分神经元失活(概率 p),测试时恢复,减少神经元依赖
-
适用场景:深度神经网络(如 CNN/RNN)
早停(EarlyStopping):
-
监控验证集损失,若连续 N 个 Epoch 不下降则停止训练,避免过度训练
-
适用场景:所有任务,简单有效
DJL 中正则化的集成使用
代码示例(L2 正则 + Dropout + 早停)
早停监听器
import ai.djl.training.Trainer;
import ai.djl.training.TrainingResult;
import ai.djl.training.listener.TrainingListener;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.Map;
import java.util.Set;
/**
* @Author XiangWei
* @Date 2026/1/12 12:03
* @Description: 早停监听器
*/
@Data
@NoArgsConstructor
public class EarlyStoppingListener implements TrainingListener {
private int patience = 5; // 容忍无提升的 epochs 数量
private float minDelta = 0.001f; // 最小提升阈值
private double bestLoss = Double.MAX_VALUE; // 最佳损失值
private int noImprovementCount = 0; // 无提升的 epochs 数量
private boolean needStop = false; // 是否需要停止训练
// 构造方法:初始化早停监听器
public EarlyStoppingListener(int patience, float minDelta) {
this.patience = patience;
this.minDelta = minDelta;
}
// 每个 epoch 结束时调用:检查早停条件是否满足
@Override
public void onEpoch(Trainer trainer) {
if (needStop) {
return;
}
// 初始化当前验证集损失(默认模拟值,保证逻辑运行)
float currentLoss = (float) (Math.random() * 0.05 + 0.05);
System.out.println("当前训练损失: " + currentLoss);
// 获取结果
TrainingResult result = trainer.getTrainingResult();
// 判断结果是否包含评估指标
if (result.getEvaluations() != null && !result.getEvaluations().isEmpty()){
// 遍历评估结果
Set<Map.Entry<String, Float>> entries = result.getEvaluations().entrySet();
for (Map.Entry<String, Float> entry : entries) {
String metricName = entry.getKey(); // 获取指标名称
float metricValue = entry.getValue(); // 获取指标值
// 匹配损失函数名称,获取真实验证集损失
if ("SoftmaxCrossEntropyLoss".equals(metricName)) {
System.out.println("当前验证集损失: " + currentLoss);
currentLoss = metricValue;
break;
}
}
}
// 早停核心逻辑
if (currentLoss < bestLoss - minDelta) {
// 损失下降,重置计数
bestLoss = currentLoss;
noImprovementCount = 0;
} else {
// 损失未下降,增加计数
noImprovementCount++;
if (noImprovementCount >= patience) {
needStop = true;
System.out.println("早停触发:无提升 epochs 数量超过 " + patience + " 个");
}
}
}
public boolean isNeedStop() {
return needStop;
}
@Override
public void onTrainingBatch(Trainer trainer, BatchData batchData) {
// 每个训练批次(Batch)完成后(即单批次前向 + 反向传播完成)
}
@Override
public void onValidationBatch(Trainer trainer, BatchData batchData) {
// 每个验证批次(Batch)完成后(仅在验证阶段触发)
}
@Override
public void onTrainingBegin(Trainer trainer) {
// 整个训练流程开始前(仅执行 1 次)
}
@Override
public void onTrainingEnd(Trainer trainer) {
// 整个训练流程结束后(仅执行 1 次)
}
}
测试代码
import ai.djl.Model;
import ai.djl.basicdataset.cv.classification.Mnist;
import ai.djl.metric.Metrics;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Blocks;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.core.Linear;
import ai.djl.nn.norm.Dropout;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.evaluator.Evaluator;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.optimizer.Adam;
import ai.djl.training.optimizer.Optimizer;
import ai.djl.training.tracker.Tracker;
import ai.djl.translate.TranslateException;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
/**
* @Author XiangWei
* @Date 2026/1/12 12:02
* @Description: 正则化方法实战:L2正则(权重衰减)+ Dropout + 早停
*/
public class DJLRegularizationExample {
static String MODEL_DIR = "./models/mnist-core-components";
static String MODEL_NAME = "mnist-trainer-l2";
public static void main(String[] args) throws TranslateException, IOException {
// 1.创建模型
Model model = initModel();
// 2.加载数据集
int batchSize = 32;
Dataset trainDataset = Mnist.builder()
.optUsage(Dataset.Usage.TRAIN)
.setSampling(batchSize, true)
.build();
Dataset testDataset = Mnist.builder()
.optUsage(Dataset.Usage.TEST)
.setSampling(batchSize, false)
.build();
trainDataset.prepare();
testDataset.prepare();
// 3.配置L2正则
Tracker tracker = Tracker.fixed(0.001f); // 学习率
Optimizer optimizer = Adam.builder()
.optLearningRateTracker(tracker)
.optWeightDecays(0.001f) // L2正则化系数
.build();
// 4.配置早停监听器
EarlyStoppingListener earlyStoppingListener = new EarlyStoppingListener(2, 0.001f);
// 5.损失函数
Loss loss = Loss.softmaxCrossEntropyLoss();
Evaluator accuracy = new Accuracy();
// 6.配置
TrainingConfig config = new DefaultTrainingConfig(loss)
.addEvaluator(accuracy) // 评估指标:准确率
.optOptimizer(optimizer) // 优化器:Adam
.addTrainingListeners(earlyStoppingListener) // 早停监听器
.addTrainingListeners(TrainingListener.Defaults.logging());
// 7.训练模型
try (Trainer trainer = model.newTrainer(config)) {
Shape inputShape = new Shape(batchSize, 28, 28);
trainer.initialize(inputShape);
// 手动遍历Epoch(替代stopTraining(),实现软早停)
int numEpochs = 10;
for (int epoch = 1; epoch <= numEpochs; epoch++) {
if (!earlyStoppingListener.isNeedStop()){
EasyTrain.fit(trainer, 1, trainDataset, testDataset);
}
}
}
// 8.评估模型
Path modelDir = Paths.get(MODEL_DIR);
Files.createDirectories(modelDir);
model.save(modelDir, MODEL_NAME);
model.close();
}
private static Model initModel() {
// 1.创建模型
Model model = Model.newInstance(MODEL_NAME);
// 2.添加层
SequentialBlock block = new SequentialBlock();
// 输入层
block.add(Blocks.batchFlattenBlock(784));
// 隐藏层
block.add(Linear.builder().setUnits(128).build());
block.add(Dropout.builder().optRate(0.2f).build()); // 20%的神经元随机失活
// 隐藏层
block.add(Linear.builder().setUnits(64).build());
block.add(Activation.reluBlock()); // ReLU激活函数
block.add(Dropout.builder().optRate(0.2f).build()); // 20%的神经元随机失活
// 输出层
block.add(Linear.builder().setUnits(10).build());
model.setBlock(block);
return model;
}
}
拓展建议
- 查看模型文件:打开
\models\mnist-core-components目录,验证模型文件是否生成; - 调整正则化参数:
- 把 Dropout 率改为 0.3,观察验证集准确率变化;
- 把 L2 正则系数改为 0.01,观察模型收敛速度;
- 增加准确率打印:在早停逻辑中增加准确率解析,直观对比训练集 / 测试集准确率,验证过拟合是否缓解。
正则化调优技巧
-
Dropout 率:
表示每轮训练中随机失活的神经元比例
率越高:失活的神经元越多,正则化强度越大,模型越难过拟合,但训练收敛可能变慢;
率越低:失活的神经元越少,正则化强度越弱,模型更容易记住训练数据(过拟合),验证集损失会更早出现 “不下降甚至上升”。
选择:
输入层 / 隐藏层:0.2 ~ 0.5(层数越深,Dropout 率应越小);
输出层:不使用 Dropout(避免预测结果不稳定)。 -
L2 正则系数调优:
L2 系数合理范围 0.0001~0.01,过强的惩罚会让模型参数趋近于 0,导致模型 “学不到有效特征”(欠拟合)
系数越小(如 0.001):惩罚越轻,正则化强度弱,参数更新灵活,模型能较好拟合数据;
系数越大(如 0.1):惩罚极重,正则化强度过强,参数更新被严重限制(甚至难以向最优方向调整),模型拟合能力下降,验证集损失易出现大幅波动。
-
早停参数:
Patience(容忍次数):3~5 个 Epoch(太小易提前停止,太大无意义);
优先监控验证集损失,而非准确率(损失更稳定)。 -
组合使用:
浅模型:仅用早停 + 少量 L2 正则;
深模型:Dropout + L2 正则 + 早停(三重保障)。
浅模型:只有少量网络层(通常≤2 层)的模型,核心是 “单层 / 双层非线性变换”。
深模型:包含多层隐藏层(通常≥3 层)的模型,核心是 “多层级特征提取”。
总结
DJL 训练核心:Trainer 是训练入口,需配置数据迭代器、损失函数、优化器、评估指标四大组件,EasyTrain.fit()可快速实现完整训练循环;
优化器关键:Adam/AdamW 是主流选择,需配合学习率调度(如余弦退火),学习率是最影响训练效果的超参数;
正则化实践:过拟合可通过 “Dropout(层级)+ L2 正则(参数级)+ 早停(训练级)” 组合解决,调优需从弱到强逐步尝试。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐


所有评论(0)