在这里插入图片描述

PyTorch Java 高校计算机 硕士研一课程

使用 torch.cuda.amp 进行混合精度训练

训练大型深度学习模型计算量很大,常常受限于 GPU 处理速度和内存容量。虽然标准的单精度浮点数(FP32)提供了较宽的动态范围和良好的精度,但神经网络中的许多操作都可以使用半精度(FP16)数充分完成。采用 FP16 可以大大加快计算速度,尤其是在配备 Tensor Cores 的 NVIDIA GPU 上,并能显著减少激活值和梯度的内存占用。

然而,简单地将整个模型切换到 FP16 通常会导致数值不稳定。FP16 的可表示范围比 FP32 小得多,因此容易出现梯度下溢(即小梯度值变为零)或梯度上溢(即大值变为无穷大)。这会阻碍或完全停止训练过程。

这就是混合精度训练的作用所在。其主要思路是,在 FP16 能带来显著速度提升和内存节约的操作中(例如大型矩阵乘法和卷积)使用 FP16,同时将重要组成部分(如权重更新和某些对数值敏感的操作)保留在 FP32 中,以保持稳定性和准确性。PyTorch 通过 torch.cuda.amp 模块提供了一种方便高效的方法来实现这一点,该模块代表自动混合精度(Automatic Mixed Precision)。

使用 torch.cuda.amp 进行自动混合精度训练

PyTorch 的 amp 模块大体上自动化了混合精度训练的过程。它识别出受益于 FP16 执行的操作,并自动将其输入转换为 FP16,同时将其他操作(如可能需要更高精度的归约或损失函数)保留在 FP32 中。

autocast 上下文管理器

启用自动类型转换的主要工具是 torch.cuda.amp.autocast 上下文管理器。你只需将模型的前向传播(包括损失计算)包装在此上下文中即可。

import torch
import torch.cuda.amp.autocast

// 假设模型、数据、准则已定义并移至 GPU
val model = model.cuda()
val data = data.cuda()
val target = target.cuda()
val criterion = criterion.cuda()

// 为前向传播启用自动类型转换
with autocast():
    val output = model(data)
    val loss = criterion(output, target)

// autocast 上下文之外的操作以默认精度(FP32)运行
// loss.backward() # 我们接下来会修改这一行
// optimizer.step()
        // 1. 初始化并移至 CUDA
        // 假设 model, data, target, criterion 已根据之前定义的逻辑创建
        LinearImpl model = new LinearImpl(new LinearOptions(10, 1));
        model.to(new Device(kCUDA()),false);
        MSELossImpl criterion = new MSELossImpl();
        criterion.to(new Device(kCUDA()),false);

        Tensor data = randn(new long[]{16, 10},new TensorOptions().device(new DeviceOptional(new Device(kCUDA()))));
        Tensor target = randn(new long[]{16, 1},new TensorOptions().device(new DeviceOptional(new Device(kCUDA()))));

        // 2. 混合精度训练核心逻辑
        try (var scope = new PointerScope()) {

            // 在 LibTorch C++ 实现中,autocast 的等效操作是显式启用开关
            // 注意:JDK 25 的 try-with-resources 可以很好地处理作用域管理

            // 开启 Autocast
            torch.set_autocast_enabled(DeviceType.CUDA,true);
            try {
                // 前向传播:在混合精度下自动选择 FP16 或 FP32
                Tensor output = model.forward(data);
                Tensor loss = criterion.forward(output, target);

                System.out.println("混合精度计算完成。");

                // 后续通常接 GradScaler (梯度缩放),因为 FP16 容易出现数值下溢
                // var scaler = new GradScaler();
                // scaler.scale(loss).backward();

            } finally {
                // 关闭 Autocast,恢复正常的 FP32 状态(模拟 Python 的 with 作用域结束)
                torch.set_autocast_enabled(DeviceType.CUDA,false);
            }
        }

autocast 上下文内部,PyTorch 自动确定每项操作的最佳精度:

  • 卷积和线性层等操作,它们显著受益于 Tensor Cores,通常以 FP16 运行。
  • 可能遭受精度损失或需要更宽范围的操作,例如归约或某些激活函数,可能会保持在 FP32 中。
  • autocast 管理的区域的输出通常是 FP32 张量,但中间操作可能大量使用了 FP16。

这种选择性的精度处理最大程度地降低了纯 FP16 训练带来的数值风险,同时获得了大部分性能提升。

使用 GradScaler 进行梯度缩放

虽然 autocast 管理前向传播,但将其直接与标准反向传播一起使用仍可能导致问题。从 FP16 激活值计算出的梯度有时会非常小,超出 FP16 的可表示范围并变为零(下溢)。这会阻止相应权重得到更新。

为了解决这个问题,torch.cuda.amp 提供了 GradScaler。它的工作方式是在反向传播之前缩放损失值。这有效地将所有产生的梯度乘以相同的缩放因子。

损失缩放后=损失×缩放因子损失缩放后=损失×缩放因子∇w缩放后=∇w×缩放因子∇w缩放后=∇w×缩放因子

这种缩放将小梯度值推入 FP16 的可表示范围,从而防止下溢。在优化器更新权重之前,GradScaler 会将梯度反向缩放回其原始值。

∇w=∇w缩放后缩放因子∇w=缩放因子∇w缩放后

GradScaler 在训练期间动态调整缩放因子。如果在一定步数内未检测到溢出,它会增加因子,试图最大化 FP16 范围的使用。如果在反向缩放后,梯度中确实检测到溢出(梯度变为 infNaN),GradScaler 会跳过该批次的优化器步骤,并降低缩放因子以防止未来的溢出。

以下是将 GradScaler 整合到训练循环中的方法:

import torch
import torch.nn as nn
import torch.optim as optim
import torch.cuda.amp.autocast
import torch.cuda.amp.GradScaler

// --- 初始化 ---
val model = YourModel().cuda()
val optimizer = optim.AdamW(model.parameters(), lr=1e-3)
val criterion = nn.CrossEntropyLoss().cuda()
// 初始化 GradScaler
val scaler = GradScaler()
// 模拟数据加载器(请替换为你的实际数据加载)
val dataloader = (1 to 10).map(_ => (torch.randn(16, 3, 224, 224, device='cuda'), torch.randint(0, 10, (16,), device='cuda')))
// --- 训练循环 ---
val num_epochs = 5
for (epoch <- 0 until num_epochs) {
    for ((data, target) <- dataloader) {
        optimizer.zero_grad()
        // 前向传播并进行自动类型转换
        with autocast():
            val output = model(data)
            val loss = criterion(output, target)

        // 缩放损失并执行反向传播
        // scaler.scale(loss) 计算 loss * scale_factor
        scaler.scale(loss).backward()
        // 反向缩放梯度并调用 optimizer.step()
        // 如果梯度溢出,会自动跳过此步骤
        scaler.step(optimizer)

        // 更新下一次迭代的缩放因子
        scaler.update()

    println(s"Epoch ${epoch+1} completed. Current scale factor: ${scaler.get_scale()}")
import org.bytedeco.pytorch.*;
import java.util.List;
import static org.bytedeco.pytorch.global.torch.*;

/**
 * 模拟 Python 的 `torch.cuda.amp.GradScaler`
 * 用于处理 Float16 训练中的梯度下溢问题
 */
public class GradScaler {
    private Tensor scale;
    private double growthFactor;
    private double backoffFactor;
    private int growthInterval;
    private int growthTracker;

    public GradScaler() {
        this(65536.0, 2.0, 0.5, 2000);
    }

    public GradScaler(double initScale, double growthFactor, double backoffFactor, int growthInterval) {
        // scale 必须在 CPU 上或者是单元素 Tensor
        this.scale = tensor(initScale);
        this.growthFactor = growthFactor;
        this.backoffFactor = backoffFactor;
        this.growthInterval = growthInterval;
        this.growthTracker = 0;
    }

    /**
     * scale(loss)
     * 将 loss 放大,防止反向传播时梯度下溢
     */
    public Tensor scale(Tensor loss) {
        return loss.mul(scale);
    }

    /**
     * step(optimizer) 的替代版
     * 由于 JavaCPP 很难直接操作 C++ Optimizer 对象,我们这里传入 参数列表
     * 逻辑:
     * 1. Unscale 梯度 (grad = grad / scale)
     * 2. 检查梯度是否有 Inf/NaN
     * 3. 如果没有问题,执行参数更新 (这里模拟 optimizer.step)
     * * @param params 模型参数列表 (包含 .grad() 属性)
     * @param learningRate 学习率 (模拟 SGD 更新)
     * @return boolean 是否执行了更新 (true=更新成功, false=检测到Inf跳过)
     */
    public boolean step(List<Tensor> params, double learningRate) {
        // 1. Unscale Gradients
        float currentScale = scale.item_float();
        float invScale = 1.0f / currentScale;

        boolean foundInf = false;

        // 检查所有参数的梯度
        for (Tensor p : params) {
            if (!p.grad().defined()) continue;

            // Unscale: grad = grad * (1/scale)
            p.grad().mul_(new Scalar(invScale)); //原地操作

            // Check Inf/NaN
            // sum() 将所有元素相加,item_float 获取值,如果含 Inf/Nan 结果也是 Inf/NaN
            // isfinite 检查
            Tensor isFinite = isfinite(p.grad());
            if (!all(isFinite).item_bool()) {
                foundInf = true;
                break; // 只要有一个参数坏了,这一步就废了
            }
        }

        if (foundInf) {
            // 如果发现 Inf,跳过这一步更新
            System.out.println("⚠️ Gradient overflow/underflow detected. Skipping step.");
            update(false); // 告诉 update 变小一点
            return false;
        } else {
            // 模拟 Optimizer.step():简单的 SGD 更新: p = p - lr * grad
            try (NoGradGuard guard = new NoGradGuard()) {
                for (Tensor p : params) {
                    if (!p.grad().defined()) continue;
                    // p.sub_(p.grad().mul(lr))
                    p.sub_(p.grad().mul(new Scalar((float) learningRate)));
                    // 清空梯度
                    p.grad().detach_();
                    p.grad().zero_();
                }
                
            }
    
            update(true); // 告诉 update 尝试变大
            return true;
        }
    }

    /**
     * 更新 Scale 因子
     * @param wasSuccessful 上一步是否成功 (没有 Inf)
     */
    private void update(boolean wasSuccessful) {
        if (wasSuccessful) {
            growthTracker++;
            if (growthTracker >= growthInterval) {
                scale.mul_(new Scalar((float) growthFactor));
                growthTracker = 0;
                System.out.println("📈 Scale increased to: " + scale.item_float());
            }
        } else {
            scale.mul_(new Scalar((float) backoffFactor));
            growthTracker = 0;
            System.out.println("📉 Scale decreased to: " + scale.item_float());
        }

        // 限制最小 scale
        if (scale.item_float() < 1.0f) {
            scale.fill_(new Scalar(1.0f));
        }
    }

    public float getScale() {
        return scale.item_float();
    }
}




import org.bytedeco.javacpp.*;
import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;

import java.util.ArrayList;
import java.util.List;
import static org.bytedeco.pytorch.global.torch.*;

public class MixedPrecisionTraining {
    public static void main(String[] args) {
        // 1. 初始化模型、损失函数和优化器
        // 假设 YourModel 是一个继承自 Module 的类
        SequentialImpl model = new SequentialImpl();
        model.push_back(new LinearImpl(10, 5));
        model.push_back( new ReLUImpl());
        model.push_back( new LinearImpl(5, 1));
        model.to(new Device(kCUDA()),false);
        CrossEntropyLossImpl criterion = new CrossEntropyLossImpl();
        criterion.to(new Device(kCUDA()),false);
        // 模拟优化器参数列表
        var params = model.parameters();
        // 初始化你定义的 GradScaler
        var scaler = new GradScaler();
        // 模拟数据加载
        var dataOptions = new TensorOptions()
                .layout(new LayoutOptional(Layout.Strided))
                .dtype(new ScalarTypeOptional(ScalarType.Float))
                .device(new DeviceOptional(new Device(DeviceType.CUDA)))
                .memory_format(new MemoryFormatOptional(MemoryFormat.Contiguous));
        var data = randn(new long[]{16, 10},dataOptions);
        var longOption = new TensorOptions()
                .layout(new LayoutOptional(Layout.Strided))
                .dtype(new ScalarTypeOptional(ScalarType.Long))
                .device(new DeviceOptional(new Device(DeviceType.CUDA)))
                .memory_format(new MemoryFormatOptional(MemoryFormat.Contiguous));
        var target = zeros(new long[]{16},longOption); // CrossEntropy 需要 Long 类型标签
        final int numEpochs = 5;
        final double learningRate = 1e-3;
        // 2. 训练循环
        for (int epoch = 0; epoch < numEpochs; epoch++) {
            // 在 JDK 25 中,建议在 batch 级别使用 PointerScope 释放临时张量显存
            try (var batchScope = new PointerScope()) {
                // 模拟 optimizer.zero_grad()
                for (long i = 0; i < params.size(); i++) {
                    var p = params.get(i);
                    if (p.grad().defined()) {
                        p.grad().zero_();
                    }
                }
                // --- 开启混合精度 Autocast ---
                torch.set_autocast_enabled(DeviceType.CUDA,true);
                Tensor loss;
                try {
                    // Forward pass
                    var output = model.forward(data);
                    loss = criterion.forward(output, target);
                } finally {
                    torch.set_autocast_enabled(DeviceType.CUDA,false);
                }

                // --- 缩放并反向传播 ---
                // 使用你实现的 scaler.scale()
                var scaledLoss = scaler.scale(loss);
                scaledLoss.backward();

                // --- 梯度更新 (包含 Unscale 和 Overflow 检查) ---
                // 转换 TensorVector 为 List 以匹配你 GradScaler 的签名
                List<Tensor> paramList = new ArrayList<>();
                for (long i = 0; i < params.size(); i++) {
                    paramList.add(params.get(i));
                }

                // 这一步包含了原版 PyTorch 中 scaler.step(optimizer) 和 scaler.update() 的逻辑
                boolean stepped = scaler.step(paramList, learningRate);

                if (stepped) {
                    System.out.printf("Epoch %d: Step successful.%n", epoch + 1);
                } else {
                    System.out.printf("Epoch %d: Step skipped due to gradient overflow.%n", epoch + 1);
                }
            }
        }

        System.out.printf("Final Scale Factor: %.2f%n", scaler.getScale());
    }
}



使用 GradScaler 的步骤:

  1. 初始化: 在训练循环开始前,创建 GradScaler 的一个实例。
  2. 前向传播: 将前向传播(模型执行 + 损失计算)包装在 autocast() 上下文内。
  3. 损失缩放: 不再调用 loss.backward(),而是调用 scaler.scale(loss).backward()。这会计算 loss * scale_factor,然后对缩放后的损失调用 backward()
  4. 优化器步进:optimizer.step() 替换为 scaler.step(optimizer)。此方法会检查由缩放后损失产生的梯度中是否存在溢出(inf/NaN)。
    • 如果未发生溢出,它会首先反向缩放附加到优化器参数上的梯度(将它们除以 scale_factor),然后调用 optimizer.step()
    • 如果确实发生了溢出,scaler.step() 会跳过 optimizer.step() 调用,以防止损坏的权重更新。
  5. 更新缩放器:scaler.step() 后调用 scaler.update()。这会更新下一次迭代的缩放因子。如果 step 因溢出而跳过优化器更新,它会降低缩放;如果更新在一段时间内成功,则可能增加缩放。

优点与注意事项

  • 速度: AMP 可以在兼容硬件(NVIDIA Volta、Turing、Ampere 架构及更新版本,它们具备专用于 FP16 计算的 Tensor Cores)上提供显著的速度提升(通常是 1.5 倍到 3 倍或更多)。
  • 内存: 减少了激活值、梯度和潜在的优化器状态(如果使用 FP16 优化器,尽管标准优化器与 AMP 也能很好地配合)的内存使用,从而允许使用更大的批次大小或模型。
  • 易用性: 与手动混合精度管理相比,torch.cuda.amp 使实现相对简单。

重要注意事项:

  • 硬件: 性能提升在配备 Tensor Cores 的 GPU 上最为明显。在旧硬件上,加速效果可能不那么明显甚至可以忽略不计。
  • 数值稳定性: 尽管 GradScaler 有很大帮助,但与纯 FP32 训练相比,某些模型或特定操作仍可能在收敛或最终准确性上表现出微小差异。请密切监控你的训练。
  • 批归一化:BatchNorm 这样的层通常需要 FP32 来稳定累积统计数据。autocast 通常会正确处理这一点,但如果在实现自定义归一化层时,请注意潜在的细节。
  • 保存/加载: 在保存检查点时,建议将 GradScaler 状态与模型和优化器状态一起保存,以便恢复训练。
// 检查点保存示例
val checkpoint = Map(
    "epoch" -> epoch,
    "model_state_dict" -> model.state_dict(),
    "optimizer_state_dict" -> optimizer.state_dict(),
    "scaler_state_dict" -> scaler.state_dict(),
    "loss" -> loss,
    // ... 其他指标或状态
)
torch.save(checkpoint, "model_checkpoint.pth")

// 检查点加载示例
val checkpoint = torch.load("model_checkpoint.pth")
model.load_state_dict(checkpoint("model_state_dict"))
optimizer.load_state_dict(checkpoint("optimizer_state_dict"))
scaler.load_state_dict(checkpoint("scaler_state_dict"))
val epoch = checkpoint("epoch").toInt
val loss = checkpoint("loss").toFloat
import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;

/**
 * 模型检查点保存/加载工具类
 * 完整实现PyTorch的torch.save/torch.load逻辑,支持:
 * 1. 检查点保存(包含epoch、model/optimizer/scaler状态、loss等)
 * 2. 检查点加载(支持map_location设备映射)
 * 3. 状态字典加载(model/optimizer/scaler.load_state_dict)
 * 基于JavaCPP-PyTorch 1.5.13-SNAPSHOT官方API实现
 */

 public class ModelCheckpointUtils {
     public static void torchSave(GenericDict checkpointDict, String savePath) throws Exception {
        // 1. 将GenericDict转换为IValue(pickle_save仅接收IValue)
        IValue checkpointIValue = new IValue(checkpointDict);

        // 2. 序列化IValue为字节数组(等效Python pickle.dumps)
        ByteVector byteVector = torch.pickle_save(checkpointIValue);

        // 3. 将字节写入文件(等效Python文件写入)
        byte[] fileBytes = new byte[(int) byteVector.size()];
        for (int i = 0; i < byteVector.size(); i++) {
            fileBytes[i] = byteVector.get(i);
        }
        Files.write(Paths.get(savePath), fileBytes,
                StandardOpenOption.CREATE, StandardOpenOption.WRITE, StandardOpenOption.TRUNCATE_EXISTING);

        // 资源释放
        checkpointIValue.close();
        byteVector.close();
     }

    // ======================== 核心:检查点加载(对应torch.load,支持map_location) ========================
    /**
     * 加载检查点文件(等效Python torch.load(path, map_location=device))
     * @param loadPath 检查点路径
     * @param mapDevice 设备映射(CPU/CUDA,null则不映射)
     * @return 检查点字典(GenericDict)
     * @throws Exception 读写/反序列化异常
     */
     public static GenericDict torchLoad(String loadPath, Device mapDevice) throws Exception {
        // 1. 读取文件字节
        byte[] fileBytes = Files.readAllBytes(Paths.get(loadPath));
        ByteVector byteVector = new ByteVector(fileBytes.length);
        for (int i = 0; i < fileBytes.length; i++) {
            byteVector.put(i, fileBytes[i]);
        }

        // 2. 反序列化为IValue(等效Python pickle.loads)
        IValue checkpointIValue = torch.pickle_load(byteVector);

        // 3. 验证并转换为GenericDict
        if (!checkpointIValue.isGenericDict()) {
            throw new RuntimeException("检查点格式错误:不是字典类型");
        }
        GenericDict checkpointDict = checkpointIValue.toGenericDict();

        // 4. 实现map_location:将所有张量映射到指定设备
        if (mapDevice != null) {
            remapTensorsToDevice(checkpointDict, mapDevice);
        }

        // 资源释放
        byteVector.close();
        checkpointIValue.close();
        return checkpointDict;
     }

    /**
     * 递归遍历字典,将所有张量迁移到指定设备(实现map_location核心逻辑)
     * 严格使用JavaCPP官方API:迭代器+access()+erase()+insert()
     */
     private static void remapTensorsToDevice(GenericDict dict, Device targetDevice) {
        GenericDictIterator iter = dict.begin();
        while (iter.notEquals(dict.end())) {
            // 获取键值对(唯一合法方式:access())
            GenericDictEntryRef entry = iter.access();
            IValue key = entry.key();
            IValue value = entry.value();

            // 情况1:值是张量 → 迁移设备
            if (value.isTensor()) {
                Tensor originalTensor = value.toTensor();
                // 迁移到目标设备(指定Float类型,避免类型不匹配)
                Tensor newTensor = originalTensor.to(targetDevice, torch.ScalarType.Float);
      
                // 替换字典值:先删后插(GenericDict无直接修改API)
                dict.erase(key);
                dict.insert(key, new IValue(newTensor));
      
                // 资源释放
                originalTensor.close();
                newTensor.close();
            }
            // 情况2:值是字典 → 递归处理
            else if (value.isGenericDict()) {
                remapTensorsToDevice(value.toGenericDict(), targetDevice);
            }
            // 情况3:值是列表 → 递归处理
            else if (value.isList()) {
                GenericList list = value.toList();
                for (int i = 0; i < list.size(); i++) {
                    IValue elem = list.get(i);
                    if (elem.isTensor()) {
                        Tensor tensor = elem.toTensor();
                        Tensor newTensor = tensor.to(targetDevice, torch.ScalarType.Float);
                        list.set(i, new IValue(newTensor));
                        tensor.close();
                        newTensor.close();
                    } else if (elem.isGenericDict()) {
                        remapTensorsToDevice(elem.toGenericDict(), targetDevice);
                    }
                }
                // 替换列表值
                dict.erase(key);
                dict.insert(key, new IValue(list));
                list.close();
            }
      
            iter.increment(); // 迭代器前进
        }
     }

    // ======================== 辅助:从检查点提取状态字典 ========================
    /**
     * 从检查点字典中提取模型/优化器/Scaler状态字典
     * @param checkpoint 检查点字典
     * @param key 键名(model_state_dict/optimizer_state_dict/scaler_state_dict)
     * @return 状态字典(GenericDict)
     */
     public static GenericDict getStateDictFromCheckpoint(GenericDict checkpoint, String key) {
        IValue keyIValue = new IValue(key);
        if (!checkpoint.contains(keyIValue)) {
            throw new RuntimeException("检查点中未找到键:" + key);
        }

        IValue stateDictIValue = checkpoint.at(keyIValue);
        if (!stateDictIValue.isGenericDict()) {
            throw new RuntimeException(key + "不是合法的字典类型");
        }

        GenericDict stateDict = stateDictIValue.toGenericDict();
        keyIValue.close();
        stateDictIValue.close();
        return stateDict;
     }

    // ======================== 核心:加载模型状态字典(等效model.load_state_dict) ========================
    /**
     * 加载模型状态字典(替代Python model.load_state_dict())
     * @param model 目标模型
     * @param stateDict 模型状态字典(GenericDict)
     */
     public static void loadModelStateDict(Module model, GenericDict stateDict) {
        // 1. 创建InputArchive(JavaCPP中模型加载的唯一合法方式)
        InputArchive inputArchive = new InputArchive();

        // 2. 遍历状态字典,写入Archive
        GenericDictIterator iter = stateDict.begin();
        while (iter.notEquals(stateDict.end())) {
            GenericDictEntryRef entry = iter.access();
            IValue key = entry.key();
            IValue value = entry.value();

            if (value.isTensor()) {
                String paramName = key.toStringRef().getString();
                Tensor paramTensor = value.toTensor();
                inputArchive.write(paramName, paramTensor); // 按参数名写入
                paramTensor.close();
            }
            iter.increment();
        }

        // 3. 加载Archive到模型
        model.load(inputArchive);

        // 资源释放
        inputArchive.close();
     }

    // ======================== 核心:加载优化器状态字典(等效optimizer.load_state_dict) ========================
    /**
     * 加载优化器状态字典(替代Python optimizer.load_state_dict())
     * @param optimizer 目标优化器
     * @param stateDict 优化器状态字典(GenericDict)
     */
     public static void loadOptimizerStateDict(Optimizer optimizer, GenericDict stateDict) {
        // 1. 创建InputArchive
        InputArchive inputArchive = new InputArchive();

        // 2. 遍历优化器状态字典(参数组+状态)
        GenericDictIterator iter = stateDict.begin();
        while (iter.notEquals(stateDict.end())) {
            GenericDictEntryRef entry = iter.access();
            IValue key = entry.key();
            IValue value = entry.value();

            String keyStr = key.toStringRef().getString();
            // 处理参数组(param_groups)
            if (keyStr.equals("param_groups") && value.isList()) {
                GenericList paramGroups = value.toList();
                inputArchive.write(keyStr, paramGroups);
                paramGroups.close();
            }
            // 处理状态(state)
            else if (keyStr.equals("state") && value.isGenericDict()) {
                GenericDict stateDictInner = value.toGenericDict();
                inputArchive.write(keyStr, stateDictInner);
                stateDictInner.close();
            }
            iter.increment();
        }

        // 3. 加载到优化器
        optimizer.load(inputArchive);

        // 资源释放
        inputArchive.close();
     }

    // ======================== 核心:加载梯度缩放器状态字典(等效scaler.load_state_dict) ========================
    /**
     * 加载GradScaler状态字典(替代Python scaler.load_state_dict())
     * @param scaler 梯度缩放器
     * @param stateDict 状态字典(GenericDict)
     */
     public static void loadScalerStateDict(GradScaler scaler, GenericDict stateDict) {
        InputArchive inputArchive = new InputArchive();

        GenericDictIterator iter = stateDict.begin();
        while (iter.notEquals(stateDict.end())) {
            GenericDictEntryRef entry = iter.access();
            IValue key = entry.key();
            IValue value = entry.value();

            String keyStr = key.toStringRef().getString();
            if (value.isTensor()) {
                Tensor tensor = value.toTensor();
                inputArchive.write(keyStr, tensor);
                tensor.close();
            } else if (value.isBool()) {
                inputArchive.write(keyStr, value.toBool());
            } else if (value.isDouble()) {
                inputArchive.write(keyStr, value.toDouble());
            }
      
            iter.increment();
        }

        scaler.load(inputArchive);
        inputArchive.close();
     }

    // ======================== 示例:完整的保存/加载流程 ========================
    public static void main(String[] args) throws Exception {
        // --------------- 1. 初始化测试环境 ---------------
        int inputDim = 10, outputDim = 5;
        Device device = torch.cuda_is_available() ?
                new Device(torch.DeviceType.CUDA) : new Device(torch.DeviceType.CPU);

        // 模型:Linear(inputDim→outputDim)
        LinearImpl model = new LinearImpl(inputDim, outputDim);
        model.to(device, false);
     
        // 优化器:Adam
        Adam optimizer = torch.optim.Adam(model.parameters(), 0.001f);
     
        // 梯度缩放器(混合精度训练)
        GradScaler scaler = new GradScaler();
     
        // 模拟训练状态
        int epoch = 10;
        Tensor loss = torch.tensor(0.5678f).to(device, torch.ScalarType.Float);
     
        // --------------- 2. 构建检查点字典 ---------------
        GenericDict checkpointDict = new GenericDict();
        // 2.1 保存epoch(整数→IValue)
        checkpointDict.insert(new IValue("epoch"), new IValue(epoch));
        // 2.2 保存模型状态字典
        checkpointDict.insert(new IValue("model_state_dict"), new IValue(model.state_dict()));
        // 2.3 保存优化器状态字典
        checkpointDict.insert(new IValue("optimizer_state_dict"), new IValue(optimizer.state_dict()));
        // 2.4 保存Scaler状态字典
        checkpointDict.insert(new IValue("scaler_state_dict"), new IValue(scaler.state_dict()));
        // 2.5 保存loss
        checkpointDict.insert(new IValue("loss"), new IValue(loss));
     
        // --------------- 3. 保存检查点到文件 ---------------
        String checkpointPath = "model_checkpoint.pth";
        torchSave(checkpointDict, checkpointPath);
        System.out.println("检查点已保存到:" + checkpointPath);
     
        // --------------- 4. 加载检查点(模拟重新加载) ---------------
        // 4.1 定义目标设备(这里自动适配GPU/CPU)
        Device loadDevice = torch.cuda_is_available() ?
                new Device(torch.DeviceType.CUDA) : new Device(torch.DeviceType.CPU);
     
        // 4.2 加载检查点(自动映射到目标设备)
        GenericDict loadedCheckpoint = torchLoad(checkpointPath, loadDevice);
        System.out.println("检查点加载完成");
     
        // --------------- 5. 恢复模型/优化器/Scaler状态 ---------------
        // 5.1 恢复模型状态
        GenericDict modelStateDict = getStateDictFromCheckpoint(loadedCheckpoint, "model_state_dict");
        loadModelStateDict(model, modelStateDict);
     
        // 5.2 恢复优化器状态
        GenericDict optimizerStateDict = getStateDictFromCheckpoint(loadedCheckpoint, "optimizer_state_dict");
        loadOptimizerStateDict(optimizer, optimizerStateDict);
     
        // 5.3 恢复Scaler状态
        GenericDict scalerStateDict = getStateDictFromCheckpoint(loadedCheckpoint, "scaler_state_dict");
        loadScalerStateDict(scaler, scalerStateDict);
     
        // 5.4 恢复epoch和loss
        int loadedEpoch = loadedCheckpoint.at(new IValue("epoch")).toInt();
        float loadedLoss = loadedCheckpoint.at(new IValue("loss")).toTensor().item().floatValue();
     
        // --------------- 6. 验证结果 ---------------
        System.out.println("恢复的epoch:" + loadedEpoch);
        System.out.println("恢复的loss:" + String.format("%.4f", loadedLoss));
        System.out.println("模型设备:" + (model.parameters().get(0).device().type() == torch.DeviceType.CUDA ? "GPU" : "CPU"));
     
        // --------------- 7. 资源释放 ---------------
        model.close();
        optimizer.close();
        scaler.close();
        loss.close();
        checkpointDict.close();
        loadedCheckpoint.close();
        modelStateDict.close();
        optimizerStateDict.close();
        scalerStateDict.close();
        device.close();
        loadDevice.close();
    }
 }


使用 torch.cuda.amp 进行混合精度训练是一种高效的方法,可加速深度学习工作流程并训练更大模型。通过巧妙地结合 FP16 和 FP32 计算,它以最少的代码改动带来了显著的性能提升,并应对了低精度算术固有的数值稳定性问题。它已成为现代深度学习实践者工具包中的标准工具,尤其是在处理大型模型或追求更快迭代周期时。

处理大型数据集的策略

当数据集大到无法在系统内存(RAM)中轻松容纳时,使用 PyTorch 映射式 Dataset 的标准数据加载方法会成为瓶颈或完全不可行。映射式数据集实现了 __getitem____len__,通常假设可以按索引随机访问任何项目,并且可能需要预先加载整个数据集的元数据。这里说明了专门为处理这些巨型数据集而设计的策略,侧重于使用 IterableDataset 高效地传输数据。

映射式数据集处理大型数据的局限性

想象一个存储在数千个文件中的 TB 级图像数据集。一个标准 Dataset 可能会尝试在其 __init__ 方法中构建所有文件路径和对应标签的列表。即使图像本身没有加载,仅这些元数据就可能超出可用内存。此外,如果数据需要从大型压缩文件或数据库查询中按顺序读取,__getitem__ 的随机访问要求可能效率低下。打乱大型映射式数据集通常也涉及创建一个覆盖整个数据集大小 (NN) 的打乱索引列表,这对于大型 NN 来说同样需要大量内存。

使用 IterableDataset 流式传输数据

PyTorch 提供了一种替代方案:torch.utils.data.IterableDataset。您无需定义 __getitem____len__,而是实现 __iter__ 方法。此方法应返回一个迭代器,每次生成一个样本。此方法根本不同;它将数据集视为数据流,而非可索引的集合。

IterableDataset 特别适合以下情况:

  1. 数据天然可流式传输(例如,从大型文本文件读取行、传感器数据、网络日志)。
  2. 按索引访问数据成本高昂或不可能。
  3. 数据集大小未知或无限。
  4. 随机访问并非训练期间的主要要求。

这是一个从大文件中逐行读取样本的实现方式:

import torch
import torch.utils.data as data

class LargeTextFileDataset(data.IterableDataset):
    def __init__(file_path, tokenizer):
        super().__init__()
        val file_path = file_path
        val tokenizer = tokenizer

    def __iter__(self):
        // 迭代器在此处为每个 epoch/worker 创建
        val file_iterator = io.Source.fromFile(file_path).getLines()
        // map 函数将处理函数应用于迭代器中的每一行
        return file_iterator.map(tokenizer)

// 用法:
// tokenizer = lambda line: torch.tensor([int(x) for x in line.strip().split(',')])
// dataset = LargeTextFileDataset('very_large_data.csv', tokenizer)
// loader = DataLoader(dataset, batch_size=32)
//
// for batch in loader:
//     // 处理批次数据
//     pass

在此示例中,open(self.file_path, 'r') 返回一个遍历文件行的迭代器。然后 map 函数在 DataLoader 请求时,对每一行进行延迟处理(应用 tokenizer)。没有尝试将整个文件加载到内存中。

使用 IterableDataset 处理多进程加载

当使用 DataLoadernum_workers > 0 时,每个工作进程会获得 IterableDataset 实例的一个副本。一个重要方面是,需要确保每个工作进程处理数据流中不同的部分,以避免重复。如果处理不当,每个工作进程都可能从头开始读取同一个大文件,导致重复工作和不正确的有效批次构成。

解决此问题的标准方法是在 __iter__ 方法中使用 torch.utils.data.get_worker_info()

import torch
import math
import torch.utils.data as data.{IterableDataset, DataLoader, get_worker_info}

class ShardedLargeFileDataset extends IterableDataset:
    def __init__(self, file_path, processor_fn):
        super().__init__()
        val file_path = file_path
        val processor_fn = processor_fn
        // 如果需要分片,确定文件大小或行/记录数量
        // self.num_records = self._get_num_records(file_path) # 辅助函数示例

    def _get_records_iterator(self):
        // 将此替换为遍历您的特定数据记录/文件的逻辑
        with io.Source.fromFile(file_path) as f:
            for line in f.getLines():
                yield line # 生成原始记录

    def __iter__(self):
        val worker_info = get_worker_info()
        val record_iterator = _get_records_iterator()

        if worker_info is None then // 单进程加载
            val worker_id = 0
            val num_workers = 1
        else:  // 多进程加载
            val worker_id = worker_info.id
            val num_workers = worker_info.num_workers

        // 基础工作进程分片:每个工作进程处理每第 N 条记录
        // 更复杂的分片可能涉及字节偏移量或文件拆分
        val sharded_iterator = (record for i, record in enumerate(record_iterator) if i % num_workers == worker_id)

        // 在工作进程的迭代器链中应用处理
        val processed_iterator = sharded_iterator.map(processor_fn)
        return processed_iterator

// 使用示例:
// processor = lambda line: torch.tensor([float(x) for x in line.strip().split()])
// dataset = ShardedLargeFileDataset('massive_dataset.txt', processor)
// loader = DataLoader(dataset, batch_size=64, num_workers=4)
//
// for batch in loader:
//     // 训练步骤...
//     pass

package vals;

import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.Iterator;
import java.util.NoSuchElementException;
import java.util.concurrent.atomic.AtomicLong;

import static vals.LinearRegressionTraining.stackData;
import static vals.LinearRegressionTraining.stackTarget;

/**
 * 函数式接口:自定义数据处理函数(替代Python的processor_fn)
 * 输入:原始字符串记录
 * 输出:处理后的Tensor
 */
@FunctionalInterface
interface ProcessorFn {
    Tensor process(String record);
}

/**
 * 分片式大文件数据集(等效Python ShardedLargeFileDataset)
 * 核心特性:
 * 1. 实现IterableDataset,支持PyTorch DataLoader加载
 * 2. 多worker分片:每个worker处理指定索引的记录
 * 3. 自定义数据处理函数,将原始文本行转换为Tensor
 * 4. 兼容单进程/多进程(多线程)加载模式
 */
public class ShardedLargeFileDataset extends JavaDataset {
    private final String filePath;       // 数据文件路径
    private final ProcessorFn processorFn; // 数据处理函数

    /**
     * 构造函数(等效Python __init__)
     * @param filePath 大文件路径
     * @param processorFn 自定义处理函数
     */
    public ShardedLargeFileDataset(String filePath, ProcessorFn processorFn) {
        this.filePath = filePath;
        this.processorFn = processorFn;
    }

    /**
     * 获取原始记录迭代器(等效Python _get_records_iterator)
     * 逐行读取文件,生成原始字符串记录
     */
    private Iterator<String> getRecordsIterator() {
        return new Iterator<String>() {
            private final BufferedReader reader;
            private String nextLine;

            // 初始化文件读取器
            {
                try {
                    this.reader = new BufferedReader(new FileReader(filePath));
                    this.nextLine = reader.readLine(); // 预读取第一行
                } catch (IOException e) {
                    throw new RuntimeException("读取文件失败:" + filePath, e);
                }
            }

            @Override
            public boolean hasNext() {
                return nextLine != null;
            }

            @Override
            public String next() {
                if (!hasNext()) {
                    throw new NoSuchElementException("没有更多记录");
                }
                String currentLine = nextLine;
                try {
                    nextLine = reader.readLine();
                    // 读完最后一行后关闭流
                    if (nextLine == null) {
                        reader.close();
                    }
                } catch (IOException e) {
                    throw new RuntimeException("读取文件行失败", e);
                }
                return currentLine;
            }

            // 防止迭代器未遍历完导致流泄漏
            @Override
            protected void finalize() throws Throwable {
                try {
                    reader.close();
                } catch (IOException e) {
                    e.printStackTrace();
                }
                super.finalize();
            }
        };
    }

    /**
     * 分片迭代器(核心:实现多worker数据分片)
     * 逻辑:i % numWorkers == workerId 的记录才会被当前worker处理
     */
    private Iterator<String> getShardedIterator(Iterator<String> recordIterator, int workerId, int numWorkers) {
        return new Iterator<String>() {
            private final AtomicLong index = new AtomicLong(0); // 记录全局索引
            private String nextRecord;

            // 预加载下一个符合条件的记录
            {
                loadNextRecord();
            }

            private void loadNextRecord() {
                nextRecord = null;
                while (recordIterator.hasNext()) {
                    long currentIndex = index.getAndIncrement();
                    String record = recordIterator.next();
                    // 分片逻辑:仅保留当前worker应处理的记录
                    if (currentIndex % numWorkers == workerId) {
                        nextRecord = record;
                        break;
                    }
                }
            }

            @Override
            public boolean hasNext() {
                return nextRecord != null;
            }

            @Override
            public String next() {
                if (!hasNext()) {
                    throw new NoSuchElementException("没有更多分片记录");
                }
                String currentRecord = nextRecord;
                loadNextRecord(); // 加载下一个符合条件的记录
                return currentRecord;
            }
        };
    }

    /**
     * 处理后的迭代器(将分片记录转换为Tensor)
     */
    private Iterator<Tensor> getProcessedIterator(Iterator<String> shardedIterator) {
        return new Iterator<Tensor>() {
            @Override
            public boolean hasNext() {
                return shardedIterator.hasNext();
            }

            @Override
            public Tensor next() {
                String record = shardedIterator.next();
                return processorFn.process(record); // 应用自定义处理函数
            }
        };
    }

    /**
     * 核心:实现IterableDataset的迭代器(等效Python __iter__)
     */
//    @Override
    public Iterator<Tensor> iterator() {
        // 1. 获取WorkerInfo(判断是否多worker模式)
//        WorkerInfo workerInfo = torch.utils.data.get_worker_info();
        int workerId = 0;
        int numWorkers = 1;

        // 2. 多worker模式:获取当前worker ID和总worker数
//        if (workerInfo != null) {
//            workerId = workerInfo.id();
//            numWorkers = workerInfo.num_workers();
//        }

        // 3. 构建迭代器链:原始记录 → 分片记录 → 处理后的Tensor
        Iterator<String> recordIterator = getRecordsIterator();
        Iterator<String> shardedIterator = getShardedIterator(recordIterator, workerId, numWorkers);
        return getProcessedIterator(shardedIterator);
    }

    // ======================== 使用示例 ========================
    public static void main(String[] args) {
        // 1. 定义数据处理函数(替代Python lambda)
        // 示例:将空格分隔的文本行转换为FloatTensor
        ProcessorFn processor = line -> {
            String[] parts = line.strip().split("\\s+");
            float[] data = new float[parts.length];
            for (int i = 0; i < parts.length; i++) {
                data[i] = Float.parseFloat(parts[i]);
            }
            // 转换为PyTorch Tensor(shape: [n])
            return torch.tensor(data).to(torch.ScalarType.Float);
        };

        // 2. 创建分片数据集
        String dataFilePath = "massive_dataset.txt"; // 替换为你的大文件路径
        ShardedLargeFileDataset dataset = new ShardedLargeFileDataset(dataFilePath, processor);
        RandomSampler sampler  = new RandomSampler(64); // 可选:随机采样器,批次大小为64

        // 3. 创建DataLoader(设置batch_size和num_workers)
        DataLoaderOptions options = new DataLoaderOptions();
        options.batch_size().put(64);        // 批次大小
        options.workers().put(4);      // 4个worker进程/线程
        options.enforce_ordering().put(true);         // 大文件一般不shuffle(可根据需求调整)
        torch.pinned_memory_or_default(new BoolOptional(true));
//        options       .pin_memory(true)        // 固定内存(加速GPU传输)
    
        JavaRandomDataLoader loader =new JavaRandomDataLoader(dataset,sampler, options);

        // 4. 遍历DataLoader进行训练
        System.out.println("开始遍历数据集...");
        int batchIdx = 0;
        ExampleVectorIterator trainIter = loader.begin();
        while (trainIter.notEquals(loader.end())) {
//        for (IValue batch : loader) {
            // batch是Tensor类型(shape: [batch_size, feature_dim])
            var batch = trainIter.access();
            Tensor batchTensor = stackData(batch); // 将ExampleVector转换为Tensor(示例方法,需实现stackData)
            Tensor batchLabel = stackTarget(batch);
            System.out.printf("批次 %d: 形状 = %s, 设备 = %s%n",
                    batchIdx++,
                    tensorShapeToString(batchTensor),
                    batchTensor.device().type() == torch.DeviceType.CUDA ? "GPU" : "CPU");

            // ======================== 训练步骤示例 ========================
            // 1. 前向传播
            // Tensor output = model.forward(batchTensor);
            // 2. 计算损失
            // Tensor loss = lossFn.forward(output, labels);
            // 3. 反向传播
            // loss.backward();
            // 4. 优化器步进
            // optimizer.step();
            // =============================================================

            // 释放批次张量资源
            batchTensor.close();
        }

        // 5. 资源释放
        loader.close();
        options.close();
    }

    /**
     * 辅助方法:打印张量形状
     */
    private static String tensorShapeToString(Tensor tensor) {
        LongVector sizes = tensor.sizes().vec();
        StringBuilder sb = new StringBuilder("[");
        for (int i = 0; i < sizes.size(); i++) {
            sb.append(sizes.get(i));
            if (i < sizes.size() - 1) {
                sb.append(", ");
            }
        }
        sb.append("]");
        sizes.close();
        return sb.toString();
    }
}


在此改进示例中,get_worker_info() 提供当前工作进程的 id 和总 num_workers 数量。然后代码过滤基础 record_iterator,使得工作进程 k 只处理 index % num_workers == k 的记录。这确保每个工作进程获得数据流中独有的、交错的子集。请注意,根据数据格式和存储方式,可能需要更复杂的分片(例如,将整个文件或字节范围分配给工作进程)。

打乱 Iterable 数据集

打乱 IterableDataset 实例需要不同的策略,不同于映射式数据集。由于没有全局索引,您不能简单地打乱索引。常见方法包括:

  1. 打乱缓冲区: 维护一个样本缓冲区,从中随机选择样本来生成,并从底层迭代器中重新填充缓冲区。这提供了近似的打乱效果。PyTorch 的 DataLoader 不直接提供 IterableDataset 的此功能,但像 torchdata(PyTorch 领域库生态系统的一部分)这样的库提供了具有打乱功能的 DataPipes(例如,shufflesharding_filter)。
  2. 预先打乱数据: 如果可行,在训练开始前离线打乱源数据本身。例如,如果您的数据包含许多小文件,您可以在每个 epoch 中打乱 IterableDataset 处理的文件列表。
  3. 结合 Iterable 和映射式: 使用 IterableDataset 流式传输数据块(例如文件路径或记录标识符),并在每个工作进程内使用映射式 Dataset 从该数据块加载和处理项目,从而允许在数据块内部进行打乱。

选择取决于数据规模、所需的随机性程度以及您可以承受的开销。

优化数据管道

无论您使用映射式还是可迭代数据集,优化数据加载管道对训练性能非常重要,特别是对于 I/O 可能成为瓶颈的大型数据集。

  • 预处理: 尽可能离线执行计算密集型预处理(例如复杂特征提取、大型图像大小调整或使用大型词汇表进行分词)。将预处理后的数据存储为高效格式(例如 TFRecord、HDF5、Parquet 或自定义二进制格式)。像 webdataset 这样的库旨在高效流式传输存储为 tar 归档文件的大型数据集,通常与 IterableDataset 配合使用。
  • DataLoader 参数:
    • num_workers:将 num_workers 设置为 > 0 可启用数据加载的多进程处理。最优值取决于 CPU 核心数、批次大小、数据处理复杂度和 I/O 速度。一个常见的起始点是可用 CPU 核心数,但需要进行实验。工作进程过少会导致数据加载瓶颈;过多则会引起开销或耗尽系统资源。
    • pin_memory=True:如果将数据加载到 GPU 上,将其设置为 True 会告诉 DataLoader 将获取的张量放入固定(页锁定)内存。这使得使用 tensor.to('cuda', non_blocking=True) 从 CPU 到 GPU 的异步数据传输更快。
    • prefetch_factor (PyTorch 1.7+):控制每个工作进程预取多少批次。默认值 (2) 通常足够,但如果工作进程有时速度较慢,增加此值可能有助于隐藏数据加载延迟。

以下图表说明了带有多个工作进程的 DataLoader 如何使用分片处理 IterableDataset

大型数据源(例如,文件/流)IterableDataset 实例DataLoader (num_workers=2)训练进程 (GPU)记录 1记录 2记录 3记录 4记录 5记录 6…IterableDataset__iter__()生成流工作进程 0为工作进程 0生成迭代器(记录 0, 2, 4, …)工作进程 1为工作进程 1生成迭代器(记录 1, 3, 5, …)模型训练循环生成批次 0, 批次 2, …生成批次 1, 批次 3, …

IterableDataset 和两个 DataLoader 工作进程的数据流。数据集向每个工作进程提供迭代器,并通过分片确保每个工作进程处理独特的数据,从而实现训练循环的并行数据加载。

通过使用 IterableDataset、细致的工作进程分片和优化数据加载管道,可以有效训练 PyTorch 模型,即便数据集远超系统内存容量,克服了大规模深度学习的一个重要障碍。这些技术通常与本章讨论的其他方法结合使用,例如梯度累积,以管理数据大小和计算限制。

自动化超参数调整

为深度学习模型找到最佳超参数组合可以大幅影响其表现,但手动调整通常是一个繁琐且依赖经验的过程。随着模型和数据集变得复杂,手动遍历庞大的可能配置空间变得不切实际。这里对自动化超参数优化(HPO)技术进行分析,并提供使用常用库将它们整合到PyTorch工作流程中的方法。

自动化HPO提供了一个系统方法来搜索超参数空间,旨在找到使预定义目标指标(通常与验证表现有关)最小化或最大化的配置。

自动化超参数优化的核心要点

在应用HPO工具之前,理解其基本组成部分是必不可少的:

  1. 超参数: 这些是在训练过程开始指定的配置设置。它们不像模型权重那样在训练期间学习得到。例子包括学习率、优化器类型(及其参数,如Adam的beta值)、权重衰减强度、Dropout概率、批大小、层数、每层单元数、激活函数,以及学习率调度器或数据增强策略的参数。
  2. 目标函数: 这是HPO算法旨在优化(最小化或最大化)的函数。它以一组特定的超参数作为输入,使用这些超参数训练模型,在验证集上评估模型,并返回一个表示模型表现的单一标量值(例如,验证损失、准确率、F1分数)。
  3. 搜索空间: 这定义了每个待调整超参数的可能值范围或集合。例如,学习率可以定义为对数范围内的浮点数(例如,10−510−5到10−110−1),层数可以定义为特定范围内的整数(例如,2到6),优化器类型可以定义为分类选择(例如,‘Adam’、‘AdamW’、‘SGD’)。
  4. 搜索算法/策略: 这是用于在搜索空间中寻找并选择下一组要评估的超参数的方法。不同的算法在计算成本和找到的解的质量之间提供不同的权衡。

常见的HPO策略

有几种算法可用于自动化HPO:

  • 网格搜索: 穷举评估在离散网格上定义的所有可能的超参数组合。虽然简单,但它受到“维度灾难”的困扰,其计算成本随超参数数量呈指数增长。如果某些超参数对目标影响不大,则效率低下。
  • 随机搜索: 从定义的搜索空间中随机采样超参数配置。出乎意料地有效,随机搜索在相同的计算预算下通常优于网格搜索,特别是当只有少数超参数对性能有明显影响时(如Bergstra和Bengio在2012年所展示)。
  • 贝叶斯优化: 构建目标函数f(x)f(x)的概率代理模型(通常使用高斯过程),其中xx代表一个超参数配置。它使用采集函数(例如,预期改进、上置信界)来平衡勘探(尝试不确定、可能高回报的配置)和专注于当前最佳配置附近(的尝试),以选择下一组要评估的超参数。这种方法通常比网格或随机搜索的样本效率更高,特别适用于计算成本高的目标函数。
  • 早期停止算法: 诸如HyperBand和异步逐次减半(ASHA)等技术专注于有效分配固定预算(例如,计算时间、训练周期)。它们启动许多配置,并根据它们的中间表现迭代地修剪掉前景较差的配置,将更多资源分配给表现更好的试验。这些在训练单个模型耗时较长时特别有用。

HPO算法训练与评估提出超参数(基于搜索策略)训练模型(使用提出的超参数)配置评估模型(计算目标指标)训练好的模型目标指标+ 中间结果(可选)

自动化超参数优化过程的简化视图。HPO算法建议一个配置,使用该配置训练并评估模型,得到的性能指标为算法的下一次建议提供依据。

将HPO库与PyTorch结合

像Optuna和Ray Tune这样的库简化了将HPO结合到PyTorch项目中的过程。典型的工作流程包括:

  1. 定义目标函数: 创建一个接受特殊trial对象(不同库的术语可能略有差异)的Python函数。
  2. 建议超参数: 在目标函数内部,使用trial对象提供的方法(例如,trial.suggest_floattrial.suggest_inttrial.suggest_categorical)根据定义的搜索空间为当前试验采样超参数值。
  3. 构建和训练模型: 实例化你的PyTorch模型、优化器、数据加载器等,使用建议的超参数。实现你的标准训练和验证循环。
  4. 评估并返回指标: 训练后(或在中间步骤),在验证集上评估模型,并返回HPO算法应优化的目标指标(例如,验证损失或准确率)。
  5. 实施剪枝(可选但推荐): 对于早期停止算法,定期使用trial.report(metric, step)向HPO库报告中间验证指标(例如,每个训练周期后)。然后,调用trial.should_prune()并在它返回真时抛出一个特殊异常(例如,optuna.TrialPruned)。这允许库提前停止无前景的试验,从而节省资源。
  6. 创建并运行研究: 使用库的API创建一个“研究”或实验实例。通过指定目标函数、优化方向(“最小化”或“最大化”)、搜索算法(采样器/调度器)、要运行的试验次数以及可能的并行执行设置来配置该研究。
  7. 分析结果: 研究完成后,库提供结果访问,包括找到的最佳超参数配置及其对应的目标值。
Optuna示例代码片段

这里是一个使用Optuna的示例,以说明其结构:

import torch
import torch.nn as nn
import torch.optim as optim
import optuna
 
//假设get_model, get_dataloaders, train_one_epoch, evaluate_model已在其他地方定义

def objective(trial):
    // 1. 建议超参数
    val lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
    val optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "AdamW", "RMSprop"])
    val dropout_rate = trial.suggest_float("dropout", 0.1, 0.5)
    val num_layers = trial.suggest_int("num_layers", 2, 5)
    val hidden_dim = trial.suggest_int("hidden_dim", 32, 256, log=True)

    // 2. 构建模型、优化器等
    val model = get_model(num_layers=num_layers, hidden_dim=hidden_dim, dropout_rate=dropout_rate)
    val device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    // 3. 构建优化器
    val optimizer_class = getattr(optim, optimizer_name)
    val optimizer = optimizer_class(model.parameters(), lr=lr)

    // 4. 加载数据加载器
    val (train_loader, valid_loader) = get_dataloaders()
    val num_epochs = 20 // 或者也可以是一个超参数

    // 3. 带有剪枝的训练循环
    for epoch <- Range(num_epochs):
        var train_loss = train_one_epoch(model, train_loader, optimizer, device)
        var validation_accuracy = evaluate_model(model, valid_loader, device)

        // 5. 报告中间结果以进行剪枝
        trial.report(validation_accuracy, epoch)

        // 根据中间值处理剪枝。
        if trial.should_prune():
            throw optuna.TrialPruned()

    // 4. 返回最终目标值
    final_validation_accuracy = evaluate_model(model, valid_loader, device)
    return final_validation_accuracy // 如果未指定,Optuna默认最大化

// 6. 创建并运行研究
val study = optuna.create_study(
    direction="maximize", // 最大化验证准确率
    pruner=optuna.pruners.MedianPruner() // 示例剪枝器
)
study.optimize(objective, n_trials=100) # 运行100个试验

// 7. 分析结果
println("完成的试验次数:", study.trials.length)
println("最佳试验:")
val trial = study.best_trial
println("  值: ", trial.value)
println("  参数: ")
for (key, value) <- trial.params:
    println(f"    {key}: {value}")
import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.Module;
import org.bytedeco.pytorch.global.torch;
import java.util.Map;

/**
 * Optuna + PyTorch超参数调优实现
 * 核心功能:
 * 1. 超参数搜索(学习率、优化器、dropout、网络层数/维度)
 * 2. 带剪枝的训练循环(MedianPruner)
 * 3. 试验结果分析(最佳参数、最佳准确率)
 */
public class OptunaHyperparameterTuning {

    // ======================== 假设的辅助方法接口(需根据实际场景实现) ========================
    /**
     * 创建模型(根据超参数)
     */
    private static Module getModel(int numLayers, int hiddenDim, float dropoutRate) {
        // 示例:返回一个简单的多层感知机(需替换为你的实际模型)
        StringAnyModuleDict layers = new StringAnyModuleDict();
        layers.insert("fc1", new AnyModule(new LinearImpl(10, hiddenDim)));
        layers.insert("relu1", new AnyModule(new ReLUImpl()));
        layers.insert("dropout1", new AnyModule(new DropoutImpl(dropoutRate)));

        // 动态添加隐藏层
        for (int i = 1; i < numLayers; i++) {
            layers.insert("fc" + (i+1), new AnyModule(new LinearImpl(hiddenDim, hiddenDim)));
            layers.insert("relu" + (i+1), new AnyModule(new ReLUImpl()));
            layers.insert("dropout" + (i+1), new AnyModule(new DropoutImpl(dropoutRate)));
        }

        layers.insert("fc_final", new AnyModule(new LinearImpl(hiddenDim, 2)));
        return new SequentialImpl(layers);
    }

    /**
     * 获取数据加载器(训练+验证)
     */
    private static JavaRandomDataLoader[] getDataloaders() {
        // 示例:返回空加载器(需替换为你的实际数据加载逻辑)
        var option = new DataLoaderOptions();
        RandomSampler sampler = new RandomSampler(64);
        option.batch_size().put(64);
        JavaRandomDataLoader trainLoader = new JavaRandomDataLoader(new TensorDataset(),sampler, option);
        JavaRandomDataLoader validLoader = new JavaRandomDataLoader(new TensorDataset(),sampler, option);
        return new JavaRandomDataLoader[]{trainLoader, validLoader};
    }

    /**
     * 训练一个epoch
     */
    private static float trainOneEpoch(Module model, JavaRandomDataLoader trainLoader, Optimizer optimizer, Device device) {
        // 示例:返回模拟训练损失(需替换为你的实际训练逻辑)
        model.train(true);
        return 0.1234f; // 模拟训练损失
    }

    /**
     * 评估模型验证准确率
     */
    private static float evaluateModel(Module model, JavaRandomDataLoader validLoader, Device device) {
        // 示例:返回模拟验证准确率(需替换为你的实际评估逻辑)
        model.eval();
        return 0.9567f; // 模拟验证准确率
    }


Optuna目标函数的结构,与PyTorch训练工作流程相结合,包括超参数建议和剪枝。

注意事项和最佳实践

  • 搜索空间设计: 仔细定义搜索空间。过窄可能错过最佳区域;过宽则增加计算成本。对学习率等参数使用对数尺度。结合先验知识设定合理的边界。
  • 目标指标: 选择一个真正反映期望的模型行为的指标(例如,验证准确率,不平衡数据集的F1分数,验证损失)。
  • 计算预算: 根据可用资源确定试验次数或时间预算。早期停止算法(Hyperband,ASHA)以及Ray Tune等库中的并行执行支持对于有效地管理预算很有帮助。
  • 剪枝: 积极实施剪枝,通过提前停止无前景的试验来节省大量计算。根据学习动态选择合适的剪枝器。
  • 可复现性: 为PyTorch、NumPy和HPO库本身设置随机种子,以确保结果可复现。
  • 复杂度: 首先调整影响最大的超参数(通常是学习率、优化器选择、正则化),然后再扩展搜索空间。

自动化超参数优化是高级深度学习从业者的工具箱中一个有价值的工具。通过系统地审视超参数配置并使用智能搜索策略和早期停止,与手动调整相比,你可以大幅提高模型性能和开发效率,从而节省时间以专注于模型架构和训练过程的其他方面。将Optuna或Ray Tune等库整合到你的PyTorch管道中,使你能够有效应用这些技术。

动手实战:实现混合精度训练

提供一个动手操作指南,介绍如何使用 torch.cuda.amp 实现自动混合精度 (AMP) 训练。自动混合精度 (AMP) 允许在特定操作中使用低精度浮点格式(如 float16)。这种能力可以大幅提升速度并减少 GPU 内存占用,通常对模型精度影响极小。该指南将讲解如何将标准的 PyTorch 训练循环转换为使用 AMP,并说明必要的变更。

本实践练习假设您可以使用支持 CUDA 的 NVIDIA GPU,其计算能力为 7.0 或更高(这是高效进行 float16 张量核心操作的必要条件),并且 PyTorch 安装版本相对较新(1.6 或更高)。

基准:标准训练循环

我们从一个使用全精度(float32)的简化标准训练循环开始。我们将使用一个基本的卷积网络和随机数据进行演示。

import torch
import torch.nn as nn
import torch.optim as optim
import time
import contextlib # 用于计时上下文管理器

// 1. 定义一个简单模型
class SimpleCNN extends nn.Module:
    def __init__( num_classes: Int=10):
        super().__init__()
        val conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        val relu = nn.ReLU()
        val pool = nn.MaxPool2d(kernel_size=2, stride=2)
        val conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        // 为全连接层展平特征
        val fc = nn.Linear(64 * 16 * 16, num_classes) // 假设输入图像为 32x32

    def forward(x: Tensor):
        x = pool(relu(conv1(x)))
        x = pool(relu(conv2(x)))
        x = torch.flatten(x, 1) // 展平除批次维度外的所有维度
        x = fc(x)
        return x

// 2. 设置:设备、模型、数据、损失函数、优化器
val device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
println(f"使用设备: {device}")

val model = SimpleCNN().to(device)
val criterion = nn.CrossEntropyLoss()
val optimizer = optim.Adam(model.parameters(), lr=1e-3)

// 模拟数据参数
val batch_size = 64
val img_size = 32
val num_batches = 100

// 简单计时器上下文管理器
@contextlib.contextmanager
def measure_time():
    val start_time = time.time()
    yield
    val end_time = time.time()
    println(f"耗时: {end_time - start_time:.4f} 秒")

// 简单内存使用报告器
def report_memory(stage=""):
    if torch.cuda.is_available() then
        println(f"{stage} - 峰值内存分配: {torch.cuda.max_memory_allocated(device) / 1e6:.2f} MB")
        torch.cuda.reset_peak_memory_stats(device) // 为下次测量重置峰值计数器

// 3. 标准训练循环 (FP32)
println("\n--- 标准 FP32 训练 ---")
report_memory("训练前")
model.train()
with measure_time():
    for i <- 0 until num_batches:
        // 实时生成模拟数据
        val inputs = torch.randn(batch_size, 3, img_size, img_size, device=device)
        val labels = torch.randint(0, 10, (batch_size,), device=device)

        optimizer.zero_grad()

        // 前向传播
        val outputs = model(inputs)
        val loss = criterion(outputs, labels)

        // 反向传播和优化
        loss.backward()
        optimizer.step()

        if (i + 1) % 20 == 0 then
             println(f"批次 [{i+1}/{num_batches}],损失: {loss.item():.4f}")

report_memory("训练后")
println("--- 标准 FP32 训练完成 ---")
import org.bytedeco.javacpp.*;
import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.Module;

import static org.bytedeco.pytorch.global.torch.*;

// 1. 定义 SimpleCNN 模型
class SimpleCNN extends Module {
    private Conv2dImpl conv1, conv2;
    private LinearImpl fc;
    private ReLUImpl relu;
    private MaxPool2dImpl pool;

    public SimpleCNN(int numClasses) {
        // 注册子模块,这样 parameters() 才能找到它们
        conv1 = register_module("conv1", new Conv2dImpl(new Conv2dOptions(3, 32, 3).padding(1)));
        conv2 = register_module("conv2", new Conv2dImpl(new Conv2dOptions(32, 64, 3).padding(1)));
        fc = register_module("fc", new LinearImpl(new LinearOptions(64 * 8 * 8, numClasses))); // 32->16->8 (两次池化)

        relu = new ReLUImpl();
        pool = new MaxPool2dImpl(new MaxPool2dOptions(2).stride(2));
    }

    public Tensor forward(Tensor x) {
        x = pool.forward(relu.forward(conv1.forward(x)));
        x = pool.forward(relu.forward(conv2.forward(x)));
        x = x.flatten(1, -1);
        x = fc.forward(x);
        return x;
    }
}

public class Main {
    public static void main(String[] args) {
        // 2. 设置设备与环境
        Device device = new Device(cuda_is_available() ? kCUDA() : kCPU());
        System.out.println("使用设备: " + device.type().toString());

        var model = new SimpleCNN(10);
        model.to(device, false);

        var criterion = new CrossEntropyLossImpl();
        var optimizer = new Adam(model.parameters(), new AdamOptions(1e-3));

        int batchSize = 64;
        int imgSize = 32;
        int numBatches = 100;

        // 3. 标准 FP32 训练循环
        System.out.println("\n--- 标准 FP32 训练 ---");
        reportMemory(device, "训练前");

        model.train();

        long startTime = System.nanoTime();

        for (int i = 0; i < numBatches; i++) {
            // 每个 Batch 开启独立的内存作用域,防止 GPU 显存持续累积
            try (var scope = new PointerScope()) {
                // 生成模拟数据
                var inputs = randn(new long[]{batchSize, 3, imgSize, imgSize},
                        new TensorOptions().device(new DeviceOptional(device)));
                var labels = randint(0, 10, new long[]{batchSize},
                        new TensorOptions().device(new DeviceOptional(device)).dtype(new ScalarTypeOptional(kLong())));

                optimizer.zero_grad();

                // 前向传播
                var outputs = model.forward(inputs);
                var loss = criterion.forward(outputs, labels);

                // 反向传播
                loss.backward();
                optimizer.step();

                if ((i + 1) % 20 == 0) {
                    System.out.printf("批次 [%d/%d],损失: %.4f%n", i + 1, numBatches, loss.item_float());
                }
            }
        }

        long endTime = System.nanoTime();
        System.out.printf("耗时: %.4f 秒%n", (endTime - startTime) / 1e9);

        reportMemory(device, "训练后");
    }

    // 简单内存报告器
    private static void reportMemory(Device device, String stage) {
        if (device.is_cuda()) {
            // JavaCPP 映射的 cudaGetDeviceProperties 等可获取详细显存
            // 此处模拟原 Python 逻辑打印基本信息
            System.out.println(stage + " - 内存检查点记录已完成 (CUDA 统计)");
        }
    }
}


运行此代码(如果您有合适的 GPU)。请记录报告的耗时和峰值内存使用。这作为我们的基准。

使用 torch.cuda.amp 实现混合精度

现在,我们来修改循环以集成 AMP。这需要 torch.cuda.amp 中的两个主要组件:

  1. autocast:这是一个上下文管理器,它可以在有益且安全的情况下,将张量操作自动转换为低精度类型(在兼容 GPU 上默认为 float16)。卷积和全连接层等操作通常会大幅加速,而其他操作(如归约)可能会保留在 float32 中以保持数值稳定性。
  2. GradScaler:由于 float16 的数值范围比 float32 小得多,反向传播期间计算的梯度可能变得非常小(下溢)并被置零,从而妨碍训练。GradScaler 通过在反向传播前将损失值 向上 缩放来帮助防止这种情况。这有效地将生成的梯度缩放到 float16 的可表示范围内。在优化器更新权重之前,GradScaler 会将梯度 反向缩放 回其原始值。如果在反向缩放期间检测到任何非有限梯度(NaN 或 Inf)(这有时会发生在训练不稳定或损失缩放因子较高时),则会跳过该批次的优化器步骤。GradScaler 还会随时间动态调整缩放因子。

以下是我们修改训练循环的方式:

import torch
import torch.nn as nn
import torch.optim as optim
import time
import contextlib # 用于计时上下文管理器
import torch.cuda.amp.{GradScaler, autocast}

// --- 重新初始化模型和优化器以进行公平比较 ---
val model = SimpleCNN().to(device)
val optimizer = optim.Adam(model.parameters(), lr=1e-3)
// --- 保持损失函数和模拟数据参数不变 ---
val criterion = nn.CrossEntropyLoss()
val batch_size = 64
val img_size = 32
val num_batches = 100

// --- 使用相同的计时器和内存报告器 ---
@contextlib.contextmanager
def measure_time():
    val start_time = time.time()
    yield
    val end_time = time.time()
    println(f"耗时: {end_time - start_time:.4f} 秒")

def report_memory(stage=""):
    if torch.cuda.is_available() then
        println(f"{stage} - 峰值内存分配: {torch.cuda.max_memory_allocated(device) / 1e6:.2f} MB")
        torch.cuda.reset_peak_memory_stats(device)

println("\n--- 混合精度 (AMP) 训练 ---")
// 1. 初始化 GradScaler
val scaler = GradScaler()

report_memory("训练前")
model.train()
with measure_time():
    for i <- 0 until num_batches:
        val inputs = torch.randn(batch_size, 3, img_size, img_size, device=device)
        val labels = torch.randint(0, 10, (batch_size,), device=device)

        optimizer.zero_grad()

        // 2. 使用 autocast 包装前向传播
        // 此上下文中的操作在支持的情况下以低精度 (FP16) 运行
        with autocast():
            val outputs = model(inputs)
            val loss = criterion(outputs, labels)

        // 3. 在 backward() 之前缩放损失
        // scaler.scale 将损失乘以当前缩放因子
        scaler.scale(loss).backward()

        // 4. scaler.step() 反向缩放梯度并调用 optimizer.step()
        // 反向缩放位于 optimizer.param_groups[...].grad 中的梯度
        // 如果梯度不是有限的 (NaN/Inf),则跳过 optimizer.step()
        scaler.step(optimizer)

        // 5. 为下一次迭代更新缩放因子
        // 如果发现 NaNs/Infs,则减小缩放因子,否则可能增大
        scaler.update()

        if (i + 1) % 20 == 0:
            // 注意:loss.item() 仍是未缩放的损失值
            println(f"批次 [{i+1}/{num_batches}],损失: {loss.item():.4f}")

report_memory("训练后")
println("--- 混合精度 (AMP) 训练完成 ---")
import org.bytedeco.javacpp.*;
import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;

import java.util.ArrayList;
import java.util.List;
import static org.bytedeco.pytorch.global.torch.*;

public class AmpTrainingComparison {
    public static void main(String[] args) {
        // --- 初始化设备与模型 ---
        Device device = new Device(cuda_is_available() ? kCUDA() : kCPU());
        var model = new SimpleCNN(10); // 使用上一步定义的模型
        model.to(device,false);

        // 获取参数列表供自定义 GradScaler 使用
        List<Tensor> paramList = new ArrayList<>();
        var paramsVector = model.parameters();
        for (long i = 0; i < paramsVector.size(); i++) {
            paramList.add(paramsVector.get(i));
        }

        var criterion = new CrossEntropyLossImpl();
        // 注意:这里我们使用自定义逻辑模拟 Adam,或直接调用原生 Adam
        var optimizer = new Adam(model.parameters(), new AdamOptions(1e-3));

        // 1. 初始化 GradScaler
        var scaler = new GradScaler(); // 使用你之前定义的自定义类

        int batchSize = 64;
        int imgSize = 32;
        int numBatches = 100;
        double learningRate = 1e-3;

        System.out.println("\n--- 混合精度 (AMP) 训练 ---");
        reportMemory(device, "训练前");

        model.train(true);
        long startTime = System.nanoTime();

        for (int i = 0; i < numBatches; i++) {
            // JDK 25 try-with-resources 管理每一批次的 Native 内存
            try (var batchScope = new PointerScope()) {

                optimizer.zero_grad();

                // 准备输入数据
                var inputs = torch.randn(new long[]{batchSize, 3, imgSize, imgSize},
                        new TensorOptions().device(new DeviceOptional(device)));
                var labels = torch.randint(0, 10, new long[]{batchSize},
                        new TensorOptions().device(new DeviceOptional(device)).dtype(new ScalarTypeOptional(kLong())));

                // 2. 模拟 with autocast():
                torch.set_autocast_enabled(DeviceType.CUDA,true);
                Tensor loss;
                try {
                    var outputs = model.forward(inputs);
                    loss = criterion.forward(outputs, labels);
                } finally {
                    torch.set_autocast_enabled(DeviceType.CUDA,false);
                }

                // 3. 缩放损失并反向传播
                // scaler.scale(loss) 返回 loss * current_scale
                var scaledLoss = scaler.scale(loss);
                scaledLoss.backward();

                // 4 & 5. 内部包含 Unscale, Overflow 检查, Step 以及 Update 逻辑
                // 这里调用你自定义实现的 step (它内部会处理梯度检查和 scale 更新)
                boolean stepped = scaler.step(paramList, learningRate);

                if ((i + 1) % 20 == 0) {
                    System.out.printf("批次 [%d/%d], 损失: %.4f, 缩放因子: %.1f%s%n",
                            i + 1, numBatches, loss.item_float(), scaler.getScale(),
                            stepped ? "" : " (溢出跳过)");
                }
            }
        }

        long endTime = System.nanoTime();
        System.out.printf("耗时: %.4f 秒%n", (endTime - startTime) / 1e9);
        reportMemory(device, "训练后");
    }

    private static void reportMemory(Device device, String stage) {
        if (device.is_cuda()) {
            // 在实际生产中,可调用 cudaMemGetInfo
            System.out.println("[Memory Report] " + stage + ": 检查点已记录");
        }
    }
}


分析与观察

如果您在兼容的 GPU 上(特别是带有张量核心的 GPU,如 V100、T4、A100、H100 或 RTX 20xx 系列及更高版本)运行这两个版本,您会注意到:

  1. 训练时间缩短:AMP 版本通常比标准 FP32 版本明显更快完成。加速 1.5 倍到 3 倍或更多是常见现象,具体取决于模型架构、GPU 和批次大小。
  2. 峰值内存使用降低:使用 float16 张量的操作所需的内存带宽和存储量是 float32 的一半。为反向传播存储的激活值也消耗更少内存,从而可以使用更大的批次或模型。
  3. 代码改动极少:集成 AMP 只需少量额外操作:初始化 GradScaler,使用 autocast 包装前向传播,以及修改 backward()optimizer.step() 调用以使用 scaler
  4. 数值稳定性:得益于 GradScaler,训练过程通常保持数值稳定,收敛情况与 FP32 基准相似。由于精度变化,您可能会看到批次间的损失值略有不同,但整体训练动态通常得以保持。

这里是 FP32 和 AMP 训练典型结果的示意图:

示意性对比,显示了使用 AMP 与标准 FP32 训练相比,典型的加速和内存减少情况。实际结果会因硬件和模型而异。

更多考量

  • bfloat16:在较新的硬件(例如,A100 等 Ampere 架构 GPU,或较新的 TPU)上,您可能更喜欢使用 torch.bfloat16。它具有与 float32 相同的指数范围,但尾数精度较低。这通常使其对下溢/上溢问题更具抵御能力,并有可能无需 GradScaler。您可以通过 autocast(dtype=torch.bfloat16) 启用它。请查阅您的硬件文档以获取最佳设置。
  • 性能分析:尽管 AMP 通常提供开箱即用的优势,但建议使用 PyTorch 性能分析器(在第 4 章中介绍)来确认预期加速得以实现,并找出特定于您的模型或操作、可能无法从低精度中获益的任何潜在瓶颈。
  • 梯度裁剪:AMP 有时会与梯度裁剪配合使用。通常建议在裁剪梯度 之前 对其进行反向缩放。scaler.unscale_(optimizer) 可以在调用 torch.nn.utils.clip_grad_norm_torch.nn.utils.clip_grad_value_ 之前调用。

本实践练习展现了将自动混合精度集成到您的 PyTorch 训练循环中是多么容易。借助 torch.cuda.amp,您可以大幅加速训练并减少内存消耗,从而可以在现有硬件上训练更大的模型或使用更大的批次。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐