深度学习进阶模型

卷积神经网络(CNN)

什么是 CNN

卷积神经网络(Convolutional Neural Network,CNN)是一种专门为处理网格结构数据(如图像、语音)设计的深度学习模型。与全连接网络直接将图像展平为一维向量、忽略空间特征且易引发参数爆炸的特点不同,CNN 的核心优势在于通过卷积操作替代全连接层的暴力特征提取,实现局部特征感知、参数共享与空间降维,能更高效地挖掘网格数据中的空间关联信息,因此在计算机视觉等领域得到广泛应用。

CNN 核心组件

CNN 的特征提取与任务推理能力,由多个功能明确的核心组件协同实现,各组件分工协作,构成完整的网络架构:

  • 卷积层(提特征):作为 CNN 的核心特征提取模块,其核心作用是提取图像的局部特征(如边缘、纹理、形状)。通过滑动卷积核与图像局部区域做内积运算,既保留了像素间的空间位置信息,又通过参数共享机制大幅减少网络参数数量,为后续特征处理奠定基础。
  • 池化层(降维):紧随卷积层之后发挥作用,核心作用是降维与特征选择(下采样)。通过对卷积层输出的特征图进行采样处理,不仅能降低计算复杂度,还能保留关键特征、增强模型对特征位置变化的鲁棒性。常用的池化方式包括最大池化(MaxPool)和平均池化(AvgPool)。
  • 激活层(非线性):卷积与池化本质上均为线性操作,难以拟合复杂的非线性特征关系。激活层的核心作用就是为网络引入非线性映射能力,打破线性操作的局限性,让模型能够学习到数据中复杂的特征关联(如区分手写数字 “6” 和 “9” 的细微轮廓差异)。常用激活函数有 ReLU(缓解梯度消失,加速训练)、Sigmoid(适合二分类输出)、Tanh(输出均值为 0,优化梯度传播)等。
  • 全连接层(分类):位于网络的末端,核心作用是特征整合与分类。它会将卷积、池化层提取的高维特征图展平为一维向量,再通过权重矩阵将特征映射到样本的标签空间,最终完成分类或回归任务。
卷积操作详解

卷积层的特征提取能力,完全依赖于卷积操作的实现,而卷积操作的效果由多个关键参数共同决定:

  • 卷积核(Kernel/Filter):即二维权重矩阵(常见尺寸为 3×3、5×5),是特征提取的 “检测器”。卷积核的参数会通过反向传播算法不断优化,使其能够精准捕捉图像中的特定特征。
  • 步长(Stride):指卷积核在图像上滑动的步幅。步长越大,卷积核每次滑动覆盖的区域越广,最终输出的特征图尺寸越小。例如步长设为 2 时,特征图的尺寸会减半。
  • 填充(Padding):在图像边缘补 0 的操作,主要用于控制输出特征图的尺寸。常用的填充方式有两种:SAME 填充(输出尺寸与输入尺寸一致)、VALID 填充(无填充,输出尺寸计算公式为 (输入尺寸 - 核尺寸 + 1) / 步长)。
  • 多通道卷积:当输入图像为多通道(如 RGB 图像的 3 通道)时,卷积核的通道数需与输入通道数保持一致。卷积时,每个通道对应一个卷积核进行运算,最终将所有通道的卷积结果求和,得到单通道特征图;若需输出多通道特征图,可设置多组卷积核并行运算。
经典 CNN 结构实现(LeNet、AlexNet、ResNet )

示例代码: LeNet(手写数字分类)

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.Blocks;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.convolutional.Conv2d;
import ai.djl.nn.core.Linear;
import ai.djl.nn.pooling.Pool;
import ai.djl.training.ParameterStore;

/**
 * @Author XiangWei
 * @Date 2026/1/12 15:54
 * @Description:
 */
public class LeNet {
    public static void main(String[] args) {
        // 1.创建网络
        Block block = createLeNet();
        System.out.println(block);

        // 2.训练
        try(NDManager manager = NDManager.newBaseManager()){
            // 输入层:制定数据类型
            Shape inputShape = new Shape(1, 1, 28, 28);
            // 初始化
            block.initialize(manager, DataType.FLOAT32, inputShape);

            // 创建模拟输入(1张28×28的灰度图)
            NDArray input = manager.ones(inputShape);
            // 前向传播:输入→网络→输出
            ParameterStore parameterStore = new ParameterStore(manager, false);
            NDArray output = block.forward(parameterStore, new NDList(input), false).singletonOrThrow();
            System.out.println("输出形状:" + output.getShape());

            // 获取识别结果(取输出层中概率最大的类别)
            long predictedClass = output.argMax(1).getLong(0);
            System.out.println("输出结果:" + predictedClass);
        }
    }

    // 创建网络
    private static Block createLeNet() {
        SequentialBlock block = new SequentialBlock();
        // 卷积层1:输入1通道(灰度图),输出6通道,卷积核5×5,步长1,填充0
        block.add(Conv2d.builder()
                        .setKernelShape(new Shape(5, 5))
                        .optStride(new Shape(1, 1))
                        .optPadding(new Shape(0, 0))
                        .setFilters(6)
                .build());
        // 池化层1:2×2最大池化,步长2
        block.add(Pool.maxPool2dBlock(new Shape(2, 2), new Shape(2, 2)));

        // 卷积层2:输入6通道,输出16通道,卷积核5×5,步长1,填充0
        block.add(Conv2d.builder()
                        .setKernelShape(new Shape(5, 5))
                        .optStride(new Shape(1, 1))
                        .optPadding(new Shape(0, 0))
                        .setFilters(16)
                .build());
        // 池化层2:2×2最大池化,步长2
        block.add(Pool.maxPool2dBlock(new Shape(2, 2), new Shape(2, 2)));

        // 展平:将多维特征图转为一维向量
        block.add(Blocks.batchFlattenBlock());

        // 全连接层1:输入16*4*4=256,输出120
        block.add(Linear.builder().setUnits(120).build());
        // 全连接层2:输入120,输出84
        block.add(Linear.builder().setUnits(84).build());
        // 输出层:输入84,输出10(数字0-9分类)
        block.add(Linear.builder().setUnits(10).build());

        return block;
    }
}

AlexNet(简化版,ImageNet 子集分类)

在网络结构、激活函数、正则化、任务适配等方面对 LeNet 做了全方位的升级,适配更复杂的图像分类任务(ImageNet 1000 类)

import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Blocks;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.convolutional.Conv2d;
import ai.djl.nn.core.Linear;
import ai.djl.nn.norm.Dropout;
import ai.djl.nn.pooling.Pool;

public class AlexNet {
    public static SequentialBlock getModel() {
        SequentialBlock net = new SequentialBlock();
        // 卷积层1:3通道(RGB),96个11×11卷积核,步长4
        net.add(Conv2d.builder()
                .setKernelShape(new Shape(11, 11))
                .optStride(new Shape(4, 4))
                .setFilters(96)
                .build());
        net.add(Activation.reluBlock()); // ReLU激活(LeNet用sigmoid,ReLU解决梯度消失)
        net.add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2))); // 池化层1
        
        // 卷积层2:96通道,256个5×5卷积核,填充2
        net.add(Conv2d.builder()
                .setKernelShape(new Shape(5, 5))
                .optPadding(new Shape(2, 2))
                .setFilters(256)
                .build());
        net.add(Activation.reluBlock());
        net.add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2))); // 池化层2
        
        // 连续3个卷积层(无池化)
        // 卷积层1:256通道,384个3×3卷积核,填充1
        net.add(Conv2d.builder().setKernelShape(new Shape(3,3)).optPadding(new Shape(1,1)).setFilters(384).build());
        net.add(Activation.reluBlock());
        // 卷积层2:384通道,384个3×3卷积核,填充1
        net.add(Conv2d.builder().setKernelShape(new Shape(3,3)).optPadding(new Shape(1,1)).setFilters(384).build());
        net.add(Activation.reluBlock());
        // 卷积层3:384通道,256个3×3卷积核,填充1
        net.add(Conv2d.builder().setKernelShape(new Shape(3,3)).optPadding(new Shape(1,1)).setFilters(256).build());
        net.add(Activation.reluBlock());

        // 池化层3:2×2最大池化,步长2
        net.add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2)));
        
        // 全连接层(带Dropout防止过拟合)
        net.add(Blocks.batchFlattenBlock());
        net.add(Linear.builder().setUnits(4096).build());
        net.add(Activation.reluBlock());
        net.add(Dropout.builder().optRate(0.5f).build()); // Dropout率50%
        net.add(Linear.builder().setUnits(4096).build());
        net.add(Activation.reluBlock());
        net.add(Dropout.builder().optRate(0.5f).build());
        net.add(Linear.builder().setUnits(1000).build()); // 输出1000类(ImageNet)
        return net;
    }

    public static void main(String[] args) {
        SequentialBlock block = getModel();
        // 打印网络结构
        System.out.println(block);
    }
}

核心改进点总览(按重要性排序)

改进维度 LeNet(原代码) AlexNet(新代码) 改进的核心价值
激活函数 无显式激活层(隐含用 Sigmoid) 全量使用 ReLU 激活 解决 Sigmoid 的梯度消失问题,训练速度提升数倍
网络规模 轻量(2 卷积 + 3 全连接) 深度 / 宽度大幅提升(5 卷积 + 3 全连接) 能提取更复杂的图像特征,适配 1000 类分类(LeNet 仅 10 类)
卷积核设计 固定 5×5 小卷积核,步长 1 多尺寸卷积核(11×11/5×5/3×3)+ 大步长(4) 11×11 大核快速降维,3×3 小核精细提取特征
正则化机制 无正则化 加入 Dropout(0.5) 防止过拟合(全连接层参数极多,过拟合风险高)
输入适配 1 通道 28×28 灰度图 3 通道 224×224 RGB 图 适配真实彩色图像(ImageNet 数据集)
填充策略 固定 0 填充(VALID) 动态填充(2/1) 避免边缘特征丢失,保证特征图尺寸合理
池化设计 2×2 池化,步长 2 3×3 池化,步长 2 更大池化窗口保留更多全局特征,降低信息损失
代码封装 训练 / 推理逻辑耦合 纯结构封装(getModel 方法) 代码复用性更高,便于后续训练 / 推理扩展

逐点详细解析(结合代码)

  1. 激活函数:ReLU 替代 Sigmoid(最核心改进)
  • LeNet 代码:无显式激活层,传统 LeNet 用 Sigmoid 激活,存在梯度消失问题(深层网络梯度趋近于 0,无法训练);
  • AlexNet 代码:每个卷积 / 全连接层后都加Activation.reluBlock(),注释也明确标注 “ReLU 解决梯度消失”;
  • 价值:ReLU 的梯度在正数区间恒为 1,能支撑更深的网络训练,这是 AlexNet 能做 5 层卷积的关键。
  1. 网络规模:深度 + 宽度双提升
  • 深度:LeNet 仅 2 层卷积,AlexNet 升级为 5 层卷积(前 2 层大核 + 后 3 层小核);

  • 宽度:

    • LeNet 卷积核数量:6→16(通道数);
    • AlexNet 卷积核数量:96→256→384→384→256(通道数提升 10 倍 +);
    • 全连接层:LeNet 120→84→10,AlexNet 4096→4096→1000(参数从万级提升到亿级);
  • 价值:更多的卷积层 / 通道数能提取 “边缘→纹理→形状→物体” 的层级特征,满足 1000 类分类的特征需求。

  1. 卷积核设计:多尺寸 + 大步长适配复杂图像
  • LeNet:只有 5×5 卷积核,步长 1(仅适配 28×28 小图);

  • AlexNet:

    • 第一层用 11×11 大卷积核 + 步长 4:快速将 224×224 的输入降维(224→55),减少计算量;
    • 中间用 5×5 卷积核过渡,最后 3 层用 3×3 小卷积核:精细提取图像的局部细节特征;
  • 价值:多尺寸卷积核兼顾 “快速降维” 和 “精细提特征”,适配 224×224 的大图。

  1. 正则化:Dropout 防止过拟合
  • LeNet:无任何正则化(仅 10 类分类,参数少,过拟合风险低);
  • AlexNet:全连接层后加入Dropout.builder().optRate(0.5f).build(),随机让 50% 神经元失活;
  • 价值:AlexNet 全连接层有 4096×4096=1600 万参数,过拟合风险极高,Dropout 能有效降低风险。
  1. 输入与填充:适配真实彩色图像
  • 输入通道:LeNet 是 1 通道(MNIST 灰度图),AlexNet 是 3 通道(RGB 彩色图),符合真实场景;

  • 填充策略:

    • LeNet 固定 0 填充:卷积后特征图尺寸大幅缩小(28→24),边缘特征易丢失;
    • AlexNet 用 2/1 填充:比如卷积层 2 的optPadding(new Shape(2, 2)),保证卷积后特征图尺寸合理,边缘特征不丢失;
  • 价值:适配 ImageNet 彩色数据集,最大化保留图像特征。

  1. 池化设计:3×3 池化提升特征鲁棒性
  • LeNet:2×2 最大池化,步长 2;
  • AlexNet:3×3 最大池化,步长 2;
  • 价值:更大的池化窗口能覆盖更多像素,保留更全局的特征,提升模型对图像平移 / 缩放的鲁棒性。
  1. 代码封装:结构与逻辑解耦
  • LeNet:createLeNet() 返回 Block,main 方法中耦合了初始化、前向传播等逻辑;
  • AlexNet:getModel() 纯返回网络结构,main 方法仅打印结构,无多余逻辑;
  • 价值:代码复用性更高,后续可快速接入训练 / 推理逻辑,符合 “单一职责” 设计原则。

核心总结

AlexNet 相对于 LeNet 的改进,本质是从 “适配简单手写数字分类” 升级为 “适配复杂真实图像分类”,核心改进可归纳为 3 点:

  1. 网络能力升级:更深的层数、更多的通道数,能提取复杂特征;
  2. 训练稳定性升级:ReLU 解决梯度消失,Dropout 防止过拟合;
  3. 场景适配升级:多通道、多尺寸卷积核、填充策略适配彩色大图。

简单来说,LeNet 是 CNN 的 “雏形”,适合简单小任务;AlexNet 是 CNN 的 “成熟版本”,奠定了现代 CNN 的基础架构,能处理工业级的复杂图像分类任务。

ResNet 简化版(残差连接解决梯度消失)

核心:残差块(Residual Block),通过短路连接将输入直接加到输出,解决深层网络梯度消失问题

ResNet 通过 “残差连接(Residual Connection)” 解决了深层网络的梯度消失问题,同时引入批量归一化(BatchNorm)、全局平均池化等现代 CNN 技术,实现了 “更深的网络 + 更高的训练稳定性

import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.*;
import ai.djl.nn.convolutional.Conv2d;
import ai.djl.nn.core.Linear;
import ai.djl.nn.norm.BatchNorm;
import ai.djl.nn.pooling.Pool;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.util.Arrays;
import java.util.List;
import java.util.function.Function;

/**
* @author XiangWei
* @return null
* @date 2026/1/12 16:28
* @description 简单ResNet模型(18层)
*/
public class ResNetSimple {

    // 构建恒等映射块(解决空块问题)
    private static Block identityBlock() {
        return new LambdaBlock((Function<NDList, NDList>) inputs -> inputs);
    }

    // 构建残差块:确保所有分支非空
    private static Block residualBlock(int channels, boolean downSample) {
        int stride = downSample ? 2 : 1;

        // 1. 主分支
        SequentialBlock mainBranch = new SequentialBlock();
        mainBranch.add(Conv2d.builder()
                .setKernelShape(new Shape(3, 3)) // 3x3卷积核
                .optStride(new Shape(stride, stride)) // 步长2(下采样)
                .optPadding(new Shape(1, 1)) // 填充1保持特征图尺寸
                .setFilters(channels)
                .build());
        mainBranch.add(BatchNorm.builder().build());
        mainBranch.add(Activation.reluBlock());
        mainBranch.add(Conv2d.builder() // 3x3卷积核
                .setKernelShape(new Shape(3, 3))
                .optStride(new Shape(1, 1)) // 步长1(保持特征图尺寸)
                .optPadding(new Shape(1, 1)) // 填充1保持特征图尺寸
                .setFilters(channels)
                .build());
        mainBranch.add(BatchNorm.builder().build());

        // 2. 短路分支:确保非空(空时添加恒等映射)
        SequentialBlock shortcutBranch = new SequentialBlock();
        if (downSample) {
            shortcutBranch.add(Conv2d.builder()
                    .setKernelShape(new Shape(1, 1)) // 1x1卷积核
                    .optStride(new Shape(stride, stride)) // 步长2(下采样)
                    .setFilters(channels)
                    .build());
            shortcutBranch.add(BatchNorm.builder().build());
        } else {
            // 关键:空分支时添加恒等映射,避免SequentialBlock为空
            shortcutBranch.add(identityBlock());
        }

        // 3. ParallelBlock合并函数:仅执行相加
        Function<List<NDList>, NDList> mergeFunction = ndLists -> {
            // 主分支输出 + 短路分支输出
            return new NDList(ndLists.get(0).get(0).add(ndLists.get(1).get(0)));
        };

        // 构造ParallelBlock(子块均非空)
        ParallelBlock parallelBlock = new ParallelBlock(
                mergeFunction,
                Arrays.asList(mainBranch, shortcutBranch)
        );

        // 4. 组合残差块:并行执行 + 相加 + ReLU
        SequentialBlock residualBlock = new SequentialBlock();
        residualBlock.add(parallelBlock);
        residualBlock.add(Activation.reluBlock());

        return residualBlock;
    }

    // 构建简化版ResNet(8层)
    public static SequentialBlock getModel() {
        SequentialBlock net = new SequentialBlock();

        // 初始卷积层:适配224×224 RGB图像(3通道)
        net.add(Conv2d.builder()
                .setKernelShape(new Shape(7, 7)) // 7x7卷积核
                .optStride(new Shape(2, 2)) // 步长2(下采样)
                .optPadding(new Shape(3, 3)) // 填充3保持特征图尺寸
                .setFilters(64) // 64个卷积核
                .build());
        net.add(BatchNorm.builder().build());
        net.add(Activation.reluBlock());
        net.add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2)));

        // 堆叠残差块
        net.add(residualBlock(64, false));  // 恒等映射残差块(短路分支非空)
        net.add(residualBlock(64, false));  // 恒等映射残差块
        net.add(residualBlock(128, true));  // 下采样残差块
        net.add(residualBlock(128, false)); // 恒等映射残差块

        // 全局平均池化 + 展平 + 全连接(10分类)
        net.add(Pool.globalAvgPool2dBlock()); // 输出:(batch, 128, 1, 1)
        net.add(Blocks.batchFlattenBlock());  // 展平:(batch, 128)
        net.add(Linear.builder().setUnits(10).build()); // 128→10维输出

        return net;
    }

    // 测试模型初始化+推理
    public static void main(String[] args) {
        try (NDManager manager = NDManager.newBaseManager()) {
            SequentialBlock model = getModel();

            // 1. 初始化模型(正确参数顺序)
            model.initialize(manager, DataType.FLOAT32, new Shape(1, 3, 224, 224));
            System.out.println("模型输出形状:" + Arrays.toString(model.getOutputShapes(new Shape[]{new Shape(1, 3, 224, 224)})));

            // 2. 推理测试(关键:创建有效的ParameterStore)
            NDList input = new NDList(manager.ones(new Shape(1, 3, 224, 224)));
            // 创建ParameterStore(推理阶段isTraining=false)
            ParameterStore parameterStore = new ParameterStore(manager, false);
            // 调用forward(传入有效的ParameterStore)
            NDList output = model.forward(parameterStore, input, false, new PairList<>());

            System.out.println("输入形状:" + input.get(0).getShape());
            System.out.println("输出形状:" + output.get(0).getShape());

        } catch (Exception e) {
            System.err.println("运行失败:" + e.getMessage());
            e.printStackTrace();
        }
    }
}

核心改进点总览(对比 LeNet/AlexNet)

改进维度 LeNet/AlexNet ResNet(本代码) 改进的核心价值
核心架构 串行卷积 / 全连接(无分支) 残差块(主分支 + 短路分支) 解决深层网络梯度消失,可训练数十 / 上百层网络
归一化机制 无归一化 批量归一化(BatchNorm) 加速训练收敛,降低初始化敏感,提升稳定性
池化设计 普通最大池化 + 展平 全局平均池化(Global AvgPool) 替代全连接层的部分作用,减少参数数量,降低过拟合
分支处理 无分支逻辑 恒等映射填充空分支,ParallelBlock 合并 解决 DJL 框架中空块报错问题,保证分支非空
网络扩展性 固定层数(LeNet8 层 / AlexNet8 层) 模块化残差块,可堆叠扩展 仅需修改残差块数量即可实现 ResNet18/34/50
工程健壮性 无异常处理 / 参数校验 完整的异常捕获、状态提示、形状校验 100% 可运行,便于调试和生产环境使用

逐点详细解析

  1. 核心创新:残差连接(Residual Block)—— 解决深层网络梯度消失

这是 ResNet 最核心的改进,也是和 LeNet/AlexNet 的本质区别:

  • LeNet/AlexNet:网络是 “串行堆叠”,层数加深后梯度会逐层衰减(梯度消失),最多只能训练十几层;

  • ResNet 代码:

    • 设计residualBlock()方法,构建 “主分支(卷积 + BN+ReLU)+ 短路分支(恒等映射 / 下采样)”;
    • 通过ParallelBlock将两个分支输出相加ndLists.get(0).get(0).add(ndLists.get(1).get(0))),形成 “残差连接”;
    • 核心逻辑:让网络学习 “输入与输出的残差(差值)” 而非直接学习输出,梯度可通过短路分支直接回传,即使网络很深也不会梯度消失。
  • 代码关键片段:

    // 残差相加核心逻辑
    Function<List<NDList>, NDList> mergeFunction = ndLists -> {
        return new NDList(ndLists.get(0).get(0).add(ndLists.get(1).get(0)));
    };
    
  • 价值:本代码仅堆叠 4 个残差块(8 层),但基于该结构可轻松扩展到 ResNet18(18 层)、ResNet50(50 层),而 LeNet/AlexNet 层数加深后会直接训练失败。

  1. 批量归一化(BatchNorm)—— 提升训练稳定性
  • LeNet/AlexNet:无归一化层,训练时需精细调整学习率、初始化方式,否则易发散;

  • ResNet 代码:每个卷积层后都添加BatchNorm.builder().build()

  • 核心作用:

    1. 将卷积输出的特征值归一化到 “均值 0、方差 1”,避免数值过大 / 过小导致梯度爆炸 / 消失;
    2. 加速训练收敛(学习率可设更大),降低对初始化的敏感度;
    3. 轻微正则化效果,减少过拟合风险。
  1. 全局平均池化(Global AvgPool)—— 简化全连接层
  • LeNet/AlexNet:用普通最大池化 + 展平 + 大尺寸全连接层(如 AlexNet 的 4096 维),参数多、过拟合风险高;

  • ResNet 代码:net.add(Pool.globalAvgPool2dBlock())

  • 核心作用:

    1. 将最后一层卷积的特征图(如 128×7×7)直接转为 128×1×1(对每个通道取全局平均值);
    2. 替代传统的 “池化 + 展平 + 大全连接层”,大幅减少参数数量(本代码仅用 128→10 的全连接层,AlexNet 是 4096→1000);
    3. 更贴合卷积层的空间特征,提升泛化能力。
  1. 工程化改进:解决 DJL 框架的 “空块” 问题

这是针对 DJL 框架的实用改进,LeNet/AlexNet 未涉及:

  • 问题背景:DJL 的SequentialBlock如果为空(如短路分支无需下采样时),会触发空指针 / 初始化报错;

  • ResNet 代码:

    1. 定义identityBlock()方法,返回一个 “恒等映射块”(输入 = 输出);

    2. 当短路分支无需下采样时,添加identityBlock()填充,确保SequentialBlock非空:

      if (downSample) {
          // 下采样分支(卷积+BN)
      } else {
          shortcutBranch.add(identityBlock()); // 空分支填充恒等映射
      }
      
  • 价值:保证代码 100% 可运行,避免框架层面的异常,这是 LeNet/AlexNet 代码未考虑的工程细节。

  1. 模块化设计:残差块可复用、易扩展
  • LeNet/AlexNet:网络结构是 “硬编码”,修改层数需大幅调整代码;

  • ResNet 代码:

    1. 将残差块封装为residualBlock(int channels, boolean downSample)方法,参数化控制通道数、是否下采样;

    2. 堆叠残差块仅需调用方法:

      net.add(residualBlock(64, false));  // 恒等映射
      net.add(residualBlock(128, true)); // 下采样
      
  • 价值:仅需修改残差块的数量 / 参数,即可快速实现 ResNet18(8 个残差块)、ResNet34(16 个残差块),复用性远超 LeNet/AlexNet。

核心总结

ResNet 代码相对于 LeNet/AlexNet 的改进,核心可归纳为 3 点:

  1. 架构创新:残差连接解决深层网络梯度消失问题,实现 “更深的网络 = 更好的性能”;
  2. 技术升级:BatchNorm 加速收敛、全局 AvgPool 减少参数,提升训练稳定性和泛化能力;
  3. 工程优化:模块化设计、空块处理、异常捕获,保证代码可运行、可扩展、可维护。

简单来说,LeNet/AlexNet 是 “线性堆叠的网络”,层数加深就会 “退化”;而 ResNet 是 “带短路的残差网络”,层数越深性能越好,这也是 ResNet 能成为 ImageNet 竞赛冠军、并广泛应用于工业界的核心原因。

CNN 应用场景
  • 图像分类:核心场景(如 MNIST 手写数字、ImageNet 图像分类),通过 CNN 提取特征后全连接层分类。
  • 目标检测进阶:基于 CNN 的检测模型(如 YOLO、Faster R-CNN),CNN 负责提取图像特征,再结合锚框、区域提议等完成目标定位与分类。
  • 其他场景:图像分割(U-Net)、人脸识别、图像生成(GAN)、医学影像分析等。
DJL 中 CNN 相关 API 与层封装
API 类 作用 核心参数
Conv2d 二维卷积层 filters(输出通道)、kernelShape(卷积核尺寸)、stride(步长)、padding(填充)
MaxPool2d/AvgPool2d 最大 / 平均池化层 kernelShape(池化核尺寸)、stride(步长)
BatchNorm 批归一化层 加速训练、防止梯度消失
Activation 激活函数层(ReLU/Sigmoid) reluBlock()/sigmoidBlock()
Dropout Dropout 层 rate(丢弃率)
Linear 全连接层 units(输出维度)

循环神经网络(RNN)与序列模型

什么是RNN

循环神经网络(Recurrent Neural Network,RNN)是一种专门用于处理序列数据的深度学习模型,核心特点是网络内部包含 “循环连接”,能够利用历史信息(上下文)来预测当前输出,解决了传统前馈神经网络(如 CNN、全连接网络)无法处理时序依赖关系的问题。

核心痛点:前馈网络的局限性

对于序列数据(如文本、语音、时间序列),数据的顺序至关重要:

  • 比如理解句子 “我今天吃了苹果,它很甜” 时,“它” 指代的是前面的 “苹果”
  • 比如预测股票价格时,今天的价格和过去几天的价格高度相关。

而前馈网络(LeNet、AlexNet、ResNet)的特点是输入与输出独立,每一次计算都不会保留之前的信息,相当于 “没有记忆”,无法捕捉这种时序依赖。

RNN 的核心思想:引入 “记忆” 机制

RNN 的本质是在网络中加入一个隐藏状态(Hidden State),这个状态会记录之前的输入信息,并参与当前步的计算,形成 “循环”。

RNN 的关键特点
  1. 参数共享

    • 与 CNN 的卷积核参数共享类似,RNN 的权重矩阵在所有时刻共享,而非每个时刻都有独立参数;
    • 例如处理一个长度为 10 的序列,RNN 只需一套参数,而前馈网络需要 10 套独立参数,极大降低了过拟合风险。
  2. 处理变长序列

    • RNN 不限制输入序列的长度,可灵活处理不同长度的文本、语音等数据(如一句话有 5 个单词或 10 个单词都能处理)。
  3. 局限性:梯度消失 / 爆炸

    • 标准 RNN 的最大问题是无法捕捉长距离依赖:当序列很长时(如超过 10 个时刻),梯度在反向传播过程中会快速衰减(梯度消失)或急剧增大(梯度爆炸);
    • 导致 RNN 只能记住短期记忆,无法记住长距离的信息(比如一句话开头的单词和结尾的单词的关联)。
RNN 原理
  • 循环核结构:核心是隐藏状态(H)的循环更新,公式:Ht=σ(XtWxh+Ht−1Whh+bh)H_t = \sigma(X_t W_{xh} + H_{t-1} W_{hh} + b_h)Ht=σ(XtWxh+Ht1Whh+bh)Yt=HtWhy+byY_t = H_t W_{hy} + b_yYt=HtWhy+by(Xt为当前输入,W为权重,b为偏置)。

  • 隐藏状态传递:每个时间步的隐藏状态依赖前一个时间步的隐藏状态,实现序列信息的记忆。

  • 梯度消失 / 爆炸问题:

    • 梯度消失:深层 RNN 中,早期时间步的梯度传递到输出时趋近于 0,模型无法学习长距离依赖。
    • 梯度爆炸:梯度值过大导致参数更新异常,可通过梯度裁剪缓解。
改进 RNN 结构(LSTM、GRU 原理与优势)

为了克服标准 RNN 的缺陷,研究者提出了LSTM(长短期记忆网络)GRU(门控循环单元),它们是目前实际应用中最常用的 RNN 变体:

LSTM(长短期记忆网络)
  • 核心结构:输入门、遗忘门、输出门 + 细胞状态(C),解决长距离依赖问题。

    • 遗忘门:控制遗忘多少历史细胞状态,ft=σ(XtWxf+Ht−1Whf+bf)f_t = \sigma(X_t W_{xf} + H_{t-1} W_{hf} + b_f)ft=σ(XtWxf+Ht1Whf+bf)
    • 输入门:控制更新多少新信息到细胞状态,it=σ(XtWxi+Ht−1Whi+bi)i_t = \sigma(X_t W_{xi} + H_{t-1} W_{hi} + b_i)it=σ(XtWxi+Ht1Whi+bi)C~t=tanh⁡(XtWxc+Ht−1Whc+bc)\tilde{C}_t = \tanh(X_t W_{xc} + H_{t-1} W_{hc} + b_c)C~t=tanh(XtWxc+Ht1Whc+bc)
    • 细胞状态更新:Ct=ft⊙Ct−1+it⊙C~tC_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_tCt=ftCt1+itC~t(⊙为按元素乘)。
    • 输出门:控制输出多少细胞状态到隐藏状态,ot=σ(XtWxo+Ht−1Who+bo)o_t = \sigma(X_t W_{xo} + H_{t-1} W_{ho} + b_o)ot=σ(XtWxo+Ht1Who+bo)
  • 优势:通过细胞状态的长期存储和门控机制,有效缓解梯度消失,能学习长距离依赖。

GRU(门控循环单元)
  • 核心结构:重置门、更新门,简化 LSTM 结构,参数更少,训练更快。

    • 更新门:融合遗忘门和输入门,控制保留多少历史隐藏状态,zt=σ(XtWxz+Ht−1Whz+bz)z_t = \sigma(X_t W_{xz} + H_{t-1} W_{hz} + b_z)zt=σ(XtWxz+Ht1Whz+bz)
    • 重置门:控制遗忘多少历史隐藏状态,rt=σ(XtWxr+Ht−1Whr+br)r_t = \sigma(X_t W_{xr} + H_{t-1} W_{hr} + b_r)rt=σ(XtWxr+Ht1Whr+br)
    • 候选隐藏状态:H~t=tanh⁡(XtWxh+rt⊙Ht−1Whh+bh)\tilde{H}_t = \tanh(X_t W_{xh} + r_t \odot H_{t-1} W_{hh} + b_h)H~t=tanh(XtWxh+rtHt1Whh+bh)
    • 隐藏状态更新:Ht=(1−zt)⊙Ht−1+zt⊙H~tH_t = (1 - z_t) \odot H_{t-1} + z_t \odot \tilde{H}_tHt=(1zt)Ht1+ztH~t
  • 优势:结构更简单,计算效率更高,效果接近 LSTM,适合数据量较大的场景。

RNN 的典型应用场景
  1. 自然语言处理(NLP):文本生成、机器翻译、情感分析、命名实体识别;
  2. 语音识别:将语音信号(时序数据)转换为文字;
  3. 时间序列预测:股票价格预测、天气预报、销量预测;
  4. 视频分析:视频帧的动作识别(视频是连续的图像序列)。
序列任务实现(文本分类、时序预测)
LSTM

为什么需要 LSTM?(从 RNN 的痛点说起)

在讲 LSTM 之前,先理解它要解决的问题 —— 传统 RNN 处理文本时的「记忆短板」:

  • 比如句子「我昨天看了一部电影,它的剧情特别精彩,我觉得这部____很好看」,RNN 在填最后一个空时,很难记住前面的「电影」这个词(长距离依赖);
  • 原因:RNN 的梯度会随着序列长度增加而「消失」或「爆炸」,导致模型记不住早出现的信息;
  • LSTM 的核心目标:专门设计了「记忆门控机制」,让模型能选择性地记住 / 忘记信息,解决长依赖问题

LSTM 的核心原理

可以把 LSTM 想象成一个「带记忆功能的信息处理工厂」,每个 LSTM 单元(处理一个单词)有 3 个核心「门」和 1 个「记忆细胞」:

组件 通俗功能(结合文本分类示例)
遗忘门(Forget Gate) 决定「忘记哪些旧信息」:比如处理到「精彩」时,忘记无关的「昨天」,保留「电影」
输入门(Input Gate) 决定「记住哪些新信息」:比如记住「剧情精彩」这个关键信息
输出门(Output Gate) 决定「输出哪些信息给下一个单元」:把「电影 + 剧情精彩」传给下一个单词的处理单元
记忆细胞(Cell State) 相当于「长期记忆」:全程保存关键信息(如「电影」),直到需要时使用

核心流程

我们的示例中,LSTM 处理的是「50 个单词的序列(One-Hot 编码后)」,每个单词的处理流程:

  1. 输入:当前单词的 128 维词向量(One-Hot 经 Linear 层转换);
  2. 遗忘门:筛选并忘记无关信息(如无意义的语气词);
  3. 输入门:把当前单词的关键信息存入记忆细胞;
  4. 输出门:生成当前步的输出(64 维隐藏状态);
  5. 传递:把输出传给下一个单词的 LSTM 单元,同时记忆细胞保留关键信息。

最终,50 个单词处理完后,LSTM 会输出一个 (32,50,64) 的张量(32 个样本,50 个单词,每个单词 64 维特征),这就是「序列的特征表示」。

LSTM 层的核心价值

// 我们示例中的LSTM层:捕捉50个单词的上下文依赖
net.add(createLSTMBlock(hiddenSize, 1)); // hiddenSize=64 → 输出64维特征
  • 输入:(32,50,128)(32 个样本,50 个单词,每个单词 128 维词向量);
  • 输出:(32,50,64)(32 个样本,50 个单词,每个单词 64 维上下文特征);
  • 关键价值:不再孤立处理每个单词,而是结合上下文(比如「不好看」中的「不」会反转「好看」的语义)

比如句子「这部电影一点都不好看」:

  • 传统 RNN 可能只记住「好看」,输出正面情感;
  • LSTM 能通过门控机制记住「不」这个否定词,正确识别为负面情感。

自定义数据集

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.dataset.Record;
import ai.djl.util.Progress;

import java.io.IOException;
import java.util.Random;

// 自定义IMDB数据集
public class IMDBTextDataset extends RandomAccessDataset {
    private final int vocabSize; // 词汇表大小
    private final int maxSeqLength; // 最大序列长度
    private final long[][] data;  // 存储文本数据(索引序列)
    private final long[] labels; // 存储标签(0/1)
    private final long datasetSize; // 数据集大小

    public static class Builder extends BaseBuilder<Builder> {
        private int vocabSize = 10000;
        private int maxSeqLength = 50; // 最大序列长度
        private NDManager manager;
        private long datasetSize = 1000; // 数据集大小

        public Builder setManager(NDManager manager) {
            this.manager = manager;
            return this;
        }

        public Builder setVocabSize(int vocabSize) {
            this.vocabSize = vocabSize;
            return this;
        }

        public Builder setMaxSequenceLength(int maxSeqLength) {
            this.maxSeqLength = maxSeqLength;
            return this;
        }

        public Builder setDatasetSize(long datasetSize) {
            this.datasetSize = datasetSize;
            return this;
        }

        public IMDBTextDataset build() {
            this.setSampling(32, false);
            return new IMDBTextDataset(this);
        }

        @Override
        protected Builder self() {
            return this;
        }
    }

    private IMDBTextDataset(Builder builder) {
        super(builder);
        this.vocabSize = builder.vocabSize;
        this.maxSeqLength = builder.maxSeqLength;
        this.datasetSize = builder.datasetSize;
        this.data = new long[(int) datasetSize][maxSeqLength];
        this.labels = new long[(int) datasetSize];

        // 预先生成数据(保存为基本类型数组)
        Random random = new Random(12345);
        for (int i = 0; i < datasetSize; i++) {
            for (int j = 0; j < maxSeqLength; j++) {
                data[i][j] = random.nextInt(vocabSize);
            }
            labels[i] = random.nextInt(2);
        }
    }

    @Override
    public Record get(NDManager manager, long index) throws IOException {
        if (index >= datasetSize) {
            throw new IndexOutOfBoundsException("索引超出范围");
        }

        // 每次创建新的NDArray,避免资源冲突
        long[] wordIndexes = data[(int) index];
        NDArray indexArray = manager.create(wordIndexes).reshape(new Shape(maxSeqLength));

        // 生成One-Hot编码
        NDArray oneHotText = indexArray.oneHot(vocabSize).toType(DataType.FLOAT32, false);

        // 创建标签
        NDArray label = manager.create(labels[(int) index]).reshape(new Shape(1));

        return new Record(new NDList(oneHotText), new NDList(label));
    }

    @Override
    protected long availableSize() {
        return datasetSize;
    }

    @Override
    public void prepare(Progress progress) {
        // 预处理:空实现
    }

    public static Builder builder() {
        return new Builder();
    }
}
One-Hot 编码

One-Hot 编码(独热编码)是一种将离散型特征(如单词、类别)转换为机器学习模型可理解的数值形式的编码方式。

  • 核心逻辑:为每个离散值创建一个长度等于「类别总数」的向量,只有该值对应的位置为 1,其余位置全为 0
  • 举例:如果词汇表包含 [“电影”, “好看”, “不好看”] 3 个词,那么:
    • “电影” → [1, 0, 0]
    • “好看” → [0, 1, 0]
    • “不好看” → [0, 0, 1]

在我们的 LSTM 文本分类示例中,文本是由单词组成的,而模型只能处理数字,所以需要把「单词」这个离散特征转为数字:

  1. 第一步:建立词汇表(词到索引的映射)

    比如把所有出现的单词编号:

    {"我":0, "爱":1, "这部":2, "电影":3, "烂":4}
    

    (词汇表大小 = 5)。

  2. 第二步:One-Hot 编码转换

    每个单词对应一个「长度 = 词汇表大小」的向量,只有编号对应的位置为 1:

    • “我” → [1, 0, 0, 0, 0]
    • “爱” → [0, 1, 0, 0, 0]
    • “烂” → [0, 0, 0, 0, 1]
  3. 第三步:文本序列转换

    比如句子「我 爱 这部 电影」→ 转为 4 个 One-Hot 向量的组合:

    [1,0,0,0,0]  # 我
    [0,1,0,0,0]  # 爱
    [0,0,1,0,0]  # 这部
    [0,0,0,1,0]  # 电影
    

    对应到我们代码中的形状就是

    (1, 4, 5)
    

    (1 个样本,4 个单词,词汇表大小 5)。

核心应用场景

在我们的 LSTM 文本分类示例中,One-Hot 是文本预处理的关键步骤,除此之外,它还广泛用于:

  1. 自然语言处理(NLP):单词 / 字符的数值化(如我们的示例);
  2. 类别特征处理:如性别(男 / 女)、职业(教师 / 医生 / 工程师)等离散特征;
  3. 多分类任务:标签的编码(如情感分类的「正面 = 0→[1,0],负面 = 1→[0,1]」)。

测试代码

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.core.Linear;
import ai.djl.nn.recurrent.LSTM;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.loss.SoftmaxCrossEntropyLoss;
import ai.djl.training.optimizer.Adam;

import java.lang.reflect.Field;

public class LSTMTextClassification {

    // LSTM层创建
    private static Block createLSTMBlock(int hiddenSize, int numLayers) {
        try {
            LSTM.Builder builder = LSTM.builder();
            builder.setStateSize(hiddenSize);
            // 反射设置numLayers
            Field numLayersField = builder.getClass().getSuperclass().getDeclaredField("numLayers");
            numLayersField.setAccessible(true);
            numLayersField.set(builder, numLayers);
            return builder.build();
        } catch (Exception e) {
            throw new RuntimeException("创建LSTM层失败", e);
        }
    }

    // 构建LSTM文本分类模型
    public static SequentialBlock getModel(int embeddingSize, int hiddenSize, int numClasses) {
        SequentialBlock net = new SequentialBlock();

        // 模拟Embedding层:One-Hot → Linear → 词向量
        net.add(Linear.builder().setUnits(embeddingSize).build());

        // LSTM层:捕捉序列依赖
        net.add(createLSTMBlock(hiddenSize, 1));

        // 池化层(将(32,50,64) → (32,64),解决序列长度维度问题)
        net.add(ndList -> {
            NDArray x = ndList.get(0);
            // 对序列维度(第1维)做均值池化,消除50这个维度
            return new NDList(x.mean(new int[]{1}));
        });

        // 分类输出层:输入(32,64) → 输出(32,2)(匹配标签形状)
        net.add(Linear.builder().setUnits(numClasses).build());

        return net;
    }

    // 主方法:完整训练流程
    public static void main(String[] args) throws Exception {
        // 1. 强制指定PyTorch引擎
        System.setProperty("ai.djl.default_engine", "PyTorch");

        // 2. 模型核心参数
        int vocabSize = 10000;     // 词汇表大小
        int embeddingSize = 128;   // 模拟词向量维度
        int hiddenSize = 64;       // LSTM隐藏层维度
        int numClasses = 2;        // 分类类别数(正面/负面)
        int batchSize = 32;        // 批量大小
        int maxSeqLength = 50;     // 序列长度

        // 3. 初始化核心组件
        try (NDManager manager = NDManager.newBaseManager()) {
            // 构建并初始化模型
            SequentialBlock modelBlock = getModel(embeddingSize, hiddenSize, numClasses);
            Model model = Model.newInstance("lstm-text-classification");
            model.setBlock(modelBlock);

            // 初始化模型:输入形状 (batchSize, maxSeqLength, vocabSize)
            Shape inputShape = new Shape(batchSize, maxSeqLength, vocabSize);
            modelBlock.initialize(manager, DataType.FLOAT32, inputShape);

            // 配置训练
            DefaultTrainingConfig trainingConfig = new DefaultTrainingConfig(new SoftmaxCrossEntropyLoss())
                    .optOptimizer(Adam.builder().build()) // Adam优化器
                    .optDevices(new Device[]{Device.cpu()}); // CPU训练

            // 加载数据集
            IMDBTextDataset trainDataset = IMDBTextDataset.builder()
                    .setManager(manager)
                    .setVocabSize(vocabSize)
                    .setMaxSequenceLength(maxSeqLength)
                    .setDatasetSize(1000) // 1000条模拟样本
                    .build();

            // 训练模型
            try (Trainer trainer = model.newTrainer(trainingConfig)) {
                trainer.initialize(inputShape); // 初始化训练器

                System.out.println("===== 开始训练LSTM文本分类模型 =====");
                System.out.println("📌 输入形状:" + inputShape);
                System.out.println("📌 数据集大小:" + trainDataset.size() + "\n");

                int epochs = 3; // 训练轮数
                for (int epoch = 0; epoch < epochs; epoch++) {
                    EasyTrain.fit(trainer, 1, trainDataset, null);
                    System.out.println("第 " + (epoch + 1) + " 轮训练完成\n");
                }

            } finally {
                model.close();
            }
        }
    }
}
时序预测(基于 GRU)

GRU 是什么?(和 LSTM 的关系)

GRU 是 2014 年提出的「轻量化 LSTM」,核心目标和 LSTM 一致 —— 解决传统 RNN 的长依赖问题,但做了大幅简化

  • 去掉了 LSTM 的「记忆细胞(Cell State)」,只用「隐藏状态(Hidden State)」存储信息;
  • 把 LSTM 的「遗忘门 + 输入门」合并为 1 个「更新门(Update Gate)」,保留「重置门(Reset Gate)」;
  • 最终只有 2 个门控(LSTM 是 3 个),参数更少、计算更快,效果却和 LSTM 接近。

简单总结:GRU = 简化版 LSTM,适合想减少计算量、又要处理长依赖的场景(比如我们的文本分类示例)。

GRU 的核心原理(对比 LSTM)

同样把 GRU 想象成「信息处理工厂」,但比 LSTM 更精简,只有 2 个核心门:

GRU 组件 通俗功能(结合文本分类示例) 对比 LSTM
重置门(Reset Gate) 决定「是否忘记之前的信息」:比如处理到「不好看」时,重置掉「好看」的正面信息,只保留「不」的否定 替代 LSTM 的「遗忘门」部分功能
更新门(Update Gate) 决定「多少旧信息保留 + 多少新信息加入」:比如保留「电影」的核心信息,加入「不」的否定信息 合并 LSTM 的「遗忘门 + 输入门」
隐藏状态(Hidden State) 同时承担 LSTM「隐藏状态 + 记忆细胞」的作用,全程保存关键信息 简化了 LSTM 的双状态设计

GRU 处理文本的核心流程(以「这部电影不好看」为例):

  1. 输入:「不」的 128 维词向量;
  2. 重置门:触发!决定忘记「好看」的正面信息;
  3. 更新门:触发!保留「电影」的核心信息,加入「不」的否定信息;
  4. 隐藏状态更新:把「电影 + 不」的组合信息存入隐藏状态;
  5. 输出:传给下一个单词的处理单元,最终体现「负面情感」。

GRU vs LSTM

维度 GRU LSTM
门控数量 2 个(重置门 + 更新门) 3 个(遗忘门 + 输入门 + 输出门)
状态数量 1 个(隐藏状态) 2 个(隐藏状态 + 记忆细胞)
参数数量 更少(计算更快,内存占用小) 更多(计算稍慢,内存占用大)
效果 多数场景和 LSTM 接近 极长序列(>100 步)效果略好
上手难度 更低(逻辑更简单) 更高(双状态 + 三门控)
适用场景 短序列(如我们的 50 词文本) 长序列(如长文本、语音)

对我们的「50 词 IMDB 情感分类」示例来说:用 GRU 完全可以替代 LSTM,且训练速度更快

示例代码

import ai.djl.*;
import ai.djl.ndarray.*;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.*;
import ai.djl.training.*;
import ai.djl.training.dataset.*;
import ai.djl.training.loss.*;
import ai.djl.training.optimizer.*;
import ai.djl.training.tracker.*;

/**
 * 完整实现sin波序列预测
 */
public class SimpleRNNExample {

    public static void main(String[] args) throws Exception {
        System.out.println("开始序列学习示例");
        System.out.println("目标:学习sin波序列的下一个值预测");

        try (NDManager manager = NDManager.newBaseManager()) {
            // 1. 准备数据 - sin波序列
            System.out.println("准备训练数据...");
            int seqLength = 10;  // 序列长度
            int numSamples = 1000;  // 样本数量

            NDArray data = manager.zeros(new Shape(numSamples, seqLength, 1));
            NDArray labels = manager.zeros(new Shape(numSamples, 1));

            // 生成sin波序列数据
            for (int i = 0; i < numSamples; i++) {
                float start = i * 0.1f;
                for (int j = 0; j < seqLength; j++) {
                    float x = start + j * 0.1f;
                    data.set(new NDIndex(i, j, 0), Math.sin(x));
                }
                // 标签:序列的下一个值
                labels.set(new NDIndex(i, 0), Math.sin(start + seqLength * 0.1f));
            }

            // 2. 创建数据集
            RandomAccessDataset dataset = new ArrayDataset.Builder()
                    .setData(data)
                    .optLabels(labels)
                    .setSampling(32, true)  // 批次大小32
                    .build();

            // 3. 创建模型 - 全连接网络
            Model model = Model.newInstance("sequence-model");
            ai.djl.nn.SequentialBlock block = new ai.djl.nn.SequentialBlock();

            // 展平输入 (batch, seq_len, 1) -> (batch, seq_len)
            block.add(ai.djl.nn.Blocks.batchFlattenBlock());

            // 全连接层 + ReLU激活
            block.add(ai.djl.nn.core.Linear.builder().setUnits(64).build());
            block.add(new ai.djl.nn.LambdaBlock(ndList ->
                    new NDList(ai.djl.nn.Activation.relu(ndList.singletonOrThrow()))
            ));

            block.add(ai.djl.nn.core.Linear.builder().setUnits(32).build());
            block.add(new ai.djl.nn.LambdaBlock(ndList ->
                    new NDList(ai.djl.nn.Activation.relu(ndList.singletonOrThrow()))
            ));

            // 输出层(预测单个值)
            block.add(ai.djl.nn.core.Linear.builder().setUnits(1).build());

            model.setBlock(block);

            // 4. 配置训练
            Loss l2Loss = Loss.l2Loss();
            TrainingConfig config = new DefaultTrainingConfig(l2Loss)
                    .optOptimizer(Optimizer.adam()
                            .optLearningRateTracker(Tracker.fixed(0.001f))
                            .build());

            // 5. 训练模型
            try (Trainer trainer = model.newTrainer(config)) {
                // 初始化模型参数(适配输入形状)
                trainer.initialize(new Shape(32, seqLength, 1));

                System.out.println("开始训练...");

                // 训练循环
                for (int epoch = 1; epoch <= 10; epoch++) {
                    System.out.println("\n第 " + epoch + " 轮训练");
                    float totalLoss = 0;
                    int batchCount = 0;

                    // 遍历数据集
                    for (Batch batch : trainer.iterateDataset(dataset)) {
                        try (GradientCollector gc = trainer.newGradientCollector()) {
                            // 手动前向+反向传播
                            NDList outputs = trainer.forward(batch.getData(), batch.getLabels());
                            NDArray loss = l2Loss.evaluate(batch.getLabels(), outputs);

                            // 计算批次平均损失
                            float batchLoss = loss.mean().getFloat();
                            totalLoss += batchLoss;
                            batchCount++;

                            // 反向传播+参数更新
                            gc.backward(loss);
                            trainer.step();

                        } finally {
                            batch.close();
                        }
                    }

                    // 计算并打印平均损失
                    float avgLoss = totalLoss / batchCount;
                    System.out.printf("平均损失: %.6f%n", avgLoss);

                    // 每2轮测试一次(核心:用Trainer验证,避免Block.forward的复杂调用)
                    if (epoch % 2 == 0) {
                        validateWithTrainer(trainer, manager, seqLength);
                    }
                }
            }

            System.out.println("\n训练完成!");

        } catch (Exception e) {
            System.err.println("训练出错: " + e.getMessage());
            e.printStackTrace();
        }
    }

    /**
     * 核心修复:用Trainer验证模型
     */
    private static void validateWithTrainer(Trainer trainer, NDManager manager, int seqLength) throws Exception {
        System.out.println("\n===== 验证模型预测能力 =====");

        // 创建单个测试样本(转为批次形状 [1, seqLength, 1])
        NDArray testData = manager.zeros(new Shape(1, seqLength, 1));
        NDArray testLabel = manager.zeros(new Shape(1, 1));

        // 填充测试数据
        float start = 0.0f;
        for (int j = 0; j < seqLength; j++) {
            float x = start + j * 0.1f;
            testData.set(new NDIndex(0, j, 0), Math.sin(x));
        }
        testLabel.set(new NDIndex(0, 0), Math.sin(start + seqLength * 0.1f));

        // 核心:用Trainer的forward验证(无需手动创建ParameterStore/Workspace)
        NDList inputs = new NDList(testData);
        NDList labels = new NDList(testLabel);
        NDList outputs = trainer.forward(inputs, labels);

        // 安全获取预测值
        NDArray pred = outputs.get(0);
        float predictedValue = pred.getFloat();  // 自动适配一维/二维张量
        float trueValue = testLabel.getFloat();

        // 输出结果
        System.out.printf("输入序列: sin(0.0), sin(0.1), ..., sin(%.1f)%n", (seqLength-1)*0.1f);
        System.out.printf("预测下一个值: %.6f%n", predictedValue);
        System.out.printf("真实下一个值: %.6f%n", trueValue);
        System.out.printf("绝对误差: %.6f%n", Math.abs(predictedValue - trueValue));
    }
}

执行结果

准备训练数据...
开始训练...

第 1 轮训练
平均损失: 0.079387

第 2 轮训练
平均损失: 0.001763

===== 验证模型预测能力 =====
输入序列: sin(0.0), sin(0.1), ..., sin(0.9)
预测下一个值: 0.847484
真实下一个值: 0.841471
绝对误差: 0.006013

第 3 轮训练
平均损失: 0.000243

第 4 轮训练
平均损失: 0.000120

===== 验证模型预测能力 =====
输入序列: sin(0.0), sin(0.1), ..., sin(0.9)
预测下一个值: 0.843613
真实下一个值: 0.841471
绝对误差: 0.002142

第 5 轮训练
平均损失: 0.000078

第 6 轮训练
平均损失: 0.000058

===== 验证模型预测能力 =====
输入序列: sin(0.0), sin(0.1), ..., sin(0.9)
预测下一个值: 0.838216
真实下一个值: 0.841471
绝对误差: 0.003254

第 7 轮训练
平均损失: 0.000049

第 8 轮训练
平均损失: 0.000038

===== 验证模型预测能力 =====
输入序列: sin(0.0), sin(0.1), ..., sin(0.9)
预测下一个值: 0.836871
真实下一个值: 0.841471
绝对误差: 0.004600

第 9 轮训练
平均损失: 0.000030

第 10 轮训练
平均损失: 0.000027

===== 验证模型预测能力 =====
输入序列: sin(0.0), sin(0.1), ..., sin(0.9)
预测下一个值: 0.849429
真实下一个值: 0.841471
绝对误差: 0.007958

训练完成!

import ai.djl.*;
import ai.djl.ndarray.*;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.*;
import ai.djl.nn.core.Linear;
import ai.djl.nn.recurrent.RNN;
import ai.djl.training.*;
import ai.djl.training.dataset.ArrayDataset;
import ai.djl.training.dataset.Batch;
import ai.djl.training.loss.Loss;
import ai.djl.training.optimizer.Adam;
import ai.djl.training.tracker.Tracker;

/**
 * 极简RNN示例
 * 核心映射:代码 → RNN公式 → LSTM/GRU改进点
 */
public class SimpleRnnDemo {

    public static void main(String[] args) {
        // 基础配置(序列长度=时间步,隐藏层维度=隐藏状态H的维度)
        int seqLength = 10;      // 时间步T=10(对应公式中的t=1~10)
        int inputSize = 1;       // 每个时间步输入维度X_t(仅注释说明)
        int hiddenSize = 16;     // 隐藏状态H_t的维度(对应公式中的W_hh权重维度)
        int sampleCount = 500;   // 训练样本数
        int batchSize = 16;      // 批次大小

        try (NDManager manager = NDManager.newBaseManager()) {
            // 1. 生成序列数据(模拟时间步输入)
            System.out.println("=== 步骤1:生成序列数据(X_t) ===");
            NDArray inputData = manager.zeros(new Shape(sampleCount, seqLength, inputSize));
            NDArray labelData = manager.zeros(new Shape(sampleCount, 1));

            // 生成sin序列:每个样本是长度为10的序列(X_1~X_10),标签是第11个值(Y)
            for (int i = 0; i < sampleCount; i++) {
                float start = i * 0.1f;
                for (int t = 0; t < seqLength; t++) {  // t=时间步(对应公式中的t)
                    float x = start + t * 0.1f;
                    inputData.set(new NDIndex(i, t, 0), Math.sin(x));  // X_t:第t个时间步的输入
                }
                labelData.set(new NDIndex(i, 0), Math.sin(start + seqLength * 0.1f));
            }

            // 2. 定义RNN模型
            System.out.println("=== 步骤2:定义RNN模型(对应核心公式) ===");
            Model model = Model.newInstance("simple-rnn");
            SequentialBlock network = new SequentialBlock();

            // 核心:RNN循环核(对应公式 H_t = tanh(X_t W_xh + H_{t-1} W_hh + b_h))
            network.add(
                    RNN.builder()
                            .setActivation(RNN.Activation.TANH)  // 激活函数tanh(对应公式中的σ)
                            .setStateSize(hiddenSize)             // 隐藏状态H_t的维度(必填)
                            .setNumLayers(1)                      // 单层RNN(避免梯度消失/爆炸)
                            .build()
            );

            // 取最后一个时间步的隐藏状态H_T(体现RNN的序列记忆)
            network.add(new LambdaBlock(ndList -> {
                NDArray rnnOutput = ndList.singletonOrThrow();  // RNN输出:(batch, seqLen, hiddenSize)
                return new NDList(rnnOutput.get(new NDIndex(":, -1, :")));  // 取最后一个时间步的H_T
            }));

            // 输出层(对应公式 Y_t = H_t W_hy + b_y)
            network.add(Linear.builder().setUnits(1).build());
            model.setBlock(network);

            // 3. 配置训练参数
            System.out.println("=== 步骤3:配置训练参数 ===");
            Loss l2Loss = Loss.l2Loss();

            // 配置Adam优化器
            Adam optimizer = Adam.builder()
                    .optLearningRateTracker(Tracker.fixed(0.001f))  // 固定学习率
                    .build();

            // 配置TrainingConfig
            DefaultTrainingConfig config = new DefaultTrainingConfig(l2Loss)
                    .optOptimizer(optimizer);  // 设置优化器(唯一配置项)

            // 4. 训练模型(观察隐藏状态传递)
            System.out.println("=== 步骤4:训练RNN模型(观察隐藏状态传递) ===");
            ArrayDataset dataset = new ArrayDataset.Builder()
                    .setData(inputData)
                    .optLabels(labelData)
                    .setSampling(batchSize, true)
                    .build();

            try (Trainer trainer = model.newTrainer(config)) {
                // 初始化参数:匹配输入形状 (batch, seqLen, inputSize)
                trainer.initialize(new Shape(batchSize, seqLength, inputSize));

                // 训练5轮(单层RNN易收敛,多层易出现梯度消失)
                for (int epoch = 1; epoch <= 5; epoch++) {
                    System.out.println("\n--- 第" + epoch + "轮训练 ---");
                    float totalLoss = 0;
                    int batchCount = 0;

                    for (Batch batch : trainer.iterateDataset(dataset)) {
                        try (GradientCollector gc = trainer.newGradientCollector()) {
                            // 前向传播:RNN自动完成H_1→H_2→...→H_T的循环计算
                            NDList outputs = trainer.forward(batch.getData(), batch.getLabels());
                            // 计算损失(L2损失适合回归预测)
                            NDArray loss = l2Loss.evaluate(batch.getLabels(), outputs);
                            // 反向传播:单层RNN梯度稳定(多层易消失,LSTM/GRU解决此问题)
                            gc.backward(loss);
                            // 更新参数
                            trainer.step();

                            // 累加损失值
                            totalLoss += loss.mean().getFloat();
                            batchCount++;
                        } finally {
                            batch.close();  // 释放资源(必写)
                        }
                    }

                    // 打印本轮平均损失
                    float avgLoss = totalLoss / batchCount;
                    System.out.println("平均损失值: " + String.format("%.6f", avgLoss));
                }

                // 5. 验证+原理对比
                System.out.println("\n=== 步骤5:模型验证 + RNN/LSTM/GRU核心原理 ===");
                // 测试序列:X_1~X_10 = sin(0)~sin(0.9)
                NDArray testData = manager.zeros(new Shape(1, seqLength, inputSize));
                for (int t = 0; t < seqLength; t++) {
                    testData.set(new NDIndex(0, t, 0), Math.sin(t * 0.1f));
                }
                // 真实值:第11个sin值(Y=sin(1.0))
                float trueValue = (float) Math.sin(seqLength * 0.1f);

                // 预测:RNN通过隐藏状态传递记忆序列信息
                NDList testOutputs = trainer.forward(new NDList(testData), new NDList(manager.zeros(new Shape(1, 1))));
                float predictedValue = testOutputs.get(0).getFloat();

                // 打印结果+核心原理(保留关键知识点)
                System.out.println("测试序列:X_1=sin(0.0), X_2=sin(0.1), ..., X_10=sin(0.9)");
                System.out.println("基于最后一个隐藏状态H_10的预测值: " + String.format("%.6f", predictedValue));
                System.out.println("真实值(sin(1.0)): " + String.format("%.6f", trueValue));
                System.out.println("预测误差: " + String.format("%.6f", Math.abs(predictedValue - trueValue)));
            }

            System.out.println("\n=== 训练圆满完成! ===");

        } catch (Exception e) {
            System.err.println("运行时错误: " + e.getMessage());
            e.printStackTrace();
        }
    }
}

深度学习实战:基于 DJL 的迁移学习与预训练模型应用

引言:迁移学习的价值与应用场景

在深度学习实践中,从零开始训练一个高性能模型往往面临两大核心难题:一是需要海量标注数据,二是训练过程消耗大量算力和时间。而迁移学习作为深度学习领域的核心实用技术,完美解决了这一痛点 —— 它将大规模通用数据集上训练好的预训练模型参数,迁移到数据量较小的特定任务中,既能复用通用特征、节省训练成本,又能显著提升小数据集场景下的任务效果(如小众品类图像分类、特定领域文本情感分析)。

对于 Java 开发者而言,DJL(Deep Java Library)提供了无需掌握 Python 深度学习框架即可调用工业级预训练模型的极简方案,是 Java 生态下落地迁移学习的首选工具。本文将从迁移学习核心原理出发,结合 DJL 框架,完整讲解预训练模型加载、迁移学习实操的全流程。

迁移学习核心原理与实施策略

核心思想

迁移学习的本质是 “知识复用”:将预训练模型在通用数据集(如 ImageNet 图像数据集、WikiText 文本数据集)上学到的通用特征(如图像的边缘 / 纹理、文本的语义编码),适配到新的特定任务中,避免从零训练模型。

三大关键实施策略

迁移学习的实施策略按复杂度从低到高可分为三类,核心是通过 “冻结层” 操作控制参数更新范围:

  1. 特征提取(Feature Extraction):冻结预训练模型的底层 / 中层网络(通用特征层),仅替换最后一层全连接分类层,用新任务数据仅训练该层。此时预训练层仅作为固定特征提取器,不参与参数更新。
  2. 微调(Fine-tuning):解冻预训练模型的部分高层网络(或全部网络),使用远小于从头训练的学习率,用新任务数据训练整个模型。该策略平衡 “通用特征复用” 与 “任务定制化”,让通用特征适配新任务专属特征。
  3. 冻结层(Freezing Layers):是特征提取和微调的核心操作,通过setFreeze(true)固定层参数(不参与反向传播),需训练的层则设置setFreeze(false),保障迁移学习参数更新策略落地。
DJL 预训练模型库:Java 调用预训练模型的核心工具

DJL 为 Java 开发者封装了预训练模型的加载与使用逻辑,无需关注底层框架(PyTorch/TensorFlow)细节,即可快速调用主流预训练模型。

核心能力与支持的模型

DJL 通过ModelZoo(模型库)和PretrainedModel(预训练模型)核心接口,原生支持两大领域的主流预训练模型:

  • 计算机视觉:ResNet、VGG、MobileNet、YOLO(覆盖图像分类、目标检测);
  • 自然语言处理:BERT、DistilBERT、GPT(覆盖文本分类、问答、情感分析)。

核心使用流程

加载预训练模型的核心是通过Criteria类配置参数,再调用ModelZoo.loadModel()完成加载,关键配置项如下:

配置项 作用
optApplication 指定模型应用场景(如CV.IMAGE_CLASSIFICATIONNLP.TEXT_CLASSIFICATION
setTypes 指定输入 / 输出数据类型(如图像分类输入Image、输出Classifications
optFilter 按属性过滤模型(如layers=50匹配 ResNet50)
optModelUrls 直接指定模型地址(兜底方案,避免模型匹配失败)
optProgress 添加进度条,可视化模型下载过程

实操要点

  1. 模型匹配:BERT 等模型易出现ModelNotFoundException,优先用optModelUrls直接指定模型地址;
  2. 资源管理:使用后调用close()释放ZooModel,避免内存泄漏;
  3. 异常处理:区分ModelException(模型匹配失败)和IOException(下载 / 读取失败);
  4. 缓存机制:DJL 将模型缓存到~/.djl.ai/cache,后续加载无需重复下载。

实战:加载预训练模型

以下代码实现 ResNet50(图像分类)和 DistilBERT(文本分类)预训练模型的加载,包含完整的异常处理和资源释放:

import ai.djl.ModelException;
import ai.djl.modality.Classifications;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;

import java.io.IOException;

public class PretrainedModelLoader {
    // 加载图像分类预训练模型(ResNet50)
    public static ZooModel loadImageClassificationModel() throws ModelException, IOException {
        Criteria<ai.djl.modality.cv.Image, Classifications> criteria = Criteria.builder()
                .optApplication(ai.djl.Application.CV.IMAGE_CLASSIFICATION)
                .setTypes(ai.djl.modality.cv.Image.class, Classifications.class)
                .optFilter("layers", "50") // 匹配ResNet50
                .optProgress(new ProgressBar())
                .build();
        return ModelZoo.loadModel(criteria);
    }

    // 加载文本分类预训练模型(DistilBERT,直接指定模型地址避免匹配失败)
    public static ZooModel loadTextModel() throws ModelException, IOException {
        Criteria<String, Classifications> criteria = Criteria.builder()
                .optModelUrls("djl://ai.djl.pytorch/distilbert") // 直接指定模型地址
                .setTypes(String.class, Classifications.class)
                .optProgress(new ProgressBar())
                .build();
        return ModelZoo.loadModel(criteria);
    }

    public static void main(String[] args) {
        ZooModel imageModel = null;
        ZooModel textModel = null;
        try {
            // 加载ResNet50
            imageModel = loadImageClassificationModel();
            System.out.println("ResNet50 预训练模型加载完成:" + imageModel.getName());

            // 加载DistilBERT
            textModel = loadTextModel();
            System.out.println("DistilBERT 预训练模型加载完成:" + textModel.getName());

        } catch (ModelException e) {
            System.err.println("模型加载失败:找不到匹配的预训练模型");
            e.printStackTrace();
        } catch (IOException e) {
            System.err.println("IO异常:模型文件下载/读取失败");
            e.printStackTrace();
        } finally {
            // 确保资源释放
            if (imageModel != null) {
                imageModel.close();
            }
            if (textModel != null) {
                textModel.close();
            }
        }
    }
}
迁移学习实操:基于 DJL 微调预训练模型

加载预训练模型是基础,结合自定义数据集完成迁移学习微调,才是落地特定任务的核心。以下以 ResNet50 为例,完整实现 “加载预训练模型→修改输出层→加载自定义数据集→冻结层→微调” 的全流程。

步骤 1:加载预训练模型并修改输出层

核心逻辑:加载 ResNet50 预训练模型后,替换最后一层分类器为自定义分类数的输出层,适配新任务:

import ai.djl.Application;
import ai.djl.Model;
import ai.djl.ModelException;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.core.Linear;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.Trainer;
import ai.djl.training.loss.Loss;

import java.io.IOException;

public class TransferLearningDemo {

    // 加载ResNet50预训练模型 + 修改输出层为指定分类数
    public static Model fineTuneResNet(int numClasses) throws ModelException, IOException {
        // 1. 加载ResNet50预训练模型
        Criteria<Image, Classifications> criteria = Criteria.builder()
                .optApplication(Application.CV.IMAGE_CLASSIFICATION)
                .setTypes(Image.class, Classifications.class)
                .optFilter("layers", "50") // 匹配ResNet50
                .build();
        ZooModel<Image, Classifications> pretrainedZooModel = ModelZoo.loadModel(criteria);
        Block pretrainedBlock = pretrainedZooModel.getBlock();

        // 2. 构建新输出层并替换
        Linear newOutputLayer = Linear.builder().setUnits(numClasses).build();
        SequentialBlock newModelBlock = new SequentialBlock();
        newModelBlock.add(pretrainedBlock).add(newOutputLayer);

        // 3. 初始化新模型(必须步骤,否则模型无法使用)
        Model fineTunedModel = Model.newInstance("resnet50-finetuned");
        fineTunedModel.setBlock(newModelBlock);

        // 初始化模型输入形状(ResNet50固定:1样本、3通道、224x224)
        try (Trainer trainer = fineTunedModel.newTrainer(new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()))) {
            trainer.initialize(new Shape(1, 3, 224, 224));
        }

        // 释放预训练模型资源
        pretrainedZooModel.close();

        return fineTunedModel;
    }

    // 测试入口:仅验证模型加载和输出层修改
    public static void main(String[] args) {
        try {
            Model model = fineTuneResNet(10); // 修改输出层为10分类
            System.out.println("模型加载并修改输出层成功:" + model.getName());
            model.close();
        } catch (ModelException | IOException e) {
            e.printStackTrace();
        }
    }
}

步骤 2:加载自定义数据集并完成微调

在步骤 1 基础上,新增自定义数据集加载、冻结层配置、训练流程,实现完整的迁移学习微调:

import ai.djl.Application;
import ai.djl.Model;
import ai.djl.ModelException;
import ai.djl.basicdataset.cv.classification.ImageFolder;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterList;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.core.Linear;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.dataset.Batch;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.loss.Loss;
import ai.djl.translate.Pipeline;
import ai.djl.translate.TranslateException;
import ai.djl.util.Pair;

import java.io.IOException;
import java.nio.file.Paths;

public class TransferLearningDemo {

    // 自定义数据集根路径(需替换为实际路径)
    private static final String DATASET_ROOT = "C:\\Users\\Administrator\\Desktop\\catAndDog";

    // 核心:加载ResNet50 + 修改输出层 + 加载自定义数据集 + 微调
    public static Model fineTuneResNet(int numClasses, int epochs) throws ModelException, IOException, TranslateException {
        // 1. 加载ResNet50预训练模型
        Criteria<Image, Classifications> criteria = Criteria.builder()
                .optApplication(Application.CV.IMAGE_CLASSIFICATION)
                .setTypes(Image.class, Classifications.class)
                .optFilter("layers", "50")
                .build();
        ZooModel<Image, Classifications> pretrainedZooModel = ModelZoo.loadModel(criteria);
        Block pretrainedBlock = pretrainedZooModel.getBlock();

        // 2. 替换输出层(适配自定义分类数)
        Linear newOutputLayer = Linear.builder().setUnits(numClasses).build();
        SequentialBlock newModelBlock = new SequentialBlock();
        newModelBlock.add(pretrainedBlock).add(newOutputLayer);

        // 3. 初始化模型
        Model fineTunedModel = Model.newInstance("resnet50-finetuned");
        fineTunedModel.setBlock(newModelBlock);

        // 4. 配置训练(损失函数+精度评估)
        DefaultTrainingConfig trainConfig = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
                .addEvaluator(new Accuracy()); // 仅保留核心精度评估

        try (Trainer trainer = fineTunedModel.newTrainer(trainConfig)) {
            trainer.initialize(new Shape(1, 3, 224, 224));

            // 5. 冻结预训练层(迁移学习核心):仅解冻最后2个参数(新输出层)
            ParameterList allParams = fineTunedModel.getBlock().getParameters();
            for (int i = 0; i < allParams.size(); i++) {
                Pair<String, Parameter> paramPair = allParams.get(i);
                Parameter param = paramPair.getValue();
                if (i > allParams.size() - 3) {
                    param.freeze(false); // 解冻新输出层
                } else {
                    param.freeze(true); // 冻结预训练层
                }
            }

            // 6. 加载自定义数据集(ImageFolder格式:按文件夹分类存储图片)
            Pipeline pipeline = new Pipeline();
            pipeline.add(new Resize(224, 224)).add(new ToTensor()); // 图像预处理

            // 训练集(批次大小8,避免内存溢出)
            ImageFolder trainDataset = ImageFolder.builder()
                    .setRepositoryPath(Paths.get(DATASET_ROOT))
                    .optPipeline(pipeline)
                    .setSampling(8, true)
                    .build();
            trainDataset.prepare();

            // 验证集(实际建议拆分train/test子目录)
            ImageFolder valDataset = ImageFolder.builder()
                    .setRepositoryPath(Paths.get(DATASET_ROOT))
                    .optPipeline(pipeline)
                    .setSampling(8, false)
                    .build();
            valDataset.prepare();

            // 7. 执行微调训练
            System.out.println("开始微调:分类数=" + numClasses + ",轮数=" + epochs);
            System.out.println("训练集样本数=" + trainDataset.size() + ",验证集=" + valDataset.size());

            for (int epoch = 0; epoch < epochs; epoch++) {
                // 训练轮次
                for (Batch batch : trainer.iterateDataset(trainDataset)) {
                    EasyTrain.trainBatch(trainer, batch);
                    trainer.step();
                    batch.close();
                }
                // 验证轮次
                for (Batch batch : trainer.iterateDataset(valDataset)) {
                    EasyTrain.validateBatch(trainer, batch);
                    batch.close();
                }
                System.out.println("完成第 " + (epoch + 1) + " 轮训练");
            }
        }

        pretrainedZooModel.close();
        return fineTunedModel;
    }

    // 测试入口
    public static void main(String[] args) {
        try {
            Model model = fineTuneResNet(10, 3); // 10分类,训练3轮
            System.out.println("\n模型微调完成!");
            model.close();
        } catch (Exception e) {
            System.err.println("异常:" + e.getMessage());
            e.printStackTrace();
        }
    }
}
全章节核心总结

迁移学习核心

  1. 核心价值:复用预训练模型的通用特征,解决小数据集、低算力场景下的模型训练问题;
  2. 核心操作:冻结底层通用特征层,替换输出层适配新任务,按需微调高层网络。

DJL 使用关键

  1. 模型加载:优先用optModelUrls指定模型地址,避免匹配失败;使用后必须调用close()释放资源;
  2. 迁移学习实操:加载预训练模型→替换输出层→冻结预训练层→加载自定义数据集→微调训练。

工程化要点

  1. 异常处理:区分ModelException(模型匹配)和IOException(下载 / 读取);
  2. 资源管理:模型缓存路径~/.djl.ai/cache,无需重复下载;
  3. 训练优化:减小批次大小避免内存溢出,冻结大部分层降低训练成本。

通过以上流程,你可以基于 DJL 快速落地 Java 端的迁移学习任务,无论是图像分类还是文本分析,都能借助预训练模型实现高效、高性能的模型开发。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐