在这里插入图片描述

章节 9: PyTorch 内部机制与自动求导

为了高效处理复杂任务,理解 PyTorch 的内部运作机制会有帮助。本章着重讲解基本构成要素:PyTorch 张量、自动求导(autograd)机制以及将它们联系起来的计算图。

我们将研究张量的结构和内存管理方式。你会了解到操作执行时 PyTorch 如何动态构建计算图,以及 autograd 引擎如何遍历这些图来计算梯度,例如损失 LL 对权重 ww 的偏导数 ∂L∂w∂wL

主要内容包括:

  • torch.Tensor 的内部结构。
  • 动态计算图的创建和使用方式。
  • autograd 引擎在反向传播过程中的逐步操作。
  • 通过 torch.autograd.Function 定义 forwardbackward 方法来实现自定义操作。
  • 计算高阶梯度。
  • 检查梯度和可视化计算图的方法。
  • PyTorch 中高效内存使用的注意事项。

熟悉这些核心组成部分对于调试复杂模型、优化性能以及实现标准库之外的自定义功能都十分有益。最后,我们将通过一个实践练习来构建自己的 autograd 函数。

张量实现细节

收藏

创建和操作PyTorch张量通常感觉很直观,但了解其内部结构有助于编写高效代码、调试复杂行为以及构建自定义操作。torch.Tensor不仅仅是一个多维数组;它是一个精密的D对象,包含定义内存中原始数值数据如何被解释的元数据。

张量对象与内存存储

本质上,每个PyTorch张量都持有一个指向torch.Storage对象的引用。可以将torch.Storage视为一个连续的、一维的特定类型(例如,float32int64)数值数据数组。Tensor对象本身不直接包含数值,而是持有描述如何在与其相关联的Storage中查看数据的元数据。

这种分离很重要,因为多个张量可以共享相同的内存存储。切片、转置或重塑等操作通常会创建具有不同元数据的新Tensor对象,但它们指向由Storage管理的相同内存块。这使得这些操作非常节省内存,因为它们通常不涉及数据复制。

import torch.*

// 创建一个张量
val x = torch.arange(12, dtype=torch.float32)
println(f"原始张量 x: {x}")

// 存储是一个包含12个浮点数的一维数组
println(f"存储元素: {x.storage().tolist()}")
println(f"存储类型: {x.storage().dtype}")
println(f"存储大小: {len(x.storage())}")

// 通过重塑创建视图
val y = x.view(3, 4)
println(f"\n重塑后的张量 y:\n{y}")

// y 具有不同的形状/步幅,但共享相同的存储
println(f"y 是否与 x 共享存储? {y.storage().data_ptr() == x.storage().data_ptr()}")

// 修改视图会影响原始张量(反之亦然)
y(0, 0) = 99.0
println(f"\n修改后的 y:\n{y}")
println(f"修改 y 后的原始 x: {x}")


import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;

import java.util.Arrays;

/**
 * PyTorch张量存储共享+视图修改演示Java实现:
 * 1. 创建一维浮点张量,打印原始张量和存储信息
 * 2. 重塑为二维视图,验证视图与原张量共享存储
 * 3. 修改视图值,验证原张量同步变化(存储共享特性)
 */
public class TensorStorageShareDemo {

    public static void main(String[] args) {
        // ======================== 1. 创建原始张量(等效Scala torch.arange(12, dtype=torch.float32)) ========================
        // arange参数:start=0, end=12, step=1, dtype=Float, device=CPU
        var options = new TensorOptions()
                .dtype(new ScalarTypeOptional(torch.ScalarType.Float))
                .device(new DeviceOptional(new Device(torch.DeviceType.CPU)));
        Tensor x = torch.arange(new Scalar(0), new Scalar(12), new Scalar(1),  options);
        System.out.println("原始张量 x: " + tensorToString(x));

        // ======================== 2. 打印张量存储信息(等效Scala x.storage()相关操作) ========================
        // 获取张量存储
        Storage storageX = x.storage();
        // 存储元素转为列表(等效Scala storage().tolist())
//        float[] storageElements = getStorageFloatArray(storageX);
//        System.out.println("存储元素: " + Arrays.toString(storageElements));
        // 存储数据类型(等效Scala storage().dtype)
        torch.ScalarType storageDtype = x.scalar_type();
        System.out.println("存储类型: " + storageDtype);
        // 修复2:Storage.size()不存在,改用nbytes()/元素字节数计算存储元素数
        // 浮点型(Float)每个元素占4字节,nbytes()是总字节数 → 元素数 = nbytes / 4
        long storageBytes = storageX.nbytes();
        long storageElementCount = storageBytes / 4; // Float: 4 bytes per element
        System.out.println("存储总字节数: " + storageBytes);
        System.out.println("存储元素数(等效len(storage)): " + storageElementCount);

        // 修复3:读取存储数据(FloatPointer转换)
        float[] storageElements = getStorageFloatArray(storageX, (int) storageElementCount);
        System.out.println("存储元素: " + Arrays.toString(storageElements));
        
//        // 存储大小(等效Scala len(x.storage()))
//        long storageSize = storageX.size();
//        System.out.println("存储大小: " + storageSize);

        // ======================== 3. 重塑张量为视图(等效Scala x.view(3, 4)) ========================
        Tensor y = x.view(new long[]{3, 4});
        System.out.println("\n重塑后的张量 y:\n" + tensorToString(y));

        // ======================== 4. 验证存储共享(等效Scala y.storage().data_ptr() == x.storage().data_ptr()) ========================
        Storage storageY = y.storage();
        boolean isShareStorage = storageY.data_ptr() == storageX.data_ptr();
        System.out.println("y 是否与 x 共享存储? " + isShareStorage);

        // ======================== 5. 修改视图值,验证原张量同步变化 ========================
        // 等效Scala y(0, 0) = 99.0:修改y[0][0]位置的值
        y.put(torch.tensor(new long[]{0, 0}), torch.tensor(99.0f));
        System.out.println("\n修改后的 y:\n" + tensorToString(y));
        System.out.println("修改 y 后的原始 x: " + tensorToString(x));

        // ======================== 6. 资源释放 ========================
        x.close();
        y.close();
        storageX.close();
        storageY.close();
    }

    /**
     * 辅助方法:将Tensor转为可读字符串(模拟Scala的张量打印格式)
     */
    private static String tensorToString(Tensor tensor) {
        // 1. 获取张量数据数组
        float[] data = new float[(int) tensor.numel()];
        tensor.data().get(torch.tensor(data));
        // 2. 获取张量形状
        long[] shape = tensor.sizes().vec().get();
        // 3. 拼接为可读字符串(适配一维/二维张量)
        StringBuilder sb = new StringBuilder();
        if (shape.length == 1) {
            // 一维张量:[0.0, 1.0, ...]
            sb.append("[");
            for (int i = 0; i < data.length; i++) {
                sb.append(String.format("%.1f", data[i]));
                if (i < data.length - 1) sb.append(", ");
            }
            sb.append("]");
        } else if (shape.length == 2) {
            // 二维张量:分行打印
            int rows = (int) shape[0];
            int cols = (int) shape[1];
            sb.append("[");
            for (int i = 0; i < rows; i++) {
                sb.append("[");
                for (int j = 0; j < cols; j++) {
                    sb.append(String.format("%.1f", data[i * cols + j]));
                    if (j < cols - 1) sb.append(", ");
                }
                sb.append("]");
                if (i < rows - 1) sb.append(",\n ");
            }
            sb.append("]");
        }
        return sb.toString();
    }

    /**
     * 辅助方法:将Float类型的Storage转为float数组(等效Scala storage().tolist())
     */
//    private static float[] getStorageFloatArray(Storage storage) {
//        float[] arr = new float[(int) storage.size()];
//        // 读取存储数据到数组
//        storage.data().get(arr);
//        return arr;
//    }
    /**
     * 修复版:读取Float类型Storage的元素(关键:FloatPointer转换)
     */
    private static float[] getStorageFloatArray(Storage storage, int elementCount) {
        float[] arr = new float[elementCount];
        // 核心修复:将通用Pointer转为FloatPointer,再读取到float数组
        FloatPointer floatPtr = (FloatPointer)storage.data();
        floatPtr.get(arr); // 正确读取浮点数据
        return arr;
    }
}

在上面的例子中,xy是不同的Tensor对象,但由于y是通过reshape创建的视图,它们共享相同的内存存储。修改y中的元素也会改变通过x可见的相应元素。

张量元数据

除了对其Storage的引用之外,Tensor对象还维护多项元数据,这些数据定义了其属性和数据解释方式:

  1. 设备(device): 指定张量数据所在的设备,可以是CPU(torch.device('cpu'))或特定的GPU(torch.device('cuda:0'))。张量间的操作通常需要数据在相同设备上。在设备之间移动数据(例如,使用.to(device))涉及内存复制,这可能影响性能。
  2. 数据类型(dtype): 定义张量中元素的数值类型,例如torch.float32torch.int64torch.bool。操作通常要求张量具有兼容的数据类型,并且数据类型的选择显著影响内存使用和数值精度。
  3. 形状(shapesize()): 一个表示张量维度的元组。例如,一个3x4矩阵的形状是(3, 4)
  4. 存储偏移(storage_offset()): 一个整数,表示此张量数据在内存存储中开始的索引。对于直接创建的张量(非视图),这通常是0。切片可能具有非零偏移量。
  5. 步幅(stride()): 这也许是理解内存布局最重要的元数据。步幅是一个元组,其中第 ii 个元素指定了沿张量第 ii 维度移动一步所需在内存(Storage中的元素数量)中的步长。

考虑一个3x4的张量t

// 创建一个张量
val t = torch.arange(12, dtype=torch.float32).view(3, 4)
println(f"张量 t:\n{t}")
println(f"形状: {t.shape}")
println(f"步幅: {t.stride()}")

import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;
import org.bytedeco.javacpp.FloatPointer;

import java.util.Arrays;

/**
 * 创建二维张量并打印形状/步幅(适配指定的arange调用格式)
 * 核心:严格使用 Scalar 参数调用 arange,完整复刻张量创建→重塑→属性打印流程
 */
public class TensorShapeStrideDemo {

    public static void main(String[] args) {
        // ======================== 1. 配置张量参数(float32类型) ========================
        // 定义dtype=float32的张量选项(等效Scala的dtype=torch.float32)
        TensorOptions options = new TensorOptions().dtype(torch.ScalarType.Float);

        // ======================== 2. 创建一维张量(严格匹配指定的arange调用格式) ========================
        // 等效Scala:torch.arange(12, dtype=torch.float32)
        // 参数对应:start=0, end=12, step=1, options(指定dtype)
        Tensor t1d = torch.arange(new Scalar(0), new Scalar(12), new Scalar(1), options);

        // ======================== 3. 重塑为二维张量(3行4列) ========================
        // 等效Scala:.view(3, 4)
        Tensor t = t1d.view(new long[]{3, 4});

        // ======================== 4. 打印张量及属性(复刻Scala输出格式) ========================
        System.out.println("张量 t:");
        System.out.println(tensorToString(t)); // 格式化打印二维张量

        // 打印形状(等效Scala t.shape)
        long[] shape = t.sizes().get();
        System.out.println("形状: " + Arrays.toString(shape));

        // 打印步幅(等效Scala t.stride())
        long[] stride = t.stride().get();
        System.out.println("步幅: " + Arrays.toString(stride));

        // ======================== 5. 资源释放 ========================
        t1d.close();
        t.close();
        options.close();
    }

    /**
     * 辅助方法:将二维Float张量转为可读字符串(模拟Scala的张量打印格式)
     */
    private static String tensorToString(Tensor tensor) {
        // 1. 获取张量基本信息
        long[] shape = tensor.sizes().get();
        int rows = (int) shape[0];
        int cols = (int) shape[1];
        int totalElements = rows * cols;

        // 2. 读取Float类型张量数据(转为float数组)
        float[] data = new float[totalElements];
        FloatPointer floatPtr = tensor.data().asFloatPointer();
        floatPtr.get(data);

        // 3. 拼接为二维格式字符串
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < rows; i++) {
            sb.append("[");
            for (int j = 0; j < cols; j++) {
                sb.append(String.format("%.1f", data[i * cols + j]));
                if (j < cols - 1) {
                    sb.append(", ");
                }
            }
            sb.append("]");
            if (i < rows - 1) {
                sb.append("\n");
            }
        }
        return sb.toString();
    }
}

输出:

张量 t:
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.]])
形状: torch.Size([3, 4])
步幅: (4, 1)

步幅(4, 1)表示:

  • 沿维度0(向下行)移动一步,你需要在内存一维Storage中跳跃4个元素。(例如,从元素0到元素4)。
  • 沿维度1(横跨列)移动一步,你需要在内存一维Storage中跳跃1个元素。(例如,从元素0到元素1)。

步幅确定了多维张量如何映射到线性Storage

内存布局:连续与非连续

如果张量的元素在Storage中的排列顺序与标准的C风格(行主序)遍历顺序相同,则该张量在内存中被认为是连续的。对于连续张量,步幅通常遵循一种模式:最后一个维度的步幅是1,倒数第二个维度的步幅是最后一个维度的大小,以此类推。对于我们上面3x4的张量t,步幅是(4, 1),这符合这种模式(stride[1] == 1stride[0] == shape[1] == 4),因此它是连续的。

// 检查张量是否连续
println(f"t 是否连续? {t.is_contiguous()}") // 输出: True

然而,转置等操作可以创建非连续张量(视图)。

// 转置操作
val t_transposed = t.t()
println(f"\n转置后的张量 t_transposed:\n{t_transposed}")
println(f"形状: {t_transposed.shape}")
println(f"步幅: {t_transposed.stride()}")
println(f"t_transposed 是否连续? {t_transposed.is_contiguous()}") // 输出: False
println(f"t_transposed 是否与 t 共享存储? {t_transposed.storage().data_ptr() == t.storage().data_ptr()}") // 输出: True

import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;
import org.bytedeco.javacpp.FloatPointer;
import java.util.Arrays;

/**
 * 张量转置操作演示:
 * 1. 创建二维张量 → 转置 → 打印转置后属性
 * 2. 验证转置张量的连续性、存储共享特性
 * 3. 严格适配JavaCPP-PyTorch 1.5.13-SNAPSHOT API
 */
public class TensorTransposeDemo {

    public static void main(String[] args) {
        // ======================== 1. 创建原始二维张量 ========================
        TensorOptions options = new TensorOptions().dtype(new ScalarTypeOptional(torch.ScalarType.Float));
        Tensor t1d = torch.arange(new Scalar(0), new Scalar(12), new Scalar(1), options);
        Tensor t = t1d.view(new long[]{3, 4});

        System.out.println("原始张量 t:");
        System.out.println(tensorToString(t));
        System.out.println("原始形状: " + Arrays.toString(t.sizes().vec().get()));
        System.out.println("原始步幅: " + Arrays.toString(t.strides().vec().get()));

        // ======================== 2. 转置操作(等效Scala t.t()) ========================
        // t.t() 是二维张量的转置(等价于 transpose(0,1))
        Tensor tTransposed = t.t();

        // ======================== 3. 打印转置后张量及属性 ========================
        System.out.println("\n转置后的张量 t_transposed:");
        System.out.println(tensorToString(tTransposed));
        System.out.println("形状: " + Arrays.toString(tTransposed.sizes().vec().get()));
        System.out.println("步幅: " + Arrays.toString(tTransposed.strides().vec().get()));

        // 验证连续性(等效Scala t_transposed.is_contiguous())
        boolean isContiguous = tTransposed.is_contiguous();
        System.out.println("t_transposed 是否连续? " + isContiguous); // 输出: false

        // 验证存储共享(等效Scala t_transposed.storage().data_ptr() == t.storage().data_ptr())
        Storage storageT = t.storage();
        Storage storageTTransposed = tTransposed.storage();
        boolean isShareStorage = storageTTransposed.data_ptr().address() == storageT.data_ptr().address();
        System.out.println("t_transposed 是否与 t 共享存储? " + isShareStorage); // 输出: true

        // ======================== 4. 资源释放 ========================
        t1d.close();
        t.close();
        tTransposed.close();
        options.close();
        storageT.close();
        storageTTransposed.close();
    }

    /**
     * 辅助方法:将二维Float张量转为可读字符串
     */
    private static String tensorToString(Tensor tensor) {
        long[] shape = tensor.sizes().vec().get();
        int rows = (int) shape[0];
        int cols = (int) shape[1];
        int totalElements = rows * cols;

        float[] data = new float[totalElements];
        FloatPointer floatPtr = (FloatPointer) tensor.data().data_ptr_float();
        floatPtr.get(data);

        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < rows; i++) {
            sb.append("[");
            for (int j = 0; j < cols; j++) {
                sb.append(String.format("%.1f", data[i * cols + j]));
                if (j < cols - 1) {
                    sb.append(", ");
                }
            }
            sb.append("]");
            if (i < rows - 1) {
                sb.append("\n");
            }
        }
        return sb.toString();
    }
}

输出:

转置后的张量 t_transposed:
tensor([[ 0.,  4.,  8.],
        [ 1.,  5.,  9.],
        [ 2.,  6., 10.],
        [ 3.,  7., 11.]])
形状: torch.Size([4, 3])
步幅: (1, 4)
t_transposed 是否连续? False
t_transposed 是否与 t 共享存储? True

请注意,t_transposed的形状是(4, 3),但其步幅是(1, 4)。沿维度0(在转置视图中向下行)移动,在原始存储中跳跃1个元素。沿维度1(横跨列)移动,跳跃4个元素。这种布局不是C语言连续的。

为什么连续性很重要?

  • 性能: 许多PyTorch操作(特别是低层CPU/GPU核)针对连续张量进行了优化。当数据在内存中按顺序排列时,它们可以更高效地处理数据。非连续张量可能触发内部内存复制或使用较慢的算法。
  • 兼容性: 某些操作,如view(),要求张量是连续的。如果你尝试在非连续张量(如t_transposed)上使用view(),会得到错误。在这种情况下,你通常需要使用reshape(),如果可能它会返回一个视图,但如果需要满足形状改变,它会返回一个副本。或者,你可以使用.contiguous()方法显式创建一个连续的副本。
// 这会引发 RuntimeError,因为 t_transposed 不连续
// flat_view = t_transposed.view(-1)

// .contiguous() 在需要时会创建一个具有连续内存布局的新张量
val t_contiguous_copy = t_transposed.contiguous()
println(f"\n连续副本是否连续? {t_contiguous_copy.is_contiguous()}") // 输出: True
println(f"连续副本的步幅: {t_contiguous_copy.stride()}") // 输出: (3, 1)
println(f"存储共享? {t_contiguous_copy.storage().data_ptr() == t_transposed.storage().data_ptr()}") // 输出: False (这是一个副本)

// 现在视图可以工作了
val flat_view = t_contiguous_copy.view(-1)
println(f"展平视图: {flat_view}")

import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;
import org.bytedeco.javacpp.FloatPointer;

import java.util.Arrays;

/**
 * 张量连续性修复+展平视图演示:
 * 1. 验证非连续张量view报错 → 用contiguous创建连续副本
 * 2. 验证连续副本的属性(连续性/步幅/存储独立性)
 * 3. 对连续副本执行展平视图操作
 */
public class TensorContiguousDemo {

    public static void main(String[] args) {
        // ======================== 1. 创建原始张量并转置 ========================
        TensorOptions options = new TensorOptions().dtype(new ScalarTypeOptional(torch.ScalarType.Float));
        Tensor t1d = torch.arange(new Scalar(0), new Scalar(12), new Scalar(1), options);
        Tensor t = t1d.view(new long[]{3, 4});
        Tensor tTransposed = t.t();

        // ======================== 2. 验证非连续张量view会报错 ========================
        System.out.println("=== 验证非连续张量view报错 ===");
        try {
            // 等效Scala: flat_view = t_transposed.view(-1) → 会抛出RuntimeError
            Tensor flatViewError = tTransposed.view(new long[]{-1});
            System.out.println("非连续张量view执行成功(预期报错): " + flatViewError);
        } catch (Exception e) {
            System.err.println("捕获到预期的RuntimeError: " + e.getMessage());
        }

        // ======================== 3. 创建连续副本(contiguous) ========================
        // 等效Scala: t_contiguous_copy = t_transposed.contiguous()
        Tensor tContiguousCopy = tTransposed.contiguous();

        // ======================== 4. 打印连续副本的属性 ========================
        System.out.println("\n=== 连续副本属性 ===");
        System.out.println("连续副本是否连续? " + tContiguousCopy.is_contiguous()); // 输出: true
        System.out.println("连续副本的步幅: " + Arrays.toString(tContiguousCopy.strides().vec().get())); // 输出: [3, 1]
        System.out.println("连续副本形状: " + Arrays.toString(tContiguousCopy.sizes().vec().get())); // 输出: [4, 3]

        // 验证存储是否共享(连续副本是新存储,不共享)
        Storage storageTTransposed = tTransposed.storage();
        Storage storageContiguous = tContiguousCopy.storage();
        boolean isShareStorage = storageContiguous.data_ptr().address() == storageTTransposed.data_ptr().address();
        System.out.println("连续副本与原转置张量存储共享? " + isShareStorage); // 输出: false

        // ======================== 5. 对连续副本执行展平视图 ========================
        // 等效Scala: flat_view = t_contiguous_copy.view(-1)
        Tensor flatView = tContiguousCopy.view(new long[]{-1});
        System.out.println("\n展平视图: " + tensorToString(flatView));

        // ======================== 6. 资源释放 ========================
        t1d.close();
        t.close();
        tTransposed.close();
        tContiguousCopy.close();
        flatView.close();
        options.close();
        storageTTransposed.close();
        storageContiguous.close();
    }

    /**
     * 通用辅助方法:将任意维度Float张量转为可读字符串
     */
    private static String tensorToString(Tensor tensor) {
        long[] shape = tensor.sizes().vec().get();
        int totalElements = (int) tensor.numel(); // 总元素数(适配任意维度)

        // 读取Float类型数据
        float[] data = new float[totalElements];
        FloatPointer floatPtr = tensor.data().data_ptr_float();
        floatPtr.get(data);

        // 拼接字符串(适配一维/二维)
        StringBuilder sb = new StringBuilder();
        if (shape.length == 1) {
            // 一维(展平视图)
            sb.append("[");
            for (int i = 0; i < data.length; i++) {
                sb.append(String.format("%.1f", data[i]));
                if (i < data.length - 1) sb.append(", ");
            }
            sb.append("]");
        } else if (shape.length == 2) {
            // 二维
            int rows = (int) shape[0];
            int cols = (int) shape[1];
            for (int i = 0; i < rows; i++) {
                sb.append("[");
                for (int j = 0; j < cols; j++) {
                    sb.append(String.format("%.1f", data[i * cols + j]));
                    if (j < cols - 1) sb.append(", ");
                }
                sb.append("]");
                if (i < rows - 1) sb.append("\n");
            }
        }
        return sb.toString();
    }
}

以下图表说明了两个张量,T(原始3x4)和T_transpose(其转置),如何将其元素映射到相同的底层一维存储块。请注意步幅如何决定不同的访问模式。

01234567891011torch.Storage (一维浮点数组)张量Tshape=(3, 4)stride=(4, 1)offset=0contiguous=TrueT[0,0]T[0,1]T[1,0]T[1,1]T[2,0]张量T_转置shape=(4, 3)stride=(1, 4)offset=0contiguous=FalseT_T[0,0]T_T[0,1]T_T[0,2]T_T[1,0]T_T[1,1]

张量元数据与内存存储的关系,适用于3x4张量T及其转置T_transpose。两个张量对象都指向相同的存储,但根据它们的形状、步幅和偏移量以不同方式解释它。箭头表示元素如何从张量视图映射到存储索引。

理解这些实现细节,Tensor元数据与Storage的区别,步幅的作用,以及连续性的思想,为推断PyTorch中内存使用、性能特点和各种张量操作行为提供了坚实的支持。当优化瓶颈或与低层代码交互时,这些知识尤其有用。

理解计算图

收藏

PyTorch 自动微分功能的主要组成部分是计算图。它并非预先定义的静态结构;相反,PyTorch 在对张量执行操作时动态地构建它。可以将其看作一个有向无环图(DAG),其中节点代表张量或操作,边代表数据流和功能依赖关系。

理解这个图是根本所在,因为 autograd 引擎在反向传播过程中正是遍历它来使用链式法则计算梯度。涉及跟踪梯度的张量的每个操作,都在幕后有助于构建这个图结构。

动态图与静态图

像 TensorFlow 1.x 或 Theano 这样的框架采用的是静态计算图。在这些系统中,你首先定义整个图结构,编译它,然后用不同的输入数据执行它,可能多次执行。这种“先定义后运行”的方法在执行前允许进行重要的图级别优化。

相反,PyTorch 采用的是动态计算图方法,通常被称为“边运行边定义”。该图是隐式地、逐个操作地构建的,随着你的 Python 代码执行而形成。如果你的模型前向传播中包含循环或条件语句(例如 if 块),图结构实际上可以根据所采取的执行路径在不同迭代之间发生变化。

动态图的优点:

  1. 灵活性: 动态图本质上更灵活。具有依赖于中间结果的控制流结构(循环、条件)的模型更容易实现和理解。
  2. 调试: 调试通常更直接。由于图是作为标准 Python 代码运行时构建的,你可以直接在模型执行流程中使用熟悉的 Python 调试工具(例如 pdb 或打印语句)来检查中间值或图连接性。

权衡:

尽管极其灵活,但边运行边定义的特性可能对某些在静态图环境中更简单的整体图优化带来挑战。然而,PyTorch 通过 TorchScript(第 4 章介绍)等工具弥补了这一点,这些工具允许图捕获和优化。

构建图:grad_fn 属性

PyTorch 实际如何跟踪操作以构建这个图?当你对一个 requires_grad=True 的张量执行操作时,生成的输出张量会自动获得对其创建函数的引用。

这个引用存储在输出张量的 grad_fn 属性中。

让我们用一个简单例子来说明:

import torch

// 需要梯度的输入张量
val a = torch.tensor([2.0, 3.0], requires_grad=true)

// 操作 1: 乘以 3
val b = a * 3

// 操作 2: 计算均值
val c = b.mean()

// 检查 grad_fn 属性
println(f"Tensor a: requires_grad={a.requires_grad}, grad_fn={a.grad_fn}")
// 预期输出: 张量 a: requires_grad=True, grad_fn=None

println(f"Tensor b: requires_grad={b.requires_grad}, grad_fn={b.grad_fn}")
// 预期输出: 张量 b: requires_grad=True, grad_fn=<MulBackward0 object at 0x...>

println(f"Tensor c: requires_grad={c.requires_grad}, grad_fn={c.grad_fn}")
// 预期输出: 张量 c: requires_grad=True, grad_fn=<MeanBackward0 object at 0x...>


import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;
import org.bytedeco.javacpp.Pointer;

/**
 * PyTorch张量梯度追踪演示:
 * 1. 创建带requires_grad=True的输入张量
 * 2. 执行乘法/均值运算,追踪梯度函数(grad_fn)
 * 3. 打印各张量的requires_grad和grad_fn属性
 */
public class TensorGradTrackingDemo {

    public static void main(String[] args) {
        // ======================== 1. 创建需要梯度的输入张量 ========================
        // 等效Scala: val a = torch.tensor([2.0, 3.0], requires_grad=true)
        // 步骤1: 创建数据数组
        float[] aData = {2.0f, 3.0f};
        // 步骤2: 配置张量选项(float32 + requires_grad=true)
        TensorOptions aOptions = new TensorOptions()
                .dtype(new ScalarTypeOptional(torch.ScalarType.Float))
                .requires_grad(new BoolOptional(true));
        // 步骤3: 创建张量
        Tensor a = torch.tensor(aData, aOptions);

        // ======================== 2. 执行张量运算 ========================
        // 操作1: 乘以3(等效Scala val b = a * 3)
        Tensor b = a.mul(new Scalar(3.0f));

        // 操作2: 计算均值(等效Scala val c = b.mean())
        Tensor c = b.mean();

        // ======================== 3. 打印各张量的梯度属性 ========================
        // 打印张量a的属性
        boolean aRequiresGrad = a.requires_grad();
        String aGradFn = getGradFnString(a.grad_fn());
        System.out.printf("Tensor a: requires_grad=%b, grad_fn=%s%n", aRequiresGrad, aGradFn);
        // 预期输出: Tensor a: requires_grad=true, grad_fn=None

        // 打印张量b的属性
        boolean bRequiresGrad = b.requires_grad();
        String bGradFn = getGradFnString(b.grad_fn());
        System.out.printf("Tensor b: requires_grad=%b, grad_fn=%s%n", bRequiresGrad, bGradFn);
        // 预期输出: Tensor b: requires_grad=true, grad_fn=<MulBackward0>

        // 打印张量c的属性
        boolean cRequiresGrad = c.requires_grad();
        String cGradFn = getGradFnString(c.grad_fn());
        System.out.printf("Tensor c: requires_grad=%b, grad_fn=%s%n", cRequiresGrad, cGradFn);
        // 预期输出: Tensor c: requires_grad=true, grad_fn=<MeanBackward0>

        // ======================== 4. 资源释放 ========================
        a.close();
        b.close();
        c.close();
        aOptions.close();
    }

    /**
     * 辅助方法:将grad_fn指针转为可读字符串(模拟Scala的grad_fn输出格式)
     * @param gradFnPtr 梯度函数指针(来自Tensor.grad_fn())
     * @return 可读的grad_fn描述字符串
     */
    private static String getGradFnString(Pointer gradFnPtr) {
        // 1. 原始输入张量的grad_fn为null → 返回"None"
        if (gradFnPtr == null || gradFnPtr.address() == 0) {
            return "None";
        }
        // 2. 获取grad_fn的类型名称(简化版,模拟Scala的输出)
        String ptrInfo = gradFnPtr.getClass().getSimpleName();
        // 适配JavaCPP的指针命名,映射为PyTorch标准的Backward名称
        if (ptrInfo.contains("MulBackward")) {
            return "<MulBackward0>";
        } else if (ptrInfo.contains("MeanBackward")) {
            return "<MeanBackward0>";
        } else {
            return String.format("<%s>", ptrInfo);
        }
    }
}

请注意以下几点:

  1. 张量 a 是图中的一个叶节点。它是用户直接创建的,而非 autograd 跟踪的操作结果。因此,它的 grad_fnNone
  2. 张量 b 是由 a 乘以 3 产生的。它的 grad_fn 指向 MulBackward0,代表乘法操作。这个对象持有对乘法输入(张量 a 和标量 3)的引用,并且知道如何计算对 a 的梯度。
  3. 张量 c 是由对 b 进行 mean 操作产生的。它的 grad_fn 指向 MeanBackward0,它知道如何计算对它的输入 b 的梯度。

这些 grad_fn 引用形成了一个链表,从输出张量(c)经过操作(MeanBackward0MulBackward0)向后追溯到输入叶张量(a)。这个链式结构就是 autograd 使用的反向计算图。

可视化前向和反向图

尽管 PyTorch 不提供像 TensorBoard 为静态图提供的图视图那样的内置实时图可视化工具,但我们可以将前面例子中构建的图进行可视化。前向传播创建张量并关联 grad_fn 对象。反向传播(c.backward())反向遍历这个结构。

a(叶节点, requires_grad=True)*b均值c(输出)grad_fn=MulBackward0grad_fn=MeanBackward0

c = (a * 3).mean() 的计算图表示。矩形是张量,椭圆形是操作。边显示数据流。grad_fn 将创建的张量链接到它们的生成操作,从而形成反向路径。

图与 Autograd

当你对一个标量张量(像我们例子中的 c,或通常是一个损失值)调用 .backward() 时,autograd 引擎会从该张量开始向后遍历图。

  1. 它调用与张量的 grad_fn 关联的函数(cMeanBackward0)。
  2. 此函数计算输出(c)对其输入(b)的梯度。
  3. 引擎随后将这些梯度进一步向后传播到输入的 grad_fn 对象。因此,为 b 计算的梯度被传递给 MulBackward0
  4. MulBackward0 计算对其输入(a)的梯度。
  5. 由于 a 是叶节点(grad_fnNone)并且 requires_grad=True,计算出的梯度会累积在 a.grad 中。

这个过程一直持续,直到所有路径都到达叶节点或不需要梯度的张量。计算图为链式法则的这种应用提供了路线图。

图的属性

  • 无环: 图必须是 DAG。循环会导致梯度计算期间的无限循环。如果某个操作创建了涉及需要梯度跟踪的节点的循环,PyTorch 将引发错误。
  • 动态: 如前所述,图结构可以根据运行时控制流发生变化。这使得像 RNNs 这样的模型能够直观地实现,其中计算依赖于序列长度。

理解计算图不仅仅是理论性的。它告诉你如何构建模型、调试梯度问题(例如,None 梯度通常意味着图的一部分已断开连接或不需要梯度),以及如何实现带有自己反向传播的自定义操作,正如我们将在本章后面看到的那样。它是使 PyTorch 自动微分得以实现的看不见的机制。

Autograd 引擎机制

收藏

autograd 是驱动动态计算图进行梯度计算的引擎。该系统是 PyTorch 自动计算梯度的基础,是训练神经网络通过反向传播不可或缺的。

其核心是,autograd 执行反向模式自动微分。当您对 requires_grad 设为 True 的张量进行操作时,PyTorch 会构建一个表示操作序列的图。此图是随着计算的发生动态构建的。autograd 引擎随后从最终输出(通常是标量损失)开始,反向遍历此图,以计算该输出相对于参数(图的叶节点,通常是模型权重和偏置)的梯度。

backward() 调用

此过程通常通过在一个张量上调用 .backward() 方法来启动,最常见的是在正向传播结束时计算的标量损失值。

import torch.*

// 示例设置
val w = torch.randn(5, 3, requires_grad=true)
val x = torch.randn(1, 5)
// 如果 x 仅是输入数据,确保它不需要梯度
// x.requires_grad_(False) // 或在创建时不设置 requires_grad

val y = x @@ w  // y 通过矩阵乘法依赖于 w
val z = y.mean() // z 是从 y 派生的标量

// 从 z 开始计算梯度
z.backward()

// 梯度现在填充在 w.grad 中
// d(z)/dw 梯度被计算并存储
println(w.grad.shape)
// 输出: torch.Size([5, 3])


import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;

import java.util.Arrays;

/**
 * PyTorch矩阵乘法反向传播+梯度计算演示:
 * 1. 创建带梯度的权重张量w + 输入张量x(无梯度)
 * 2. 矩阵乘法 → 均值运算生成标量
 * 3. 反向传播计算梯度,验证梯度形状
 */
public class TensorBackwardGradDemo {

    public static void main(String[] args) {
        // ======================== 1. 初始化张量(权重w带梯度,输入x无梯度) ========================
        // 等效Scala: val w = torch.randn(5, 3, requires_grad=true)
        TensorOptions wOptions = new TensorOptions()
                .dtype(new ScalarTypeOptional(torch.ScalarType.Float))
                .requires_grad(new BoolOptional(true)); // 权重需要梯度
        Tensor w = torch.randn(new long[]{5, 3}, wOptions);

        // 等效Scala: val x = torch.randn(1, 5)(输入数据,默认requires_grad=false)
        Tensor x = torch.randn(new long[]{1, 5}); // 创建时不设置requires_grad,默认false

        // 可选:显式确保x无梯度(等效Scala x.requires_grad_(False))
        // x.requires_grad_(false);

        // ======================== 2. 执行张量运算 ========================
        // 等效Scala: val y = x @@ w(矩阵乘法,x(1,5) × w(5,3) → y(1,3))
        // Java中矩阵乘法用mm()方法(matmul的简写,适配二维矩阵)
        Tensor y = x.mm(w);

        // 等效Scala: val z = y.mean()(标量均值,用于反向传播)
        Tensor z = y.mean();

        // ======================== 3. 反向传播计算梯度 ========================
        // 等效Scala: z.backward() → 从标量z开始反向传播,计算w的梯度
        z.backward();

        // ======================== 4. 获取并打印梯度形状 ========================
        // 等效Scala: println(w.grad.shape) → 梯度存储在w.grad中,形状与w一致[5,3]
        Tensor wGrad = w.grad(); // 获取w的梯度张量
        long[] gradShape = wGrad.sizes().vec().get(); // 获取梯度形状
        System.out.println("w.grad 形状: " + Arrays.toString(gradShape));
        // 预期输出: w.grad 形状: [5, 3]

        // (可选)打印梯度张量内容,验证梯度计算结果
        System.out.println("\nw.grad 梯度值:\n" + tensorToString(wGrad));

        // ======================== 5. 资源释放 ========================
        w.close();
        x.close();
        y.close();
        z.close();
        wGrad.close();
        wOptions.close();
    }

    /**
     * 辅助方法:将Float张量转为可读字符串(适配二维张量)
     */
    private static String tensorToString(Tensor tensor) {
        long[] shape = tensor.sizes().vec().get();
        int rows = (int) shape[0];
        int cols = (int) shape[1];
        int totalElements = rows * cols;

        // 读取Float类型数据
        float[] data = new float[totalElements];
        tensor.data().data_ptr_float().get(data);

        // 拼接为二维格式字符串
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < rows; i++) {
            sb.append("[");
            for (int j = 0; j < cols; j++) {
                sb.append(String.format("%.4f", data[i * cols + j]));
                if (j < cols - 1) sb.append(", ");
            }
            sb.append("]");
            if (i < rows - 1) sb.append("\n");
        }
        return sb.toString();
    }
}

当调用 z.backward() 时,autogradz 开始反向运算。由于 z 是一个标量,backward() 隐式使用 1.0 作为起始梯度。这意味着 ∂z∂z=1∂zz=1。如果您在一个非标量张量 t 上调用 backward(),则必须提供一个 gradient 输入参数,它应与 t 的形状相同。此输入表示某个最终标量损失 LL 相对于张量 t 的梯度,即 ∂L∂t∂tL

从本质上讲,autograd 计算向量-雅可比乘积 (VJP)。回想一下导数的链式法则。如果我们有一个标量损失 LL,它是向量 yy 的函数,L=g(y)L=g(y),并且 yy 本身是另一个向量 xx 的函数,y=f(x)y=f(x),那么 LL 相对于 xx 的梯度由下式给出:

∂L∂x=∂L∂y∂y∂x∂xL=∂yLxy

此处,∂y∂x∂xy 是 ff 相对于 xx 的雅可比矩阵,而 ∂L∂y∂yL 是一个行向量,表示 LL 相对于 yy 的梯度。由 y.backward(gradient=dL_dy) 计算的 VJP 正是乘积 vTJvTJ,其中 vv 是上游梯度(由 gradient=dL_dy 表示),而 JJ 是雅可比矩阵 ∂y∂x∂xy。在一个标量损失 zz 上调用 z.backward() 对应于使用初始梯度向量 v=[1.0]v=[1.0]。与显式构建可能庞大的雅可比矩阵相比,这种 VJP 方法在计算上更为高效。

使用 grad_fn 遍历图

每个由至少一个 requires_grad=True 的张量参与操作而产生的张量,都将具有 grad_fn 属性。此属性是指向在正向传播期间创建该张量的函数对象(例如 AddBackward0MulBackward0MmBackward0 等)的引用。重要的是,此函数对象存储对其输入的引用,并包含其相应反向操作的实现,这是梯度计算所必需的。

我们来检查之前示例中的 grad_fn 属性:

// 我们需要中间张量来检查它们的 grad_fn
val w = torch.randn(5, 3, requires_grad=true)
val x = torch.randn(1, 5)
val y = x @@ w
val z = y.mean()

println(f"y 源自: {y.grad_fn}")
// 输出: y originated from: <MmBackward0 object at 0x...>

println(f"z 源自: {z.grad_fn}")
// 输出: z originated from: <MeanBackward0 object at 0x...>

// 用户创建的叶张量没有 grad_fn
println(f"w.grad_fn: {w.grad_fn}")
// 输出: w.grad_fn: None
println(f"x.grad_fn: {x.grad_fn}")
// 输出: x.grad_fn: None


import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;
import org.bytedeco.javacpp.Pointer;

import java.util.Arrays;

/**
 * 扩展演示PyTorch张量grad_fn(梯度函数)归属:
 * 1. 严格使用指定的TensorOptions格式创建带梯度的权重张量
 * 2. 打印矩阵乘法/均值运算张量的grad_fn,验证叶子张量grad_fn为None
 * 3. 适配指定的shape获取方式(sizes().vec().get())
 */
public class TensorGradFnDetailDemo {

    public static void main(String[] args) {
        // ======================== 1. 创建张量(严格适配参考的TensorOptions格式) ========================
        // 权重张量w:带requires_grad=true,使用指定的Optional参数格式
        TensorOptions wOptions = new TensorOptions()
                .dtype(new ScalarTypeOptional(torch.ScalarType.Float)) // 参考格式:ScalarTypeOptional封装
                .requires_grad(new BoolOptional(true)); // 参考格式:BoolOptional封装(权重需要梯度)
        Tensor w = torch.randn(new long[]{5, 3}, wOptions);

        // 输入张量x:无梯度(默认requires_grad=false)
        Tensor x = torch.randn(new long[]{1, 5});

        // ======================== 2. 执行张量运算 ========================
        Tensor y = x.mm(w); // 矩阵乘法(等效Scala x @@ w)
        Tensor z = y.mean(); // 均值运算生成标量

        // ======================== 3. 打印各张量的grad_fn(梯度函数) ========================
        // 打印y的grad_fn(矩阵乘法对应的MmBackward0)
        String yGradFn = getGradFnDesc(y.grad_fn());
        System.out.printf("y 源自: %s%n", yGradFn);
        // 预期输出: y 源自: <MmBackward0>

        // 打印z的grad_fn(均值运算对应的MeanBackward0)
        String zGradFn = getGradFnDesc(z.grad_fn());
        System.out.printf("z 源自: %s%n", zGradFn);
        // 预期输出: z 源自: <MeanBackward0>

        // 打印叶子张量的grad_fn(均为None)
        String wGradFn = getGradFnDesc(w.grad_fn());
        System.out.printf("w.grad_fn: %s%n", wGradFn);
        // 预期输出: w.grad_fn: None

        String xGradFn = getGradFnDesc(x.grad_fn());
        System.out.printf("x.grad_fn: %s%n", xGradFn);
        // 预期输出: x.grad_fn: None

        // ======================== 4. 验证shape获取方式(参考格式:sizes().vec().get()) ========================
        long[] wShape = w.sizes().vec().get(); // 参考格式:通过vec().get()获取形状数组
        System.out.println("\nw 的形状(参考格式获取): " + Arrays.toString(wShape));
        // 预期输出: w 的形状(参考格式获取): [5, 3]

        // ======================== 5. 资源释放 ========================
        w.close();
        x.close();
        y.close();
        z.close();
        wOptions.close();
    }

    /**
     * 辅助方法:将grad_fn指针转为PyTorch标准的梯度函数描述字符串
     * @param gradFnPtr 梯度函数指针(Tensor.grad_fn()返回值)
     * @return 可读的grad_fn描述(匹配Scala输出格式)
     */
    private static String getGradFnDesc(Pointer gradFnPtr) {
        // 叶子张量的grad_fn为null/空指针 → 返回None
        if (gradFnPtr == null || gradFnPtr.address() == 0) {
            return "None";
        }

        // 获取梯度函数指针的类型信息,映射为PyTorch标准Backward名称
        String ptrClassName = gradFnPtr.getClass().getSimpleName();
        if (ptrClassName.contains("MmBackward")) {
            return "<MmBackward0>"; // 矩阵乘法反向传播函数
        } else if (ptrClassName.contains("MeanBackward")) {
            return "<MeanBackward0>"; // 均值运算反向传播函数
        } else {
            return String.format("<%s>", ptrClassName); // 其他梯度函数兜底
        }
    }
}

grad_fn 对象形成一个有向无环图 (DAG),记录计算历史。当执行 z.backward() 时:

  1. autograd 引擎从目标张量 z 开始。
  2. 它访问 z.grad_fn(即 MeanBackward0)。使用传入梯度(隐式为 1.0),MeanBackward0 计算均值操作相对于其输入 y 的梯度。我们称之为 ∂z∂y∂yz
  3. 引擎随后移动到 y 指示的图中下一个节点。它使用 y.grad_fnMmBackward0)和传入梯度 ∂z∂y∂yz 来计算矩阵乘法相对于其输入 xw 的梯度。这涉及计算 ∂z∂y∂y∂w∂yzwy 和 ∂z∂y∂y∂x∂yzxy
  4. 由于 w 是一个叶张量(由用户创建)并且 requires_grad=True,计算出的梯度 ∂z∂w∂wz累积w.grad 属性中。
  5. 由于 xrequires_grad=False,沿此路径的梯度计算停止,x.grad 保持为 None

此递归过程持续进行,将链式法则反向应用于图,直到所有返回到需要梯度的叶张量的路径都已处理完毕。

正向传播z (标量)MeanBackward0grad_z=1.0ymean()MatMulBackward0grad_yw (叶节点)requires_grad=True@AccumulateGradgrad_wx (叶节点)requires_grad=False@grad_ygrad_wgrad_x (忽略)存储到 w.grad

z.backward() 启动的反向传播图示。Autograd 沿着 grad_fn 指针(由从张量指向函数节点的箭头表示)从输出 z 反向回溯。计算出的相对于 w 的梯度累积到 w.grad 中。对于涉及 requires_grad=False 的张量(如 x)的路径,计算会停止。

梯度累积

autograd 行为的一个重要方面是梯度会累积到叶张量的 .grad 属性中。每次调用 backward() 时,为参数新计算的梯度都会添加到其 .grad 属性中当前存储的值上。如果 .grad 最初为 None,则会用第一次计算的梯度对其进行初始化。

这种设计选择在典型的训练循环中需要明确的管理。在计算损失并对新批次数据执行反向传播之前,您必须重置所有模型参数的梯度。否则,当前批次的梯度将与前一批次的梯度相加,导致错误的权重更新。执行此操作的标准方法是使用优化器的 zero_grad() 方法:

// 假设已定义 model、optimizer、criterion、dataloader

for inputs, targets <- dataloader:
    // 1. 重置上一迭代的梯度
    optimizer.zero_grad()

    // 2. 执行正向传播
    val outputs = model(inputs)
    val loss = criterion(outputs, targets)

    // 3. 执行反向传播以计算梯度
    loss.backward()

    // 4. 使用计算出的梯度更新模型参数
    optimizer.step()


import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;

import java.util.Iterator;

/**
 * PyTorch标准训练循环Java实现:
 * 1. 迭代DataLoader获取输入/目标张量
 * 2. 梯度清零 → 前向传播 → 反向传播 → 优化器更新参数
 * 3. 完整的资源管理与异常处理
 */
public class ModelTrainingLoopDemo {

    public static void main(String[] args) {
        // ======================== 模拟定义训练依赖(需根据实际场景替换) ========================
        // 1. 模拟模型(示例:简单线性模型)
        LinearImpl model = new LinearImpl(10, 2); // 输入10维,输出2维(替换为你的实际模型)
        // 2. 模拟优化器(SGD,学习率0.01)
        SGDOptions options = new SGDOptions(0.01f);
        Optimizer optimizer = new SGD(model.parameters(),options);
        // 3. 模拟损失函数(MSE)
        MSELossImpl criterion = new MSELossImpl();
        // 4. 模拟DataLoader(需替换为你的实际DataLoader)
        // 注:此处为演示,实际需通过Dataset+DataLoader构建
        JavaRandomDataLoader dataloader = createDummyDataLoader();

        // ======================== 核心训练循环(等效Scala代码) ========================
        // 获取DataLoader迭代器(等效Scala for inputs, targets <- dataloader)
        Iterator<DataLoader.Batch> iterator = dataloader.iterator();
        while (iterator.hasNext()) {
            DataLoader.Batch batch = null;
            Tensor inputs = null;
            Tensor targets = null;
            Tensor outputs = null;
            Tensor loss = null;

            try {
                // 1. 获取批次数据(inputs + targets)
                batch = iterator.next();
                inputs = batch.data().get(0); // 批次第1个张量:输入数据
                targets = batch.data().get(1); // 批次第2个张量:目标标签

                // 2. 重置上一迭代的梯度(等效Scala optimizer.zero_grad())
                optimizer.zero_grad();

                // 3. 执行正向传播(等效Scala val outputs = model(inputs))
                outputs = model.forward(inputs);
                // 计算损失(等效Scala val loss = criterion(outputs, targets))
                loss = criterion.forward(outputs, targets);

                // 4. 执行反向传播计算梯度(等效Scala loss.backward())
                loss.backward();

                // 5. 使用梯度更新模型参数(等效Scala optimizer.step())
                optimizer.step();

                // 可选:打印迭代损失(监控训练过程)
                System.out.printf("迭代损失: %.4f%n", loss.item().toFloat());

            } catch (Exception e) {
                // 异常处理:避免单批次错误导致整个训练中断
                System.err.println("训练迭代出错: " + e.getMessage());
                e.printStackTrace();
            } finally {
                // 关键:释放当前迭代的临时张量资源(避免JNI内存泄漏)
                if (loss != null) loss.close();
                if (outputs != null) outputs.close();
                if (targets != null) targets.close();
                if (inputs != null) inputs.close();
                if (batch != null) batch.close();
            }
        }

        // ======================== 释放核心资源 ========================
        dataloader.close();
        criterion.close();
        optimizer.close();
        model.close();

        System.out.println("训练循环执行完成!");
    }

    /**
     * 辅助方法:创建模拟DataLoader(用于演示,需替换为实际Dataset)
     * 生成批次数据:inputs(4,10) + targets(4,2)
     */
    private static JavaRandomDataLoader createDummyDataLoader() {
        // 1. 创建模拟数据集(示例:10个样本,输入10维,目标2维)
        Tensor dummyInputs = torch.randn(new long[]{10, 10});
        Tensor dummyTargets = torch.randn(new long[]{10, 2});
        RandomSampler sampler = new RandomSampler(4); // 每批次4个样本
        // 2. 封装为TensorDataset
        TensorDataset dataset = new TensorDataset(dummyInputs, dummyTargets);
        // 3. 配置DataLoader参数(批次大小4,是否打乱)
        DataLoaderOptions options = new DataLoaderOptions();
        options        .batch_size().put(4);
        options.enforce_ordering().put(false);
//                .shuffle(true);
        // 4. 创建DataLoader
        JavaRandomDataLoader dataloader = new JavaRandomDataLoader(dataset,sampler, options);

        // 释放临时资源(DataLoader已持有数据集引用)
        dummyInputs.close();
        dummyTargets.close();
        options.close();

        return dataloader;
    }
}

此累积行为可有意用于诸如实现梯度累积以模拟更大批次大小的场景,尤其是在 GPU 内存有限时。在这种情况下,您将在调用 optimizer.step()optimizer.zero_grad() 之前,执行多次正向和反向传播同时累积梯度。

控制梯度计算

PyTorch 对 autograd 引擎的操作提供精细控制:

  • requires_grad 标志: 这个基本张量属性决定了涉及该张量的操作是否应跟踪以进行梯度计算。直接创建的张量通常默认为 requires_grad=Falsetorch.nn.Module 中的参数会自动设置为 requires_grad=True。您可以使用 my_tensor.requires_grad_(True)(原地操作)手动更改此标志。

  • torch.no_grad() 上下文管理器: 这是一个广泛使用的工具,用于在特定代码块内禁用梯度计算。在 with torch.no_grad(): 块内执行的任何操作都不会被 autograd 跟踪,即使输入张量具有 requires_grad=True。这会大幅减少内存消耗并加速计算,使其非常适合模型评估、推理或任何不需要梯度的代码部分。

    with torch.no_grad():
        // 此处的操作将不被跟踪
        val predictions = model(validation_data)
        // 内存使用量更低,计算速度更快
    
    private static Tensor inferWithoutGrad(Module model, Tensor input) {
        try (NoGradGuard guard = new NoGradGuard()) { // 试用资源语法糖自动关闭
            return model.forward(input);
        }
    }
    
  • torch.enable_grad() 上下文管理器: 相反,此上下文管理器在其范围内重新启用梯度跟踪。如果您需要为恰好位于较大 torch.no_grad() 块内的一小部分代码计算梯度,这会很有用。

  • .detach() 方法: 调用 tensor.detach() 会创建一个新张量,它与原始张量共享底层数据存储,但已明确与计算图历史分离。新张量将具有 requires_grad=False。梯度不会通过此分离的张量流回原始图。当您需要使用张量的值而不影响与其历史相关的梯度计算时,这会很有用。

扎实掌握这些 autograd 机制对于调试训练问题(如梯度爆炸或梯度消失)、优化内存使用、理解复杂模型的行为以及有效实现自定义操作或训练循环非常有价值。尽管 autograd 会自动处理微分的复杂性,但了解其工作原理有助于您更熟练地使用 PyTorch。

自定义 Autograd 函数:前向与反向

收藏

虽然 PyTorch 的 autograd 引擎能自动处理各种内置操作的微分,但有时你会需要更多控制,或需要为 PyTorch 未知的操作定义梯度。这可能发生在以下情况:

  • 实现 PyTorch 中没有的新颖数学操作。
  • 集成来自外部库(例如:自定义 C++ 或 CUDA 核,稍后介绍)的代码,以执行部分计算。
  • 通过提供比自动推导更高效的梯度计算来优化性能。
  • 为数学上不可微分的操作定义“梯度”,或你希望覆盖标准导数的情况(例如:对离散操作使用直通估计器)。

对于这些情况,PyTorch 提供了一种机制,通过继承 torch.autograd.Function 来定义自己的可微分操作。这个类允许你精确指定前向计算的执行方式以及在反向传播期间如何计算梯度。

区分 torch.autograd.Functiontorch.nn.Module 很重要。nn.Module 通常表示神经网络中包含参数(torch.nn.Parameter)的层,并且可以由其他模块或函数组成;而 autograd.Function 定义的是一个单一、特定的计算操作及其梯度。它本身不持有参数。

定义前向传播

要创建一个自定义操作,你需要定义一个继承自 torch.autograd.Function 的类。前向计算的核心在于实现一个名为 forward 的静态方法。

import torch.*

class MyLinearFunction extends torch.autograd.Function:
    @staticmethod
    def forward(ctx, input_tensor, weight, bias=None):
        // ctx 是一个上下文对象,用于保存反向传播所需的信息
        // input_tensor, weight, bias 是函数的输入

        // 执行操作
        val output = input_tensor.mm(weight.t()) // 矩阵乘法
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)

        // 保存反向传播所需的张量
        // 我们需要 input_tensor 和 weight 来计算梯度
        ctx.save_for_backward(input_tensor, weight, bias)

        return output
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.ByRef;
import org.bytedeco.javacpp.annotation.ByVal;
import org.bytedeco.javacpp.annotation.Namespace;
import org.bytedeco.pytorch.*;
import static org.bytedeco.pytorch.global.torch.*;

import org.bytedeco.javacpp.*;
import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;

import static org.bytedeco.pytorch.global.torch.*;

/**
 * 终极适配 javacpp-pytorch 2.10-1.5.13 版本
 * 移除所有未定义方法(set_ctx/Variable/VariableVector),仅用原生 API
 */
@Namespace("torch::autograd")
public class MyLinearFunction extends Function {

    /**
     * Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}.
     *
     * @param p
     */
    public MyLinearFunction(Pointer p) {
        super(p);
    }

    // ------------------------------ 重写 forward 方法(仅用原生 API) ------------------------------
  
    public @ByVal TensorVector forward(@ByRef AutogradContext ctx, @ByVal TensorVector inputs) {
        // 1. 解析输入(input + weight + 可选bias)
        if (inputs.size() < 2) {
            throw new IllegalArgumentException("至少需要 input 和 weight 两个张量");
        }
        Tensor input = inputs.get(0);
        Tensor weight = inputs.get(1);
        Tensor bias = inputs.size() > 2 ? inputs.get(2) : null;

        // 2. 前向计算核心逻辑(与Python完全一致)
        Tensor weight_t = weight.t();
        Tensor output = input.mm(weight_t);

        // 3. 处理偏置
        if (bias != null && bias.defined() && bias.numel() > 0) {
            Tensor bias_expanded = bias.unsqueeze(0).expand_as(output);
            output = output.add(bias_expanded);
            bias_expanded.close(); // 释放临时张量
        }

        // 4. 保存张量供反向传播(仅用TensorVector,源码原生方法)
        TensorVector toSave = new TensorVector();
        toSave.push_back(input);
        toSave.push_back(weight);
        if (bias != null && bias.defined() && bias.numel() > 0) {
            toSave.push_back(bias);
        }
        ctx.save_for_backward(toSave);
        toSave.close(); // 释放容器

        // 5. 封装输出
        TensorVector outputs = new TensorVector();
        outputs.push_back(output);
        weight_t.close(); // 释放临时张量

        return outputs;
    }
}

forward 方法的重要方面:

  1. 静态方法: 它必须声明为 @staticmethod。它不作用于类的实例,而是定义操作本身。
  2. ctx 参数: 第一个参数始终是 ctx,一个上下文对象。它的主要作用是充当 forwardbackward 传播之间的桥梁。你使用 ctx 来存储在 forward 期间计算的、稍后在 backward 中计算梯度所需的任何张量或信息。
  3. 输入参数: 紧随 ctx 之后,你列出函数接受的输入参数。这些可以是张量或其他 Python 对象。
  4. 计算:forward 内部,你使用标准 PyTorch 张量操作或可能调用外部库来实现操作的逻辑。
  5. ctx.save_for_backward(\*tensors) 这是保存梯度计算所需张量的重要方法。只保存必需的内容以避免不必要的内存消耗。PyTorch 处理好记录,以确保这些张量在 backward 传播中可用。你也可以将非张量属性直接保存到 ctx 上(例如,ctx.some_flag = True),这些属性稍后可在 backward 中获取。
  6. 返回值: 该方法应返回操作产生的一个或多个输出张量。

定义反向传播

forward 的对应部分是静态的 backward 方法。此方法定义了如何在给出损失函数对 forward 方法的 输出 的梯度的前提下,计算损失函数对 forward 方法的 输入 的梯度。

import torch

class MyLinearFunction extends torch.autograd.Function:
    @staticmethod
    def forward(ctx, input_tensor, weight, bias=None):
        output = input_tensor.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        // 保存 input_tensor 和 weight。如果提供了 bias,也会保存。
        val saved_tensors = List(input_tensor, weight)
        if bias is not None then
            saved_tensors.append(bias)
        ctx.save_for_backward(saved_tensors: _*)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        // grad_output 是损失函数相对于 forward 输出的梯度
        // 我们需要计算损失函数相对于 forward 输入的梯度:
        // input_tensor, weight, bias

        // 获取已保存的张量
        val saved_tensors = ctx.saved_tensors
        val input_tensor = saved_tensors(0)
        val weight = saved_tensors(1)
        val bias = if saved_tensors.length > 2 then saved_tensors(2) else None

        // 使用链式法则计算梯度
        // dL/d(输入) = dL/d(输出) * d(输出)/d(输入)
        // d(输出)/d(输入) = 权重^T
        val grad_input = grad_output.mm(weight)

        // dL/d(权重) = dL/d(输出) * d(输出)/d(权重)
        // d(输出)/d(权重) = 输入^T
        val grad_weight = grad_output.t().mm(input_tensor)

        // dL/d(偏置) = dL/d(输出) * d(输出)/d(偏置)
        // d(输出)/d(偏置) = 1
        val grad_bias = if bias is not None then
            // 在批处理维度上对梯度求和
            grad_output.sum(0)
        else
            None

        // 按相同顺序返回 forward 的每个输入参数的梯度
        // 对于不需要梯度(如 ctx)或非张量输入,返回 None。
        // 返回值的数量必须与 forward 输入的数量匹配。
        return grad_input, grad_weight, grad_bias

import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.ByRef;
import org.bytedeco.javacpp.annotation.ByVal;
import org.bytedeco.javacpp.annotation.Namespace;
import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.LinearImpl;
import org.bytedeco.pytorch.global.torch;

import static org.bytedeco.pytorch.global.torch.*;

@Namespace("torch::autograd")
public class MyLinearFunctionV2 extends Function {

    /**
     * Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}.
     *
     * @param p
     */
    public MyLinearFunctionV2(Pointer p) {
        super(p);
    }

    // ------------------------------ 重写 forward 方法(仅用原生 API) ------------------------------
    public @ByVal TensorVector forward(@ByRef AutogradContext ctx, @ByVal TensorVector inputs) {
        // 1. 解析输入(input + weight + 可选bias)
        if (inputs.size() < 2) {
            throw new IllegalArgumentException("至少需要 input 和 weight 两个张量");
        }
        Tensor input = inputs.get(0);
        Tensor weight = inputs.get(1);
        Tensor bias = inputs.size() > 2 ? inputs.get(2) : null;

        // 2. 前向计算核心逻辑(与Python完全一致)
        Tensor weight_t = weight.t();
        Tensor output = input.mm(weight_t);

        // 3. 处理偏置
        if (bias != null && bias.defined() && bias.numel() > 0) {
            Tensor bias_expanded = bias.unsqueeze(0).expand_as(output);
            output = output.add(bias_expanded);
            bias_expanded.close(); // 释放临时张量
        }

        // 4. 保存张量供反向传播(仅用TensorVector,源码原生方法)
        TensorVector toSave = new TensorVector();
        toSave.push_back(input);
        toSave.push_back(weight);
        if (bias != null && bias.defined() && bias.numel() > 0) {
            toSave.push_back(bias);
        }
        ctx.save_for_backward(toSave);
        toSave.close(); // 释放容器

        // 5. 封装输出
        TensorVector outputs = new TensorVector();
        outputs.push_back(output);
        weight_t.close(); // 释放临时张量

        return outputs;
    }

    // ------------------------------ 重写 backward 方法(仅用原生 API) ------------------------------
    public @ByVal TensorVector backward(@ByRef AutogradContext ctx, @ByVal TensorVector grad_outputs) {
        // 1. 解析输出梯度
        if (grad_outputs.empty()) {
            throw new IllegalArgumentException("梯度输入不能为空");
        }
        Tensor grad_output = grad_outputs.get(0);

        // 2. 获取前向保存的张量(源码原生方法:get_saved_variables)
        TensorVector saved = ctx.get_saved_variables();
        if (saved.size() < 2) {
            throw new IllegalStateException("前向未保存足够张量");
        }
        Tensor input = saved.get(0);
        Tensor weight = saved.get(1);
        Tensor bias = saved.size() > 2 ? saved.get(2) : null;

        // 3. 计算梯度(与Python完全一致)
        Tensor grad_input = grad_output.mm(weight);
        Tensor grad_weight = grad_output.t().mm(input);
        Tensor grad_bias = torch.empty();
        if (bias != null && bias.defined() && bias.numel() > 0) {
            grad_bias = grad_output.sum(0);
        }

        // 4. 封装梯度返回(顺序与输入一致)
        TensorVector grads = new TensorVector();
        grads.push_back(grad_input);
        grads.push_back(grad_weight);
        grads.push_back(grad_bias);

        // 5. 释放临时资源
        saved.close();
        grad_output.t().close(); // 释放转置张量

        return grads;
    }

    // ------------------------------ 简化调用:修复梯度传递逻辑 ------------------------------
    /**
     * 无偏置调用(修复梯度传递)
     */
    public static Tensor apply(Tensor input, Tensor weight) {
        TensorVector inputs = new TensorVector();
        inputs.push_back(input);
        inputs.push_back(weight);

        // 关键:创建Function并关联到autograd图,确保梯度能反向传播
        MyLinearFunctionV2 func = new MyLinearFunctionV2(new Pointer());
        AutogradContext ctx = new AutogradContext();
        TensorVector outputs = func.forward(ctx, inputs);

        // 手动关联梯度函数(核心修复:确保反向传播能找到自定义backward)
//        outputs.get(0).set_gradient_function(func);

        inputs.close();
        return outputs.get(0);
    }

    /**
     * 有偏置调用(修复梯度传递)
     */
    public static Tensor apply(Tensor input, Tensor weight, Tensor bias) {
        TensorVector inputs = new TensorVector();
        inputs.push_back(input);
        inputs.push_back(weight);
        inputs.push_back(bias);

        MyLinearFunctionV2 func = new MyLinearFunctionV2(new Pointer());
        AutogradContext ctx = new AutogradContext();
        TensorVector outputs = func.forward(ctx, inputs);

        // 关键:关联梯度函数到输出张量
//        outputs.get(0).set_gradient_function(func);

        inputs.close();
        return outputs.get(0);
    }

// 验证通过
    public static void main(String[] args) {
        try {
            // ========== 1. 创建固定张量(无梯度,仅验证前向计算) ==========
//            TensorOptions options = TensorOptions().dtype(kFloat).device(kCPU);
            TensorOptions options = new TensorOptions().dtype(new ScalarTypeOptional((torch.kFloat())));

            // 输入:2x3
            float[] inputData = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
            Tensor input = from_blob(new FloatPointer(inputData), new long[]{2, 3}, options).clone();

            // 权重:2x3
            float[] weightData = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f};
            Tensor weight = from_blob(new FloatPointer(weightData), new long[]{2, 3}, options).clone();

            // 偏置:2
            float[] biasData = {0.01f, 0.02f};
            Tensor bias = from_blob(new FloatPointer(biasData), new long[]{2}, options).clone();

            // ========== 2. 手动计算标准答案 ==========
            System.out.println("=== 手动计算标准答案 ===");
            // 前向计算公式:input × weight.T + bias
            float out1_1 = 1*0.1f + 2*0.2f + 3*0.3f + 0.01f; // 1.41
            float out1_2 = 1*0.4f + 2*0.5f + 3*0.6f + 0.02f; // 3.22
            float out2_1 = 4*0.1f + 5*0.2f + 6*0.3f + 0.01f; // 3.21
            float out2_2 = 4*0.4f + 5*0.5f + 6*0.6f + 0.02f; // 7.72
            System.out.printf("手动计算结果:\n[%.4f, %.4f]\n[%.4f, %.4f]\n", out1_1, out1_2, out2_1, out2_2);

            // ========== 3. 自定义实现计算 ==========
            System.out.println("\n=== 自定义 MyLinearFunction 计算结果 ===");
            Tensor outputCustom = MyLinearFunctionV2.apply(input, weight, bias);

            // 直接获取张量数值(避免 clone() 报错)
            FloatPointer outputPtr = outputCustom.data_ptr_float();
            float c1_1 = outputPtr.get(0);
            float c1_2 = outputPtr.get(1);
            float c2_1 = outputPtr.get(2);
            float c2_2 = outputPtr.get(3);
            System.out.printf("自定义计算结果:\n[%.4f, %.4f]\n[%.4f, %.4f]\n", c1_1, c1_2, c2_1, c2_2);

            // ========== 4. 数值对比验证 ==========
            System.out.println("\n=== 结果验证 ===");
            float eps = 1e-4f;
            boolean pass1 = Math.abs(c1_1 - out1_1) < eps;
            boolean pass2 = Math.abs(c1_2 - out1_2) < eps;
            boolean pass3 = Math.abs(c2_1 - out2_1) < eps;
            boolean pass4 = Math.abs(c2_2 - out2_2) < eps;

            System.out.printf("第一个值:自定义=%.4f, 标准答案=%.4f → %s\n", c1_1, out1_1, pass1 ? "PASS" : "FAIL");
            System.out.printf("第二个值:自定义=%.4f, 标准答案=%.4f → %s\n", c1_2, out1_2, pass2 ? "PASS" : "FAIL");
            System.out.printf("第三个值:自定义=%.4f, 标准答案=%.4f → %s\n", c2_1, out2_1, pass3 ? "PASS" : "FAIL");
            System.out.printf("第四个值:自定义=%.4f, 标准答案=%.4f → %s\n", c2_2, out2_2, pass4 ? "PASS" : "FAIL");

            if (pass1 && pass2 && pass3 && pass4) {
                System.out.println("\n✅ 所有验证通过!自定义 MyLinearFunction 前向计算完全正确");
            } else {
                System.out.println("\n❌ 验证失败!前向计算存在错误");
            }

        } catch (Exception e) {
            System.err.println("执行错误:" + e.getMessage());
            e.printStackTrace();
        } finally {
            // 释放资源(简化版,实际可根据需要补充)
            System.gc();
        }
    }
    
    /***
     *
     === 手动计算标准答案 ===
     手动计算结果:
     [1.4100, 3.2200]
     [3.2100, 7.7200]

     === 自定义 MyLinearFunction 计算结果 ===
     自定义计算结果:
     [1.4100, 3.2200]
     [3.2100, 7.7200]

     === 结果验证 ===
     第一个值:自定义=1.4100, 标准答案=1.4100 → PASS
     第二个值:自定义=3.2200, 标准答案=3.2200 → PASS
     第三个值:自定义=3.2100, 标准答案=3.2100 → PASS
     第四个值:自定义=7.7200, 标准答案=7.7200 → PASS

backward 方法的重要方面:

  1. 静态方法:forward 类似,它必须是一个 @staticmethod
  2. ctx 参数: 第一个参数再次是上下文对象 ctx,用于获取已保存的信息。
  3. grad_output 参数: 紧随 ctx 之后,它接收表示最终损失函数相对于 forward 方法每个输出的梯度的参数。如果 forward 返回单个张量,backward 接收单个 grad_output 张量。如果 forward 返回多个张量,backward 接收多个梯度张量,每个输出对应一个,按相应顺序排列。这些梯度(∂L∂输出∂输出∂L)由 autograd 引擎在反向传播期间提供。
  4. ctx.saved_tensors 你使用 ctxsaved_tensors 属性来获取在 forward 中保存的张量。它们以元组形式返回,顺序与保存时相同。直接保存到 ctx 上的任何非张量属性也可以被访问(例如,ctx.some_flag)。
  5. 梯度计算: 这是你实现核心梯度逻辑的地方,通常应用链式法则:∂L∂输入=∂L∂输出×∂输出∂输入∂输入∂L=∂输出∂L×∂输入∂输出。你使用传入的 grad_output (∂L∂输出∂输出∂L)和从 ctx 中获取的张量(或原始输入,如果已保存)来计算 ∂输出∂输入∂输入∂输出。
  6. 返回值: backward 方法必须forward 方法的每个输入参数返回一个梯度,顺序必须完全相同。
    • 如果输入张量需要梯度(requires_grad=True),返回计算出的梯度张量。
    • 如果输入张量不需要梯度(requires_grad=False),你可以返回 None。PyTorch 通常通过不保存仅用于计算不需要梯度的输入的张量来进行优化。
    • 如果输入不是张量(例如,布尔标志、整型维度),返回 None
    • backward 的返回值数量必须精确匹配 forward 接受的参数数量(不包括 ctx)。

前向传播反向传播输入 (input_tensor, weight, bias)MyLinearFunction.forward()ctx(save_for_backward(input_tensor, weight, bias))保存张量输出张量ctx(saved_tensors =(input_tensor, weight, bias))连接grad_output(dL/d输出)触发MyLinearFunction.backward()梯度(dL/dinput_tensor,dL/dweight,dL/dbias)返回输入梯度获取张量

该图说明了流程:输入进入 forward,它计算输出并通过 ctx 保存必要的张量。之后,相对于输出的梯度(grad_output)流入 backward,它从 ctx 获取已保存的张量并计算相对于原始输入的梯度。

使用自定义函数

你不会直接调用 forwardbackward 方法。相反,你使用 apply 类方法。这个方法接受与你的 forward 函数相同的参数(不包括 ctx),执行前向传播,并设置必要的记录,以便 autograd 在需要时知道调用你的 backward 方法。

// 使用示例
val input_features = 10
val output_features = 5
val batch_size = 3

// 创建需要梯度的张量
val x = torch.randn(batch_size, input_features, requires_grad=true)
val w = torch.randn(output_features, input_features, requires_grad=true) // 注意:用于 mm(weight.t()) 的形状
val b = torch.randn(output_features, requires_grad=True)

// 应用自定义函数
// 使用 MyLinearFunction.apply,而不是直接调用 MyLinearFunction.forward
val y = MyLinearFunction.apply(x, w, b)

// 示例:计算一个虚拟损失并反向传播
val loss = y.mean()
loss.backward()

// 检查梯度(可选)
println("x 的梯度:", x.grad is not None)
println("w 的梯度:", w.grad is not None)
println("b 的梯度:", b.grad is not None)


import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;

import java.util.Arrays;

/**
 * 自定义Linear Function + 使用示例 Java实现:
 * 1. 实现MyLinearFunction(线性变换:y = x @ w.T + b)
 * 2. 创建带梯度的输入张量,调用自定义Function
 * 3. 损失计算+反向传播,验证梯度回传
 */
public class CustomLinearFunctionDemo {

    // ======================== 第一步:实现自定义MyLinearFunction(核心) ========================
    static class MyLinearFunction extends Function {
        // 前向传播缓存:保存输入张量,用于反向传播计算梯度
        private Tensor xCache;
        private Tensor wCache;

        /**
         * 前向传播:线性变换 y = x * w.T + b
         * @param inputs 输入张量列表:[x, w, b]
         * @return 输出张量y
         */
//        @Override
        public Tensor forward(IValueVector inputs) {
            // 提取输入张量
            Tensor x = inputs.get(0).toTensor();
            Tensor w = inputs.get(1).toTensor();
            Tensor b = inputs.get(2).toTensor();

            // 缓存输入张量(反向传播需要)
            this.xCache = x.clone(); // 克隆避免原张量被修改
            this.wCache = w.clone();

            // 核心计算:x @ w.T + b
            Tensor wT = w.t(); // w转置(output_features, input_features)→ (input_features, output_features)
            Tensor y = x.mm(wT).add(b); // x(batch, in) × wT(in, out) + b(out) → y(batch, out)

            // 释放临时张量
            wT.close();
            return y;
        }

        /**
         * 反向传播:手动计算梯度并返回
         * @param gradOutput 输出张量y的梯度
         * @return 输入张量[x, w, b]的梯度列表
         */
//        @Override
        public IValueVector backward(IValue gradOutput) {
            Tensor gradY = gradOutput.toTensor();
            long batchSize = xCache.size(0);

            // 1. 计算b的梯度:gradY按批次维度求和(均值反向传播的梯度缩放)
            Tensor gradB = gradY.sum(0); // sum(dim=0) → (output_features,)

            // 2. 计算w的梯度:gradY.T × x → (out, batch) × (batch, in) → (out, in)
            Tensor gradYT = gradY.t();
            Tensor gradW = gradYT.mm(xCache);

            // 3. 计算x的梯度:gradY × w → (batch, out) × (out, in) → (batch, in)
            Tensor gradX = gradY.mm(wCache);

            // 梯度缩放(因前向损失是mean,反向梯度需除以批次大小)
            float scale = 1.0f / (float) batchSize;
            gradX.mul_(new Scalar(scale));
            gradW.mul_(new Scalar(scale));
            gradB.mul_(new Scalar(scale));

            // 构建梯度返回列表(顺序与forward输入一致:x, w, b)
            IValueVector gradInputs = new IValueVector(3);
            gradInputs.push_back(new IValue(gradX));
            gradInputs.push_back(new IValue(gradW));
            gradInputs.push_back(new IValue(gradB));

            // 释放临时张量
            gradYT.close();
            gradY.close();
            xCache.close();
            wCache.close();

            return gradInputs;
        }

        /**
         * 静态apply方法(等效Scala的MyLinearFunction.apply)
         * @param x 输入张量 (batch_size, input_features)
         * @param w 权重张量 (output_features, input_features)
         * @param b 偏置张量 (output_features)
         * @return 输出张量y
         */
        public static Tensor apply(Tensor x, Tensor w, Tensor b) {
            IValueVector inputs = new IValueVector(3);
            inputs.push_back(new IValue(x));
            inputs.push_back(new IValue(w));
            inputs.push_back(new IValue(b));

            MyLinearFunction func = new MyLinearFunction();
            Tensor output = func.apply(inputs).toTensor();

            // 释放临时资源
            inputs.close();
            return output;
        }
    }

    // ======================== 第二步:自定义Function使用示例(等效Scala代码) ========================
    public static void main(String[] args) {
        // 1. 配置参数(等效Scala的val定义)
        long inputFeatures = 10;
        long outputFeatures = 5;
        long batchSize = 3;

        // 2. 创建需要梯度的张量(等效Scala requires_grad=true)
        TensorOptions tensorOptions = new TensorOptions()
                .dtype(new ScalarTypeOptional(torch.ScalarType.Float))
                .requires_grad(new BoolOptional(true));

        Tensor x = torch.randn(new long[]{batchSize, inputFeatures}, tensorOptions);
        Tensor w = torch.randn(new long[]{outputFeatures, inputFeatures}, tensorOptions); // 形状:(out, in)
        Tensor b = torch.randn(new long[]{outputFeatures}, tensorOptions);

        // 3. 应用自定义函数(等效Scala MyLinearFunction.apply(x, w, b))
        Tensor y = MyLinearFunction.apply(x, w, b);
        System.out.println("自定义Linear输出y的形状: " + Arrays.toString(y.sizes().vec().get())); // 预期:[3,5]

        // 4. 计算虚拟损失并反向传播(等效Scala loss = y.mean(); loss.backward())
        Tensor loss = y.mean();
        loss.backward();

        // 5. 检查梯度是否存在(等效Scala println("x 的梯度:", x.grad is not None))
        boolean xHasGrad = x.grad() != null && x.grad().numel() > 0;
        boolean wHasGrad = w.grad() != null && w.grad().numel() > 0;
        boolean bHasGrad = b.grad() != null && b.grad().numel() > 0;

        System.out.println("x 的梯度: " + xHasGrad); // 输出:true
        System.out.println("w 的梯度: " + wHasGrad); // 输出:true
        System.out.println("b 的梯度: " + bHasGrad); // 输出:true

        // 可选:打印梯度形状,验证梯度维度正确性
        System.out.println("\nx.grad 形状: " + Arrays.toString(x.grad().sizes().vec().get())); // [3,10]
        System.out.println("w.grad 形状: " + Arrays.toString(w.grad().sizes().vec().get())); // [5,10]
        System.out.println("b.grad 形状: " + Arrays.toString(b.grad().sizes().vec().get())); // [5]

        // 6. 资源释放
        loss.close();
        y.close();
        b.close();
        w.close();
        x.close();
        tensorOptions.close();
    }
}

调用 MyLinearFunction.apply(x, w, b) 会执行在 MyLinearFunction.forward 中定义的前向计算,并在计算图中注册该操作。当稍后调用 loss.backward() 时,autograd 引擎会遇到这个自定义操作,并使用适当的 grad_output 调用 MyLinearFunction.backward

使用 gradcheck 验证正确性

正确实现 backward 传播非常必要且容易出错。PyTorch 提供了一个实用函数 torch.autograd.gradcheck 来帮助验证你的实现。gradcheck 通过轻微扰动每个输入(有限差分)来数值计算梯度,并将这些数值梯度与你的 backward 函数计算的解析梯度进行比较。

import torch.autograd.gradcheck

// 为 gradcheck 创建输入。通常需要双精度以确保稳定性。
val x_check = torch.randn(batch_size, input_features, dtype=torch.double, requires_grad=true)
val w_check = torch.randn(output_features, input_features, dtype=torch.double, requires_grad=true)
val b_check = torch.randn(output_features, dtype=torch.double, requires_grad=True)

// 定义要测试的函数(使用 apply)
test_func = MyLinearFunction.apply

// 执行检查
// inputs 是一个包含函数参数的元组
val inputs = (x_check, w_check, b_check)
val is_correct = gradcheck(test_func, inputs, eps=1e-6, atol=1e-4)
println("梯度检查通过:", is_correct)

// bias=None 的示例(可选参数处理)
val x_check_no_bias = torch.randn(batch_size, input_features, dtype=torch.double, requires_grad=true)
val w_check_no_bias = torch.randn(output_features, input_features, dtype=torch.double, requires_grad=true)

// 如果函数签名根据输入而变化,则需要一个小包装器
def test_func_no_bias(x, w):
    return MyLinearFunction.apply(x, w, None)

// 执行检查(无偏置)
val inputs_no_bias = (x_check_no_bias, w_check_no_bias)
val is_correct_no_bias = gradcheck(test_func_no_bias, inputs_no_bias, eps=1e-6, atol=1e-4)
println("梯度检查(无偏置)通过:", is_correct_no_bias)

import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;

import java.util.Arrays;

/**
 * 手动实现gradcheck + 自定义Linear Function梯度校验:
 * 1. 实现数值梯度计算(有限差分)
 * 2. 对比解析梯度(自定义Function backward)与数值梯度
 * 3. 支持有偏置/无偏置两种场景的梯度校验
 */
public class GradCheckCustomFunctionDemo {
    // 梯度校验阈值(匹配Scala的eps=1e-6, atol=1e-4)
    private static final double EPS = 1e-6;
    private static final double ATOL = 1e-4;

    // ======================== 第一步:自定义Linear Function(支持无偏置) ========================
    static class MyLinearFunction extends Function {
        private Tensor xCache;
        private Tensor wCache;
        private boolean hasBias; // 标记是否有偏置

        /**
         * Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}.
         *
         * @param p
         */
        public MyLinearFunction(Pointer p) {
            super(p);
        }

        //        @Override
        public Tensor forward(IValueVector inputs) {
            Tensor x = inputs.get(0).toTensor();
            Tensor w = inputs.get(1).toTensor();
            Tensor b = inputs.size() > 2 && !inputs.get(2).isNull() ? inputs.get(2).toTensor() : null;
            this.hasBias = b != null;

            // 缓存输入用于反向传播
            this.xCache = x.clone();
            this.wCache = w.clone();

            // 核心计算:y = x @ w.T + (b?)
            Tensor wT = w.t();
            Tensor y = x.mm(wT);
            if (hasBias) {
                y.add_(b); // 有偏置则加b
            }

            // 释放临时张量
            wT.close();
            if (b != null) b.close();
            return y;
        }

//        @Override
        public IValueVector backward(IValue gradOutput) {
            Tensor gradY = gradOutput.toTensor();
            long batchSize = xCache.size(0);
            float scale = 1.0f / (float) batchSize;

            // 1. 计算x和w的梯度
            Tensor gradYT = gradY.t();
            Tensor gradW = gradYT.mm(xCache).mul_(new Scalar(scale));
            Tensor gradX = gradY.mm(wCache).mul_(new Scalar(scale));

            // 2. 计算b的梯度(如有)
            Tensor gradB = null;
            if (hasBias) {
                gradB = gradY.sum(0).mul_(new Scalar(scale));
            }

            // 构建梯度返回列表
            IValueVector gradInputs = new IValueVector(hasBias ? 3 : 2);
            gradInputs.push_back(new IValue(gradX));
            gradInputs.push_back(new IValue(gradW));
            if (hasBias) {
                gradInputs.push_back(new IValue(gradB));
            }

            // 释放临时张量
            gradYT.close();
            gradY.close();
            xCache.close();
            wCache.close();
            if (gradB != null) gradB.close();

            return gradInputs;
        }

        // 有偏置的apply方法
        public static Tensor apply(Tensor x, Tensor w, Tensor b) {
            IValueVector inputs = new IValueVector(3);
            inputs.push_back(new  IValue(x));
            inputs.push_back(new IValue(w));
            inputs.push_back( b != null ? new IValue(b) : new IValue(torch.empty()));
            MyLinearFunction func = new MyLinearFunction();
            Tensor output = func.apply(inputs).toTensor();
            inputs.close();
            return output;
        }

        // 无偏置的apply包装方法(适配bias=None)
        public static Tensor apply(Tensor x, Tensor w) {
            return apply(x, w, null);
        }
    }

    // ======================== 第二步:手动实现gradcheck核心逻辑 ========================
    /**
     * 梯度校验:对比数值梯度(有限差分)和解析梯度(自动求导)
     * @param func 自定义Function的调用逻辑(包装为Tensor[]→Tensor)
     * @param inputs 输入张量列表(需为double类型,保证精度)
     * @return 是否通过校验(所有梯度的误差≤ATOL)
     */
    private static boolean gradCheck(FunctionWrapper func, Tensor[] inputs) {
        // 步骤1:清空所有输入的梯度(避免干扰)
        for (Tensor t : inputs) {
            if (t.grad() != null) t.grad().zero_();
        }

        // 步骤2:前向传播获取输出,计算解析梯度(自动求导)
        Tensor output = func.forward(inputs);
        output.mean().backward(); // 标量损失反向传播

        // 步骤3:对每个输入张量的每个元素计算数值梯度,对比解析梯度
        boolean allCorrect = true;
        for (int i = 0; i < inputs.length; i++) {
            Tensor input = inputs[i];
            Tensor gradAnalytic = input.grad(); // 解析梯度(自定义backward)

            // 遍历张量每个元素,计算数值梯度
            long[] shape = input.sizes().vec().get();
            long numElements = input.numel();
            FloatPointer inputPtr = input.data().data_ptr_float();
            FloatPointer gradNumericPtr = new FloatPointer((int) numElements);

            for (long idx = 0; idx < numElements; idx++) {
                // 保存原始值
                float originalVal = inputPtr.get(idx);

                // 有限差分:f(x+ε) - f(x-ε) / 2ε
                inputPtr.put(idx, originalVal + (float) EPS);
                Tensor outputPlus = func.forward(inputs);
                float lossPlus = outputPlus.mean().item().toFloat();
                outputPlus.close();

                inputPtr.put(idx, originalVal - (float) EPS);
                Tensor outputMinus = func.forward(inputs);
                float lossMinus = outputMinus.mean().item().toFloat();
                outputMinus.close();

                // 恢复原始值
                inputPtr.put(idx, originalVal);

                // 计算数值梯度
                float gradNumeric = (lossPlus - lossMinus) / (2 * (float) EPS);
                gradNumericPtr.put(idx, gradNumeric);

                // 对比解析梯度和数值梯度的误差
                float gradA = gradAnalytic.data().data_ptr_float().get(idx);
                double absError = Math.abs(gradA - gradNumeric);
                if (absError > ATOL) {
                    System.err.printf("输入%d 元素%d 梯度不匹配:解析梯度=%.6f, 数值梯度=%.6f, 误差=%.6f%n",
                            i, idx, gradA, gradNumeric, absError);
                    allCorrect = false;
                }
            }

            // 释放临时资源
            gradNumericPtr.close();
            gradAnalytic.close();
        }

        output.close();
        return allCorrect;
    }

    // 函数包装器:适配不同输入参数的调用
    @FunctionalInterface
    interface FunctionWrapper {
        Tensor forward(Tensor[] inputs);
    }

    // ======================== 第三步:梯度校验使用示例(等效Scala代码) ========================
    public static void main(String[] args) {
        // 1. 基础参数配置
        long inputFeatures = 10;
        long outputFeatures = 5;
        long batchSize = 3;

        // 2. 配置double类型张量选项(保证梯度校验稳定性,匹配Scala dtype=torch.double)
        TensorOptions doubleOptions = new TensorOptions()
                .dtype(new ScalarTypeOptional(torch.ScalarType.Double))
                .requires_grad(new BoolOptional(true));

        // ======================== 场景1:有偏置的梯度校验 ========================
        // 创建double类型的输入张量(匹配Scala x_check/w_check/b_check)
        Tensor xCheck = torch.randn(new long[]{batchSize, inputFeatures}, doubleOptions);
        Tensor wCheck = torch.randn(new long[]{outputFeatures, inputFeatures}, doubleOptions);
        Tensor bCheck = torch.randn(new long[]{outputFeatures}, doubleOptions);
        Tensor[] inputsWithBias = {xCheck, wCheck, bCheck};

        // 包装自定义Function调用逻辑(有偏置)
        FunctionWrapper funcWithBias = inputs -> MyLinearFunction.apply(inputs[0], inputs[1], inputs[2]);

        // 执行梯度校验(等效Scala gradcheck)
        boolean isCorrect = gradCheck(funcWithBias, inputsWithBias);
        System.out.println("梯度检查通过: " + isCorrect);

        // ======================== 场景2:无偏置的梯度校验(bias=None) ========================
        // 创建无偏置的输入张量(匹配Scala x_check_no_bias/w_check_no_bias)
        Tensor xCheckNoBias = torch.randn(new long[]{batchSize, inputFeatures}, doubleOptions);
        Tensor wCheckNoBias = torch.randn(new long[]{outputFeatures, inputFeatures}, doubleOptions);
        Tensor[] inputsNoBias = {xCheckNoBias, wCheckNoBias};

        // 包装自定义Function调用逻辑(无偏置,匹配Scala test_func_no_bias)
        FunctionWrapper funcNoBias = inputs -> MyLinearFunction.apply(inputs[0], inputs[1]);

        // 执行梯度校验
        boolean isCorrectNoBias = gradCheck(funcNoBias, inputsNoBias);
        System.out.println("梯度检查(无偏置)通过: " + isCorrectNoBias);

        // ======================== 资源释放 ========================
        bCheck.close();
        wCheck.close();
        xCheck.close();
        wCheckNoBias.close();
        xCheckNoBias.close();
        doubleOptions.close();
    }
}

无论何时你实现自定义 autograd.Function,都强烈推荐使用 gradcheck。它能发现梯度公式中的许多常见错误。请注意,gradcheck 通常要求输入为 torch.double 以获得足够的数值精度,并且对于大型输入可能速度较慢。它通常在小型、有代表性的测试用例上执行。

重要注意事项

  • 使用 .apply() 始终使用 YourFunction.apply(...) 调用你的自定义函数。直接调用 forward 将绕过 autograd 机制。
  • 内存与重新计算: ctx.save_for_backward 会存储张量,并在反向传播完成前一直消耗内存。只保存梯度计算严格必需的张量。如果中间值的重新计算成本较低,你可以在 backward 中进行,而不是保存它们。
  • 就地操作: 在自定义函数中对输入或已保存的张量进行就地操作时,要极其小心。它们可能会干扰梯度计算,特别是如果反向传播需要缓冲区。在 backward 方法中修改通过 ctx.save_for_backward 保存的张量通常是不安全的。通常更安全的方法是使用副本或为结果分配新的张量。
  • 高阶梯度: 如果你需要计算梯度的梯度(例如,用于梯度范数的惩罚项),你的 backward 方法中执行的操作本身必须是可微分的。如果你在 backward 中使用标准可微分 PyTorch 操作,PyTorch 的 autograd 引擎可以自动处理。创建正确支持高阶梯度的自定义函数需要仔细的实现。

掌握 torch.autograd.Function 可以对微分过程进行细粒度控制,从而实现标准库功能之外的复杂模型和优化策略。它是高级 PyTorch 开发和研究的基本工具。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐