Java AI 之 DJL 实战(第 10 篇):深度学习进阶模型
深度学习进阶模型
卷积神经网络(CNN)
什么是 CNN
卷积神经网络(Convolutional Neural Network,CNN)是一种专门为处理网格结构数据(如图像、语音)设计的深度学习模型。与全连接网络直接将图像展平为一维向量、忽略空间特征且易引发参数爆炸的特点不同,CNN 的核心优势在于通过卷积操作替代全连接层的暴力特征提取,实现局部特征感知、参数共享与空间降维,能更高效地挖掘网格数据中的空间关联信息,因此在计算机视觉等领域得到广泛应用。
CNN 核心组件
CNN 的特征提取与任务推理能力,由多个功能明确的核心组件协同实现,各组件分工协作,构成完整的网络架构:
- 卷积层(提特征):作为 CNN 的核心特征提取模块,其核心作用是提取图像的局部特征(如边缘、纹理、形状)。通过滑动卷积核与图像局部区域做内积运算,既保留了像素间的空间位置信息,又通过参数共享机制大幅减少网络参数数量,为后续特征处理奠定基础。
- 池化层(降维):紧随卷积层之后发挥作用,核心作用是降维与特征选择(下采样)。通过对卷积层输出的特征图进行采样处理,不仅能降低计算复杂度,还能保留关键特征、增强模型对特征位置变化的鲁棒性。常用的池化方式包括最大池化(MaxPool)和平均池化(AvgPool)。
- 激活层(非线性):卷积与池化本质上均为线性操作,难以拟合复杂的非线性特征关系。激活层的核心作用就是为网络引入非线性映射能力,打破线性操作的局限性,让模型能够学习到数据中复杂的特征关联(如区分手写数字 “6” 和 “9” 的细微轮廓差异)。常用激活函数有 ReLU(缓解梯度消失,加速训练)、Sigmoid(适合二分类输出)、Tanh(输出均值为 0,优化梯度传播)等。
- 全连接层(分类):位于网络的末端,核心作用是特征整合与分类。它会将卷积、池化层提取的高维特征图展平为一维向量,再通过权重矩阵将特征映射到样本的标签空间,最终完成分类或回归任务。
卷积操作详解
卷积层的特征提取能力,完全依赖于卷积操作的实现,而卷积操作的效果由多个关键参数共同决定:
- 卷积核(Kernel/Filter):即二维权重矩阵(常见尺寸为 3×3、5×5),是特征提取的 “检测器”。卷积核的参数会通过反向传播算法不断优化,使其能够精准捕捉图像中的特定特征。
- 步长(Stride):指卷积核在图像上滑动的步幅。步长越大,卷积核每次滑动覆盖的区域越广,最终输出的特征图尺寸越小。例如步长设为 2 时,特征图的尺寸会减半。
- 填充(Padding):在图像边缘补 0 的操作,主要用于控制输出特征图的尺寸。常用的填充方式有两种:SAME 填充(输出尺寸与输入尺寸一致)、VALID 填充(无填充,输出尺寸计算公式为
(输入尺寸 - 核尺寸 + 1) / 步长)。 - 多通道卷积:当输入图像为多通道(如 RGB 图像的 3 通道)时,卷积核的通道数需与输入通道数保持一致。卷积时,每个通道对应一个卷积核进行运算,最终将所有通道的卷积结果求和,得到单通道特征图;若需输出多通道特征图,可设置多组卷积核并行运算。
经典 CNN 结构实现(LeNet、AlexNet、ResNet )
示例代码: LeNet(手写数字分类)
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.Blocks;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.convolutional.Conv2d;
import ai.djl.nn.core.Linear;
import ai.djl.nn.pooling.Pool;
import ai.djl.training.ParameterStore;
/**
* @Author XiangWei
* @Date 2026/1/12 15:54
* @Description:
*/
public class LeNet {
public static void main(String[] args) {
// 1.创建网络
Block block = createLeNet();
System.out.println(block);
// 2.训练
try(NDManager manager = NDManager.newBaseManager()){
// 输入层:制定数据类型
Shape inputShape = new Shape(1, 1, 28, 28);
// 初始化
block.initialize(manager, DataType.FLOAT32, inputShape);
// 创建模拟输入(1张28×28的灰度图)
NDArray input = manager.ones(inputShape);
// 前向传播:输入→网络→输出
ParameterStore parameterStore = new ParameterStore(manager, false);
NDArray output = block.forward(parameterStore, new NDList(input), false).singletonOrThrow();
System.out.println("输出形状:" + output.getShape());
// 获取识别结果(取输出层中概率最大的类别)
long predictedClass = output.argMax(1).getLong(0);
System.out.println("输出结果:" + predictedClass);
}
}
// 创建网络
private static Block createLeNet() {
SequentialBlock block = new SequentialBlock();
// 卷积层1:输入1通道(灰度图),输出6通道,卷积核5×5,步长1,填充0
block.add(Conv2d.builder()
.setKernelShape(new Shape(5, 5))
.optStride(new Shape(1, 1))
.optPadding(new Shape(0, 0))
.setFilters(6)
.build());
// 池化层1:2×2最大池化,步长2
block.add(Pool.maxPool2dBlock(new Shape(2, 2), new Shape(2, 2)));
// 卷积层2:输入6通道,输出16通道,卷积核5×5,步长1,填充0
block.add(Conv2d.builder()
.setKernelShape(new Shape(5, 5))
.optStride(new Shape(1, 1))
.optPadding(new Shape(0, 0))
.setFilters(16)
.build());
// 池化层2:2×2最大池化,步长2
block.add(Pool.maxPool2dBlock(new Shape(2, 2), new Shape(2, 2)));
// 展平:将多维特征图转为一维向量
block.add(Blocks.batchFlattenBlock());
// 全连接层1:输入16*4*4=256,输出120
block.add(Linear.builder().setUnits(120).build());
// 全连接层2:输入120,输出84
block.add(Linear.builder().setUnits(84).build());
// 输出层:输入84,输出10(数字0-9分类)
block.add(Linear.builder().setUnits(10).build());
return block;
}
}
AlexNet(简化版,ImageNet 子集分类)
在网络结构、激活函数、正则化、任务适配等方面对 LeNet 做了全方位的升级,适配更复杂的图像分类任务(ImageNet 1000 类)
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Blocks;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.convolutional.Conv2d;
import ai.djl.nn.core.Linear;
import ai.djl.nn.norm.Dropout;
import ai.djl.nn.pooling.Pool;
public class AlexNet {
public static SequentialBlock getModel() {
SequentialBlock net = new SequentialBlock();
// 卷积层1:3通道(RGB),96个11×11卷积核,步长4
net.add(Conv2d.builder()
.setKernelShape(new Shape(11, 11))
.optStride(new Shape(4, 4))
.setFilters(96)
.build());
net.add(Activation.reluBlock()); // ReLU激活(LeNet用sigmoid,ReLU解决梯度消失)
net.add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2))); // 池化层1
// 卷积层2:96通道,256个5×5卷积核,填充2
net.add(Conv2d.builder()
.setKernelShape(new Shape(5, 5))
.optPadding(new Shape(2, 2))
.setFilters(256)
.build());
net.add(Activation.reluBlock());
net.add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2))); // 池化层2
// 连续3个卷积层(无池化)
// 卷积层1:256通道,384个3×3卷积核,填充1
net.add(Conv2d.builder().setKernelShape(new Shape(3,3)).optPadding(new Shape(1,1)).setFilters(384).build());
net.add(Activation.reluBlock());
// 卷积层2:384通道,384个3×3卷积核,填充1
net.add(Conv2d.builder().setKernelShape(new Shape(3,3)).optPadding(new Shape(1,1)).setFilters(384).build());
net.add(Activation.reluBlock());
// 卷积层3:384通道,256个3×3卷积核,填充1
net.add(Conv2d.builder().setKernelShape(new Shape(3,3)).optPadding(new Shape(1,1)).setFilters(256).build());
net.add(Activation.reluBlock());
// 池化层3:2×2最大池化,步长2
net.add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2)));
// 全连接层(带Dropout防止过拟合)
net.add(Blocks.batchFlattenBlock());
net.add(Linear.builder().setUnits(4096).build());
net.add(Activation.reluBlock());
net.add(Dropout.builder().optRate(0.5f).build()); // Dropout率50%
net.add(Linear.builder().setUnits(4096).build());
net.add(Activation.reluBlock());
net.add(Dropout.builder().optRate(0.5f).build());
net.add(Linear.builder().setUnits(1000).build()); // 输出1000类(ImageNet)
return net;
}
public static void main(String[] args) {
SequentialBlock block = getModel();
// 打印网络结构
System.out.println(block);
}
}
核心改进点总览(按重要性排序)
| 改进维度 | LeNet(原代码) | AlexNet(新代码) | 改进的核心价值 |
|---|---|---|---|
| 激活函数 | 无显式激活层(隐含用 Sigmoid) | 全量使用 ReLU 激活 | 解决 Sigmoid 的梯度消失问题,训练速度提升数倍 |
| 网络规模 | 轻量(2 卷积 + 3 全连接) | 深度 / 宽度大幅提升(5 卷积 + 3 全连接) | 能提取更复杂的图像特征,适配 1000 类分类(LeNet 仅 10 类) |
| 卷积核设计 | 固定 5×5 小卷积核,步长 1 | 多尺寸卷积核(11×11/5×5/3×3)+ 大步长(4) | 11×11 大核快速降维,3×3 小核精细提取特征 |
| 正则化机制 | 无正则化 | 加入 Dropout(0.5) | 防止过拟合(全连接层参数极多,过拟合风险高) |
| 输入适配 | 1 通道 28×28 灰度图 | 3 通道 224×224 RGB 图 | 适配真实彩色图像(ImageNet 数据集) |
| 填充策略 | 固定 0 填充(VALID) | 动态填充(2/1) | 避免边缘特征丢失,保证特征图尺寸合理 |
| 池化设计 | 2×2 池化,步长 2 | 3×3 池化,步长 2 | 更大池化窗口保留更多全局特征,降低信息损失 |
| 代码封装 | 训练 / 推理逻辑耦合 | 纯结构封装(getModel 方法) | 代码复用性更高,便于后续训练 / 推理扩展 |
逐点详细解析(结合代码)
- 激活函数:ReLU 替代 Sigmoid(最核心改进)
- LeNet 代码:无显式激活层,传统 LeNet 用 Sigmoid 激活,存在梯度消失问题(深层网络梯度趋近于 0,无法训练);
- AlexNet 代码:每个卷积 / 全连接层后都加
Activation.reluBlock(),注释也明确标注 “ReLU 解决梯度消失”; - 价值:ReLU 的梯度在正数区间恒为 1,能支撑更深的网络训练,这是 AlexNet 能做 5 层卷积的关键。
- 网络规模:深度 + 宽度双提升
-
深度:LeNet 仅 2 层卷积,AlexNet 升级为 5 层卷积(前 2 层大核 + 后 3 层小核);
-
宽度:
- LeNet 卷积核数量:6→16(通道数);
- AlexNet 卷积核数量:96→256→384→384→256(通道数提升 10 倍 +);
- 全连接层:LeNet 120→84→10,AlexNet 4096→4096→1000(参数从万级提升到亿级);
-
价值:更多的卷积层 / 通道数能提取 “边缘→纹理→形状→物体” 的层级特征,满足 1000 类分类的特征需求。
- 卷积核设计:多尺寸 + 大步长适配复杂图像
-
LeNet:只有 5×5 卷积核,步长 1(仅适配 28×28 小图);
-
AlexNet:
- 第一层用 11×11 大卷积核 + 步长 4:快速将 224×224 的输入降维(224→55),减少计算量;
- 中间用 5×5 卷积核过渡,最后 3 层用 3×3 小卷积核:精细提取图像的局部细节特征;
-
价值:多尺寸卷积核兼顾 “快速降维” 和 “精细提特征”,适配 224×224 的大图。
- 正则化:Dropout 防止过拟合
- LeNet:无任何正则化(仅 10 类分类,参数少,过拟合风险低);
- AlexNet:全连接层后加入
Dropout.builder().optRate(0.5f).build(),随机让 50% 神经元失活; - 价值:AlexNet 全连接层有 4096×4096=1600 万参数,过拟合风险极高,Dropout 能有效降低风险。
- 输入与填充:适配真实彩色图像
-
输入通道:LeNet 是 1 通道(MNIST 灰度图),AlexNet 是 3 通道(RGB 彩色图),符合真实场景;
-
填充策略:
- LeNet 固定 0 填充:卷积后特征图尺寸大幅缩小(28→24),边缘特征易丢失;
- AlexNet 用 2/1 填充:比如卷积层 2 的
optPadding(new Shape(2, 2)),保证卷积后特征图尺寸合理,边缘特征不丢失;
-
价值:适配 ImageNet 彩色数据集,最大化保留图像特征。
- 池化设计:3×3 池化提升特征鲁棒性
- LeNet:2×2 最大池化,步长 2;
- AlexNet:3×3 最大池化,步长 2;
- 价值:更大的池化窗口能覆盖更多像素,保留更全局的特征,提升模型对图像平移 / 缩放的鲁棒性。
- 代码封装:结构与逻辑解耦
- LeNet:
createLeNet()返回 Block,main 方法中耦合了初始化、前向传播等逻辑; - AlexNet:
getModel()纯返回网络结构,main 方法仅打印结构,无多余逻辑; - 价值:代码复用性更高,后续可快速接入训练 / 推理逻辑,符合 “单一职责” 设计原则。
核心总结
AlexNet 相对于 LeNet 的改进,本质是从 “适配简单手写数字分类” 升级为 “适配复杂真实图像分类”,核心改进可归纳为 3 点:
- 网络能力升级:更深的层数、更多的通道数,能提取复杂特征;
- 训练稳定性升级:ReLU 解决梯度消失,Dropout 防止过拟合;
- 场景适配升级:多通道、多尺寸卷积核、填充策略适配彩色大图。
简单来说,LeNet 是 CNN 的 “雏形”,适合简单小任务;AlexNet 是 CNN 的 “成熟版本”,奠定了现代 CNN 的基础架构,能处理工业级的复杂图像分类任务。
ResNet 简化版(残差连接解决梯度消失)
核心:残差块(Residual Block),通过短路连接将输入直接加到输出,解决深层网络梯度消失问题
ResNet 通过 “残差连接(Residual Connection)” 解决了深层网络的梯度消失问题,同时引入批量归一化(BatchNorm)、全局平均池化等现代 CNN 技术,实现了 “更深的网络 + 更高的训练稳定性
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.*;
import ai.djl.nn.convolutional.Conv2d;
import ai.djl.nn.core.Linear;
import ai.djl.nn.norm.BatchNorm;
import ai.djl.nn.pooling.Pool;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.util.Arrays;
import java.util.List;
import java.util.function.Function;
/**
* @author XiangWei
* @return null
* @date 2026/1/12 16:28
* @description 简单ResNet模型(18层)
*/
public class ResNetSimple {
// 构建恒等映射块(解决空块问题)
private static Block identityBlock() {
return new LambdaBlock((Function<NDList, NDList>) inputs -> inputs);
}
// 构建残差块:确保所有分支非空
private static Block residualBlock(int channels, boolean downSample) {
int stride = downSample ? 2 : 1;
// 1. 主分支
SequentialBlock mainBranch = new SequentialBlock();
mainBranch.add(Conv2d.builder()
.setKernelShape(new Shape(3, 3)) // 3x3卷积核
.optStride(new Shape(stride, stride)) // 步长2(下采样)
.optPadding(new Shape(1, 1)) // 填充1保持特征图尺寸
.setFilters(channels)
.build());
mainBranch.add(BatchNorm.builder().build());
mainBranch.add(Activation.reluBlock());
mainBranch.add(Conv2d.builder() // 3x3卷积核
.setKernelShape(new Shape(3, 3))
.optStride(new Shape(1, 1)) // 步长1(保持特征图尺寸)
.optPadding(new Shape(1, 1)) // 填充1保持特征图尺寸
.setFilters(channels)
.build());
mainBranch.add(BatchNorm.builder().build());
// 2. 短路分支:确保非空(空时添加恒等映射)
SequentialBlock shortcutBranch = new SequentialBlock();
if (downSample) {
shortcutBranch.add(Conv2d.builder()
.setKernelShape(new Shape(1, 1)) // 1x1卷积核
.optStride(new Shape(stride, stride)) // 步长2(下采样)
.setFilters(channels)
.build());
shortcutBranch.add(BatchNorm.builder().build());
} else {
// 关键:空分支时添加恒等映射,避免SequentialBlock为空
shortcutBranch.add(identityBlock());
}
// 3. ParallelBlock合并函数:仅执行相加
Function<List<NDList>, NDList> mergeFunction = ndLists -> {
// 主分支输出 + 短路分支输出
return new NDList(ndLists.get(0).get(0).add(ndLists.get(1).get(0)));
};
// 构造ParallelBlock(子块均非空)
ParallelBlock parallelBlock = new ParallelBlock(
mergeFunction,
Arrays.asList(mainBranch, shortcutBranch)
);
// 4. 组合残差块:并行执行 + 相加 + ReLU
SequentialBlock residualBlock = new SequentialBlock();
residualBlock.add(parallelBlock);
residualBlock.add(Activation.reluBlock());
return residualBlock;
}
// 构建简化版ResNet(8层)
public static SequentialBlock getModel() {
SequentialBlock net = new SequentialBlock();
// 初始卷积层:适配224×224 RGB图像(3通道)
net.add(Conv2d.builder()
.setKernelShape(new Shape(7, 7)) // 7x7卷积核
.optStride(new Shape(2, 2)) // 步长2(下采样)
.optPadding(new Shape(3, 3)) // 填充3保持特征图尺寸
.setFilters(64) // 64个卷积核
.build());
net.add(BatchNorm.builder().build());
net.add(Activation.reluBlock());
net.add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2)));
// 堆叠残差块
net.add(residualBlock(64, false)); // 恒等映射残差块(短路分支非空)
net.add(residualBlock(64, false)); // 恒等映射残差块
net.add(residualBlock(128, true)); // 下采样残差块
net.add(residualBlock(128, false)); // 恒等映射残差块
// 全局平均池化 + 展平 + 全连接(10分类)
net.add(Pool.globalAvgPool2dBlock()); // 输出:(batch, 128, 1, 1)
net.add(Blocks.batchFlattenBlock()); // 展平:(batch, 128)
net.add(Linear.builder().setUnits(10).build()); // 128→10维输出
return net;
}
// 测试模型初始化+推理
public static void main(String[] args) {
try (NDManager manager = NDManager.newBaseManager()) {
SequentialBlock model = getModel();
// 1. 初始化模型(正确参数顺序)
model.initialize(manager, DataType.FLOAT32, new Shape(1, 3, 224, 224));
System.out.println("模型输出形状:" + Arrays.toString(model.getOutputShapes(new Shape[]{new Shape(1, 3, 224, 224)})));
// 2. 推理测试(关键:创建有效的ParameterStore)
NDList input = new NDList(manager.ones(new Shape(1, 3, 224, 224)));
// 创建ParameterStore(推理阶段isTraining=false)
ParameterStore parameterStore = new ParameterStore(manager, false);
// 调用forward(传入有效的ParameterStore)
NDList output = model.forward(parameterStore, input, false, new PairList<>());
System.out.println("输入形状:" + input.get(0).getShape());
System.out.println("输出形状:" + output.get(0).getShape());
} catch (Exception e) {
System.err.println("运行失败:" + e.getMessage());
e.printStackTrace();
}
}
}
核心改进点总览(对比 LeNet/AlexNet)
| 改进维度 | LeNet/AlexNet | ResNet(本代码) | 改进的核心价值 |
|---|---|---|---|
| 核心架构 | 串行卷积 / 全连接(无分支) | 残差块(主分支 + 短路分支) | 解决深层网络梯度消失,可训练数十 / 上百层网络 |
| 归一化机制 | 无归一化 | 批量归一化(BatchNorm) | 加速训练收敛,降低初始化敏感,提升稳定性 |
| 池化设计 | 普通最大池化 + 展平 | 全局平均池化(Global AvgPool) | 替代全连接层的部分作用,减少参数数量,降低过拟合 |
| 分支处理 | 无分支逻辑 | 恒等映射填充空分支,ParallelBlock 合并 | 解决 DJL 框架中空块报错问题,保证分支非空 |
| 网络扩展性 | 固定层数(LeNet8 层 / AlexNet8 层) | 模块化残差块,可堆叠扩展 | 仅需修改残差块数量即可实现 ResNet18/34/50 |
| 工程健壮性 | 无异常处理 / 参数校验 | 完整的异常捕获、状态提示、形状校验 | 100% 可运行,便于调试和生产环境使用 |
逐点详细解析
- 核心创新:残差连接(Residual Block)—— 解决深层网络梯度消失
这是 ResNet 最核心的改进,也是和 LeNet/AlexNet 的本质区别:
-
LeNet/AlexNet:网络是 “串行堆叠”,层数加深后梯度会逐层衰减(梯度消失),最多只能训练十几层;
-
ResNet 代码:
- 设计
residualBlock()方法,构建 “主分支(卷积 + BN+ReLU)+ 短路分支(恒等映射 / 下采样)”; - 通过
ParallelBlock将两个分支输出相加(ndLists.get(0).get(0).add(ndLists.get(1).get(0))),形成 “残差连接”; - 核心逻辑:让网络学习 “输入与输出的残差(差值)” 而非直接学习输出,梯度可通过短路分支直接回传,即使网络很深也不会梯度消失。
- 设计
-
代码关键片段:
// 残差相加核心逻辑 Function<List<NDList>, NDList> mergeFunction = ndLists -> { return new NDList(ndLists.get(0).get(0).add(ndLists.get(1).get(0))); }; -
价值:本代码仅堆叠 4 个残差块(8 层),但基于该结构可轻松扩展到 ResNet18(18 层)、ResNet50(50 层),而 LeNet/AlexNet 层数加深后会直接训练失败。
- 批量归一化(BatchNorm)—— 提升训练稳定性
-
LeNet/AlexNet:无归一化层,训练时需精细调整学习率、初始化方式,否则易发散;
-
ResNet 代码:每个卷积层后都添加
BatchNorm.builder().build(); -
核心作用:
- 将卷积输出的特征值归一化到 “均值 0、方差 1”,避免数值过大 / 过小导致梯度爆炸 / 消失;
- 加速训练收敛(学习率可设更大),降低对初始化的敏感度;
- 轻微正则化效果,减少过拟合风险。
- 全局平均池化(Global AvgPool)—— 简化全连接层
-
LeNet/AlexNet:用普通最大池化 + 展平 + 大尺寸全连接层(如 AlexNet 的 4096 维),参数多、过拟合风险高;
-
ResNet 代码:
net.add(Pool.globalAvgPool2dBlock()); -
核心作用:
- 将最后一层卷积的特征图(如 128×7×7)直接转为 128×1×1(对每个通道取全局平均值);
- 替代传统的 “池化 + 展平 + 大全连接层”,大幅减少参数数量(本代码仅用 128→10 的全连接层,AlexNet 是 4096→1000);
- 更贴合卷积层的空间特征,提升泛化能力。
- 工程化改进:解决 DJL 框架的 “空块” 问题
这是针对 DJL 框架的实用改进,LeNet/AlexNet 未涉及:
-
问题背景:DJL 的
SequentialBlock如果为空(如短路分支无需下采样时),会触发空指针 / 初始化报错; -
ResNet 代码:
-
定义
identityBlock()方法,返回一个 “恒等映射块”(输入 = 输出); -
当短路分支无需下采样时,添加identityBlock()填充,确保SequentialBlock非空:
if (downSample) { // 下采样分支(卷积+BN) } else { shortcutBranch.add(identityBlock()); // 空分支填充恒等映射 }
-
-
价值:保证代码 100% 可运行,避免框架层面的异常,这是 LeNet/AlexNet 代码未考虑的工程细节。
- 模块化设计:残差块可复用、易扩展
-
LeNet/AlexNet:网络结构是 “硬编码”,修改层数需大幅调整代码;
-
ResNet 代码:
-
将残差块封装为
residualBlock(int channels, boolean downSample)方法,参数化控制通道数、是否下采样; -
堆叠残差块仅需调用方法:
net.add(residualBlock(64, false)); // 恒等映射 net.add(residualBlock(128, true)); // 下采样
-
-
价值:仅需修改残差块的数量 / 参数,即可快速实现 ResNet18(8 个残差块)、ResNet34(16 个残差块),复用性远超 LeNet/AlexNet。
核心总结
ResNet 代码相对于 LeNet/AlexNet 的改进,核心可归纳为 3 点:
- 架构创新:残差连接解决深层网络梯度消失问题,实现 “更深的网络 = 更好的性能”;
- 技术升级:BatchNorm 加速收敛、全局 AvgPool 减少参数,提升训练稳定性和泛化能力;
- 工程优化:模块化设计、空块处理、异常捕获,保证代码可运行、可扩展、可维护。
简单来说,LeNet/AlexNet 是 “线性堆叠的网络”,层数加深就会 “退化”;而 ResNet 是 “带短路的残差网络”,层数越深性能越好,这也是 ResNet 能成为 ImageNet 竞赛冠军、并广泛应用于工业界的核心原因。
CNN 应用场景
- 图像分类:核心场景(如 MNIST 手写数字、ImageNet 图像分类),通过 CNN 提取特征后全连接层分类。
- 目标检测进阶:基于 CNN 的检测模型(如 YOLO、Faster R-CNN),CNN 负责提取图像特征,再结合锚框、区域提议等完成目标定位与分类。
- 其他场景:图像分割(U-Net)、人脸识别、图像生成(GAN)、医学影像分析等。
DJL 中 CNN 相关 API 与层封装
| API 类 | 作用 | 核心参数 |
|---|---|---|
Conv2d |
二维卷积层 | filters(输出通道)、kernelShape(卷积核尺寸)、stride(步长)、padding(填充) |
MaxPool2d/AvgPool2d |
最大 / 平均池化层 | kernelShape(池化核尺寸)、stride(步长) |
BatchNorm |
批归一化层 | 加速训练、防止梯度消失 |
Activation |
激活函数层(ReLU/Sigmoid) | reluBlock()/sigmoidBlock() |
Dropout |
Dropout 层 | rate(丢弃率) |
Linear |
全连接层 | units(输出维度) |
循环神经网络(RNN)与序列模型
什么是RNN
循环神经网络(Recurrent Neural Network,RNN)是一种专门用于处理序列数据的深度学习模型,核心特点是网络内部包含 “循环连接”,能够利用历史信息(上下文)来预测当前输出,解决了传统前馈神经网络(如 CNN、全连接网络)无法处理时序依赖关系的问题。
核心痛点:前馈网络的局限性
对于序列数据(如文本、语音、时间序列),数据的顺序至关重要:
- 比如理解句子
“我今天吃了苹果,它很甜”时,“它”指代的是前面的“苹果”; - 比如预测股票价格时,今天的价格和过去几天的价格高度相关。
而前馈网络(LeNet、AlexNet、ResNet)的特点是输入与输出独立,每一次计算都不会保留之前的信息,相当于 “没有记忆”,无法捕捉这种时序依赖。
RNN 的核心思想:引入 “记忆” 机制
RNN 的本质是在网络中加入一个隐藏状态(Hidden State),这个状态会记录之前的输入信息,并参与当前步的计算,形成 “循环”。
RNN 的关键特点
-
参数共享
- 与 CNN 的卷积核参数共享类似,RNN 的权重矩阵在所有时刻共享,而非每个时刻都有独立参数;
- 例如处理一个长度为 10 的序列,RNN 只需一套参数,而前馈网络需要 10 套独立参数,极大降低了过拟合风险。
-
处理变长序列
- RNN 不限制输入序列的长度,可灵活处理不同长度的文本、语音等数据(如一句话有 5 个单词或 10 个单词都能处理)。
-
局限性:梯度消失 / 爆炸
- 标准 RNN 的最大问题是无法捕捉长距离依赖:当序列很长时(如超过 10 个时刻),梯度在反向传播过程中会快速衰减(梯度消失)或急剧增大(梯度爆炸);
- 导致 RNN 只能记住短期记忆,无法记住长距离的信息(比如一句话开头的单词和结尾的单词的关联)。
RNN 原理
-
循环核结构:核心是隐藏状态(H)的循环更新,公式:Ht=σ(XtWxh+Ht−1Whh+bh)H_t = \sigma(X_t W_{xh} + H_{t-1} W_{hh} + b_h)Ht=σ(XtWxh+Ht−1Whh+bh),Yt=HtWhy+byY_t = H_t W_{hy} + b_yYt=HtWhy+by(Xt为当前输入,W为权重,b为偏置)。
-
隐藏状态传递:每个时间步的隐藏状态依赖前一个时间步的隐藏状态,实现序列信息的记忆。
-
梯度消失 / 爆炸问题:
- 梯度消失:深层 RNN 中,早期时间步的梯度传递到输出时趋近于 0,模型无法学习长距离依赖。
- 梯度爆炸:梯度值过大导致参数更新异常,可通过梯度裁剪缓解。
改进 RNN 结构(LSTM、GRU 原理与优势)
为了克服标准 RNN 的缺陷,研究者提出了LSTM(长短期记忆网络) 和 GRU(门控循环单元),它们是目前实际应用中最常用的 RNN 变体:
LSTM(长短期记忆网络)
-
核心结构:输入门、遗忘门、输出门 + 细胞状态(C),解决长距离依赖问题。
- 遗忘门:控制遗忘多少历史细胞状态,ft=σ(XtWxf+Ht−1Whf+bf)f_t = \sigma(X_t W_{xf} + H_{t-1} W_{hf} + b_f)ft=σ(XtWxf+Ht−1Whf+bf)。
- 输入门:控制更新多少新信息到细胞状态,it=σ(XtWxi+Ht−1Whi+bi)i_t = \sigma(X_t W_{xi} + H_{t-1} W_{hi} + b_i)it=σ(XtWxi+Ht−1Whi+bi),C~t=tanh(XtWxc+Ht−1Whc+bc)\tilde{C}_t = \tanh(X_t W_{xc} + H_{t-1} W_{hc} + b_c)C~t=tanh(XtWxc+Ht−1Whc+bc)。
- 细胞状态更新:Ct=ft⊙Ct−1+it⊙C~tC_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_tCt=ft⊙Ct−1+it⊙C~t(⊙为按元素乘)。
- 输出门:控制输出多少细胞状态到隐藏状态,ot=σ(XtWxo+Ht−1Who+bo)o_t = \sigma(X_t W_{xo} + H_{t-1} W_{ho} + b_o)ot=σ(XtWxo+Ht−1Who+bo)。
-
优势:通过细胞状态的长期存储和门控机制,有效缓解梯度消失,能学习长距离依赖。
GRU(门控循环单元)
-
核心结构:重置门、更新门,简化 LSTM 结构,参数更少,训练更快。
- 更新门:融合遗忘门和输入门,控制保留多少历史隐藏状态,zt=σ(XtWxz+Ht−1Whz+bz)z_t = \sigma(X_t W_{xz} + H_{t-1} W_{hz} + b_z)zt=σ(XtWxz+Ht−1Whz+bz)。
- 重置门:控制遗忘多少历史隐藏状态,rt=σ(XtWxr+Ht−1Whr+br)r_t = \sigma(X_t W_{xr} + H_{t-1} W_{hr} + b_r)rt=σ(XtWxr+Ht−1Whr+br)。
- 候选隐藏状态:H~t=tanh(XtWxh+rt⊙Ht−1Whh+bh)\tilde{H}_t = \tanh(X_t W_{xh} + r_t \odot H_{t-1} W_{hh} + b_h)H~t=tanh(XtWxh+rt⊙Ht−1Whh+bh)。
- 隐藏状态更新:Ht=(1−zt)⊙Ht−1+zt⊙H~tH_t = (1 - z_t) \odot H_{t-1} + z_t \odot \tilde{H}_tHt=(1−zt)⊙Ht−1+zt⊙H~t。
-
优势:结构更简单,计算效率更高,效果接近 LSTM,适合数据量较大的场景。
RNN 的典型应用场景
- 自然语言处理(NLP):文本生成、机器翻译、情感分析、命名实体识别;
- 语音识别:将语音信号(时序数据)转换为文字;
- 时间序列预测:股票价格预测、天气预报、销量预测;
- 视频分析:视频帧的动作识别(视频是连续的图像序列)。
序列任务实现(文本分类、时序预测)
LSTM
为什么需要 LSTM?(从 RNN 的痛点说起)
在讲 LSTM 之前,先理解它要解决的问题 —— 传统 RNN 处理文本时的「记忆短板」:
- 比如句子「我昨天看了一部电影,它的剧情特别精彩,我觉得这部____很好看」,RNN 在填最后一个空时,很难记住前面的「电影」这个词(长距离依赖);
- 原因:RNN 的梯度会随着序列长度增加而「消失」或「爆炸」,导致模型记不住早出现的信息;
- LSTM 的核心目标:专门设计了「记忆门控机制」,让模型能选择性地记住 / 忘记信息,解决长依赖问题。
LSTM 的核心原理
可以把 LSTM 想象成一个「带记忆功能的信息处理工厂」,每个 LSTM 单元(处理一个单词)有 3 个核心「门」和 1 个「记忆细胞」:
| 组件 | 通俗功能(结合文本分类示例) |
|---|---|
| 遗忘门(Forget Gate) | 决定「忘记哪些旧信息」:比如处理到「精彩」时,忘记无关的「昨天」,保留「电影」 |
| 输入门(Input Gate) | 决定「记住哪些新信息」:比如记住「剧情精彩」这个关键信息 |
| 输出门(Output Gate) | 决定「输出哪些信息给下一个单元」:把「电影 + 剧情精彩」传给下一个单词的处理单元 |
| 记忆细胞(Cell State) | 相当于「长期记忆」:全程保存关键信息(如「电影」),直到需要时使用 |
核心流程
我们的示例中,LSTM 处理的是「50 个单词的序列(One-Hot 编码后)」,每个单词的处理流程:
- 输入:当前单词的 128 维词向量(One-Hot 经 Linear 层转换);
- 遗忘门:筛选并忘记无关信息(如无意义的语气词);
- 输入门:把当前单词的关键信息存入记忆细胞;
- 输出门:生成当前步的输出(64 维隐藏状态);
- 传递:把输出传给下一个单词的 LSTM 单元,同时记忆细胞保留关键信息。
最终,50 个单词处理完后,LSTM 会输出一个 (32,50,64) 的张量(32 个样本,50 个单词,每个单词 64 维特征),这就是「序列的特征表示」。
LSTM 层的核心价值
// 我们示例中的LSTM层:捕捉50个单词的上下文依赖
net.add(createLSTMBlock(hiddenSize, 1)); // hiddenSize=64 → 输出64维特征
- 输入:
(32,50,128)(32 个样本,50 个单词,每个单词 128 维词向量); - 输出:
(32,50,64)(32 个样本,50 个单词,每个单词 64 维上下文特征); - 关键价值:不再孤立处理每个单词,而是结合上下文(比如「不好看」中的「不」会反转「好看」的语义)。
比如句子「这部电影一点都不好看」:
- 传统 RNN 可能只记住「好看」,输出正面情感;
- LSTM 能通过门控机制记住「不」这个否定词,正确识别为负面情感。
自定义数据集
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.dataset.Record;
import ai.djl.util.Progress;
import java.io.IOException;
import java.util.Random;
// 自定义IMDB数据集
public class IMDBTextDataset extends RandomAccessDataset {
private final int vocabSize; // 词汇表大小
private final int maxSeqLength; // 最大序列长度
private final long[][] data; // 存储文本数据(索引序列)
private final long[] labels; // 存储标签(0/1)
private final long datasetSize; // 数据集大小
public static class Builder extends BaseBuilder<Builder> {
private int vocabSize = 10000;
private int maxSeqLength = 50; // 最大序列长度
private NDManager manager;
private long datasetSize = 1000; // 数据集大小
public Builder setManager(NDManager manager) {
this.manager = manager;
return this;
}
public Builder setVocabSize(int vocabSize) {
this.vocabSize = vocabSize;
return this;
}
public Builder setMaxSequenceLength(int maxSeqLength) {
this.maxSeqLength = maxSeqLength;
return this;
}
public Builder setDatasetSize(long datasetSize) {
this.datasetSize = datasetSize;
return this;
}
public IMDBTextDataset build() {
this.setSampling(32, false);
return new IMDBTextDataset(this);
}
@Override
protected Builder self() {
return this;
}
}
private IMDBTextDataset(Builder builder) {
super(builder);
this.vocabSize = builder.vocabSize;
this.maxSeqLength = builder.maxSeqLength;
this.datasetSize = builder.datasetSize;
this.data = new long[(int) datasetSize][maxSeqLength];
this.labels = new long[(int) datasetSize];
// 预先生成数据(保存为基本类型数组)
Random random = new Random(12345);
for (int i = 0; i < datasetSize; i++) {
for (int j = 0; j < maxSeqLength; j++) {
data[i][j] = random.nextInt(vocabSize);
}
labels[i] = random.nextInt(2);
}
}
@Override
public Record get(NDManager manager, long index) throws IOException {
if (index >= datasetSize) {
throw new IndexOutOfBoundsException("索引超出范围");
}
// 每次创建新的NDArray,避免资源冲突
long[] wordIndexes = data[(int) index];
NDArray indexArray = manager.create(wordIndexes).reshape(new Shape(maxSeqLength));
// 生成One-Hot编码
NDArray oneHotText = indexArray.oneHot(vocabSize).toType(DataType.FLOAT32, false);
// 创建标签
NDArray label = manager.create(labels[(int) index]).reshape(new Shape(1));
return new Record(new NDList(oneHotText), new NDList(label));
}
@Override
protected long availableSize() {
return datasetSize;
}
@Override
public void prepare(Progress progress) {
// 预处理:空实现
}
public static Builder builder() {
return new Builder();
}
}
One-Hot 编码
One-Hot 编码(独热编码)是一种将离散型特征(如单词、类别)转换为机器学习模型可理解的数值形式的编码方式。
- 核心逻辑:为每个离散值创建一个长度等于「类别总数」的向量,只有该值对应的位置为 1,其余位置全为 0。
- 举例:如果词汇表包含 [“电影”, “好看”, “不好看”] 3 个词,那么:
- “电影” → [1, 0, 0]
- “好看” → [0, 1, 0]
- “不好看” → [0, 0, 1]
在我们的 LSTM 文本分类示例中,文本是由单词组成的,而模型只能处理数字,所以需要把「单词」这个离散特征转为数字:
-
第一步:建立词汇表(词到索引的映射)
比如把所有出现的单词编号:
{"我":0, "爱":1, "这部":2, "电影":3, "烂":4}(词汇表大小 = 5)。
-
第二步:One-Hot 编码转换
每个单词对应一个「长度 = 词汇表大小」的向量,只有编号对应的位置为 1:
- “我” → [1, 0, 0, 0, 0]
- “爱” → [0, 1, 0, 0, 0]
- “烂” → [0, 0, 0, 0, 1]
-
第三步:文本序列转换
比如句子「我 爱 这部 电影」→ 转为 4 个 One-Hot 向量的组合:
[1,0,0,0,0] # 我 [0,1,0,0,0] # 爱 [0,0,1,0,0] # 这部 [0,0,0,1,0] # 电影对应到我们代码中的形状就是
(1, 4, 5)(1 个样本,4 个单词,词汇表大小 5)。
核心应用场景
在我们的 LSTM 文本分类示例中,One-Hot 是文本预处理的关键步骤,除此之外,它还广泛用于:
- 自然语言处理(NLP):单词 / 字符的数值化(如我们的示例);
- 类别特征处理:如性别(男 / 女)、职业(教师 / 医生 / 工程师)等离散特征;
- 多分类任务:标签的编码(如情感分类的「正面 = 0→[1,0],负面 = 1→[0,1]」)。
测试代码
import ai.djl.Device;
import ai.djl.Model;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.core.Linear;
import ai.djl.nn.recurrent.LSTM;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.loss.SoftmaxCrossEntropyLoss;
import ai.djl.training.optimizer.Adam;
import java.lang.reflect.Field;
public class LSTMTextClassification {
// LSTM层创建
private static Block createLSTMBlock(int hiddenSize, int numLayers) {
try {
LSTM.Builder builder = LSTM.builder();
builder.setStateSize(hiddenSize);
// 反射设置numLayers
Field numLayersField = builder.getClass().getSuperclass().getDeclaredField("numLayers");
numLayersField.setAccessible(true);
numLayersField.set(builder, numLayers);
return builder.build();
} catch (Exception e) {
throw new RuntimeException("创建LSTM层失败", e);
}
}
// 构建LSTM文本分类模型
public static SequentialBlock getModel(int embeddingSize, int hiddenSize, int numClasses) {
SequentialBlock net = new SequentialBlock();
// 模拟Embedding层:One-Hot → Linear → 词向量
net.add(Linear.builder().setUnits(embeddingSize).build());
// LSTM层:捕捉序列依赖
net.add(createLSTMBlock(hiddenSize, 1));
// 池化层(将(32,50,64) → (32,64),解决序列长度维度问题)
net.add(ndList -> {
NDArray x = ndList.get(0);
// 对序列维度(第1维)做均值池化,消除50这个维度
return new NDList(x.mean(new int[]{1}));
});
// 分类输出层:输入(32,64) → 输出(32,2)(匹配标签形状)
net.add(Linear.builder().setUnits(numClasses).build());
return net;
}
// 主方法:完整训练流程
public static void main(String[] args) throws Exception {
// 1. 强制指定PyTorch引擎
System.setProperty("ai.djl.default_engine", "PyTorch");
// 2. 模型核心参数
int vocabSize = 10000; // 词汇表大小
int embeddingSize = 128; // 模拟词向量维度
int hiddenSize = 64; // LSTM隐藏层维度
int numClasses = 2; // 分类类别数(正面/负面)
int batchSize = 32; // 批量大小
int maxSeqLength = 50; // 序列长度
// 3. 初始化核心组件
try (NDManager manager = NDManager.newBaseManager()) {
// 构建并初始化模型
SequentialBlock modelBlock = getModel(embeddingSize, hiddenSize, numClasses);
Model model = Model.newInstance("lstm-text-classification");
model.setBlock(modelBlock);
// 初始化模型:输入形状 (batchSize, maxSeqLength, vocabSize)
Shape inputShape = new Shape(batchSize, maxSeqLength, vocabSize);
modelBlock.initialize(manager, DataType.FLOAT32, inputShape);
// 配置训练
DefaultTrainingConfig trainingConfig = new DefaultTrainingConfig(new SoftmaxCrossEntropyLoss())
.optOptimizer(Adam.builder().build()) // Adam优化器
.optDevices(new Device[]{Device.cpu()}); // CPU训练
// 加载数据集
IMDBTextDataset trainDataset = IMDBTextDataset.builder()
.setManager(manager)
.setVocabSize(vocabSize)
.setMaxSequenceLength(maxSeqLength)
.setDatasetSize(1000) // 1000条模拟样本
.build();
// 训练模型
try (Trainer trainer = model.newTrainer(trainingConfig)) {
trainer.initialize(inputShape); // 初始化训练器
System.out.println("===== 开始训练LSTM文本分类模型 =====");
System.out.println("📌 输入形状:" + inputShape);
System.out.println("📌 数据集大小:" + trainDataset.size() + "\n");
int epochs = 3; // 训练轮数
for (int epoch = 0; epoch < epochs; epoch++) {
EasyTrain.fit(trainer, 1, trainDataset, null);
System.out.println("第 " + (epoch + 1) + " 轮训练完成\n");
}
} finally {
model.close();
}
}
}
}
时序预测(基于 GRU)
GRU 是什么?(和 LSTM 的关系)
GRU 是 2014 年提出的「轻量化 LSTM」,核心目标和 LSTM 一致 —— 解决传统 RNN 的长依赖问题,但做了大幅简化:
- 去掉了 LSTM 的「记忆细胞(Cell State)」,只用「隐藏状态(Hidden State)」存储信息;
- 把 LSTM 的「遗忘门 + 输入门」合并为 1 个「更新门(Update Gate)」,保留「重置门(Reset Gate)」;
- 最终只有 2 个门控(LSTM 是 3 个),参数更少、计算更快,效果却和 LSTM 接近。
简单总结:GRU = 简化版 LSTM,适合想减少计算量、又要处理长依赖的场景(比如我们的文本分类示例)。
GRU 的核心原理(对比 LSTM)
同样把 GRU 想象成「信息处理工厂」,但比 LSTM 更精简,只有 2 个核心门:
| GRU 组件 | 通俗功能(结合文本分类示例) | 对比 LSTM |
|---|---|---|
| 重置门(Reset Gate) | 决定「是否忘记之前的信息」:比如处理到「不好看」时,重置掉「好看」的正面信息,只保留「不」的否定 | 替代 LSTM 的「遗忘门」部分功能 |
| 更新门(Update Gate) | 决定「多少旧信息保留 + 多少新信息加入」:比如保留「电影」的核心信息,加入「不」的否定信息 | 合并 LSTM 的「遗忘门 + 输入门」 |
| 隐藏状态(Hidden State) | 同时承担 LSTM「隐藏状态 + 记忆细胞」的作用,全程保存关键信息 | 简化了 LSTM 的双状态设计 |
GRU 处理文本的核心流程(以「这部电影不好看」为例):
- 输入:「不」的 128 维词向量;
- 重置门:触发!决定忘记「好看」的正面信息;
- 更新门:触发!保留「电影」的核心信息,加入「不」的否定信息;
- 隐藏状态更新:把「电影 + 不」的组合信息存入隐藏状态;
- 输出:传给下一个单词的处理单元,最终体现「负面情感」。
GRU vs LSTM
| 维度 | GRU | LSTM |
|---|---|---|
| 门控数量 | 2 个(重置门 + 更新门) | 3 个(遗忘门 + 输入门 + 输出门) |
| 状态数量 | 1 个(隐藏状态) | 2 个(隐藏状态 + 记忆细胞) |
| 参数数量 | 更少(计算更快,内存占用小) | 更多(计算稍慢,内存占用大) |
| 效果 | 多数场景和 LSTM 接近 | 极长序列(>100 步)效果略好 |
| 上手难度 | 更低(逻辑更简单) | 更高(双状态 + 三门控) |
| 适用场景 | 短序列(如我们的 50 词文本) | 长序列(如长文本、语音) |
对我们的「50 词 IMDB 情感分类」示例来说:用 GRU 完全可以替代 LSTM,且训练速度更快。
示例代码
import ai.djl.*;
import ai.djl.ndarray.*;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.*;
import ai.djl.training.*;
import ai.djl.training.dataset.*;
import ai.djl.training.loss.*;
import ai.djl.training.optimizer.*;
import ai.djl.training.tracker.*;
/**
* 完整实现sin波序列预测
*/
public class SimpleRNNExample {
public static void main(String[] args) throws Exception {
System.out.println("开始序列学习示例");
System.out.println("目标:学习sin波序列的下一个值预测");
try (NDManager manager = NDManager.newBaseManager()) {
// 1. 准备数据 - sin波序列
System.out.println("准备训练数据...");
int seqLength = 10; // 序列长度
int numSamples = 1000; // 样本数量
NDArray data = manager.zeros(new Shape(numSamples, seqLength, 1));
NDArray labels = manager.zeros(new Shape(numSamples, 1));
// 生成sin波序列数据
for (int i = 0; i < numSamples; i++) {
float start = i * 0.1f;
for (int j = 0; j < seqLength; j++) {
float x = start + j * 0.1f;
data.set(new NDIndex(i, j, 0), Math.sin(x));
}
// 标签:序列的下一个值
labels.set(new NDIndex(i, 0), Math.sin(start + seqLength * 0.1f));
}
// 2. 创建数据集
RandomAccessDataset dataset = new ArrayDataset.Builder()
.setData(data)
.optLabels(labels)
.setSampling(32, true) // 批次大小32
.build();
// 3. 创建模型 - 全连接网络
Model model = Model.newInstance("sequence-model");
ai.djl.nn.SequentialBlock block = new ai.djl.nn.SequentialBlock();
// 展平输入 (batch, seq_len, 1) -> (batch, seq_len)
block.add(ai.djl.nn.Blocks.batchFlattenBlock());
// 全连接层 + ReLU激活
block.add(ai.djl.nn.core.Linear.builder().setUnits(64).build());
block.add(new ai.djl.nn.LambdaBlock(ndList ->
new NDList(ai.djl.nn.Activation.relu(ndList.singletonOrThrow()))
));
block.add(ai.djl.nn.core.Linear.builder().setUnits(32).build());
block.add(new ai.djl.nn.LambdaBlock(ndList ->
new NDList(ai.djl.nn.Activation.relu(ndList.singletonOrThrow()))
));
// 输出层(预测单个值)
block.add(ai.djl.nn.core.Linear.builder().setUnits(1).build());
model.setBlock(block);
// 4. 配置训练
Loss l2Loss = Loss.l2Loss();
TrainingConfig config = new DefaultTrainingConfig(l2Loss)
.optOptimizer(Optimizer.adam()
.optLearningRateTracker(Tracker.fixed(0.001f))
.build());
// 5. 训练模型
try (Trainer trainer = model.newTrainer(config)) {
// 初始化模型参数(适配输入形状)
trainer.initialize(new Shape(32, seqLength, 1));
System.out.println("开始训练...");
// 训练循环
for (int epoch = 1; epoch <= 10; epoch++) {
System.out.println("\n第 " + epoch + " 轮训练");
float totalLoss = 0;
int batchCount = 0;
// 遍历数据集
for (Batch batch : trainer.iterateDataset(dataset)) {
try (GradientCollector gc = trainer.newGradientCollector()) {
// 手动前向+反向传播
NDList outputs = trainer.forward(batch.getData(), batch.getLabels());
NDArray loss = l2Loss.evaluate(batch.getLabels(), outputs);
// 计算批次平均损失
float batchLoss = loss.mean().getFloat();
totalLoss += batchLoss;
batchCount++;
// 反向传播+参数更新
gc.backward(loss);
trainer.step();
} finally {
batch.close();
}
}
// 计算并打印平均损失
float avgLoss = totalLoss / batchCount;
System.out.printf("平均损失: %.6f%n", avgLoss);
// 每2轮测试一次(核心:用Trainer验证,避免Block.forward的复杂调用)
if (epoch % 2 == 0) {
validateWithTrainer(trainer, manager, seqLength);
}
}
}
System.out.println("\n训练完成!");
} catch (Exception e) {
System.err.println("训练出错: " + e.getMessage());
e.printStackTrace();
}
}
/**
* 核心修复:用Trainer验证模型
*/
private static void validateWithTrainer(Trainer trainer, NDManager manager, int seqLength) throws Exception {
System.out.println("\n===== 验证模型预测能力 =====");
// 创建单个测试样本(转为批次形状 [1, seqLength, 1])
NDArray testData = manager.zeros(new Shape(1, seqLength, 1));
NDArray testLabel = manager.zeros(new Shape(1, 1));
// 填充测试数据
float start = 0.0f;
for (int j = 0; j < seqLength; j++) {
float x = start + j * 0.1f;
testData.set(new NDIndex(0, j, 0), Math.sin(x));
}
testLabel.set(new NDIndex(0, 0), Math.sin(start + seqLength * 0.1f));
// 核心:用Trainer的forward验证(无需手动创建ParameterStore/Workspace)
NDList inputs = new NDList(testData);
NDList labels = new NDList(testLabel);
NDList outputs = trainer.forward(inputs, labels);
// 安全获取预测值
NDArray pred = outputs.get(0);
float predictedValue = pred.getFloat(); // 自动适配一维/二维张量
float trueValue = testLabel.getFloat();
// 输出结果
System.out.printf("输入序列: sin(0.0), sin(0.1), ..., sin(%.1f)%n", (seqLength-1)*0.1f);
System.out.printf("预测下一个值: %.6f%n", predictedValue);
System.out.printf("真实下一个值: %.6f%n", trueValue);
System.out.printf("绝对误差: %.6f%n", Math.abs(predictedValue - trueValue));
}
}
执行结果
准备训练数据...
开始训练...
第 1 轮训练
平均损失: 0.079387
第 2 轮训练
平均损失: 0.001763
===== 验证模型预测能力 =====
输入序列: sin(0.0), sin(0.1), ..., sin(0.9)
预测下一个值: 0.847484
真实下一个值: 0.841471
绝对误差: 0.006013
第 3 轮训练
平均损失: 0.000243
第 4 轮训练
平均损失: 0.000120
===== 验证模型预测能力 =====
输入序列: sin(0.0), sin(0.1), ..., sin(0.9)
预测下一个值: 0.843613
真实下一个值: 0.841471
绝对误差: 0.002142
第 5 轮训练
平均损失: 0.000078
第 6 轮训练
平均损失: 0.000058
===== 验证模型预测能力 =====
输入序列: sin(0.0), sin(0.1), ..., sin(0.9)
预测下一个值: 0.838216
真实下一个值: 0.841471
绝对误差: 0.003254
第 7 轮训练
平均损失: 0.000049
第 8 轮训练
平均损失: 0.000038
===== 验证模型预测能力 =====
输入序列: sin(0.0), sin(0.1), ..., sin(0.9)
预测下一个值: 0.836871
真实下一个值: 0.841471
绝对误差: 0.004600
第 9 轮训练
平均损失: 0.000030
第 10 轮训练
平均损失: 0.000027
===== 验证模型预测能力 =====
输入序列: sin(0.0), sin(0.1), ..., sin(0.9)
预测下一个值: 0.849429
真实下一个值: 0.841471
绝对误差: 0.007958
训练完成!
或
import ai.djl.*;
import ai.djl.ndarray.*;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.*;
import ai.djl.nn.core.Linear;
import ai.djl.nn.recurrent.RNN;
import ai.djl.training.*;
import ai.djl.training.dataset.ArrayDataset;
import ai.djl.training.dataset.Batch;
import ai.djl.training.loss.Loss;
import ai.djl.training.optimizer.Adam;
import ai.djl.training.tracker.Tracker;
/**
* 极简RNN示例
* 核心映射:代码 → RNN公式 → LSTM/GRU改进点
*/
public class SimpleRnnDemo {
public static void main(String[] args) {
// 基础配置(序列长度=时间步,隐藏层维度=隐藏状态H的维度)
int seqLength = 10; // 时间步T=10(对应公式中的t=1~10)
int inputSize = 1; // 每个时间步输入维度X_t(仅注释说明)
int hiddenSize = 16; // 隐藏状态H_t的维度(对应公式中的W_hh权重维度)
int sampleCount = 500; // 训练样本数
int batchSize = 16; // 批次大小
try (NDManager manager = NDManager.newBaseManager()) {
// 1. 生成序列数据(模拟时间步输入)
System.out.println("=== 步骤1:生成序列数据(X_t) ===");
NDArray inputData = manager.zeros(new Shape(sampleCount, seqLength, inputSize));
NDArray labelData = manager.zeros(new Shape(sampleCount, 1));
// 生成sin序列:每个样本是长度为10的序列(X_1~X_10),标签是第11个值(Y)
for (int i = 0; i < sampleCount; i++) {
float start = i * 0.1f;
for (int t = 0; t < seqLength; t++) { // t=时间步(对应公式中的t)
float x = start + t * 0.1f;
inputData.set(new NDIndex(i, t, 0), Math.sin(x)); // X_t:第t个时间步的输入
}
labelData.set(new NDIndex(i, 0), Math.sin(start + seqLength * 0.1f));
}
// 2. 定义RNN模型
System.out.println("=== 步骤2:定义RNN模型(对应核心公式) ===");
Model model = Model.newInstance("simple-rnn");
SequentialBlock network = new SequentialBlock();
// 核心:RNN循环核(对应公式 H_t = tanh(X_t W_xh + H_{t-1} W_hh + b_h))
network.add(
RNN.builder()
.setActivation(RNN.Activation.TANH) // 激活函数tanh(对应公式中的σ)
.setStateSize(hiddenSize) // 隐藏状态H_t的维度(必填)
.setNumLayers(1) // 单层RNN(避免梯度消失/爆炸)
.build()
);
// 取最后一个时间步的隐藏状态H_T(体现RNN的序列记忆)
network.add(new LambdaBlock(ndList -> {
NDArray rnnOutput = ndList.singletonOrThrow(); // RNN输出:(batch, seqLen, hiddenSize)
return new NDList(rnnOutput.get(new NDIndex(":, -1, :"))); // 取最后一个时间步的H_T
}));
// 输出层(对应公式 Y_t = H_t W_hy + b_y)
network.add(Linear.builder().setUnits(1).build());
model.setBlock(network);
// 3. 配置训练参数
System.out.println("=== 步骤3:配置训练参数 ===");
Loss l2Loss = Loss.l2Loss();
// 配置Adam优化器
Adam optimizer = Adam.builder()
.optLearningRateTracker(Tracker.fixed(0.001f)) // 固定学习率
.build();
// 配置TrainingConfig
DefaultTrainingConfig config = new DefaultTrainingConfig(l2Loss)
.optOptimizer(optimizer); // 设置优化器(唯一配置项)
// 4. 训练模型(观察隐藏状态传递)
System.out.println("=== 步骤4:训练RNN模型(观察隐藏状态传递) ===");
ArrayDataset dataset = new ArrayDataset.Builder()
.setData(inputData)
.optLabels(labelData)
.setSampling(batchSize, true)
.build();
try (Trainer trainer = model.newTrainer(config)) {
// 初始化参数:匹配输入形状 (batch, seqLen, inputSize)
trainer.initialize(new Shape(batchSize, seqLength, inputSize));
// 训练5轮(单层RNN易收敛,多层易出现梯度消失)
for (int epoch = 1; epoch <= 5; epoch++) {
System.out.println("\n--- 第" + epoch + "轮训练 ---");
float totalLoss = 0;
int batchCount = 0;
for (Batch batch : trainer.iterateDataset(dataset)) {
try (GradientCollector gc = trainer.newGradientCollector()) {
// 前向传播:RNN自动完成H_1→H_2→...→H_T的循环计算
NDList outputs = trainer.forward(batch.getData(), batch.getLabels());
// 计算损失(L2损失适合回归预测)
NDArray loss = l2Loss.evaluate(batch.getLabels(), outputs);
// 反向传播:单层RNN梯度稳定(多层易消失,LSTM/GRU解决此问题)
gc.backward(loss);
// 更新参数
trainer.step();
// 累加损失值
totalLoss += loss.mean().getFloat();
batchCount++;
} finally {
batch.close(); // 释放资源(必写)
}
}
// 打印本轮平均损失
float avgLoss = totalLoss / batchCount;
System.out.println("平均损失值: " + String.format("%.6f", avgLoss));
}
// 5. 验证+原理对比
System.out.println("\n=== 步骤5:模型验证 + RNN/LSTM/GRU核心原理 ===");
// 测试序列:X_1~X_10 = sin(0)~sin(0.9)
NDArray testData = manager.zeros(new Shape(1, seqLength, inputSize));
for (int t = 0; t < seqLength; t++) {
testData.set(new NDIndex(0, t, 0), Math.sin(t * 0.1f));
}
// 真实值:第11个sin值(Y=sin(1.0))
float trueValue = (float) Math.sin(seqLength * 0.1f);
// 预测:RNN通过隐藏状态传递记忆序列信息
NDList testOutputs = trainer.forward(new NDList(testData), new NDList(manager.zeros(new Shape(1, 1))));
float predictedValue = testOutputs.get(0).getFloat();
// 打印结果+核心原理(保留关键知识点)
System.out.println("测试序列:X_1=sin(0.0), X_2=sin(0.1), ..., X_10=sin(0.9)");
System.out.println("基于最后一个隐藏状态H_10的预测值: " + String.format("%.6f", predictedValue));
System.out.println("真实值(sin(1.0)): " + String.format("%.6f", trueValue));
System.out.println("预测误差: " + String.format("%.6f", Math.abs(predictedValue - trueValue)));
}
System.out.println("\n=== 训练圆满完成! ===");
} catch (Exception e) {
System.err.println("运行时错误: " + e.getMessage());
e.printStackTrace();
}
}
}
深度学习实战:基于 DJL 的迁移学习与预训练模型应用
引言:迁移学习的价值与应用场景
在深度学习实践中,从零开始训练一个高性能模型往往面临两大核心难题:一是需要海量标注数据,二是训练过程消耗大量算力和时间。而迁移学习作为深度学习领域的核心实用技术,完美解决了这一痛点 —— 它将大规模通用数据集上训练好的预训练模型参数,迁移到数据量较小的特定任务中,既能复用通用特征、节省训练成本,又能显著提升小数据集场景下的任务效果(如小众品类图像分类、特定领域文本情感分析)。
对于 Java 开发者而言,DJL(Deep Java Library)提供了无需掌握 Python 深度学习框架即可调用工业级预训练模型的极简方案,是 Java 生态下落地迁移学习的首选工具。本文将从迁移学习核心原理出发,结合 DJL 框架,完整讲解预训练模型加载、迁移学习实操的全流程。
迁移学习核心原理与实施策略
核心思想
迁移学习的本质是 “知识复用”:将预训练模型在通用数据集(如 ImageNet 图像数据集、WikiText 文本数据集)上学到的通用特征(如图像的边缘 / 纹理、文本的语义编码),适配到新的特定任务中,避免从零训练模型。
三大关键实施策略
迁移学习的实施策略按复杂度从低到高可分为三类,核心是通过 “冻结层” 操作控制参数更新范围:
- 特征提取(Feature Extraction):冻结预训练模型的底层 / 中层网络(通用特征层),仅替换最后一层全连接分类层,用新任务数据仅训练该层。此时预训练层仅作为固定特征提取器,不参与参数更新。
- 微调(Fine-tuning):解冻预训练模型的部分高层网络(或全部网络),使用远小于从头训练的学习率,用新任务数据训练整个模型。该策略平衡 “通用特征复用” 与 “任务定制化”,让通用特征适配新任务专属特征。
- 冻结层(Freezing Layers):是特征提取和微调的核心操作,通过
setFreeze(true)固定层参数(不参与反向传播),需训练的层则设置setFreeze(false),保障迁移学习参数更新策略落地。
DJL 预训练模型库:Java 调用预训练模型的核心工具
DJL 为 Java 开发者封装了预训练模型的加载与使用逻辑,无需关注底层框架(PyTorch/TensorFlow)细节,即可快速调用主流预训练模型。
核心能力与支持的模型
DJL 通过ModelZoo(模型库)和PretrainedModel(预训练模型)核心接口,原生支持两大领域的主流预训练模型:
- 计算机视觉:ResNet、VGG、MobileNet、YOLO(覆盖图像分类、目标检测);
- 自然语言处理:BERT、DistilBERT、GPT(覆盖文本分类、问答、情感分析)。
核心使用流程
加载预训练模型的核心是通过Criteria类配置参数,再调用ModelZoo.loadModel()完成加载,关键配置项如下:
| 配置项 | 作用 |
|---|---|
| optApplication | 指定模型应用场景(如CV.IMAGE_CLASSIFICATION、NLP.TEXT_CLASSIFICATION) |
| setTypes | 指定输入 / 输出数据类型(如图像分类输入Image、输出Classifications) |
| optFilter | 按属性过滤模型(如layers=50匹配 ResNet50) |
| optModelUrls | 直接指定模型地址(兜底方案,避免模型匹配失败) |
| optProgress | 添加进度条,可视化模型下载过程 |
实操要点
- 模型匹配:BERT 等模型易出现
ModelNotFoundException,优先用optModelUrls直接指定模型地址; - 资源管理:使用后调用
close()释放ZooModel,避免内存泄漏; - 异常处理:区分
ModelException(模型匹配失败)和IOException(下载 / 读取失败); - 缓存机制:DJL 将模型缓存到
~/.djl.ai/cache,后续加载无需重复下载。
实战:加载预训练模型
以下代码实现 ResNet50(图像分类)和 DistilBERT(文本分类)预训练模型的加载,包含完整的异常处理和资源释放:
import ai.djl.ModelException;
import ai.djl.modality.Classifications;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import java.io.IOException;
public class PretrainedModelLoader {
// 加载图像分类预训练模型(ResNet50)
public static ZooModel loadImageClassificationModel() throws ModelException, IOException {
Criteria<ai.djl.modality.cv.Image, Classifications> criteria = Criteria.builder()
.optApplication(ai.djl.Application.CV.IMAGE_CLASSIFICATION)
.setTypes(ai.djl.modality.cv.Image.class, Classifications.class)
.optFilter("layers", "50") // 匹配ResNet50
.optProgress(new ProgressBar())
.build();
return ModelZoo.loadModel(criteria);
}
// 加载文本分类预训练模型(DistilBERT,直接指定模型地址避免匹配失败)
public static ZooModel loadTextModel() throws ModelException, IOException {
Criteria<String, Classifications> criteria = Criteria.builder()
.optModelUrls("djl://ai.djl.pytorch/distilbert") // 直接指定模型地址
.setTypes(String.class, Classifications.class)
.optProgress(new ProgressBar())
.build();
return ModelZoo.loadModel(criteria);
}
public static void main(String[] args) {
ZooModel imageModel = null;
ZooModel textModel = null;
try {
// 加载ResNet50
imageModel = loadImageClassificationModel();
System.out.println("ResNet50 预训练模型加载完成:" + imageModel.getName());
// 加载DistilBERT
textModel = loadTextModel();
System.out.println("DistilBERT 预训练模型加载完成:" + textModel.getName());
} catch (ModelException e) {
System.err.println("模型加载失败:找不到匹配的预训练模型");
e.printStackTrace();
} catch (IOException e) {
System.err.println("IO异常:模型文件下载/读取失败");
e.printStackTrace();
} finally {
// 确保资源释放
if (imageModel != null) {
imageModel.close();
}
if (textModel != null) {
textModel.close();
}
}
}
}
迁移学习实操:基于 DJL 微调预训练模型
加载预训练模型是基础,结合自定义数据集完成迁移学习微调,才是落地特定任务的核心。以下以 ResNet50 为例,完整实现 “加载预训练模型→修改输出层→加载自定义数据集→冻结层→微调” 的全流程。
步骤 1:加载预训练模型并修改输出层
核心逻辑:加载 ResNet50 预训练模型后,替换最后一层分类器为自定义分类数的输出层,适配新任务:
import ai.djl.Application;
import ai.djl.Model;
import ai.djl.ModelException;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.core.Linear;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.Trainer;
import ai.djl.training.loss.Loss;
import java.io.IOException;
public class TransferLearningDemo {
// 加载ResNet50预训练模型 + 修改输出层为指定分类数
public static Model fineTuneResNet(int numClasses) throws ModelException, IOException {
// 1. 加载ResNet50预训练模型
Criteria<Image, Classifications> criteria = Criteria.builder()
.optApplication(Application.CV.IMAGE_CLASSIFICATION)
.setTypes(Image.class, Classifications.class)
.optFilter("layers", "50") // 匹配ResNet50
.build();
ZooModel<Image, Classifications> pretrainedZooModel = ModelZoo.loadModel(criteria);
Block pretrainedBlock = pretrainedZooModel.getBlock();
// 2. 构建新输出层并替换
Linear newOutputLayer = Linear.builder().setUnits(numClasses).build();
SequentialBlock newModelBlock = new SequentialBlock();
newModelBlock.add(pretrainedBlock).add(newOutputLayer);
// 3. 初始化新模型(必须步骤,否则模型无法使用)
Model fineTunedModel = Model.newInstance("resnet50-finetuned");
fineTunedModel.setBlock(newModelBlock);
// 初始化模型输入形状(ResNet50固定:1样本、3通道、224x224)
try (Trainer trainer = fineTunedModel.newTrainer(new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()))) {
trainer.initialize(new Shape(1, 3, 224, 224));
}
// 释放预训练模型资源
pretrainedZooModel.close();
return fineTunedModel;
}
// 测试入口:仅验证模型加载和输出层修改
public static void main(String[] args) {
try {
Model model = fineTuneResNet(10); // 修改输出层为10分类
System.out.println("模型加载并修改输出层成功:" + model.getName());
model.close();
} catch (ModelException | IOException e) {
e.printStackTrace();
}
}
}
步骤 2:加载自定义数据集并完成微调
在步骤 1 基础上,新增自定义数据集加载、冻结层配置、训练流程,实现完整的迁移学习微调:
import ai.djl.Application;
import ai.djl.Model;
import ai.djl.ModelException;
import ai.djl.basicdataset.cv.classification.ImageFolder;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterList;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.core.Linear;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.dataset.Batch;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.loss.Loss;
import ai.djl.translate.Pipeline;
import ai.djl.translate.TranslateException;
import ai.djl.util.Pair;
import java.io.IOException;
import java.nio.file.Paths;
public class TransferLearningDemo {
// 自定义数据集根路径(需替换为实际路径)
private static final String DATASET_ROOT = "C:\\Users\\Administrator\\Desktop\\catAndDog";
// 核心:加载ResNet50 + 修改输出层 + 加载自定义数据集 + 微调
public static Model fineTuneResNet(int numClasses, int epochs) throws ModelException, IOException, TranslateException {
// 1. 加载ResNet50预训练模型
Criteria<Image, Classifications> criteria = Criteria.builder()
.optApplication(Application.CV.IMAGE_CLASSIFICATION)
.setTypes(Image.class, Classifications.class)
.optFilter("layers", "50")
.build();
ZooModel<Image, Classifications> pretrainedZooModel = ModelZoo.loadModel(criteria);
Block pretrainedBlock = pretrainedZooModel.getBlock();
// 2. 替换输出层(适配自定义分类数)
Linear newOutputLayer = Linear.builder().setUnits(numClasses).build();
SequentialBlock newModelBlock = new SequentialBlock();
newModelBlock.add(pretrainedBlock).add(newOutputLayer);
// 3. 初始化模型
Model fineTunedModel = Model.newInstance("resnet50-finetuned");
fineTunedModel.setBlock(newModelBlock);
// 4. 配置训练(损失函数+精度评估)
DefaultTrainingConfig trainConfig = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
.addEvaluator(new Accuracy()); // 仅保留核心精度评估
try (Trainer trainer = fineTunedModel.newTrainer(trainConfig)) {
trainer.initialize(new Shape(1, 3, 224, 224));
// 5. 冻结预训练层(迁移学习核心):仅解冻最后2个参数(新输出层)
ParameterList allParams = fineTunedModel.getBlock().getParameters();
for (int i = 0; i < allParams.size(); i++) {
Pair<String, Parameter> paramPair = allParams.get(i);
Parameter param = paramPair.getValue();
if (i > allParams.size() - 3) {
param.freeze(false); // 解冻新输出层
} else {
param.freeze(true); // 冻结预训练层
}
}
// 6. 加载自定义数据集(ImageFolder格式:按文件夹分类存储图片)
Pipeline pipeline = new Pipeline();
pipeline.add(new Resize(224, 224)).add(new ToTensor()); // 图像预处理
// 训练集(批次大小8,避免内存溢出)
ImageFolder trainDataset = ImageFolder.builder()
.setRepositoryPath(Paths.get(DATASET_ROOT))
.optPipeline(pipeline)
.setSampling(8, true)
.build();
trainDataset.prepare();
// 验证集(实际建议拆分train/test子目录)
ImageFolder valDataset = ImageFolder.builder()
.setRepositoryPath(Paths.get(DATASET_ROOT))
.optPipeline(pipeline)
.setSampling(8, false)
.build();
valDataset.prepare();
// 7. 执行微调训练
System.out.println("开始微调:分类数=" + numClasses + ",轮数=" + epochs);
System.out.println("训练集样本数=" + trainDataset.size() + ",验证集=" + valDataset.size());
for (int epoch = 0; epoch < epochs; epoch++) {
// 训练轮次
for (Batch batch : trainer.iterateDataset(trainDataset)) {
EasyTrain.trainBatch(trainer, batch);
trainer.step();
batch.close();
}
// 验证轮次
for (Batch batch : trainer.iterateDataset(valDataset)) {
EasyTrain.validateBatch(trainer, batch);
batch.close();
}
System.out.println("完成第 " + (epoch + 1) + " 轮训练");
}
}
pretrainedZooModel.close();
return fineTunedModel;
}
// 测试入口
public static void main(String[] args) {
try {
Model model = fineTuneResNet(10, 3); // 10分类,训练3轮
System.out.println("\n模型微调完成!");
model.close();
} catch (Exception e) {
System.err.println("异常:" + e.getMessage());
e.printStackTrace();
}
}
}
全章节核心总结
迁移学习核心
- 核心价值:复用预训练模型的通用特征,解决小数据集、低算力场景下的模型训练问题;
- 核心操作:冻结底层通用特征层,替换输出层适配新任务,按需微调高层网络。
DJL 使用关键
- 模型加载:优先用
optModelUrls指定模型地址,避免匹配失败;使用后必须调用close()释放资源; - 迁移学习实操:加载预训练模型→替换输出层→冻结预训练层→加载自定义数据集→微调训练。
工程化要点
- 异常处理:区分
ModelException(模型匹配)和IOException(下载 / 读取); - 资源管理:模型缓存路径
~/.djl.ai/cache,无需重复下载; - 训练优化:减小批次大小避免内存溢出,冻结大部分层降低训练成本。
通过以上流程,你可以基于 DJL 快速落地 Java 端的迁移学习任务,无论是图像分类还是文本分析,都能借助预训练模型实现高效、高性能的模型开发。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐



所有评论(0)