在这里插入图片描述

PyTorch Java 高校计算机硕士研一课程

章节 4: 模型部署和性能优化

在开发和训练复杂的 PyTorch 模型之后,重点转向如何将它们投入实际应用。本章将介绍模型部署和性能优化,提供使模型在推理时更快、更小、更节省资源的方法。

我们将介绍使用 TorchScript 进行模型序列化,学习跟踪和脚本两种方法。您将学习模型压缩技术,包括量化(静态、动态和量化感知训练)和剪枝策略,以减小模型大小和计算需求。我们将使用 PyTorch Profiler 来识别 CPU 和 GPU 执行中的性能瓶颈。此外,您还将学习将模型导出为 ONNX 格式以获得更广泛的兼容性,并学习使用 TorchServe 高效地提供模型服务。

在本章结束时,您将掌握分析模型性能的实用技能,并运用多种优化技术,这对于将 PyTorch 模型从开发环境部署到生产环境是必不可少的。

TorchScript 基础: 追踪与脚本化

收藏

PyTorch 模型训练完成后,要将其部署到生产环境或嵌入到应用程序中,需要一种不依赖 Python 运行时、可序列化且针对推理进行优化的格式。TorchScript 通过将 PyTorch 模型转换为中间表示 (IR) 来实现这一功能,该 IR 可以在 C++ 服务器或移动设备等环境中保存、加载和执行,而无需 Python 依赖。

TorchScript 在 PyTorch 灵活的即时执行模式(其中操作按照 Python 中的定义立即运行)与部署环境通常需要的静态图和性能优化之间架起了一座桥梁。它通过两种主要方法实现此目的:追踪脚本化。了解这两种方法的区别对于有效使用 TorchScript 进行模型部署是十分重要的。

使用 torch.jit.trace 进行追踪

追踪通过使用一组示例输入执行 PyTorch 模型,并记录在此特定执行过程中执行的操作序列来运作。这个被记录的序列,或称“追踪”,随后被转换为封装在 torch.jit.ScriptModule 中的静态图表示。

工作原理: 当你调用 torch.jit.trace(model, example_inputs) 时,PyTorch 会使用所提供的 example_inputs 运行模型的 forward 方法。每次操作执行时,PyTorch 都会对其进行记录。生成的 ScriptModule 本质上包含该单次前向传播期间计算图的冻结快照。

示例:

import torch
import torch.nn as nn

class SimpleModel extends nn.Module:
    def __init__(self):
        super().__init__()
        val linear = nn.Linear(10, 5)

    def forward(x: Tensor):
        // 简单、直接的计算
        return torch.relu(linear(x))

// 实例化模型
val model = SimpleModel()
model.eval() // 设置为评估模式

// 提供示例输入
val example_input = torch.randn(1, 10)

// 追踪模型
val traced_model = torch.jit.trace(model, example_input)

// 打印生成的 TorchScript 代码(通常与追踪结果相似)
println(traced_model.code)

// 打印底层图表示
println(traced_model.graph)

// 测试追踪后的模型
val output = traced_model(example_input)
println("输出形状:", output.shape)
package vals;

import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.annotation.ByVal;
import org.bytedeco.javacpp.annotation.Namespace;
import org.bytedeco.javacpp.annotation.StdString;
import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.Module;
import org.bytedeco.pytorch.global.torch;

import java.util.Arrays;

/**
 * Java实现PyTorch模型TorchScript追踪(等效torch.jit.trace)
 * 核心:基于torch::jit::trace_module封装,还原Scala的模型追踪、代码/图打印、测试输出逻辑
 */
public class TorchScriptTraceDemo {

    // ======================== 第一步:自定义SimpleModel(等效Scala的SimpleModel) ========================
    public static class SimpleModel extends Module {
        public LinearImpl linear; // 线性层:10→5

        public SimpleModel() {
            super("SimpleModel");
            // 初始化线性层:nn.Linear(10,5)
            linear = new LinearImpl(10, 5);
            // 注册子模块(必需,否则trace_module无法识别)
            register_module("linear", linear);
        }

//        @Override
        public Tensor forward(IValue x) {
            // 前向传播:linear(x) → relu激活(等效Scala的torch.relu(linear(x)))
            Tensor input = x.toTensor();
            Tensor linearOut = linear.forward(input);
            Tensor reluOut = torch.relu(linearOut);

            // 释放临时张量
            linearOut.close();
            return reluOut;
        }

        // 释放子模块资源
        @Override
        public void close() {
            if (!isNull()) {
                linear.close();
                super.close();
            }
        }
    }

    public static void main(String[] args) {
        // ======================== 第二步:实例化模型并设置评估模式 ========================
        SimpleModel model = new SimpleModel();
        model.eval(); // 等效Scala的model.eval(),设置评估模式(禁用Dropout/BatchNorm等)
        System.out.println("SimpleModel已实例化并设置为评估模式");

        // ======================== 第三步:创建示例输入(1,10) ========================
        Tensor exampleInput = torch.randn(new long[]{1, 10});
        System.out.println("示例输入形状:" + Arrays.toString(exampleInput.sizes().vec().get()));

        // ======================== 第四步:TorchScript追踪(核心:等效torch.jit.trace) ========================
        // 1. 封装模型为IValue(trace_module要求输入为IValue)
        IValue modelIValue = IValue.fromModule(model);
        //StringTensorVector
        // 2. 封装输入:trace_module要求输入为「方法名→输入列表」的映射,默认追踪"forward"方法
//        StringIValueMap inputs = new StringIValueMap();
        StringTensorVector inputs = new StringTensorVector();
        IValueVector inputVec = new IValueVector();
        inputVec.push_back(new IValue(exampleInput));
        inputs.put(new BytePointer("forward"), inputVec);

        // 3. 执行追踪(核心API:torch::jit::trace_module)
        //    参数:模型IValue、输入映射、追踪选项(默认即可)
        ScriptModule tracedModel = trace_module(modelIValue, inputs);
        System.out.println("\n=== 模型追踪完成,生成ScriptModule ===");

        // ======================== 第五步:打印TorchScript代码和图表示 ========================
        // 1. 打印生成的TorchScript代码(等效Scala的traced_model.code)
        BytePointer codePtr = tracedModel.code();
        String code = codePtr.getString();
        System.out.println("=== 生成的TorchScript代码 ===");
        System.out.println(code);

        // 2. 打印底层图表示(等效Scala的traced_model.graph)
        BytePointer graphPtr = tracedModel.graph();
        String graph = graphPtr.getString();
        System.out.println("\n=== 底层Graph表示 ===");
        System.out.println(graph);

        // ======================== 第六步:测试追踪后的模型 ========================
        Tensor output = tracedModel.forward(new IValue(exampleInput)).toTensor();
        long[] outputShape = output.sizes().vec().get();
        System.out.println("\n=== 测试追踪后的模型 ===");
        System.out.println("输出形状: " + Arrays.toString(outputShape)); // 预期:[1,5]

        // ======================== 资源释放(JNI资源必须手动释放) ========================
        output.close();
        graphPtr.close();
        codePtr.close();
        tracedModel.close();
        inputVec.close();
//        inputs.close();
        modelIValue.close();
        exampleInput.close();
        model.close();
    }

    // ======================== 补充:JavaCPP中torch::jit的核心API封装(确保编译通过) ========================
    // 注:以下接口已包含在bytedeco-pytorch的JNI绑定中,无需手动实现,仅作说明
    @Namespace("torch::jit")
    public static class ScriptModule extends Pointer {
        public ScriptModule(Pointer p) { super(p); }
        public native @ByVal BytePointer code(); // 获取TorchScript代码
        public native @ByVal BytePointer graph(); // 获取Graph表示
        public native @ByVal IValue forward(@ByVal IValue input); // 前向传播
        public native void close(); // 释放资源
    }

    // trace_module:JavaCPP封装的核心追踪API(对应C++ torch::jit::trace_module)
    @Namespace("torch::jit")
    public static native @ByVal ScriptModule trace_module(
            @ByVal IValue module,
            @ByVal StringTensorVector inputs
    );
}

追踪的优点:

  • 简便性: 通常易于应用,仅需模型和示例输入。
  • 现有代码: 对许多现有模型运行良好,无需代码修改,只要其结构是静态的。

追踪的局限性: 追踪的主要局限在于它无法捕获数据依赖的控制流。因为追踪只记录针对特定示例输入执行的操作,所以任何行为依赖于输入张量的条件语句 (if) 或循环 (forwhile) 都不会在追踪图中正确表示。追踪只包含示例输入所采用的路径。

考虑这个修改后的模型:

class ControlFlowModel extends nn.Module:
    def __init__(self):
        super().__init__()
        val linear1 = nn.Linear(10, 5)
        val linear2 = nn.Linear(5, 1)

    def forward(x: Tensor):
        x = torch.relu(linear1(x))
        // 数据依赖的控制流
        if x.mean() > 0.5:
            return linear2(x)
        else:
            return torch.zeros_like(linear2(x))

// 实例化模型
val model_cf = ControlFlowModel()
model_cf.eval()

// 示例输入 1(可能触发 'if' 分支)
val input1 = torch.randn(1, 10) * 2
val traced_model_cf1 = torch.jit.trace(model_cf, input1)

// 示例输入 2(可能触发 'else' 分支)
val input2 = torch.randn(1, 10) * -2
// 注意:使用 input2 进行追踪会产生 *不同* 的追踪结果!

println(f"输入 1 均值: {input1.mean().item()}")
println(f"输入 2 均值: {input2.mean().item()}")

// 将两个输入都通过使用 input1 追踪的模型运行
val output1_trace1 = traced_model_cf1(input1)
val output2_trace1 = traced_model_cf1(input2) // 如果 input2 走 'else' 路径,这很可能是错误的

println(f"输入 1 均值: {input1.mean().item()}")
println(f"输入 2 均值: {input2.mean().item()}")

// 将两个输入都通过使用 input1 追踪的模型运行
val output1_trace1 = traced_model_cf1(input1)
val output2_trace1 = traced_model_cf1(input2) // 如果 input2 走 'else' 路径,这很可能是错误的

println(f"输入 1 的输出(用 input1 追踪): {output1_trace1.item()}")
println(f"输入 2 的输出(用 input1 追踪): {output2_trace1.item()}") // 无论 input2 的均值如何,都遵循追踪到的路径

// 与即时执行进行比较
val output1_eager = model_cf(input1)
val output2_eager = model_cf(input2)
println(f"输入 1 的输出(即时执行): {output1_eager.item()}")
println(f"输入 2 的输出(即时执行): {output2_eager.item()}") // 正确使用了 'else' 路径

在上面的示例中,traced_model_cf1 总是会执行使用 input1 追踪时记录的操作序列,无论新输入是否应该实际触发 else 分支。

使用 torch.jit.script 进行脚本化

脚本化采取了不同的方法。torch.jit.script 不会执行代码并记录操作,而是直接使用 TorchScript 编译器分析你的 Python 源代码。该编译器能够识别 Python 语言的一个子集(包括 ifforwhile 等控制流结构),并将其转换为 TorchScript IR。

工作原理: You可以通过在函数或整个 nn.Module 类上使用 @torch.jit.script 装饰器,或通过在实例或函数上调用 torch.jit.script() 来应用脚本化。编译器会解析 Python 代码,检查与 TorchScript 语言子集的兼容性,并生成一个 ScriptModuleScriptFunction,它准确地表示了原始逻辑,包括控制流。

示例: 让我们对 ControlFlowModel 进行脚本化:

// 沿用之前的 ControlFlowModel 类

val model_cf = ControlFlowModel()
model_cf.eval()

// 脚本化模型实例
val scripted_model = torch.jit.script(model_cf)

println(scripted_model.code) // 打印 TorchScript 代码,包括 if/else

// 使用不同输入进行测试
val input1 = torch.randn(1, 10) * 2
val input2 = torch.randn(1, 10) * -2

println(f"\n输入 1 均值: {input1.mean().item()}")
println(f"输入 2 均值: {input2.mean().item()}")

// 脚本化模型实例运行
val output1_script = scripted_model(input1)
val output2_script = scripted_model(input2)

println(f"输入 1 的输出(脚本化): {output1_script.item()}")
println(f"输入 2 的输出(脚本化): {output2_script.item()}") // 正确处理了控制流

// 与即时执行进行比较(应该匹配)
val output1_eager = model_cf(input1)
val output2_eager = model_cf(input2)
println(f"输入 1 的输出(即时执行): {output1_eager.item()}")
println(f"输入 2 的输出(即时执行): {output2_eager.item()}") // 正确处理了控制流

正如你所见,脚本化的模型正确处理了数据依赖的控制流,因为 if/else 逻辑被编译器直接翻译了。

脚本化的优点:

  • 处理控制流: 精确捕获数据依赖的条件逻辑和循环。
  • 通用性: 生成更通用的模型表示,不与转换期间使用的特定输入形状或值绑定。
  • 稳定性: 更适合具有动态行为的复杂模型。

脚本化的局限性:

  • Python 子集: 要求代码符合 TorchScript 语言子集。并非所有 Python 功能或库都受支持(例如,任意外部库调用、高度动态的元编程)。你可能需要重构部分模型代码才能使其可脚本化。
  • 调试: 编译器错误有时可能不如标准 Python 错误直观,需要仔细检查有问题的代码段。

追踪与脚本化的选择

追踪和脚本化之间的选择主要取决于模型 forward 方法的性质:

是否需要部署 PyTorch 模型?使用 TorchScript模型/模块是否使用数据依赖的控制流(基于输入的 if、for 循环)?使用追踪(torch.jit.trace)否使用脚本化(torch.jit.script)是考虑混合方法(脚本化带控制流的部分,追踪较简单的部分)可能可序列化且可优化的ScriptModule

根据模型控制流决定使用 TorchScript 追踪还是脚本化。

  • 在以下情况下使用追踪(torch.jit.trace):
    • 你的模型或模块不含数据依赖的控制流
    • 计算图是静态的,无论输入值如何(尽管如果适当追踪,它可以依赖输入形状)。
    • 你希望快速捕获简单模块的操作。
  • 在以下情况下使用脚本化(torch.jit.script):
    • 你的模型包含 if 语句、for 循环或其行为依赖于所处理张量值的其他结构。
    • 你需要一种能够在不同输入下正确运行的表示,这些输入可能触发不同的执行路径。
    • 你愿意确保代码符合 TorchScript 子集。

混合方法: 也可以混合使用追踪和脚本化。你可以脚本化一个内部调用追踪子模块的模块,反之亦然。通常,你可能会脚本化包含控制流的主模型,并在其中追踪更简单、静态的组件。

序列化与使用

一旦你拥有了一个 ScriptModule(无论是通过追踪还是脚本化获得),你就可以方便地将其保存到文件并稍后加载,可能在不同的环境中:

// 保存脚本化模型
torch.jit.save(scripted_model, 'control_flow_model.pt')

// 稍后加载模型(可能在另一个进程或 C++ 中)
val loaded_model = torch.jit.load('control_flow_model.pt')
loaded_model.eval()

// 使用加载的模型
val output_loaded = loaded_model(input2)
println(f"加载模型的输出: {output_loaded.item()}") // 正确处理了控制流

这个保存的 .pt 文件包含模型的架构、参数以及执行所需的 TorchScript 代码/图,使其成为一个用于部署的自包含工件。

掌握 TorchScript,特别是追踪和脚本化之间的区别,是使你的 PyTorch 模型为在生产环境中进行高效且稳定部署做准备的重要一步。通过选择合适的方法,你可以创建优化的、独立的模型版本,以进行推理。

模型量化技术

减少深度学习模型的计算开销对于部署非常重要,尤其是在资源受限设备或对延迟敏感的应用上。模型量化是一种有效方法,通过将模型转换为使用低精度数值格式(通常是8位整数INT8),而不是训练期间使用的标准32位浮点(FP32)表示,来达成此目的。这种转换带来明显优势:

  • 减小模型大小: 精度降低意味着存储模型参数(权重和偏置)所需的内存减少,从FP32到INT8通常可以减少4倍。
  • 更快的推理: 在许多硬件平台,特别是CPU和专用加速器(如NPU或DSP)上,整数算术运算通常比浮点运算快得多。这会降低推理延迟。
  • 更低的功耗: 更快的执行和更简单的算术运算通常会减少能耗,这对于移动和边缘设备很重要。

然而,量化并非没有代价。用更少位数表示值会引入近似误差,这可能降低模型精度。目标是尽量减少这种精度下降,同时尽可能提高性能收益。PyTorch提供了一个torch.quantization工具包来实现多种量化策略。

量化核心理念

本质上,量化涉及将一系列浮点值映射到较小范围的整数值。最常用方案是仿射量化,由两个参数定义:比例因子 (SS) 和 零点 (ZZ)。比例因子是一个正浮点数,决定量化的步长;零点是一个整数,对应实数0.0。

从实数rr (FP32) 到其量化整数表示qq (例如INT8) 的映射由以下公式给出:

q=截断(取整(r/S+Z))q=截断(取整(r/S+Z))

反向映射(反量化)从qq 返回到近似实数r′r′ 为:

r′=(q−Z)×Sr′=(qZS

round操作将值四舍五入到最近的整数,clamp确保结果保持在目标整数类型的有效范围内(例如,有符号INT8为[-128, 127])。比例因子SS和零点ZZ是根据被量化的浮点值的范围(例如,在权重或激活中观察到的最小值/最大值)来确定的。

量化可以应用于逐张量(对整个张量使用一个SS和ZZ)或逐通道(对每个通道使用独立的SS和ZZ值,通常沿卷积权重的输出通道轴)。逐通道量化通常为卷积层带来更高的精度,但会增加一些复杂性。

PyTorch支持三种主要模型量化方法:

1. 动态量化(训练后动态量化)

这通常是最简单的方法。

  • 工作方式: 权重离线量化(转换为INT8并存储)。而激活在推理期间“即时”量化。支持动态量化的算子(如nn.Linearnn.LSTM)会动态量化激活,使用高效的INT8核进行计算,然后在传递给下一次操作前,将结果反量化回FP32。
  • 优点: 非常容易实现,无需改变模型定义或训练过程,并且不需要校准数据集。
  • 缺点: 激活的动态量化/反量化会引入运行时开销。性能提升通常不如静态量化明显,特别是对于计算量通常超过内存带宽的卷积网络。精度可能低于静态方法。
  • 使用场景: 很好的起点。特别适用于权重大小是瓶颈的模型,或LSTMs和Transformers等序列模型,其中全连接层通常主导计算。

以下是您可能对模型应用动态量化的方式:

import torch
import torch.quantization
import torch.nn as nn

// 假设 'model_fp32' 是您已训练好的FP32模型
// 确保模型处于评估模式
model_fp32.eval()

// 指定要动态量化的层
// 通常侧重于 nn.Linear, nn.LSTM, nn.GRU
val quantized_model = torch.quantization.quantize_dynamic(
    model=model_fp32,
    qconfig_spec=Set(nn.Linear, nn.LSTM), // 要量化的层类型集合
    dtype=torch.qint8                   // 目标数据类型
)

// 现在 'quantized_model' 可以用于推理
input_fp32 = torch.randn(1, input_size) // 示例输入
val output = quantized_model(input_fp32)
package vals;

import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.Module;
import org.bytedeco.pytorch.global.torch;

import java.util.HashSet;
import java.util.Set;

/**
 * 手动实现PyTorch动态量化(quantize_dynamic):
 * 核心逻辑:遍历模型→识别目标层→替换为量化层→权重离线量化→激活动态量化
 * 替代方案:无quantized_matmul时,反量化后用FP32矩阵乘法模拟量化计算
 */
public class ManualDynamicQuantization {

    // ======================== 1. 定义量化Linear层(动态量化核心) ========================
    /**
     * 动态量化Linear层:
     * - 权重离线量化为qint8(压缩内存)
     * - 输入激活推理时动态量化,反量化后用FP32计算(替代quantized_matmul)
     */
    public static class QuantizedLinear extends Module {
        // 量化权重(qint8)+ 缩放因子 + 零点(用于反量化)
        public Tensor weightQ;
        public float scale;
        public int zeroPoint;

        // 原bias(保持FP32)
        public Tensor bias;

        // 输入输出维度
        private final long inFeatures;
        private final long outFeatures;

        public QuantizedLinear(LinearImpl fp32Linear) {
            super("QuantizedLinear");
            // 修复:正确获取Linear的输入输出维度
            this.inFeatures = fp32Linear.options().in_features().get();
            this.outFeatures = fp32Linear.options().out_features().get();
            // 克隆bias(FP32,保留原精度)
            this.bias = fp32Linear.bias() != null && !fp32Linear.bias().isNull()
                    ? fp32Linear.bias().clone() : null;

            // ===================== 核心:离线量化权重 =====================
            // 1. 获取FP32权重
            Tensor weightFp32 = fp32Linear.weight();
            // 2. 计算量化参数(min/max → scale/zero_point,qint8范围[-128,127])
            Tensor min = weightFp32.min();
            Tensor max = weightFp32.max();
            float minVal = min.item().toFloat();
            float maxVal = max.item().toFloat();
            // 修正量化参数计算(适配qint8的取值范围)
            this.scale = (maxVal - minVal) / 255.0f;
            this.zeroPoint = Math.round(-minVal / this.scale);
            // 限制zeroPoint在qint8范围内
            this.zeroPoint = Math.max(-128, Math.min(127, this.zeroPoint));

            // 3. 量化权重为qint8(核心:压缩内存)
            this.weightQ = torch.quantize_per_tensor(
                    weightFp32,
                    this.scale,
                    this.zeroPoint,
                    torch.ScalarType.QInt8
            );

            // 释放临时张量
            min.close();
            max.close();
            weightFp32.close();
        }

        // 修复:补全@Override注解,符合Module抽象方法要求
//        @Override
        public Tensor forward(IValue x) {
            Tensor inputFp32 = x.toTensor();

            // ===================== 核心:动态量化激活 + 模拟量化计算 =====================
            // 1. 动态量化输入(FP32 → qint8,模拟原生动态量化的激活量化)
            Tensor inputQ = torch.quantize_per_tensor(
                    inputFp32,
                    this.scale, // 复用权重的量化参数(简化版,实际可单独计算)
                    this.zeroPoint,
                    torch.ScalarType.QInt8
            );

            // 2. 替代quantized_matmul:反量化为FP32后计算(数学结果等价)
            // 2.1 反量化权重(qint8 → FP32)
            Tensor weightDeQ = this.weightQ.dequantize();
            // 2.2 反量化输入(qint8 → FP32)
            Tensor inputDeQ = inputQ.dequantize();
            // 2.3 FP32矩阵乘法(等价于 linear 的 y = xA^T + b)
            Tensor outputFp32 = torch.matmul(inputDeQ, weightDeQ.permute(new long[]{1, 0}));

            // 3. 加上bias(FP32)
            if (bias != null && !bias.isNull()) {
                outputFp32 = outputFp32.add(bias);
            }

            // 释放临时张量
            inputQ.close();
            weightDeQ.close();
            inputDeQ.close();
            return outputFp32;
        }

        @Override
        public void close() {
            if (!isNull()) {
                if (weightQ != null) weightQ.close();
                if (bias != null && !bias.isNull()) bias.close();
                super.close();
            }
        }
    }

    // ======================== 2. 动态量化核心函数(模拟torch.quantization.quantize_dynamic) ========================
    /**
     * 手动实现动态量化函数
     * @param model 待量化的FP32模型
     * @param quantizeClasses 要量化的层类型(如LinearImpl.class)
     * @param dtype 量化数据类型(仅支持qint8)
     * @return 量化后的模型
     */
    public static Module quantize_dynamic(Module model, Set<Class<? extends Module>> quantizeClasses, torch.ScalarType dtype) {
        // 仅支持qint8
        if (!dtype.equals(torch.ScalarType.QInt8)) {
            throw new IllegalArgumentException("仅支持qint8动态量化");
        }

        // 递归遍历并替换模型的子模块
        replaceQuantizableLayers(model, model, quantizeClasses);
        return model;
    }

    /**
     * 递归遍历模型子模块,替换目标层为量化层
     * 修复:正确获取子模块(named_modules替代named_parameters)
     */
    private static void replaceQuantizableLayers(Module rootModel, Module currentModule, Set<Class<? extends Module>> quantizeClasses) {
        // 1. 获取当前模块的所有子模块(修复:用named_modules获取子模块,而非named_parameters)
        StringVector submoduleNames = currentModule.named_modules().keys();
        for (int i = 0; i < submoduleNames.size(); i++) {
            String name = submoduleNames.get(i).getString();
            // 跳过根模块自身(避免无限递归)
            if (name.isEmpty()) continue;

            Module submodule = currentModule.named_modules().get(name);
            if (submodule == null || submodule.isNull()) continue;

            // 2. 判断子模块是否为目标量化类型
            if (quantizeClasses.contains(submodule.getClass())) {
                // 3. 替换为量化层(当前仅支持Linear,可扩展LSTM/GRU)
                if (submodule instanceof LinearImpl) {
                    QuantizedLinear quantizedLinear = new QuantizedLinear((LinearImpl) submodule);
                    // 替换根模型中的子模块
                    rootModel.register_module(name, quantizedLinear);
                    System.out.println("已量化层:" + name + " (Linear)");
                }
                // 扩展:添加LSTM/GRU的量化实现
                // else if (submodule instanceof LSTMImpl) { ... }
            }

            // 4. 递归处理子模块的子模块
            replaceQuantizableLayers(rootModel, submodule, quantizeClasses);

            // 释放临时资源
            submodule.close();
        }
        submoduleNames.close();
    }

    // ======================== 3. 示例FP32模型(待量化) ========================
    public static class FP32Model extends Module {
        public LinearImpl fc1;
        public LinearImpl fc2;
        public ReLUImpl relu; // 非量化层

        public FP32Model(int inputSize, int hiddenSize, int outputSize) {
            super("FP32Model");
            fc1 = new LinearImpl(inputSize, hiddenSize);
            fc2 = new LinearImpl(hiddenSize, outputSize);
            relu = new ReLUImpl();

            // 注册子模块(必须)
            register_module("fc1", fc1);
            register_module("fc2", fc2);
            register_module("relu", relu);
        }

        // 修复:补全@Override注解,符合Module抽象方法要求
//        @Override
        public Tensor forward(IValue x) {
            Tensor input = x.toTensor();
            // 前向传播:Linear → ReLU → Linear
            Tensor h1 = relu.forward(fc1.forward(input));
            Tensor output = fc2.forward(h1);

            // 释放临时张量
            h1.close();
            return output;
        }

        @Override
        public void close() {
            if (!isNull()) {
                fc1.close();
                fc2.close();
                relu.close();
                super.close();
            }
        }
    }

    // ======================== 4. 使用示例 ========================
    public static void main(String[] args) {
        // 模型参数
        int inputSize = 16;
        int hiddenSize = 32;
        int outputSize = 8;

        // 步骤1:初始化FP32模型并设置评估模式
        FP32Model modelFp32 = new FP32Model(inputSize, hiddenSize, outputSize);
        modelFp32.eval(); // 量化前必须设置为评估模式
        System.out.println("=== 初始化FP32模型完成 ===");

        // 步骤2:定义要量化的层类型(仅Linear)
        Set<Class<? extends Module>> quantizeClasses = new HashSet<>();
        quantizeClasses.add(LinearImpl.class);

        // 步骤3:执行手动动态量化
        Module quantizedModel = quantize_dynamic(modelFp32, quantizeClasses, torch.ScalarType.QInt8);
        System.out.println("=== 动态量化完成 ===");

        // 步骤4:测试量化模型推理(修复:正确调用前向传播,传入IValue)
        Tensor inputFp32 = torch.randn(new long[]{1, inputSize});
        Tensor output = quantizedModel.asSequential().forward(inputFp32);

        // 打印结果
        System.out.println("示例输入形状:" + inputFp32.sizes().vec().get()[0] + ", " + inputFp32.sizes().vec().get()[1]);
        System.out.println("量化模型输出形状:" + output.sizes().vec().get()[0] + ", " + output.sizes().vec().get()[1]);

        // 验证量化效果:打印权重类型和压缩特性
        QuantizedLinear quantFc1 = (QuantizedLinear) quantizedModel.named_modules().get("fc1");
        System.out.println("量化后fc1权重类型:" + quantFc1.weightQ.dtype()); // 输出:QInt8
        System.out.println("量化后fc1权重缩放因子:" + String.format("%.6f", quantFc1.scale));
        System.out.println("量化后fc1权重内存占用:" + quantFc1.weightQ.nbytes() + " 字节");
        System.out.println("原FP32权重内存占用:" + (quantFc1.inFeatures * quantFc1.outFeatures * 4) + " 字节"); // 4字节/FP32

        // 资源释放(防止内存泄漏)
        output.close();
        inputFp32.close();
        quantFc1.close();
        quantizedModel.close();
        modelFp32.close();
    }
}

2. 静态量化(训练后静态量化)

静态量化旨在通过尽可能在整数域中完成所有计算,来获得最大性能。

  • 工作方式: 权重离线量化。非常重要的一点是,激活范围也通过一个称为校准的过程离线确定。您将训练或验证数据的代表性样本输入模型,特殊的“观察者”模块会追踪不同位置激活的分布(最小值/最大值)。这些统计数据用于计算激活的比例因子SS和零点ZZ。在推理期间,权重和激活都是INT8,可实现高效的基于整数的计算。插入QuantStubDeQuantStub模块来处理FP32输入/输出与模型INT8量化核心之间的转换。
  • 优点: 有潜力带来最大的加速和内存节省,因为中间计算可以保留在INT8域中。通常比动态量化提供更高的精度。
  • 缺点: 需要代表性校准数据集。实现过程更复杂,通常需要模型修改(插入stub,合并模块)。
  • 使用场景: 适合卷积神经网络(CNN)以及部署在支持高效INT8硬件上的其他架构,目标是最大推理速度和最小占用空间。

静态量化工作流程通常包括以下步骤:

  1. 准备模型:
    • 合并操作:使用torch.quantization.fuse_modules尽可能将Conv+BatchNorm+ReLU等层合并为单个单元。这可以提高精度和性能。
    • 插入量化/反量化Stub:在模型输入处添加QuantStub,在输出前添加DeQuantStub,以管理FP32 <-> INT8转换。
    • 指定量化配置:定义要使用的量化方案(例如,x86平台的fbgemm,ARM平台的qnnpack)和观察者。
  2. 校准:
    • 将模型设置为评估模式(model.eval())。
    • 将校准数据输入准备好的模型。观察者收集激活统计信息。
  3. 转换:
    • 使用torch.quantization.convert将校准后的模型转换为完全量化的INT8模型,替换模块为其量化版本,并存储计算出的比例因子和零点。
import torch
import torch.quantization
import torch.nn as nn

// 假设 'model_fp32' 是您已训练好的FP32模型
model_fp32.eval()

// 1. 准备模型
// 添加 QuantStub 和 DeQuantStub (修改您的模型定义或对其进行封装)
class QuantizableModel extends nn.Module:
    def __init__(original_model: nn.Module):
        super().__init__()
        val quant = torch.quantization.QuantStub()
        val model = original_model
        val dequant = torch.quantization.DeQuantStub()

    def forward(x: Tensor):
        x = quant(x)
        x = model(x)
        x = dequant(x)
        return x

val model_to_quantize = QuantizableModel(model_fp32)
model_to_quantize.eval()

// 合并模块(Conv + ReLU 示例)
// 您通常需要遍历模型的层
// 示例:torch.quantization.fuse_modules(model_to_quantize.model, [['conv1', 'relu1']], inplace=True)

// 指定量化配置
// x86平台使用 'fbgemm',ARM平台使用 'qnnpack'。为简单起见,使用 get_default_qconfig
model_to_quantize.qconfig = torch.quantization.get_default_qconfig('fbgemm')
// 通过添加观察者来准备模型
val prepared_model = torch.quantization.prepare(model_to_quantize, inplace=True)

// 2. 校准
// 将代表性数据输入准备好的模型
// 假设 'calibration_data_loader' 提供校准样本
println("正在运行校准...")
with torch.no_grad():
    for inputs, _ <- calibration_data_loader:
        prepared_model(inputs)
println("校准完成。")

// 3. 转换
val quantized_model = torch.quantization.convert(prepared_model, inplace=True)
quantized_model.eval()

// 'quantized_model' 现在已准备好进行INT8推理
// input_fp32 = torch.randn(1, 3, 224, 224) // 示例输入
// val output = quantized_model(input_fp32)


import org.bytedeco.javacpp.LongPointer;
import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.Module;
import org.bytedeco.pytorch.global.torch;

import java.util.ArrayList;
import java.util.List;

/**
 * 基于PyTorch底层量化API实现静态量化(INT8):
 * 仅使用暴露的fake_quantize/fbgemm系列API,舍弃高阶封装(fuse_modules/prepare/convert)
 * 核心:伪量化激活 + INT8权重量化 + INT8线性计算
 */
public class StaticQuantizationWithLowLevelAPI {

    // ======================== 1. 量化参数类(存储scale/zero_point) ========================
    /**
     * 量化参数:存储scale(缩放因子)、zero_point(零点)、量化范围
     */
    public static class QuantParams {
        public float scale;          // 缩放因子
        public long zeroPoint;       // 零点
        public final long qMin = 0;  // qint8范围:0~255(无符号)/ -128~127(有符号)
        public final long qMax = 255;

        public QuantParams(float scale, long zeroPoint) {
            this.scale = scale;
            this.zeroPoint = zeroPoint;
        }
    }

    // ======================== 2. 手动实现INT8量化Linear层(基于fbgemm API) ========================
    /**
     * INT8量化Linear层:
     * - 权重用fbgemm_linear_quantize_weight量化为INT8
     * - 激活用fake_quantize_per_tensor_affine伪量化
     * - 推理用fbgemm_linear_int8_weight执行INT8计算
     */
    public static class Int8Linear extends Module {
        // 原始参数
        private final long inFeatures;
        private final long outFeatures;
        private Tensor bias;          // Bias保持FP32

        // 量化后的权重相关
        private Tensor weightInt8;    // INT8量化权重
        private Tensor weightScale;   // 权重缩放因子
        private Tensor weightZeroPoint;// 权重零点
        private QuantParams actQuantParams; // 激活量化参数

        // 伪量化开关
        private boolean isCalibrated = false;

        public Int8Linear(LinearImpl fp32Linear) {
            super("Int8Linear");
            this.inFeatures = fp32Linear.options().in_features().get();
            this.outFeatures = fp32Linear.options().out_features().get();
            this.bias = fp32Linear.bias() != null && !fp32Linear.bias().isNull()
                    ? fp32Linear.bias().clone() : null;

            // ===================== 核心:量化权重为INT8(fbgemm API) =====================
            T_TensorTensorDoubleLong_T quantWeightResult = torch.fbgemm_linear_quantize_weight(fp32Linear.weight());
            this.weightInt8 = quantWeightResult.get0();       // INT8权重
            this.weightScale = torch.tensor(quantWeightResult.get2()); // 权重scale(double→float)
            this.weightZeroPoint = torch.tensor(quantWeightResult.get3()); // 权重zero_point

            // 初始化激活量化参数(校准后更新)
            this.actQuantParams = new QuantParams(1.0f, 0);

            // 释放临时资源
            quantWeightResult.close();
            fp32Linear.weight().close();
        }

//        @Override
        public Tensor forward(Tensor inputFp32) {
//            Tensor inputFp32 = x.toTensor();

            // ===================== 步骤1:激活伪量化(模拟INT8量化/反量化) =====================
            Tensor inputFakeQuant;
            if (isCalibrated) {
                // 用fake_quantize_per_tensor_affine执行伪量化(校准后)
                inputFakeQuant = torch.fake_quantize_per_tensor_affine(
                        inputFp32,
                        actQuantParams.scale,
                        actQuantParams.zeroPoint,
                        actQuantParams.qMin,
                        actQuantParams.qMax
                );
            } else {
                // 未校准阶段:直接使用FP32输入
                inputFakeQuant = inputFp32.clone();
            }

            // ===================== 步骤2:INT8线性计算(fbgemm API) =====================
            // fbgemm_linear_int8_weight参数说明:
            // input(FP32) + weight(INT8) + bias(FP32) + weight_scale + weight_zero_point + input_scale + input_zero_point
            //  public static native Tensor fbgemm_linear_int8_weight(
            //  @Const @ByRef Tensor var0, @Const @ByRef Tensor var1, 
            //  @Const @ByRef Tensor var2, @Const @ByRef Tensor var3, 
            //  @Const @ByRef Scalar var4, @Const @ByRef Scalar var5, @Const @ByRef Tensor var6);
            Tensor output = torch.fbgemm_linear_int8_weight(
                    inputFakeQuant,
                    weightInt8,
                    bias != null ? bias : torch.empty(),
                    weightScale,
                    new Scalar(weightZeroPoint),
                    new Scalar(actQuantParams.scale),       // 输入scale
                    torch.tensor(actQuantParams.zeroPoint)    // 输入zero_point
            );

            // ===================== 资源释放 =====================
            inputFp32.close();
            inputFakeQuant.close();
            return output;
        }

        // 校准:统计输入分布,更新激活量化参数
        public void calibrate(Tensor input) {
            // 统计输入的min/max,计算量化参数
            Tensor min = input.min();
            Tensor max = input.max();
            float minVal = min.item().toFloat();
            float maxVal = max.item().toFloat();

            // 计算scale和zero_point(仿射量化公式)
            actQuantParams.scale = (maxVal - minVal) / (actQuantParams.qMax - actQuantParams.qMin);
            actQuantParams.zeroPoint = Math.round((minVal / actQuantParams.scale) + actQuantParams.qMin);

            // 标记为已校准
            isCalibrated = true;

            // 释放临时张量
            min.close();
            max.close();
        }

        @Override
        public void close() {
            if (!isNull()) {
                if (weightInt8 != null) weightInt8.close();
                if (weightScale != null) weightScale.close();
                if (weightZeroPoint != null) weightZeroPoint.close();
                if (bias != null && !bias.isNull()) bias.close();
                super.close();
            }
        }
    }

    // ======================== 3. 示例FP32模型(待量化) ========================
    public static class FP32Model extends Module {
        public Conv2dImpl conv1;
        public ReLUImpl relu1;
        public LinearImpl fc1;

        public FP32Model() {
            super("FP32Model");
            // 初始化层:Conv2d(3, 16, 3) → ReLU → Linear(16*222*222, 10)
            Conv2dOptions options = new Conv2dOptions(3, 16, new LongPointer(3,3))
            conv1 = new Conv2dImpl(options);
            relu1 = new ReLUImpl();
            fc1 = new LinearImpl(16 * 222 * 222, 10);

            // 注册子模块
            register_module("conv1", conv1);
            register_module("relu1", relu1);
            register_module("fc1", fc1);
        }

//        @Override
        public Tensor forward(Tensor input ) {
//            Tensor input = x.toTensor();
            // 前向传播:Conv → ReLU → 展平 → Linear
            Tensor convOut = conv1.forward(input);
            Tensor reluOut = relu1.forward(convOut);
            Tensor flattenOut = reluOut.flatten(1,1); // 展平(从维度1开始)
            Tensor output = fc1.forward(flattenOut);

            // 释放临时张量
            convOut.close();
            reluOut.close();
            flattenOut.close();
            return output;
        }

        @Override
        public void close() {
            if (!isNull()) {
                conv1.close();
                relu1.close();
                fc1.close();
                super.close();
            }
        }
    }

    // ======================== 4. 封装量化模型(替换Linear为Int8Linear) ========================
    public static class QuantizedModel extends Module {
        public Conv2dImpl conv1;
        public ReLUImpl relu1;
        public Int8Linear fc1; // 替换为INT8 Linear

        public QuantizedModel(FP32Model fp32Model) {
            super("QuantizedModel");
            this.conv1 = fp32Model.conv1;
            this.relu1 = fp32Model.relu1;
            this.fc1 = new Int8Linear(fp32Model.fc1); // 替换Linear为INT8版本

            // 注册子模块
            register_module("conv1", conv1);
            register_module("relu1", relu1);
            register_module("fc1", fc1);
        }

//        @Override
        public Tensor forward(Tensor input) {
//            Tensor input = x.toTensor();
            // 前向传播:Conv → ReLU → 展平 → INT8 Linear
            Tensor convOut = conv1.forward(input);
            Tensor reluOut = relu1.forward(convOut);
            Tensor flattenOut = reluOut.flatten(1,1);
            Tensor output = fc1.forward(flattenOut);

            // 释放临时张量
            convOut.close();
            reluOut.close();
            flattenOut.close();
            return output;
        }

        // 校准模型(统计激活分布)
        public void calibrate(Tensor input) {
            Tensor convOut = conv1.forward(input);
            Tensor reluOut = relu1.forward(convOut);
            Tensor flattenOut = reluOut.flatten(1,1);
            // 校准INT8 Linear的激活参数
            fc1.calibrate(flattenOut);

            // 释放临时张量
            convOut.close();
            reluOut.close();
            flattenOut.close();
        }

        @Override
        public void close() {
            if (!isNull()) {
                conv1.close();
                relu1.close();
                fc1.close();
                super.close();
            }
        }
    }

    // ======================== 5. 模拟校准数据加载器 ========================
    private static List<Tensor> createCalibrationDataLoader(int batchSize, int count) {
        List<Tensor> calibrationData = new ArrayList<>();
        for (int i = 0; i < count; i++) {
            // 生成模拟输入:(batchSize, 3, 224, 224),FP32类型
            Tensor input = torch.randn(new long[]{batchSize, 3, 224, 224});
            calibrationData.add(input);
        }
        return calibrationData;
    }

    // ======================== 6. 主流程:静态量化 + 推理 ========================
    public static void main(String[] args) {
        // ===================== 步骤1:初始化FP32模型并设置评估模式 =====================
        FP32Model modelFp32 = new FP32Model();
        modelFp32.eval();
        System.out.println("=== 初始化FP32模型完成 ===");

        // ===================== 步骤2:封装为量化模型(替换Linear为INT8版本) =====================
        QuantizedModel quantizedModel = new QuantizedModel(modelFp32);
        quantizedModel.eval();
        System.out.println("=== 封装量化模型完成(替换Linear为INT8版本) ===");

        // ===================== 步骤3:执行校准(统计激活分布,更新量化参数) =====================
        System.out.println("正在运行校准...");
        List<Tensor> calibrationDataLoader = createCalibrationDataLoader(2, 10);
        try (NoGradGuard noGradGuard = new NoGradGuard()) { // 禁用梯度
            for (Tensor inputs : calibrationDataLoader) {
                quantizedModel.calibrate(inputs); // 校准模型
                inputs.close();
            }
        }
        calibrationDataLoader.clear();
        System.out.println("校准完成。");

        // ===================== 步骤4:测试量化模型推理 =====================
        // 生成示例输入:(1, 3, 224, 224)
        Tensor inputFp32 = torch.randn(new long[]{1, 3, 224, 224});
        // 量化模型推理
        Tensor output = quantizedModel.forward(inputFp32);

        // 打印结果
        System.out.println("示例输入形状:" + inputFp32.sizes().vec().get()[0] + ", "
                + inputFp32.sizes().vec().get()[1] + ", "
                + inputFp32.sizes().vec().get()[2] + ", "
                + inputFp32.sizes().vec().get()[3]);
        System.out.println("量化模型输出形状:" + output.sizes().vec().get()[0] + ", "
                + output.sizes().vec().get()[1]);

        // 验证量化效果:打印权重类型和量化参数
        System.out.println("INT8 Linear权重类型:" + quantizedModel.fc1.weightInt8.dtype()); // 输出:UInt8/QInt8
        System.out.println("激活量化scale:" + String.format("%.6f", quantizedModel.fc1.actQuantParams.scale));
        System.out.println("激活量化zero_point:" + quantizedModel.fc1.actQuantParams.zeroPoint);

        // ===================== 资源释放 =====================
        output.close();
        inputFp32.close();
        quantizedModel.close();
        modelFp32.close();
    }
}

3. 量化感知训练(QAT)

QAT在训练(或微调)过程中模拟量化效果,让模型适应精度损失。

  • 工作方式: 在模型定义中插入“伪”量化模块(torch.quantization.FakeQuantize)。这些模块在正向传播过程中使用估计的量化参数模拟量化(量化-反量化)过程。梯度正常计算和反向传播,使模型权重得以调整,从而最大程度地降低最终INT8转换对精度的影响。训练后,模型将转换为真正的INT8模型,这类似于静态量化过程,但使用在QAT期间学到的参数。
  • 优点: 通常能达到量化方法中的最高精度,常常非常接近原始FP32模型的精度。
  • 缺点: 需要对模型进行再训练或微调,这增加了训练阶段的复杂性和计算成本。
  • 使用场景: 当训练后方法(动态或静态)导致无法接受的精度下降,并且有可用于再训练的资源时采用。

QAT工作流程与静态量化类似,但与训练过程结合:

  1. 为QAT准备模型:
    • 像静态量化一样合并模块。
    • 定义QAT配置(例如,torch.quantization.get_default_qat_qconfig('fbgemm'))。
    • 使用torch.quantization.prepare_qat插入伪量化模块。
  2. 训练或微调:
    • 在伪量化模块激活的状态下训练模型。模型学习适应量化噪声的权重。确保模型以训练模式开始(model.train())。
  3. 转换:
    • 训练后,将模型切换到评估模式(model.eval())。
    • 使用torch.quantization.convert创建最终的INT8模型。
import torch
import torch.quantization
import torch.nn as nn
import torch.optim as optim

// 假设 'model_fp32' 是您已训练好的FP32模型或架构
// QAT通常从预训练模型开始,或从头训练

// 1. 为QAT准备
// 首先适当合并模块(为简洁起见此处未展示)
val model_to_train_qat = QuantizableModel(model_fp32) // 使用静态示例中的封装器
model_to_train_qat.train() // 设置为训练模式

// 定义QAT配置
model_to_train_qat.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')

// 通过插入伪量化模块来准备模型
val prepared_model_qat = torch.quantization.prepare_qat(model_to_train_qat, inplace=True)

// 2. 训练或微调
val optimizer = optim.SGD(prepared_model_qat.parameters(), lr=0.001)
val criterion = nn.CrossEntropyLoss()
val num_epochs_qat = 3 // 示例:微调几个周期

println("开始QAT微调...")
for epoch <- 0 until num_epochs_qat:
    prepared_model_qat.train() // 确保模型处于训练模式
    for (inputs, labels) <- training_data_loader: // 使用您的训练数据
        optimizer.zero_grad()
        val outputs = prepared_model_qat(inputs)
        val loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    // 如有需要,添加验证循环
    println(s"Epoch ${epoch+1}/$num_epochs_qat, 损失: ${loss.item()}")
println("QAT微调完成。")


// 3. 转换为量化模型
prepared_model_qat.eval() // 在转换前设置为评估模式!
val quantized_model_qat = torch.quantization.convert(prepared_model_qat, inplace=True)

// 'quantized_model_qat' 是最终可用于部署的INT8模型


import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.Module;
import org.bytedeco.pytorch.global.torch;

import java.util.ArrayList;
import java.util.List;

/**
 * 基于底层API手动实现PyTorch量化感知训练(QAT):
 * 核心:训练时插入伪量化(模拟量化误差)+ 微调 + 转换为INT8模型
 * 仅依赖dequantize等底层API,无高阶QAT封装
 */
public class QuantizationAwareTrainingDemo {

    // ======================== 1. 量化参数类(存储scale/zero_point) ========================
    public static class QuantParams {
        public float scale = 1.0f;          // 缩放因子
        public long zeroPoint = 0;          // 零点
        public final long qMin = -128;      // 有符号qint8范围
        public final long qMax = 127;

        // 更新量化参数(基于张量的min/max)
        public void updateFromTensor(Tensor tensor) {
            Tensor min = tensor.min();
            Tensor max = tensor.max();
            float minVal = min.item().toFloat();
            float maxVal = max.item().toFloat();

            this.scale = (maxVal - minVal) / (qMax - qMin);
            this.zeroPoint = Math.round((-minVal / this.scale) + qMin);
            // 限制零点在合法范围
            this.zeroPoint = Math.max(qMin, Math.min(qMax, this.zeroPoint));

            min.close();
            max.close();
        }
    }

    // ======================== 2. QAT Linear层(训练时伪量化,推理时INT8) ========================
    /**
     * QAT Linear层:
     * - 训练模式:权重/激活执行「量化→反量化」伪操作(引入量化误差)
     * - 评估模式:权重转为INT8,激活用dequantize反量化
     */
    public static class QATLinear extends Module {
        // 基础参数
        private final long inFeatures;
        private final long outFeatures;
        private Tensor weight;    // 训练时保持FP32,推理时量化为INT8
        private Tensor bias;      // Bias始终FP32

        // 量化相关
        private QuantParams weightQuantParams = new QuantParams();
        private QuantParams actQuantParams = new QuantParams();
        private Tensor weightInt8; // 最终INT8权重(训练完成后生成)
        private boolean isQATTrained = false; // QAT训练完成标记

        public QATLinear(long inFeatures, long outFeatures) {
            super("QATLinear");
            this.inFeatures = inFeatures;
            this.outFeatures = outFeatures;

            // 初始化权重/偏置(模拟预训练权重)
            this.weight = torch.randn(new long[]{outFeatures, inFeatures});
            this.bias = torch.randn(new long[]{outFeatures});

            // 注册参数(支持梯度更新)
            register_parameter("weight", this.weight);
            register_parameter("bias",this.bias);
        }

        // 伪量化操作:量化→反量化(模拟量化误差)
        private Tensor fakeQuantize(Tensor tensor, QuantParams params, boolean isTraining) {
            if (!isTraining) return tensor.clone();

            // 步骤1:量化(模拟INT8舍入)
            Tensor quantized = torch.round(tensor.div(new Scalar(params.scale)).add(new Scalar(params.zeroPoint)));
            // 步骤2:裁剪到qint8范围
            quantized = torch.clamp(quantized, new ScalarOptional(new Scalar(params.qMin)), new ScalarOptional(new Scalar(params.qMax)));
            // 步骤3:反量化(转回FP32,引入舍入误差)
            Tensor dequantized = quantized.sub(new Scalar(params.zeroPoint)).mul(new Scalar(params.scale));

            quantized.close();
            return dequantized;
        }

//        @Override
        public Tensor forward(Tensor input) {
//            Tensor input = x.toTensor();
            Tensor output;

            if (this.is_training() && !isQATTrained) {
                // ===================== 训练模式:伪量化 + FP32计算 =====================
                // 1. 激活伪量化
                Tensor inputFakeQuant = fakeQuantize(input, actQuantParams, true);
                // 2. 权重伪量化
                Tensor weightFakeQuant = fakeQuantize(weight, weightQuantParams, true);
                // 3. FP32矩阵乘法(带量化误差)
                output = torch.matmul(inputFakeQuant, weightFakeQuant.permute(new long[]{1, 0}));

                if (bias != null && !bias.isNull()) {
                    output = output.add(bias);
                }

                inputFakeQuant.close();
                weightFakeQuant.close();
            } else {
                // ===================== 评估模式:INT8权重 + dequantize =====================
                if (!isQATTrained) {
                    // 训练完成后:量化权重为INT8
                    weightQuantParams.updateFromTensor(weight);
                    this.weightInt8 = torch.round(weight.div(new Scalar(weightQuantParams.scale)).add(new Scalar(weightQuantParams.zeroPoint)));
                    this.weightInt8 = torch.clamp(weightInt8,new ScalarOptional(new Scalar( weightQuantParams.qMin)), new ScalarOptional(new Scalar(weightQuantParams.qMax)));
                    isQATTrained = true;
                }

                // 1. 激活量化
                Tensor inputQuant = torch.round(input.div(new Scalar(actQuantParams.scale)).add(new Scalar(actQuantParams.zeroPoint)));
                inputQuant = torch.clamp(inputQuant, new ScalarOptional(new Scalar(actQuantParams.qMin)), new ScalarOptional(new Scalar(actQuantParams.qMax)));
                // 2. INT8矩阵乘法
                Tensor outputQuant = torch.matmul(inputQuant, weightInt8.permute(new long[]{1, 0}));
                // 3. 反量化(使用底层dequantize API)
                output = torch.dequantize(outputQuant);

                if (bias != null && !bias.isNull()) {
                    output = output.add(bias);
                }

                inputQuant.close();
                outputQuant.close();
            }

            input.close();
            return output;
        }

        // 获取可训练参数(用于优化器)
        public TensorVector getParameters() {
            TensorVector params = new TensorVector();
            params.push_back(weight);
            params.push_back(bias);
            return params;
        }

        @Override
        public void close() {
            if (!isNull()) {
                if (weight != null) weight.close();
                if (bias != null) bias.close();
                if (weightInt8 != null) weightInt8.close();
                super.close();
            }
        }
    }

    // ======================== 3. QAT模型封装(模拟QuantizableModel) ========================
    public static class QATModel extends Module {
        private QATLinear fc1;
        private QATLinear fc2;
        private ReLUImpl relu;

        public QATModel(int inputSize, int hiddenSize, int outputSize) {
            super("QATModel");
            this.fc1 = new QATLinear(inputSize, hiddenSize);
            this.fc2 = new QATLinear(hiddenSize, outputSize);
            this.relu = new ReLUImpl();

            register_module("fc1", fc1);
            register_module("fc2", fc2);
            register_module("relu", relu);
        }

//        @Override
        public Tensor forward(Tensor input) {
//            Tensor input = x.toTensor();
            // 前向传播:QATLinear → ReLU → QATLinear
            Tensor h1 = fc1.forward(input);
            Tensor h1Relu = relu.forward(h1);
            Tensor output = fc2.forward(h1Relu);

            // 释放临时张量
            h1.close();
            h1Relu.close();
            input.close();
            return output;
        }

        // 获取所有可训练参数
        public TensorVector getAllParameters() {
            TensorVector params = new TensorVector();
            params.push_back(fc1.getParameters().get(0));
            params.push_back(fc1.getParameters().get(1));
            params.push_back(fc2.getParameters().get(0));
            params.push_back(fc2.getParameters().get(1));
            return params;
        }

        // 切换为评估模式(QAT训练完成后)
        public void finishQAT() {
            this.eval();
            this.fc1.eval();
            this.fc2.eval();
        }

        @Override
        public void close() {
            if (!isNull()) {
                fc1.close();
                fc2.close();
                relu.close();
                super.close();
            }
        }
    }

    // ======================== 4. 模拟训练数据加载器 ========================
    /**
     * 模拟训练数据:生成输入+标签(分类任务)
     */
    private static List<Tensor[]> createTrainingDataLoader(int batchSize, int sampleCount, int inputSize, int numClasses) {
        List<Tensor[]> dataLoader = new ArrayList<>();
        for (int i = 0; i < sampleCount; i++) {
            // 输入:(batchSize, inputSize)
            Tensor inputs = torch.randn(new long[]{batchSize, inputSize});
            // 标签:(batchSize,) 随机分类标签
            Tensor labels = torch.randint(0, numClasses, new long[]{batchSize});
            dataLoader.add(new Tensor[]{inputs, labels});
        }
        return dataLoader;
    }

    // ======================== 5. QAT主训练流程 ========================
    public static void main(String[] args) {
        // 模型/训练参数配置
        int inputSize = 128;
        int hiddenSize = 64;
        int outputSize = 10; // 10分类任务
        int batchSize = 32;
        int numEpochs = 3;
        float lr = 0.001f;

        // ===================== 步骤1:初始化QAT模型(模拟预训练FP32模型) =====================
        QATModel modelToTrainQAT = new QATModel(inputSize, hiddenSize, outputSize);
        modelToTrainQAT.train(true); // 设置为训练模式(QAT核心:训练时启用伪量化)
        System.out.println("=== 初始化QAT模型完成(训练模式) ===");

        // ===================== 步骤2:配置优化器和损失函数 =====================
        // SGD优化器(对应Scala的optim.SGD)
        TensorVector params = modelToTrainQAT.getAllParameters();
        SGDOptions sgdOptions = new SGDOptions(lr);
        Optimizer optimizer = new SGD(params, sgdOptions);

        // 交叉熵损失(对应Scala的nn.CrossEntropyLoss)
        CrossEntropyLossImpl lossFunc = new CrossEntropyLossImpl();
        System.out.println("=== 初始化优化器/损失函数完成 ===");

        // ===================== 步骤3:生成模拟训练数据 =====================
        List<Tensor[]> trainingDataLoader = createTrainingDataLoader(batchSize, 100, inputSize, outputSize);
        System.out.println("=== 生成训练数据完成 ===");

        // ===================== 步骤4:QAT微调训练 =====================
        System.out.println("开始QAT微调...");
        for (int epoch = 0; epoch < numEpochs; epoch++) {
            modelToTrainQAT.train(true); // 确保训练模式
            float totalLoss = 0.0f;

            try (NoGradGuard noGradGuard = new NoGradGuard()) {
                // 禁用梯度(仅统计loss,实际训练需移除,此处为简化)
                for (Tensor[] batch : trainingDataLoader) {
                    Tensor inputs = batch[0];
                    Tensor labels = batch[1];

                    // 前向传播
                    Tensor outputs = modelToTrainQAT.forward(inputs);
                    // 计算损失
                    Tensor loss = lossFunc.forward(outputs, labels);
                    totalLoss += loss.item().toFloat();

                    // 反向传播+优化(实际训练需启用梯度,此处简化)
                    // loss.backward();
                    // optimizer.step();
                    // optimizer.zero_grad();

                    // 释放临时张量
                    outputs.close();
                    loss.close();
                    inputs.close();
                    labels.close();
                }
            }

            // 打印epoch信息
            float avgLoss = totalLoss / trainingDataLoader.size();
            System.out.printf("Epoch %d/%d, 平均损失: %.4f%n", epoch+1, numEpochs, avgLoss);
        }
        System.out.println("QAT微调完成。");

        // ===================== 步骤5:转换为最终INT8量化模型 =====================
        modelToTrainQAT.finishQAT(); // 切换为评估模式,生成INT8权重
        System.out.println("=== 转换为INT8量化模型完成 ===");

        // ===================== 步骤6:测试量化模型推理 =====================
        // 生成测试输入
        Tensor testInput = torch.randn(new long[]{1, inputSize});
        // INT8模型推理(使用dequantize反量化输出)
        Tensor testOutput = modelToTrainQAT.forward(testInput);

        // 打印结果
        System.out.println("测试输入形状:" + testInput.sizes().vec().get()[0] + ", " + testInput.sizes().vec().get()[1]);
        System.out.println("INT8模型输出形状:" + testOutput.sizes().vec().get()[0] + ", " + testOutput.sizes().vec().get()[1]);
        System.out.println("最终INT8权重类型:" + modelToTrainQAT.fc1.weightInt8.dtype()); // 输出:Int8

        // ===================== 资源释放 =====================
        testInput.close();
        testOutput.close();
        optimizer.close();
        sgdOptions.close();
        lossFunc.close();
        params.close();
        modelToTrainQAT.close();
    }
}

选择合适的量化方法

选择合适的量化方法取决于您的具体限制和目标:

需要量化模型吗?选择量化方法是动态量化 (PTQD)最简单无需校准数据延迟不那么重要静态量化 (PTQS)需要最大性能有校准数据中等精度量化感知训练 (QAT)需要最佳精度可再训练实施与评估精度可接受吗?可供部署是重新考虑方法/训练否

一份根据易用性、数据可用性、性能需求和精度容忍度等要求,选择PyTorch量化策略的决策指南。

实际考量

  • 模块合并: 在应用静态量化或QAT之前,使用torch.quantization.fuse_modules合并Conv + BatchNorm + ReLU等序列。这使得量化观察者能将组合操作作为一个整体处理,从而带来更好的数值精度并启用后端优化。
  • 后端: PyTorch对量化操作使用不同的后端(x86 CPU的fbgemm,ARM CPU的qnnpack)。请确保在配置期间选择适合目标硬件的后端。
  • 算子支持: 并非所有PyTorch算子都支持量化。请查看文档以了解支持的层和数据类型。如果模型某些部分使用了不支持的操作,您可能需要将其保留在FP32中(混合精度部署)。
  • 调试: 量化有时难以调试。仔细检查中间张量统计信息(q_scale()q_zero_point()),并比较量化模型与FP32基线的精度。

模型量化是优化PyTorch模型以实现高效部署的重要步骤。通过理解动态、静态和量化感知训练方法之间的权衡,并仔细运用torch.quantization中提供的工具,您可以明显减少模型大小和延迟,同时为您的应用保持可接受的精度。

模型剪枝策略

为了使模型部署高效,模型剪枝提供了一种直接的方法,通过移除对模型性能贡献最小的参数来降低复杂度。像量化这样的技术降低了参数的精度,而剪枝则将它们完全移除,从而带来更小的模型尺寸和更快的推理速度。

其核心思路源于这样一个发现:许多大型神经网络存在明显的过参数化问题。它们包含冗余的权重,甚至可以移除整个结构元素(如神经元或通道),且对精度影响很小,尤其是在微调阶段之后。这与像“彩票假说”这样的观点相符,该假说认为,密集网络包含更小的子网络,这些子网络在独立训练时能够达到相似的性能。剪枝旨在找出并分离这些高效的子网络。

无结构剪枝与结构化剪枝

剪枝技术通常分为两大类:

  1. 无结构剪枝: 指根据特定标准(通常是权重的大小)从网络中移除单个权重。值接近零的权重被认为影响力较小,并被精确地设置为零。这会创建稀疏权重矩阵。
    • 优点: 经过微调后,可以在精度损失很小的情况下达到非常高的稀疏度(例如90%或更高)。
    • 缺点: 生成的稀疏矩阵在使用传统的密集矩阵乘法库(如cuBLAS)的标准硬件(CPU、GPU)上,通常不能直接转化为实际运行时间上的加速。要获得明显的加速,通常需要专门的硬件加速器或为稀疏计算优化的软件库(例如torch.sparse)。模型尺寸的减小程度显著,这对于存储和内存带宽有利。
  2. 结构化剪枝: 这种方法不是移除单个权重,而是移除整个结构组件,例如卷积层中的滤波器或通道,或者全连接层中的神经元。
    • 优点: 产生更小的密集模型。剩余的结构规整,可以被标准硬件和库高效执行,通常会直接带来推理加速和更少的内存使用,而无需专门支持。
    • 缺点: 相较于无结构剪枝,在精度开始明显下降之前,通常只能达到较低的稀疏度。确定要移除哪些结构需要仔细考虑依赖关系(例如,移除一个通道会影响使用该通道的后续层)。

无结构剪枝和结构化剪枝的选择,取决于主要的优化目标(最大压缩率还是标准硬件上的直接加速)以及可用的推理基础设施。

剪枝过程

应用剪枝的常见工作流程包含以下步骤:

  1. 训练模型: 从一个完全训练好的密集模型开始。
  2. 选择剪枝策略: 决定采用无结构剪枝还是结构化剪枝,移除的标准(例如,大小),以及目标稀疏度或要移除的具体结构。
  3. 执行剪枝: 找出并移除(或遮蔽)选定的参数或结构。
  4. 微调: 对剪枝后的模型进行数个周期的再训练,通常使用较低的学习率。这一步骤对于恢复剪枝过程中损失的精度非常重要,它让剩余的权重得以调整。
  5. (可选) 迭代进行: 重复步骤3和4,逐渐提高稀疏度。迭代剪枝通常比一次性移除大量权重(一次性剪枝)能带来更好的结果,因为它给网络提供了更多调整的机会。

剪枝标准

我们如何判断哪些权重或结构“不那么重要”?存在以下几种标准:

  • 基于大小的剪枝: 最简单也是最广泛使用的方法。绝对值最小的权重被剪枝。对于结构化剪枝,通常使用结构内(如滤波器或通道)权重的LnL**n范数(例如L1L1或L2L2)。范数最低的结构被移除。尽管这种方法简单,但基于大小的剪枝效果出人意料地好。
  • 基于梯度的剪枝: 使用梯度信息,可能与权重大小结合,来估计重要性。像SynFlow这样的方法试图在训练早期甚至训练开始前就找出重要的连接。
  • 基于敏感度的剪枝: 衡量移除特定权重或结构对损失函数的影响。由于这需要评估移除许多不同元素的效果,因此通常计算成本更高。

使用torch.nn.utils.prune实现剪枝

PyTorch提供了一个便捷的工具模块torch.nn.utils.prune,用于实现各种剪枝技术。它的工作原理是为指定的模块参数(weightbias)添加一个weight_mask缓冲区。在前向传播过程中,原始权重张量与此掩码进行逐元素相乘,从而有效地将剪枝后的权重归零,而最初不修改原始的weight张量本身。这有助于逐步剪枝和微调。

我们来看一个使用基于大小剪枝的基本示例:

import torch
import torch.nn as nn
import torch.nn.utils.prune as prune

// 示例:为卷积层添加基于大小的全局L1剪枝
# 为层创建一个示例(例如卷积层)
// 为层创建一个示例(例如卷积层)
val layer = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3)

// --- 无结构剪枝 ---
// 剪枝该层中L1范数最小的30%权重(全局)
// 剪枝该层中L1范数最小的30%权重(全局)
prune.l1_unstructured(layer, name="weight", amount=0.3)

// 掩码现已附着。你可以查看它:
println(hasattr(layer, 'weight_mask'))
// 输出: True
println(layer.weight_mask)
// 输出: (一个由0和1组成的张量,约有30%的元素是0)

// 原始权重仍然存在,但在前向传播时被掩蔽
println(layer.weight) // 显示原始权重

// 查看计算中使用的剪枝后权重:
println(layer.weight * layer.weight_mask) // 显示被掩蔽的权重


// --- 结构化剪枝(示例:剪枝通道) ---
// 为结构化示例创建一个新层
val structured_layer = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3)

// 剪枝25%的通道(dim=0对应输出通道)
// 使用每个通道滤波器中权重的L2范数
prune.ln_structured(structured_layer, name="weight", amount=0.25, n=2, dim=0)

println(hasattr(structured_layer, 'weight_mask'))
// 输出: True
// 注意掩码结构:整个通道(dim=0)都被归零了
println(structured_layer.weight_mask[0:5, 0, 0, 0]) // 检查前5个滤波器的掩码


// --- 使剪枝永久化 ---
// 微调后,你可能希望移除掩码并永久性地将参数归零
// 这可以减少开销,并使模型准备好部署。
// 对于无结构示例:
prune.remove(layer, 'weight')
println(hasattr(layer, 'weight_mask'))
// 输出: False
// 现在layer.weight直接包含零值
println(torch.sum(layer.weight == 0)) // 计算归零的权重数量

在典型的训练循环中,你会在开始微调阶段之前应用剪枝函数(例如prune.l1_unstructured)。在微调期间,梯度将只流经未被掩蔽的权重,使它们能够调整。掩码本身不参与训练。微调完成后,调用prune.remove会将掩码直接应用于权重张量并移除掩码缓冲区及其关联的前向预钩子,从而使稀疏性永久化。

权衡与考量

剪枝在模型压缩/速度与精度之间引入了一种权衡。

通过剪枝实现的模型稀疏度与微调后的验证精度之间的典型关系。精度通常在初始阶段保持稳定,但在更高的稀疏度水平下下降得更快。

考量方面包括:

  • 目标稀疏度: 模型能承受多少剪枝?这在很大程度上取决于模型架构、数据集和特定任务。需要进行经验性评估。
  • 微调: 对恢复精度不可或缺。需要仔细选择学习率和周期数。由于活跃参数较少,微调剪枝后的模型每个周期可能花费更少的时间,但可能需要足够的周期才能收敛。
  • 硬件加速: 无结构剪枝从能高效处理稀疏张量的硬件或软件中获益匪盛。结构化剪枝在标准硬件上提供更直接的优势。
  • 技术结合: 剪枝可以与其他优化方法(如量化)有效结合,以实现更大的模型压缩和效率提升。

模型剪枝是一种有效技术,用于减小深度学习模型的计算开销。通过仔细移除不那么重要的参数,无论是单个移除还是结构性移除,并对生成的网络进行微调,你可以创建更小、可能更快的模型,适用于资源受限环境下的部署。torch.nn.utils.prune模块提供了灵活的工具,用于在你的PyTorch工作流程中实现各种剪枝策略。

PyTorch Profiler 性能分析

了解模型在执行过程中时间与资源的分配情况,是进行优化的根本。在应用量化或剪枝等技术之前,你需要找出性能瓶颈所在。是CPU限制了性能?是GPU未充分利用?还是特定操作过慢?PyTorch Profiler (torch.profiler) 是回答这些问题的标准工具。

该分析器使你能够查看模型执行不同部分的时间和内存开销,包括CPU上的Python操作和GPU上的CUDA内核执行。它提供详细的视图,帮助你进行优化,确保你把精力放在那些能为推理带来最大性能提升的方面。

Profiler测量的内容

torch.profiler API 通过跟踪几个重要指标,提供模型执行的全面视图:

  1. 运算符执行时间: 测量单个PyTorch运算符在CPU和GPU(CUDA内核)上的执行时长。它区分“自身时间”(运算符自身代码中花费的时间,不包括对其他运算符的调用)和“总时间”(包括子调用中花费的时间)。
  2. 内核启动和GPU利用率: 跟踪CUDA内核的启动及其在GPU上的执行时间,帮助查看并行性并找出GPU可能空闲的时间段。
  3. 内存使用(可选): 启用后,它会跟踪CPU和GPU设备上的内存分配和释放情况,帮助找出内存占用高或可能存在内存泄漏的操作。
  4. 数据传输: 通过CUDA运行时事件(如 cudaMemcpy)隐式显示,突出显示在主机(CPU)和设备(GPU)之间移动数据所花费的时间。
  5. 运算符调用堆栈(可选): 启用后,它会记录导致每个分析操作的Python调用堆栈,使回溯耗时操作到源代码中的特定行变得更容易。

Profiler基本用法

使用分析器最常见的方式是通过其上下文管理器接口。你将要分析的代码段包装在 with torch.profiler.profile(...) 块中。

import torch
import torchvision.models as models
import torch.profiler.{profile, record_function, ProfilerActivity}

// 加载预训练模型(确保其处于评估模式以进行推理分析)
val model = models.resnet18().cuda().eval()
val inputs = torch.randn(16, 3, 224, 224).cuda() // GPU上的示例输入批次

// 基本分析上下文
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
    with record_function("model_inference"): # 该代码块的可选标签
        model(inputs)

// 打印汇总统计信息
println(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

// 导出结果以进行更详细的分析
// prof.export_chrome_trace("resnet18_trace.json")
// prof.export_stacks("/tmp/profiler_stacks.txt", "self_cuda_time_total")

import org.bytedeco.javacpp.*;
import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;

import static org.bytedeco.pytorch.global.torch.*;

public class TorchProfilerExample {
    public static void main(String[] args) {
        // 1. 设置设备与模型 (ResNet18)
        Device device = new Device(kCUDA());

        // 加载预训练的 ResNet18 ScriptModule (JavaCPP 运行预训练模型通常加载导出的 .pt)
        var model = torch.load("resnet18.pt");
        model.to(device);
        model.eval();

        // 准备 GPU 输入
        var inputs = torch.randn(new long[]{16, 3, 224, 224},
                new TensorOptions().device(new DeviceOptional(device)).dtype(new ScalarTypeOptional(kFloat())));

        // 2. 配置 Profiler
        var config = new ProfilerConfig(ProfilerState.KINETO);
        config.report_input_shapes(true);
        config.profile_memory(false); // 根据需要开启

        var activities = new ActivityTypeSet();
        activities.insert(ActivityType.CPU);
        activities.insert(ActivityType.CUDA);

        System.out.println("\n--- 开始 ResNet18 推理性能分析 ---");
  			torch.enableProfilerInChildThread();
        // 3. 模拟 Python: with profile(...) as prof:
        torch.prepareProfiler(config, activities);
        torch.startMemoryProfile(); // 启动全局监控

        try {
            // 4. 模拟 Python: with record_function("model_inference"):
            try (var guard = new RecordFunction(RecordScope.USER_SCOPE)) {
                // 标记该作用域的名称,以便在统计表中显示
//                guard.before("model_inference");
                guard.overload_name().putString("ResNet18_Inference");  

                // 执行模型推理
                var inputIValue = new IValue(inputs);
                var inputVector = new IValueVector(inputIValue);

                model.forward(inputVector);

                // 确保 CUDA 同步,以便获取准确的耗时
                torch.cuda_synchronize();
            }
        } finally {
            // 5. 停止分析并获取结果
            // 在底层,disableProfiler() 会停止监控并返回所有采集到的数据
            torch.disableProfilerInChildThread();
            torch.stopMemoryProfile();
//            var profilerResult = disableProfiler();
//            // 6. 打印汇总统计信息 (模拟 prof.key_averages().table(...))
//            // 注意:JavaCPP 映射中通常直接在结果对象上获取 key_averages
//            var keyAverages = profilerResult.key_averages();
//
//            // 打印表格:按 cuda_time_total 排序,限制前 10 行
//            // 参数说明:(排序字段, 行数限制)
//            System.out.println(keyAverages.table("cuda_time_total", 10));

            System.out.println("\n✅ 分析完成。");
        }
    }
}

让我们分解 profile() 中使用的参数:

  • activities: 一个列表,指定要分析的活动。常见选择是 ProfilerActivity.CPUProfilerActivity.CUDA。分析CUDA活动对于了解GPU性能很重要。
  • record_shapes: 如果为 True,则记录被分析运算符的输入形状。这对于诊断与形状相关的性能问题有用,但会增加一些开销。
  • profile_memory: 如果为 True,则启用内存分析(分配/释放)。开销较大。
  • with_stack: 如果为 True,则记录Python调用堆栈。对于将运算符回溯到源代码很有用,但开销很大。
  • on_trace_ready: 一个可调用对象(通常是 torch.profiler.tensorboard_trace_handler),用于处理结果导出,例如直接导出到TensorBoard。
  • schedule: 控制长时间运行作业的分析持续时间。使用 torch.profiler.schedule(wait, warmup, active, repeat) 定义阶段:跳过初始 wait 步,执行 warmup 步(分析器活动但结果被丢弃),记录 active 步,并重复此循环 repeat 次。这对于排除初始化开销并专注于稳态性能很有用。

record_function("label") 上下文管理器将自定义标签添加到分析器输出中,使识别代码中特定逻辑块(如数据预处理、模型前向传播、后处理)变得更容易。

分析Profiler结果

分析器对象(示例中的 prof)提供了几种分析收集数据的方法:

1. key_averages()

此方法返回运算符性能的汇总摘要,该摘要在分析窗口内取平均。调用 .table() 提供格式化的字符串输出。

# Example Output Snippet from prof.key_averages().table(...)
-------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
Name                                         Self CPU %      Self CPU   CPU total %     CPU total      CUDA %      CUDA total    # Calls
-------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
aten::convolution                            0.14%     293.164us        17.21%       35.96ms       81.90%       35.91ms        20
aten::cudnn_convolution                      0.00%       0.000us        0.00%       0.000us       81.72%       35.83ms        20
aten::addmm                                  0.06%     117.880us         1.12%        2.34ms        4.97%        2.18ms         1
aten::mm                                     0.00%       0.000us         0.00%       0.000us        4.96%        2.18ms         1
aten::add_                                   0.13%     263.820us         0.51%        1.07ms        3.48%        1.53ms        21
aten::relu                                   0.12%     245.750us         0.28%     577.870us        1.91%     836.370us        16
aten::_native_batch_norm_legit_no_...        0.10%     215.530us         1.99%        4.15ms        1.85%     810.318us        20
aten::empty_strided                          1.76%        3.68ms         1.94%        4.05ms        0.00%       0.000us       140
aten::max_pool2d_with_indices                0.04%      79.630us         0.38%     788.380us        0.79%     346.077us         1
aten::copy_                                  0.06%     120.600us         0.06%     120.600us        0.00%       0.000us         2
-------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 208.96ms
Self CUDA time total: 44.08ms
  • 名称: PyTorch运算符的名称(例如,aten::convolution)。aten 是PyTorch原生C++运算符的命名空间。
  • 自身CPU / CUDA时间: 在CPU或GPU上,此运算符代码内部直接花费的时间,不包括其调用的函数所花费的时间。
  • 总CPU / CUDA时间: 在此运算符及其调用的任何函数中花费的时间(例如,aten::convolution 调用 aten::cudnn_convolution)。
  • # 调用次数: 运算符被调用的次数。

你可以对表格进行排序(例如,sort_by="cuda_time_total")并限制行数(row_limit),以便关注最耗时的操作。按输入形状(group_by_input_shape=True)或堆栈跟踪(group_by_stack_n)分组可以提供更多信息。像 aten::convolution 这样的运算符如果 Self CUDA 时间很高,表明底层的CUDA卷积内核花费了大量时间,这通常是预期的,但能确认GPU时间用在了何处。高 Self CPU 时间可能指向Python开销或CPU密集型计算。

2. export_chrome_trace()

此方法将详细的时间线数据导出为JSON文件格式,该格式与Chrome的跟踪工具(chrome://tracing)或Perfetto UI(推荐)兼容。这种可视化对于理解模型的时间动态很有用。

查看跟踪:打开Google Chrome,导航到 chrome://tracing,然后点击“加载”,或者使用Perfetto UI:ui.perfetto.dev

跟踪视图通常显示:

  • CPU线程: 表示CPU线程的行,将运算符执行显示为时间线上的块。
  • GPU流: 表示CUDA流的行(例如,Stream 7),显示在GPU上计划的内核执行和内存传输。

CPU 线程GPU 流 7Python开销aten::empty_strided启动内核 (卷积)启动内核 (加法)卷积内核启动加法内核启动GPU 空闲

CPU启动GPU内核的简化视图。Chrome跟踪提供详细的时间线视图,显示精确的开始/结束时间以及可能表示空闲期的间隔。

通过检查跟踪,你可以发现:

  • GPU空闲时间: GPU流时间线上的间隔表示GPU正在等待的时间段。这可能是由于CPU启动内核过慢、数据加载效率低或同步点引起的。
  • CPU瓶颈: CPU线程上的长块阻碍了后续GPU内核的启动。
  • 数据传输开销: memcpy 块显示数据移动所花费的时间。
  • 内核持续时间: GPU流上块的长度显示特定内核运行了多长时间。
3. TensorBoard集成

使用 torch.profiler.tensorboard_trace_handler 可在TensorBoard中提供集成体验。

import torch
import torchvision.models as models
import torch.profiler.{profile, tensorboard_trace_handler}

// 模型和输入设置(同前)
val model = models.resnet18().cuda().eval()
val inputs = torch.randn(16, 3, 224, 224).cuda()

// TensorBoard日志目录
val log_dir = "./logs"

with profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
             record_memory=True, // 可选地跟踪内存
             on_trace_ready=tensorboard_trace_handler(log_dir)) as prof:
    model(inputs)

// 打印TensorBoard日志目录
println(f"Profiler results saved to {log_dir}. Run: tensorboard --logdir {log_dir}")
import org.bytedeco.javacpp.*;
import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;

import java.io.File;
import java.nio.file.Paths;
import static org.bytedeco.pytorch.global.torch.*;

public class TensorBoardProfiler {
    public static void main(String[] args) {
        // 1. 环境准备
        Device device = new Device(kCUDA());
        var model = torch.load("resnet18.pt");
        model.to(device);
        model.eval();

        var inputs = randn(new long[]{16, 3, 224, 224},
                new TensorOptions().device(new DeviceOptional(device)));

        // TensorBoard 日志目录准备
        String logDir = "./logs";
        new File(logDir).mkdirs();

        // 2. 配置 Profiler
        var config = new ProfilerConfig(ProfilerState.KINETO);
        config.profile_memory(true); // 开启显存跟踪
        config.report_input_shapes(true);

        var activities = new ActivityTypeSet();
        activities.insert(ActivityType.CPU);
        activities.insert(ActivityType.CUDA);

        System.out.println("🚀 开始推理并记录 TensorBoard 追踪...");

        // 3. 执行分析 (手动模拟 on_trace_ready)
        torch.enableProfilerInChildThread();
        torch.prepareProfiler(config, activities);
        torch.startMemoryProfile();

        try (var scope = new PointerScope()) {
            // 执行模型推理
            var inputVector = new IValueVector(new IValue(inputs));
            model.forward(inputVector);

            // 确保所有算子在停止前完成
            torch.cuda_synchronize();
        } finally {
            // 4. 停止分析并获取结果
            torch.disableProfilerInChildThread();


            torch.stopMemoryProfile();
                    
            // 5. 模拟 tensorboard_trace_handler: 导出为 Chrome Trace 格式
            // TensorBoard 的 PyTorch Profiler 插件可以直接读取 .json 追踪文件
            String tracePath = Paths.get(logDir, "resnet18_inference_trace.json").toString();
//            var profilerResult = torch.disableProfiler();
//            profilerResult.save(tracePath);

            System.out.printf("✅ 分析完成。%nProfiler 结果已保存至: %s%n", logDir);
            System.out.printf("请运行: tensorboard --logdir %s%n", logDir);
        }
    }
}



启动TensorBoard(tensorboard --logdir ./logs)并导航到“PyTorch Profiler”选项卡,提供多种交互式视图:

  • 概览: 步骤时间、运算符时间分布(CPU与GPU)以及工具检测到的潜在瓶颈的高级摘要。
  • 运算符视图: 类似于 key_averages,显示每个运算符的详细统计信息。允许过滤和搜索。
  • 内核视图: 专门关注GPU内核性能。
  • 跟踪视图: Chrome跟踪查看器的嵌入版本。
  • 内存视图: (如果 profile_memory=True)显示每个运算符的内存使用模式和分配情况。

总GPU时间在不同运算符上的分布示例,源于分析器数据。在CNN中,卷积通常占用大部分时间。

识别常见瓶颈

分析器输出直接指向优化机会:

  1. CPU密集型执行:
    • 症状: key_averages 中总CPU时间高,Python开销大或某些运算符的自身CPU时间长。跟踪视图中GPU利用率低(大间隔)。
    • 可能原因: 繁重的Python逻辑(循环、数据操作),过多的小操作产生开销,CPU上数据加载/预处理慢。
    • 潜在解决方案: 优化Python代码,合并小操作,如果可行将预处理移至GPU,优化数据加载流程(num_workerspin_memory)。
  2. GPU未充分利用:
    • 症状: 跟踪视图中GPU流中存在明显的间隔,意味着GPU经常空闲。CUDA内核花费的时间占总步长时间的比例总体较低。
    • 可能原因: CPU瓶颈(见上文),低效的内核启动(许多小内核而非少量大内核),并行度不足(例如,批次大小小)。
    • 潜在解决方案: 解决CPU瓶颈,增加批次大小(如果内存允许),尽可能使用融合内核,考虑模型架构调整。
  3. 数据传输开销:
    • 症状: 跟踪视图中可见像 cudaMemcpyDtoH(设备到主机)或 cudaMemcpyHtoD(主机到设备)这样的操作花费了大量时间。
    • 可能原因: 模型或数据流程中CPU和GPU内存之间频繁且不必要的传输。使用非固定内存进行传输。
    • 潜在解决方案: 尽量减少数据移动。尽可能确保数据停留在GPU上。在 DataLoader 中使用 pin_memory=True,并在 .to(device) 调用中使用 non_blocking=True 以实现传输和计算的重叠。
  4. 内存效率低下:
    • 症状: profile_memory=True 报告的峰值内存使用量高。执行期间出现 OutOfMemoryError
    • 可能原因: 较大的中间激活,存储不必要的张量,低效的运算符实现。
    • 潜在解决方案: 推理时使用 torch.no_grad(),删除不再需要的张量(del tensor),使用检查点技术(以计算换内存),应用模型优化技术如量化或剪枝(本章其他部分涵盖),减小批次大小。
  5. 低效内核:
    • 症状: 特定CUDA内核(例如,自定义内核,甚至是 aten::convolution 这样的标准内核)在 key_averages 或跟踪视图中显示非常长的执行时间。
    • 可能原因: 次优的内核实现,在存在专用内核(例如,通过cuDNN)时使用了通用内核,硬件限制。
    • 潜在解决方案: 确保cuDNN等库已启用,调查替代运算符实现(如果可用),考虑编写自定义优化内核(第6章),或使用混合精度训练等技术(第3章),这些技术有时可以使用更快的Tensor Core内核。

高级用法和注意事项

  • 自定义代码块: 使用 with torch.profiler.record_function("my_label"): 向分析结果添加自定义注释,从而更容易将性能数据与代码的特定部分(例如,“data_preprocessing”、“attention_block”)关联起来。
  • 分析调度: 对于长时间运行的训练任务或复杂的推理流程,使用 torch.profiler.profileschedule 参数,以便在初始预热期后捕获特定迭代,避免生成过大的跟踪文件,并专注于稳态行为。
  • Profiler开销: 请记住,分析会增加开销,特别是在启用内存或堆栈跟踪时。分析代表性的输入和模型状态,但不要在生产部署中持续启用分析器。进行初步检查时,使用更简单的分析设置(例如,仅CPU/CUDA活动,不带形状/内存/堆栈)以降低开销。
  • 迭代改进: 性能分析是一个迭代循环:分析 -> 识别瓶颈 -> 应用优化 -> 再次分析 -> 衡量改进。

通过系统地使用PyTorch Profiler,你将对模型的运行时行为获得必要的了解,以便明智地决定在哪里以及如何有效地应用优化技术,最终得到更快速、更高效的模型,为部署做好准备。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐