【Scala PyTorch深度学习】PyTorch On Scala系列课程 第十章 21 :PyTorch微分【AI Infra 3.0】[PyTorch Scala 高校计算机硕士研一课程]

PyTorch Scala 高校计算机硕士研一课程
神经常微分方程
传统深度神经网络,例如残差网络 (ResNets),通过离散的层序列处理输入。我们可以将ResNet块视为连续变换的欧拉离散化:ht+1=ht+f(ht,θt)h**t+1=h**t+f(h**t,θ**t)。这种观点自然引出一个问题:我们能否连续地建模这种变换?神经常微分方程 (Neural ODEs) 给出了肯定的答复,它将网络深度定义为连续时间区间,而非层数。
连续深度框架
不同于离散变换,神经ODE使用常微分方程 (ODE) 建模隐藏状态 h(t)h(t) 随连续时间变量 tt 的演变。其主要思想是,使用神经网络 ff(以权重 θθ 为参数)来定义隐藏状态随时间的变化率:
dh(t)dt=f(h(t),t,θ)dtd**h(t)=f(h(t),t,θ)
在这里,h(t)h(t) 表示时间 tt 时的隐藏状态,而 ff 通常是一个标准神经网络(例如,一个MLP)它以当前状态 h(t)h(t)、当前时间 tt 和参数 θθ 作为输入,输出状态的变化率。
将输入 z0z0(即 h(t0)h(t0))转换为输出 z1z1(即 h(t1)h(t1))的整体过程,是通过在指定时间区间 [t0,t1][t0,t1] 上求解此ODE初始值问题得到的:
h(t1)=h(t0)+∫t0t1f(h(t),t,θ)dth(t1)=h(t0)+∫t0t1f(h(t),t,θ)d**t
这个积分通过ODE求解器进行数值计算。神经网络 ff 定义了向量场,求解器模拟了隐藏状态通过该向量场从起始时间 t0t0 到结束时间 t1t1 的路径。
神经ODE的优势
这种连续的表述方式具有多项有益的特点:
- 训练时的内存效率: 标准反向传播需要存储每一层的激活值来计算梯度。对于层数多的网络(或等同于ODE求解器正向传播中的许多步骤),这会消耗大量内存。神经ODE采用伴随敏感度方法来计算梯度。该方法涉及逆向求解第二个相关的ODE。重要的是,它计算参数 θθ 和初始状态 h(t0)h(t0) 所需的梯度时,内存使用量相对于“深度”或积分时间近似为常数。这使得训练具有复杂变换潜力的模型成为可能,而无需承担存储中间状态带来的内存负担。
- 自适应计算: 现代ODE求解器在积分过程中会自动调整步长。当动态 ff 变化迅速时,它们会采取较小的步长;当动态平滑时,则采取较大的步长。这意味着计算工作量可以适应所学函数的复杂性,与ResNets等固定步长架构相比,可能会带来更高的计算效率。
- 处理不规则时间序列: 神经ODE天然适合建模连续过程和在不规则时间点采样的数据。模型可以通过将ODE积分到任意时间 tt 来评估该点的隐藏状态。
使用PyTorch实现
实现神经ODE通常需要一个提供可微分ODE求解器的外部库。一个常用的选择是 torchdiffeq。
通常的工作流程包括:
-
定义动态函数: 创建一个标准
torch.nn.Module来表示函数 f(h(t),t,θ)f(h(t),t,θ)。这个模块以当前状态h和时间t作为输入,并返回计算出的导数dh/dt。import torch import torch.nn as nn class ODEFunc extends nn.Module: def __init__(hidden_dim: Int): super(ODEFunc, self).__init__() val net = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.Tanh(), nn.Linear(hidden_dim, hidden_dim), ) def forward(t: Float, h: Tensor): // t:当前时间(标量) // h:当前隐藏状态(张量) // 返回 dh/dt return net(h) -
使用ODE求解器: 使用
torchdiffeq中的odeint等函数。该函数接收动态函数func、初始状态h0、要评估解的时间点t(例如torch.tensor([t0, t1]))以及可选的求解器参数。它返回在指定时间点计算出的隐藏状态。// 假设 torchdiffeq 已安装:pip install torchdiffeq from torchdiffeq import odeint_adjoint as odeint // 使用伴随方法以节省内存 // 示例用法: val func = ODEFunc(hidden_dim=20) val h0 = torch.randn(batch_size, 20) // 初始状态 val t_span = torch.tensor([0.0, 1.0]) // 从 t=0 积分到 t=1 // 计算最终状态 h(t1) // odeint 通过伴随方法处理数值积分和梯度计算 val h1 = odeint(func, h0, t_span)[-1] // 获取最后一个时间点 (t1) 的状态 // h1 现在可以在后续层或损失函数中使用 // 可以通过 h1.backward() 计算 func.parameters() 和 h0 的梯度注意
odeint_adjoint的使用。该版本实现了内存高效的伴随反向传播方法。标准的odeint也可用,但可能占用更多内存。
伴随敏感度方法
直接通过ODE求解器的操作进行反向传播,可能会耗费大量计算资源和内存,因为它需要存储求解器计算的所有中间状态。伴随方法提供了一种替代方案。
它定义了伴随状态 a(t)=∂L∂h(t)a(t)=∂h(t)∂L,这表示最终损失 LL 对隐藏状态 h(t)h(t) 的梯度。该伴随状态的演变由另一个逆时间(从 t1t1 到 t0t0)运行的ODE控制:
da(t)dt=−a(t)T∂f(h(t),t,θ)∂hdtd**a(t)=−a(t)T∂h∂f(h(t),t,θ)
损失对参数 θθ 的梯度可以通过逆时间积分另一个相关量来计算:
∂L∂θ=∫t1t0a(t)T∂f(h(t),t,θ)∂θdt∂θ∂L=∫t1t0a(t)T∂θ∂f(h(t),t,θ)d**t
求解这些逆向ODE需要在反向传播过程中获取 h(t)h(t) 的值。然而,可以通过再次求解原始正向ODE dh(t)dt=f(h(t),t,θ)dtd**h(t)=f(h(t),t,θ) 来实时重新计算这些值,这次是从 h(t1)h(t1) 到 h(t0)h(t0) 逆向进行。这种重新计算避免了存储整个正向轨迹,从而大大节省了内存,通常将内存成本从 O(Nt)O(N**t) 降低到 O(1)O(1),其中 NtN**t 是求解器步数。
求解器选择与实际考量
torchdiffeq 等库提供了多种ODE求解器:
- 显式固定步长求解器: 欧拉法、中点法、RK4(四阶龙格-库塔法)。简单,但为了稳定性和准确性可能需要较小的步长。
- 自适应步长求解器: Dormand-Prince (
dopri5)、Adams方法。自动调整步长,对于平滑问题通常更高效、更准确。dopri5常作为不错的默认选项。 - 隐式求解器: 对于“刚性”ODE很有用,因为如果步长不大,显式方法会变得不稳定。它们通常在每一步都涉及方程求解,计算开销可能更大。
求解器的选择会影响准确性、稳定性和计算速度。它通常被视为一个需要调整的超参数。
挑战:
- 计算成本: 尽管内存高效,但求解ODE可能比固定数量的离散层计算速度慢,特别是使用伴随方法进行反向传播时。
- 数值稳定性: 求解器的选择以及所学动态函数 ff 的性质会影响积分过程中的数值稳定性。如果 ff 表现不佳,仍然可能出现激活值爆炸或消失的问题。
- 训练动态: 优化神经ODE有时比优化标准网络更具挑战性。
神经ODE展现了深度学习与微分方程之间引人入胜的联系。它们提供了一种内存高效的方式来建模复杂、连续的变换,并为涉及连续动态或不规则时间序列数据的问题提供了一种独特的工具,从而扩展了PyTorch中可用的高级网络架构种类。
元学习算法
“标准深度学习模型在大型标记数据集上训练时通常表现优异。然而,许多情况下需要用少量示例快速适应新任务,这种场景称为少样本学习。元学习,即“学习如何学习”,提供了一个训练模型的体系,使其能够有效泛化到数据有限的新任务。元学习算法不是学习如何很好地执行一个特定任务,而是学习一个过程或一个初始化方法,从而能够快速适应新的、相关任务。”
主要说明如何在PyTorch中实现元学习算法,特别是介绍一种流行且多功能的方案:模型无关元学习(MAML)。
元学习问题设置
在典型的监督学习设置中,我们有一个数据集 D={(xi,yi)}D={(x**i,y**i)},目标是学习一个由 θθ 参数化的函数 fθf**θ,使其在数据集上最小化损失 L(fθ(xi),yi)L(f**θ(x**i),y**i)。
元学习重新定义了这个问题。我们假设存在任务分布 p(T)p(T)。在元训练期间,我们从 p(T)p(T) 中采样批次任务 TiTi。对于每个任务 TiTi,我们通常有一个小的支持集 DisuppDisupp 用于任务内部学习,以及一个查询集 DiqueryDiquery 用于评估该任务的学习效果。目标是学习模型参数 θθ(通常称为元参数),使得模型能够利用新的、以前未见的任务 TnewTnew 的支持集快速适应,从而在其查询集 DnewqueryDnewquery 上获得良好性能。
模型无关元学习(MAML)
MAML 由 Finn 等人于 2017 年提出,其目标是找到对任务变化敏感的元参数 θθ,仅用少量梯度步长就能在小支持集上进行有效微调。它之所以“模型无关”,是因为它不对模型架构 fθf**θ 做强假设;它可以应用于 CNN 或 RNN 等多种模型。
其核心思想涉及一个两层优化过程:
-
内循环(任务特定适应): 对于每个采样的任务 TiTi,从当前元参数 θθ 开始。仅使用任务的支持集 DisuppDisupp 执行一次或几次梯度下降步骤,以获得任务特定参数 θi′θ**i′。对于学习率为 αα 的单个梯度步长:
θi′=θ−α∇θLTi(fθ(Disupp))θ**i′=θ−α∇θLTi(f**θ(Disupp))
这里,LTiLTi 是任务 TiTi 的损失函数,fθ(Disupp)f**θ(Disupp) 表示模型使用参数 θθ 对支持集进行预测的结果。请注意,此梯度是相对于初始参数 θθ 计算的。
-
外循环(元优化): 评估已适应参数 θi′θ**i′ 在任务查询集 DiqueryDiquery 上的表现。元目标是在适应之后最小化跨任务的损失。元参数 θθ 根据这些适应后查询集损失的总和(或平均值)进行更新,使用元学习率 ββ:
θ←θ−β∇θ∑Ti∼p(T)LTi(fθi′(Diquery))θ←θ−β∇θTi∼p(T)∑LTi(fθi′(Diquery))
关键在于,外循环中的梯度 ∇θ∑LTi(fθi′(…))∇θ∑LTi(fθi′(…)) 涉及到对内循环更新步骤的求导。这意味着我们需要计算相对于 θθ 的梯度,并考虑 θi′θ**i′ 是如何从 θθ 推导出来的。这导致梯度计算涉及二阶导数(梯度的梯度)。
内循环(任务 Ti)外循环(跨任务)元参数θ∇_θ L_supp(θ)计算支持集上的梯度已适应参数θ’适应步骤θ’ = θ - α∇_θ∇_θ L_query(θ’)计算查询集上的梯度元更新(使用 ∇_θ)元梯度(涉及 ∇_θ’)更新 θ
流程图说明了 MAML 优化过程。内循环使用支持集损失,将参数 θθ 适应为任务特定的 θ′θ′。外循环根据使用已适应参数 θ′θ′ 的查询集损失计算元梯度,该元梯度随后用于更新原始元参数 θθ。
在 PyTorch 中实现 MAML
实现外循环的梯度计算需要谨慎。标准的 PyTorch backward() 调用会丢弃梯度中梯度计算所需的中间图信息。
有两种主要的方法来处理这个问题:
-
使用
torch.autograd.grad: 使用torch.autograd.grad并设置create_graph=True参数来手动计算内部梯度。这会告诉 PyTorch 为梯度计算本身构建一个计算图,从而允许稍后进行反向传播。
示意图:内循环梯度计算
val inner_loss = calculate_loss(model(support_x), support_y)
val grads = torch.autograd.grad(inner_loss, model.parameters(), create_graph=True)
// 计算已适应参数(函数式方法在这里通常更简单)
val adapted_params = [p - alpha * g for p, g in zip(model.parameters(), grads)]
// 使用 adapted_params 计算外部损失(需要函数式模型调用)
// ... outer_loss = calculate_loss(functional_model(adapted_params, query_x), query_y) ...
// 外循环梯度计算稍后会汇总跨任务的 outer_loss
// 并对总和调用 backward()。
- 使用高阶梯度库: 像
higher这样的库能显著简化此过程。higher提供了上下文管理器,让你可以创建模型的临时可微分副本。你在此临时副本上执行内循环更新,库会自动处理外循环梯度所需的计算跟踪。
使用 ‘higher’ 的示意图
import higher
val meta_optimizer.zero_grad()
val total_outer_loss = 0.0
for task_i <- batch_of_tasks:
val (support_x, support_y, query_x, query_y) = get_task_data(task_i)
with higher.innerloop_ctx(model, inner_optimizer, copy_initial_weights=True) as (fmodel, diffopt):
// 内循环更新
for _ <- range(num_inner_steps):
val inner_loss = calculate_loss(fmodel(support_x), support_y)
diffopt.step(inner_loss) // 更新 fmodel 的参数
// 外循环评估
val outer_loss = calculate_loss(fmodel(query_x), query_y)
total_outer_loss += outer_loss
// 反向传播元目标
total_outer_loss.backward()
meta_optimizer.step()
higher 方法因其更简洁的实现而常受青睐,它抽象了 create_graph=True 的手动处理和函数式参数更新。
MAML 变体
- 一阶 MAML (FOMAML): 计算二阶导数在计算上可能开销很大。FOMAML 通过忽略二阶项来近似 MAML 更新。本质上,它计算内部梯度 ∇θLTi(fθ(Disupp))∇θLTi(f**θ(Disupp)),然后使用已适应参数计算外部梯度 ∇θ′LTi(fθi′(Diquery))∇θ′LTi(fθi′(Diquery)),但在外部反向传播期间,它将内部梯度步骤视为与初始 θθ 无关。这更快,但性能可能略逊于完整的 MAML。在 PyTorch 的手动实现中,这对应于调用
torch.autograd.grad不带create_graph=True。 - Reptile: 另一种一阶元学习算法(Nichol 等人,2018),它简化了更新过程。它在内循环中执行多个梯度步骤,然后通过简单地将元参数 θθ 稍微朝已适应参数 θi′θ**i′ 的方向移动来更新它们:θ←θ+β(θi′−θ)θ←θ+β(θ**i′−θ)。这完全避免了显式的二阶导数计算。
应用与考虑
元学习,特别是 MAML 及其变体,已在以下方面获得应用:
- 少样本图像分类: 学习能够从极少量示例中识别新物体类别的分类器。
- 强化学习: 训练能够快速适应新环境或动态变化的智能体。
- 域适应: 将在一个数据分布(源域)上训练的模型适应到相关但不同分布(目标域)上,并使用有限的目标数据实现良好性能。
挑战:
- 计算成本: MAML 的二阶梯度计算和存储开销可能很大,特别是对于大型模型。FOMAML 和 Reptile 提供了替代方案。
- 训练稳定性: 元学习优化场景可能很复杂,有时需要仔细调整超参数(例如,内外部学习率、内循环步数)。
- 任务定义: 元学习的有效性很大程度上取决于任务 p(T)p(T) 的定义和分布。任务需要共享某种元学习器可以加以运用的潜在结构。
元学习代表了一种转变,从训练单个任务的模型转向训练具备高效学习能力的模型。像 MAML 这样的算法为实现这一目标提供了一个具体机制,通过优化可适应的初始化,使模型能够在数据稀缺的情况下快速适应。实现这些需要仔细处理梯度计算,这通常通过专门的库或手动应用 PyTorch 的自动求导功能来简化。
实践:实现自定义GNN层
PyTorch Geometric等库提供强大的预构建层来构建图神经网络(GNN)。虽然这些层非常有用,但了解如何使用核心PyTorch操作从头开始构建GNN层,能提供更全面的理解,并具备实现新颖或定制化消息传递方案的灵活性。本次实践练习将指导你创建一个简单的自定义GNN层。
许多GNN层背后的基本思想是消息传递,即节点迭代地从邻居节点聚合信息并更新自身的表示。我们可以将其分解为每个节点 ii 的两个主要步骤:
- 聚合: 从邻居节点 j∈N(i)j∈N(i) 收集特征或“消息”。
- 更新: 将聚合的信息与节点当前的特征向量 hih**i 结合,以生成更新后的特征向量 hi′h**i′。
让我们实现一个执行这些步骤的基本层。我们将定义一个层,它使用可学习的权重矩阵转换节点特征,使用简单的求和从邻居聚合转换后的特征,然后应用激活函数。
从数学上讲,对于节点 ii,此操作可以描述为:
ai=∑j∈N(i)∪{i}Whja**i=j∈N(i)∪{i}∑Whjhi′=σ(ai)h**i′=σ(a**i)
这里,hjh**j 表示节点 jj 的特征向量,WW 是一个可学习权重矩阵,N(i)N(i) 是节点 ii 的邻居集合,σσ 是一个非线性激活函数(如ReLU)。注意,我们将节点自身(ii)也包含在聚合中,这通常被称为添加自环。这确保了节点原始特征在更新时得到考量。
设置自定义层
首先,请确保已导入PyTorch。我们将把自定义层定义为一个继承自 torch.nn.Module 的Python类。
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleGNNLayer extends nn.Module:
"""
一个实现消息传递的基本图神经网络层。
Args:
in_features (int): 每个输入节点特征向量的大小。
out_features (int): 每个输出节点特征向量的大小。
"""
def __init__( in_features: Int, out_features: Int):
super(SimpleGNNLayer, self).__init__()
val in_features = in_features
val out_features = out_features
// 定义可学习的权重矩阵
val linear = nn.Linear(in_features, out_features, bias=False)
// 初始化权重(可选但通常是好的做法)
nn.init.xavier_uniform_(linear.weight)
def forward(x: torch.Tensor, edge_index: torch.Tensor):
"""
定义每次调用时执行的计算。
Args:
x (torch.Tensor): 节点特征张量,形状为 [num_nodes, in_features]。
edge_index (torch.Tensor): COO格式的图连接信息,形状为 [2, num_edges]。
edge_index[0] = 源节点,edge_index[1] = 目标节点。
Returns:
torch.Tensor: 更新后的节点特征张量,形状为 [num_nodes, out_features]。
"""
val num_nodes = x.size(0)
// 1. 为edge_index表示的邻接矩阵添加自环
// 创建节点索引张量 [0, 1, ..., num_nodes-1]
val self_loops = torch.arange(0, num_nodes, device=x.device).unsqueeze(0)
val self_loops = self_loops.repeat(2, 1) // 形状 [2, num_nodes]
// 将原始边与自环拼接
val edge_index_with_self_loops = torch.cat([edge_index, self_loops], dim=1)
// 提取源节点和目标节点索引
val row = edge_index_with_self_loops(0)
val col = edge_index_with_self_loops(1)
// 2. 线性变换节点特征
val x_transformed = linear(x) // 形状: [num_nodes, out_features]
// 3. 聚合来自邻居(包括自身)的特征
// 我们希望对每个目标节点(col)求和源节点(row)的特征
// 使用零初始化输出张量
val aggregated_features = torch.zeros(num_nodes, out_features, device=x.device)
// 使用 index_add_ 进行高效聚合(散列求和)
// 将 x_transformed[row] 的元素添加到 aggregated_features 中由 col 指定的索引处
// index_add_(维度, 索引张量, 要添加的张量)
aggregated_features.index_add_(0, col, x_transformed(row))
// 4. 应用最终激活函数(可选)
// 在此示例中,我们使用ReLU
val output_features = F.relu(aggregated_features)
return output_features
def __repr__(self):
return f'{self.__class__.__name__}({self.in_features}, {self.out_features})'
理解实现
- 初始化 (
__init__):我们定义一个nn.Linear层。此层将把可学习权重变换 WW 应用于输入节点特征。为简单起见,我们设置bias=False,这与一些GNN公式(如基本GCN)一致。使用nn.init.xavier_uniform_进行权重初始化有助于稳定训练。 - 前向传播 (
forward):这是消息传递逻辑的所在。- 自环:我们显式地将自环添加到
edge_index。这确保了在为节点聚合邻居特征时,节点自身的转换特征也包含在内。我们创建一个表示从每个节点到自身的边的边索引,并将其与原始edge_index拼接。 - 特征变换:我们同时对所有节点特征
x应用线性变换 (self.linear)。 - 聚合:这是GNN的主要步骤。我们需要对每个目标节点 (
col) 求和源节点 (x_transformed[row]) 的转换特征。torch.index_add_是一种高效执行此“散列-求和”操作的方法。它接受要累积到的张量 (aggregated_features)、进行索引的维度(节点为0)、要添加到的索引 (col,即目标节点),以及要添加的值 (x_transformed[row],即源节点的转换特征)。 - 激活:最后,逐元素应用一个非线性激活函数 (
F.relu)。
- 自环:我们显式地将自环添加到
这里有一个小型图可视化,用以显示 edge_index 格式和邻居的思想:
0123
对于上面的图,一个可能的
edge_index(表示用于消息传递的有向边,假设无向原始边意味着消息双向传递)可能是:tensor([[0, 0, 1, 2, 1, 2, 3, 3], [1, 2, 0, 0, 3, 3, 1, 2]])。 第一行包含源节点,第二行包含目标节点。当为节点3聚合时,我们会查看来自源节点1和2的消息。
使用自定义层
现在,让我们看看如何使用这个 SimpleGNNLayer。我们需要一些示例节点特征和一个 edge_index。
// 示例用法
// 定义图数据
val num_nodes = 4
val num_features = 8
val out_layer_features = 16
// 节点特征(随机)
val x = torch.randn(num_nodes, num_features)
// 边索引表示连接(例如,0->1, 0->2, 1->3, 2->3;对于无向图则反之)
val edge_index = torch.tensor(Seq(
Seq(0, 0, 1, 2, 1, 2, 3, 3), // 源节点
Seq(1, 2, 0, 0, 3, 3, 1, 2) // 目标节点
), dtype=torch.long)
// 实例化层
val gnn_layer = SimpleGNNLayer(in_features=num_features, out_features=out_layer_features)
println(s"已实例化层: $gnn_layer")
// 将数据通过该层
val output_node_features = gnn_layer(x, edge_index)
// 检查输出形状
println(s"\n输入节点特征形状: ${x.shape}")
println(s"边索引形状: ${edge_index.shape}")
println(s"输出节点特征形状: ${output_node_features.shape}")
// 验证输出形状是否符合预期: [num_nodes, out_features]
assert output_node_features.shape == (num_nodes, out_layer_features)
print("\n数据已成功通过自定义GNN层。")
// 显示节点0的前几个输出特征
println(s"节点0的输出特征(前5维): ${output_node_features(0, 0 until 5)}")
此示例展示了创建随机节点特征和示例 edge_index,实例化我们的 SimpleGNNLayer,并执行前向传播。输出形状 [num_nodes, out_features] 确认该层按预期运行,为每个节点根据其邻域生成新的嵌入。
潜在的扩展
这个简单的层可作为根本。你可以通过多种方式对其进行扩展:
- 不同聚合方式:将
index_add_(求和聚合)替换为平均或最大值聚合。平均聚合通常需要知道每个节点的度。 - 边特征:修改
forward传播以接受和运用边特征,并可能在聚合前将其加入到消息计算中。 - 标准化:添加标准化步骤,例如GCN层中常见的对称标准化,这通常涉及节点度。
- 偏置项:在
nn.Linear层中包含一个偏置项,或在聚合后添加。 - 多层堆叠:堆叠这些层,可能加入标准化或跳跃连接,以构建更深的GNN模型。
构建这样的自定义层是一项很有价值的技能。它使你能够直接根据研究论文实现前沿GNN架构,或在必要时精确地根据问题需求定制消息传递方案。构建自定义 nn.Module 组件的这一相同原理,也适用于在本课程中实现的Transformer、归一化流或其他高级架构中的独特机制。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐



所有评论(0)