在这里插入图片描述

[PyTorch Java 高校计算机硕士研一课程]

Spring Boot 集成 JavaCPP-PyTorch (2.10-1.5.13) 实战:在 Java 中运行 PyTorch 模型

在人工智能应用落地场景中,模型训练常依托 Python 生态的 PyTorch 框架,而生产环境多采用 Java 生态构建微服务。如何在 Spring Boot 项目中直接加载、运行 PyTorch 模型,避免跨语言调用的性能损耗与部署复杂度?JavaCPP-PyTorch 给出了完美解决方案。

本文将基于 JavaCPP-PyTorch 2.10-1.5.13 版本,手把手带你实现 Spring Boot 与 PyTorch 模型的无缝集成,覆盖环境搭建、依赖配置、模型加载、推理预测、接口封装全流程,打造生产级 AI 推理服务。

一、技术选型与核心概念

1. 核心组件介绍

  • Spring Boot 3.x:主流 Java 微服务框架,快速构建稳定、可扩展的后端服务,适配 AI 推理接口的生产部署。

  • JavaCPP-PyTorch:基于 JavaCPP 封装的 PyTorch Java 绑定库,无需安装 Python、无需部署 PyTorch 环境,直接在 JVM 上调用 PyTorch C++ 核心 API,支持模型加载、张量操作、GPU/CPU 推理。

  • 版本匹配:本次使用 org.bytedeco:pytorch-platform:2.10-1.5.13,对应 PyTorch 2.1 核心版本,兼容 Windows/Linux/Mac 多平台,支持 CPU/GPU 推理。

2. 适用场景

  • 生产环境需用 Java 部署 PyTorch 模型(图像分类、目标检测、文本向量、回归预测等)

  • 避免 Python 服务的性能瓶颈、部署复杂度

  • 微服务架构下,AI 推理能力与业务服务一体化

二、环境准备

  1. 开发环境:JDK 17+、Maven 3.6+、Spring Boot 3.2.x

  2. 模型准备:提前将 PyTorch 模型导出为 TorchScript 格式(.pt/.pth)(JavaCPP-PyTorch 仅支持 TorchScript 模型)

  3. 平台说明:依赖包自动适配操作系统,无需手动下载底层库

三、Spring Boot 项目搭建

1. 创建 Spring Boot 项目

初始化基础 Spring Boot 项目,引入 spring-boot-starter-web用于接口封装。

2. 核心 Maven 依赖(关键)

JavaCPP-PyTorch 需引入平台通用依赖,自动加载对应系统的底层库,pom.xml 配置如下:

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>
    <parent>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-parent</artifactId>
        <version>3.2.5</version>
        <relativePath/>
    </parent>

    <groupId>com.example</groupId>
    <artifactId>springboot-pytorch-demo</artifactId>
    <version>0.0.1-SNAPSHOT</version>

    <properties>
        <java.version>17</java.version>
        <javacpp.version>1.5.13</javacpp.version>
        <pytorch.version>2.10-1.5.13</pytorch.version>
    </properties>

    <dependencies>
        <!-- Spring Boot Web 核心依赖 -->
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>

        <!-- JavaCPP 核心依赖 -->
        <dependency>
            <groupId>org.bytedeco</groupId>
            <artifactId>javacpp</artifactId>
            <version>${javacpp.version}</version>
        </dependency>

        <!-- PyTorch Java 绑定 平台通用依赖(核心) -->
        <dependency>
            <groupId>org.bytedeco</groupId>
            <artifactId>pytorch-platform</artifactId>
            <version>${pytorch.version}</version>
        </dependency>

        <!-- 测试依赖 -->
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
        </dependency>
    </dependencies>

    <build>
        <plugins>
            <plugin>
                <groupId>org.springframework.boot</groupId>
                <artifactId>spring-boot-maven-plugin</artifactId>
            </plugin>
        </plugins>
    </build>
</project>

注意:首次加载会自动下载 500MB+ 底层库,属于正常现象,包含 PyTorch 核心运行时。

四、PyTorch 模型封装与推理工具类

创建单例模式的模型加载类,避免重复加载模型导致的内存泄漏,适配 Spring Boot 启动时初始化模型。

import org.bytedeco.pytorch.IValue;
import org.bytedeco.pytorch.Module;
import org.bytedeco.pytorch.Tensor;
import org.springframework.stereotype.Component;
import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;
import java.io.File;
import static org.bytedeco.pytorch.global.torch.*;

/**
 * PyTorch 模型加载与推理工具类(单例)
 * 适配 JavaCPP-PyTorch 2.10-1.5.13
 */
@Component
public class PyTorchModelManager {

    // PyTorch 模型对象(线程安全,可并发推理)
    private Module model;

    /**
     * Spring 容器初始化时加载模型
     */
    @PostConstruct
    public void loadModel() {
        try {
            // 模型路径(resources 目录下的 model.pt)
            String modelPath = new File("src/main/resources/model.pt").getAbsolutePath();
            // 加载 TorchScript 模型
            model = load(modelPath);
            // 设置为推理模式(禁用梯度计算,提升性能)
            model.eval();
            System.out.println("✅ PyTorch 模型加载成功,推理模式已开启");
            // 打印设备信息(CPU/GPU)
            System.out.println("✅ 运行设备:" + (cuda_is_available() ? "GPU" : "CPU"));
        } catch (Exception e) {
            System.err.println("❌ 模型加载失败:" + e.getMessage());
            throw new RuntimeException("模型初始化失败");
        }
    }

    /**
     * 模型推理核心方法
     * @param inputTensor 输入张量
     * @return 推理结果张量
     */
    public Tensor infer(Tensor inputTensor) {
        try (IValue input = new IValue(inputTensor);
             IValue output = model.forward(input)) {
            // 释放输入张量内存
            inputTensor.close();
            return output.toTensor();
        }
    }

    /**
     * 数组转 PyTorch 张量(通用转换方法)
     */
    public Tensor arrayToTensor(float[] data, long[] shape) {
        return tensor_from_float_buffer(data, shape);
    }

    /**
     * 张量转数组(结果解析)
     */
    public float[] tensorToArray(Tensor tensor) {
        float[] result = new float[(int) tensor.numel()];
        tensor.getDataAsFloatBuffer().get(result);
        // 释放张量内存
        tensor.close();
        return result;
    }

    /**
     * 容器销毁时释放模型资源
     */
    @PreDestroy
    public void close() {
        if (model != null) {
            model.close();
            System.out.println("✅ PyTorch 模型资源已释放");
        }
    }
}

五、封装 AI 推理接口

创建 Controller 层,对外提供 HTTP 接口,接收业务参数,完成模型推理并返回结果。

import org.bytedeco.pytorch.Tensor;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import javax.annotation.Resource;
import java.util.HashMap;
import java.util.Map;

@RestController
@RequestMapping("/ai/pytorch")
public class PyTorchInferController {

    @Resource
    private PyTorchModelManager modelManager;

    /**
     * 通用推理接口
     * @param params 输入参数(示例:一维数组)
     * @return 推理结果
     */
    @PostMapping("/infer")
    public Map<String, Object> infer(@RequestBody float[] params) {
        Map<String, Object> result = new HashMap<>();
        try {
            // 1. 构造张量形状(根据模型输入定义)
            long[] shape = {1, params.length};
            // 2. 数组转张量
            Tensor inputTensor = modelManager.arrayToTensor(params, shape);
            // 3. 模型推理
            Tensor outputTensor = modelManager.infer(inputTensor);
            // 4. 张量转结果数组
            float[] output = modelManager.tensorToArray(outputTensor);
            // 5. 封装返回结果
            result.put("code", 200);
            result.put("msg", "推理成功");
            result.put("data", output);
        } catch (Exception e) {
            result.put("code", 500);
            result.put("msg", "推理失败:" + e.getMessage());
        }
        return result;
    }
}

六、模型准备与测试

1. 模型导出(Python 端)

JavaCPP-PyTorch 仅支持 TorchScript 模型,Python 端导出代码示例:

import torch

# 1. 定义/加载训练好的模型
class MyModel(torch.nn.Module):
    def forward(self, x):
        return x * 2

model = MyModel()
# 2. 导出为 TorchScript 模型
example_input = torch.rand(1, 5)  # 示例输入
traced_model = torch.jit.trace(model, example_input)
# 3. 保存模型
traced_model.save("src/main/resources/model.pt")

2. 启动项目测试

启动 Spring Boot 项目,使用 Postman/Curl 调用接口:

  • 请求地址POST http://localhost:8080/ai/pytorch/infer

  • 请求参数[1.0, 2.0, 3.0, 4.0, 5.0]

  • 返回结果:推理成功,返回模型计算后的数组

七、关键优化与生产注意事项

1. 性能优化

  • 推理模式:必须调用 model.eval(),禁用梯度计算,提升 30%+ 推理速度

  • 内存管理:及时调用 close() 释放张量/模型资源,避免 JVM 内存溢出

  • 并发推理:PyTorch Module 线程安全,可直接支持多线程并发推理

2. 跨平台适配

pytorch-platform依赖自动适配 Windows/Linux/Mac,打包部署时无需修改代码,直接运行。

3. GPU 加速

若服务器配备 NVIDIA 显卡,安装 CUDA 后,JavaCPP-PyTorch 自动调用 GPU 推理,无需额外配置。

八、总结

通过 Spring Boot + JavaCPP-PyTorch 2.10-1.5.13,我们实现了:

  1. 纯 Java 环境运行 PyTorch 模型,无需 Python 依赖

  2. 一键集成,底层库自动加载,开发成本极低

  3. 生产级部署,线程安全、内存可控、性能优异

  4. 微服务一体化,AI 推理能力与业务服务无缝融合

该方案完美解决了 Java 生态与 PyTorch 模型的落地壁垒,适用于图像识别、文本处理、数据预测等各类 AI 场景,是生产环境部署 PyTorch 模型的最优选择之一。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐