LeWorldModel的理解4——训练目标
文章目录
- 一、前言
- 二、LeWorldModel论文
-
- 3 方法: LeWorldModel
-
- 3.1 学习潜在世界模型
- 【公式含义】:
- 【符号解释】:
- 【公式解释】:
- 【总结】:
- 问题1:单位范数方向、 d − 1 d-1 d−1 维单位超球面 S d − 1 \mathbb{S}^{d-1} Sd−1</font>
- 1. 单位范数方向 (Unit Norm Direction)
- 2. d − 1 d-1 d−1 维单位超球面 S d − 1 \mathbb{S}^{d-1} Sd−1
- 3. 均匀采样 (Uniform Sampling)
- 4. 为什么这样做?(Cramér-Wold 定理的应用)
- 总结
- 通俗解释
- 一句话总结
- 类比理解
- 1. 单位范数方向 = 手电筒的光束方向
- 2. d − 1 d-1 d−1 维单位超球面 = 所有可能方向的集合
- 3. 均匀采样 = 公平地随机选方向
- 为什么要这样做?(Cramér-Wold 定理)
- 核心思想:降维打击
- 具体场景
- 最直观的类比
- 总结
- 回到原文
- 【公式含义】:
- 【符号解释】:
- 【公式解释】:
- 【总结】:
- 回到原文
- 代码解释
- 整体架构
- 逐行解释
- 输入
- 1. 编码器(Encoder)
- 2. 预测器(Predictor)
- 3. 预测损失(主任务)
- 4. SIGReg 正则化(防止坍缩)
- 5. 总损失
- 为什么需要 SIGReg?(直观例子)
- 总结流程图
- 一句话总结
一、前言
仅供参考,未经实验验证。
二、LeWorldModel论文
论文标题:LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels(LeWorldModel: 从像素到像素的稳定端到端联合嵌入预测架构)
作者:Lucas Maes*, Quentin Le Lidec*, Damien Scieur, Yann LeCun, Randall Balestriero
论文地址: https://arxiv.org/pdf/2603.19312
GitHub 地址: https://github.com/lucas-maes/le-wm
3 方法: LeWorldModel
3.1 学习潜在世界模型
Training Objective. Our objective is to learn latent representations useful for predicting the future, i.e., modeling the environment dynamics. LeWorldModel training objective is the sum of two terms: a prediction loss and a regularization loss. The prediction loss L pred \mathcal{L}_{\text{pred}} Lpred (teacher-forcing) computes the error between the predicted embedding of consecutive time-steps:
训练目标 我们的目标是学习对预测未来有用的潜在表征,即对环境动力学进行建模。LeWorldModel 的训练目标是两个项的总和:预测损失和正则化损失。预测损失 L pred \mathcal{L}_{\text{pred}} Lpred (教师强制)计算了连续时间步的预测嵌入之间的误差:
L pred ≜ ∥ z ^ t + 1 − z t + 1 ∥ 2 2 , z ^ t + 1 = pred ϕ ( z t , a t ) . (1) \mathcal{L}_{\text{pred}} \triangleq \|\hat{\mathbf{z}}_{t+1} - \mathbf{z}_{t+1}\|_2^2, \quad \hat{\mathbf{z}}_{t+1} = \text{pred}_\phi(\mathbf{z}_t, \mathbf{a}_t). \tag{1} Lpred≜∥z^t+1−zt+1∥22,z^t+1=predϕ(zt,at).(1)
Through the prediction loss, the encoder is incentivized to learn a predictable representation for the predictor.
通过预测损失,编码器被激励去学习一个可预测的表示,以供预测器使用。
However, this loss alone leads to representation collapse, yielding a trivial solution in which the encoder maps all inputs to a constant representation. To prevent this behavior, we introduce an anti-collapse regularization term that promotes feature diversity in the embedding space.
然而,仅此一项损失就会导致表示坍塌,产生一个平凡解,其中编码器将所有输入映射到一个常数表示。为防止此行为,我们引入了一个反坍塌正则化项,以促进嵌入空间中的特征多样性。
Specifically, we adopt the Sketched-Isotropic-Gaussian Regularizer (SIGReg) [25] due to its simplicity, scalability, and stability. SIGReg encourages the latent embeddings to match an isotropic Gaussian target distribution.
具体而言,我们采用草图等向高斯正则化器(SIGReg)[25],因为它具有简单性、可扩展性和稳定性。SIGReg 鼓励潜在嵌入匹配等向高斯目标分布。
Let Z ∈ R N × B × d \mathbf{Z} \in \mathbb{R}^{N \times B \times d} Z∈RN×B×d denote the tensor of latent embeddings collected over the history length N N N , the batch size B B B , and where d d d denotes the embedding dimension. Assessing normality directly in high-dimensional spaces is challenging, as most classical normality tests are designed for univariate data and do not scale reliably with dimensionality.
令 Z ∈ R N × B × d \mathbf{Z} \in \mathbb{R}^{N \times B \times d} Z∈RN×B×d 表示在历史长度 N、批量大小 B 和嵌入维度 d 上收集的潜在嵌入张量。在高维空间中直接评估正态性是具有挑战性的,因为大多数经典的正态性检验是为单变量数据设计的,并且不能可靠地随维度扩展。
SIGReg circumvents this limitation by projecting embeddings onto M M M random unit-norm directions u ( m ) ∈ S d − 1 \mathbf{u}^{(m)} \in \mathbb{S}^{d-1} u(m)∈Sd−1 and optimizing the univariate Epps–Pulley [38] test statistic T ( ⋅ ) T(\cdot) T(⋅) along the resulting one-dimensional projections h ( m ) = Z u ( m ) \mathbf{h}^{(m)} = \mathbf{Z}\mathbf{u}^{(m)} h(m)=Zu(m) , as illustrated in Fig.1. By the Cramér–Wold theorem [39], matching all one-dimensional marginals is equivalent to matching the full joint distribution.
SIGReg 通过将嵌入投影到 M M M 个随机单位范数方向 u ( m ) ∈ S d − 1 \mathbf{u}^{(m)} \in \mathbb{S}^{d-1} u(m)∈Sd−1 并优化一元 Epps–Pulley [38] 检验统计量 T ( ⋅ ) T(\cdot) T(⋅) 沿所得的一维投影 h ( m ) = Z u ( m ) \mathbf{h}^{(m)} = \mathbf{Z}\mathbf{u}^{(m)} h(m)=Zu(m)来规避此限制,如图 1 所示。根据 Cramér–Wold 定理 [39],匹配所有一维边际分布等价于匹配联合分布。
SIGReg ( Z ) ≜ 1 M ∑ m = 1 M T ( h ( m ) ) . (2) \text{SIGReg}(\mathbf{Z}) \triangleq \frac{1}{M} \sum_{m=1}^M T(\mathbf{h}^{(m)}). \tag{2} SIGReg(Z)≜M1m=1∑MT(h(m)).(2)
【公式含义】:
该公式定义了 Sketched-Isotropic-Gaussian Regularizer (SIGReg) 损失函数,其核心目的是强制潜在嵌入(latent embeddings)的分布符合各向同性高斯(isotropic Gaussian)目标分布,从而防止表示坍塌(representation collapse),并促进特征多样性。
【符号解释】:
- SIGReg ( Z ) \text{SIGReg}(\mathbf{Z}) SIGReg(Z): 表示 Sketched-Isotropic-Gaussian Regularizer 损失函数,其输入是潜在嵌入张量 Z \mathbf{Z} Z。
- Z ∈ R N × B × d \mathbf{Z} \in \mathbb{R}^{N \times B \times d} Z∈RN×B×d: 论文中定义为“在历史长度 N N N、批次大小 B B B 和嵌入维度 d d d 上收集的潜在嵌入张量”。具体来说,它包含了批次中所有样本在一段时间步内的潜在嵌入。
- M M M: 表示随机单位范数方向(random unit-norm directions)的数量。这些方向用于将高维潜在嵌入投影到一维空间。
- 1 M ∑ m = 1 M \frac{1}{M} \sum_{m=1}^{M} M1∑m=1M: 表示对 M M M 个随机投影方向上的统计量进行平均。这是为了通过这些单维投影的统计量来近似评估高维分布的性质。
- T ( h ( m ) ) T(\mathbf{h}^{(m)}) T(h(m)): 表示 Epps-Pulley 统计检验(Epps-Pulley statistical test)的函数,应用于沿着第 m m m 个随机方向投影后得到的一维数据 h ( m ) \mathbf{h}^{(m)} h(m)。Epps-Pulley 检验用于评估给定一维数据分布与标准高斯分布(或任何目标分布)的匹配程度。
- h ( m ) \mathbf{h}^{(m)} h(m): 是通过将潜在嵌入 Z \mathbf{Z} Z 投影到第 m m m 个随机单位范数方向 u ( m ) \mathbf{u}^{(m)} u(m) 上得到的一维投影数据。根据论文描述, h ( m ) = Z u ( m ) \mathbf{h}^{(m)} = \mathbf{Z}\mathbf{u}^{(m)} h(m)=Zu(m),其中 u ( m ) ∈ S d − 1 \mathbf{u}^{(m)} \in \mathbb{S}^{d-1} u(m)∈Sd−1( d − 1 d-1 d−1 维单位超球面上的向量),且这些方向是均匀采样的。
【公式解释】:
该公式本身不是一个推导公式,而是一个定义了 SIGReg 损失的计算方法。理解这个公式需要以下背景知识:
-
表示坍塌(Representation Collapse):
在自监督学习中,模型可能倾向于将所有输入映射到相同的(或非常相似的)潜在表示,从而使得预测目标变得微不足道,但学到的表示却毫无用处。这被称为表示坍塌。SIGReg 的目的就是通过强制潜在表示具有多样性(即符合各向同性高斯分布)来避免这种坍塌。 -
各向同性高斯分布(Isotropic Gaussian Distribution):
各向同性高斯分布是一种在所有维度上具有相同方差且各维度之间不相关的多元高斯分布。它的概率密度函数在各个方向上都是对称的,形状像一个球体。强制潜在嵌入符合这种分布,可以鼓励特征在各个维度上具有丰富的变化性,并且彼此独立,从而增强表示的多样性和信息量。 -
Cramér-Wold 定理:
这是一个重要的数学定理,它指出:如果一个多维概率分布的所有一维投影都与某个目标一维分布相匹配,那么这个多维分布本身就与相应的目标多维分布相匹配。在 SIGReg 中,这意味着如果我们将高维潜在嵌入投影到足够多的随机一维方向上,并确保每个一维投影都符合标准高斯分布,那么原始的高维潜在嵌入的联合分布也将近似于各向同性高斯分布(目标是标准高斯分布 N ( 0 , I ) N(0, I) N(0,I),即均值为0,协方差矩阵为单位矩阵)。 -
Epps-Pulley 统计检验:
这是一种用于检验一维数据是否服从特定目标分布(例如标准高斯分布)的统计方法。论文中提到,SIGReg 应用 Epps-Pulley 检验统计量 T ( ⋅ ) T(\cdot) T(⋅) 到每个一维投影 h ( m ) \mathbf{h}^{(m)} h(m) 上。这个检验统计量会量化 h ( m ) \mathbf{h}^{(m)} h(m) 的经验分布与标准高斯分布之间的差异。
公式的计算步骤和解释:
-
获取潜在嵌入张量 Z \mathbf{Z} Z:
在 LeWorldModel 的训练过程中,编码器会将原始像素观测数据 O t : T \mathbf{O}_{t:T} Ot:T 映射到低维潜在表示 Z t : T \mathbf{Z}_{t:T} Zt:T。这些潜在表示在每个时间步 t t t、每个批次样本 i i i 以及每个嵌入维度 d d d 上构成张量 Z ∈ R N × B × d \mathbf{Z} \in \mathbb{R}^{N \times B \times d} Z∈RN×B×d。 -
生成随机投影方向 u ( m ) \mathbf{u}^{(m)} u(m):
为了应用 Cramér-Wold 定理,SIGReg 首先生成 M M M 个随机的单位范数方向 u ( m ) \mathbf{u}^{(m)} u(m)。这些方向是从 d − 1 d-1 d−1 维单位超球面 S d − 1 \mathbb{S}^{d-1} Sd−1 上均匀采样的。 -
进行一维投影 h ( m ) \mathbf{h}^{(m)} h(m):
对于每一个生成的随机方向 u ( m ) \mathbf{u}^{(m)} u(m),潜在嵌入张量 Z \mathbf{Z} Z 被投影到这个方向上,生成一维数据 h ( m ) = Z u ( m ) \mathbf{h}^{(m)} = \mathbf{Z}\mathbf{u}^{(m)} h(m)=Zu(m)。这个操作将高维的潜在嵌入降维到一维,便于进行单变量的统计检验。 -
计算 Epps-Pulley 检验统计量 T ( h ( m ) ) T(\mathbf{h}^{(m)}) T(h(m)):
对于每个一维投影数据 h ( m ) \mathbf{h}^{(m)} h(m),应用 Epps-Pulley 统计检验函数 T ( ⋅ ) T(\cdot) T(⋅)。这个函数会计算 h ( m ) \mathbf{h}^{(m)} h(m) 的经验分布与目标标准高斯分布之间的“距离”或不匹配程度。如果 h ( m ) \mathbf{h}^{(m)} h(m) 越接近标准高斯分布,则 T ( h ( m ) ) T(\mathbf{h}^{(m)}) T(h(m)) 的值就越小。 -
对所有投影方向的统计量进行平均:
最后,将所有 M M M 个随机投影方向上计算得到的 Epps-Pulley 统计检验值 T ( h ( m ) ) T(\mathbf{h}^{(m)}) T(h(m)) 求和并取平均,得到最终的 SIGReg 损失值: SIGReg ( Z ) = 1 M ∑ m = 1 M T ( h ( m ) ) \text{SIGReg}(\mathbf{Z}) = \frac{1}{M} \sum_{m=1}^{M} T(\mathbf{h}^{(m)}) SIGReg(Z)=M1∑m=1MT(h(m))。
如何帮助反坍塌:
通过最小化这个 SIGReg ( Z ) \text{SIGReg}(\mathbf{Z}) SIGReg(Z) 损失,模型被激励调整其编码器,使得其生成的潜在嵌入 Z \mathbf{Z} Z 在所有随机一维投影上都尽可能地接近标准高斯分布。根据 Cramér-Wold 定理,这等价于鼓励整个高维潜在嵌入的联合分布接近各向同性高斯分布。这种强制性的分布匹配有效地防止了表示坍塌,因为一个各向同性高斯分布的特性是其在所有方向上都具有非零方差,且各维度之间不相关,这意味着潜在空间中的特征必须是多样化且信息丰富的,而不是集中在一点。
【总结】:
SIGReg 损失通过结合 Cramér-Wold 定理和 Epps-Pulley 统计检验,将高维潜在嵌入投影到多个随机一维方向上,并对这些一维投影应用正态性检验。通过最小化这些检验统计量的平均值,SIGReg 损失有效地强制模型的潜在嵌入分布接近各向同性高斯分布,从而在无需复杂启发式方法或多项损失项的情况下,稳定地防止表示坍塌并促进学习到多样化的、有意义的特征表示。
问题1:单位范数方向、 d − 1 d-1 d−1 维单位超球面 S d − 1 \mathbb{S}^{d-1} Sd−1
为了应用 Cramér-Wold 定理,SIGReg 首先生成 M M M 个随机的单位范数方向 u ( m ) \mathbf{u}^{(m)} u(m)。这些方向是从 d − 1 d-1 d−1 维单位超球面 S d − 1 \mathbb{S}^{d-1} Sd−1 上均匀采样的。 什么意思?单位范数方向什么意思? d − 1 d-1 d−1 维单位超球面 S d − 1 \mathbb{S}^{d-1} Sd−1是什么?
我来解释这些数学概念,它们涉及概率论/统计中的分布表征和高维几何:
1. 单位范数方向 (Unit Norm Direction)
“方向” 指的是一个向量,“单位范数” 意味着这个向量的长度为 1。
数学上,向量 u \mathbf{u} u 满足:
∥ u ∥ 2 = 1 \|\mathbf{u}\|_2 = 1 ∥u∥2=1
即 u 1 2 + u 2 2 + ⋯ + u d 2 = 1 \sqrt{u_1^2 + u_2^2 + \cdots + u_d^2} = 1 u12+u22+⋯+ud2=1
直观理解:想象你站在原点,向某个方向射出一支箭,这支箭的长度恰好是 1。这个"方向"就是单位范数方向。
为什么需要单位范数? 因为 Cramér-Wold 定理关心的是投影后的分布,如果向量长度不为 1,投影会同时受到方向和长度两个因素的影响。固定长度为 1,就只剩下"方向"这一个变量。
2. d − 1 d-1 d−1 维单位超球面 S d − 1 \mathbb{S}^{d-1} Sd−1
这是**高维空间中的"球面"**概念:
| 维度 | 符号 | 直观例子 | 方程 |
|---|---|---|---|
| 2维 | S 1 \mathbb{S}^1 S1 | 单位圆(平面上的圆) | x 2 + y 2 = 1 x^2 + y^2 = 1 x2+y2=1 |
| 3维 | S 2 \mathbb{S}^2 S2 | 单位球面(我们熟悉的球表面) | x 2 + y 2 + z 2 = 1 x^2 + y^2 + z^2 = 1 x2+y2+z2=1 |
| d d d维 | S d − 1 \mathbb{S}^{d-1} Sd−1 | d − 1 d-1 d−1 维球面(嵌入在 d d d 维空间中) | ∑ i = 1 d x i 2 = 1 \sum_{i=1}^d x_i^2 = 1 ∑i=1dxi2=1 |
注意:虽然球面"生活"在 d d d 维空间里,但球面本身只有 d − 1 d-1 d−1 个自由度(因为你被约束在 ∥ x ∥ = 1 \|\mathbf{x}\|=1 ∥x∥=1 这个条件上,相当于一个方程约束)。
SIGReg 中的场景:假设你的特征空间是 d d d 维的(比如图像嵌入是 512 维),那么方向向量 u \mathbf{u} u 就在 S d − 1 \mathbb{S}^{d-1} Sd−1 上采样,即 S 511 \mathbb{S}^{511} S511。
3. 均匀采样 (Uniform Sampling)
"均匀"意味着球面上每个区域被采到的概率与它的面积成正比。
在 3D 球面上,不能简单地独立均匀采样 θ \theta θ 和 ϕ \phi ϕ,那样会导致两极聚集。正确的方法是:
- 先采样 z ∼ Uniform [ − 1 , 1 ] z \sim \text{Uniform}[-1, 1] z∼Uniform[−1,1]
- 再采样 θ ∼ Uniform [ 0 , 2 π ] \theta \sim \text{Uniform}[0, 2\pi] θ∼Uniform[0,2π]
- 然后 x = 1 − z 2 cos θ x = \sqrt{1-z^2}\cos\theta x=1−z2cosθ, y = 1 − z 2 sin θ y = \sqrt{1-z^2}\sin\theta y=1−z2sinθ
在高维空间中,通常使用正态分布采样再归一化:
u ∼ N ( 0 , I d ) , u ( m ) = u ∥ u ∥ \mathbf{u} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}_d), \quad \mathbf{u}^{(m)} = \frac{\mathbf{u}}{\|\mathbf{u}\|} u∼N(0,Id),u(m)=∥u∥u
这在高维下近似均匀(因为高维高斯分布集中在球面上)。
4. 为什么这样做?(Cramér-Wold 定理的应用)
Cramér-Wold 定理说:一个 d d d 维随机向量 X \mathbf{X} X 的分布,完全由它在所有方向上的一维投影分布决定。
即:如果 X \mathbf{X} X 和 Y \mathbf{Y} Y 在所有方向 u \mathbf{u} u 上的投影 u ⊤ X \mathbf{u}^\top\mathbf{X} u⊤X 和 u ⊤ Y \mathbf{u}^\top\mathbf{Y} u⊤Y 分布相同,那么 X \mathbf{X} X 和 Y \mathbf{Y} Y 的联合分布也相同。
SIGReg 的实用做法:不是检查"所有"方向(无穷多个),而是:
- 随机采样 M M M 个方向 u ( 1 ) , … , u ( M ) \mathbf{u}^{(1)}, \dots, \mathbf{u}^{(M)} u(1),…,u(M)
- 对每个方向,比较投影后的分布(比如用 MMD 或 Sinkhorn 距离)
- 综合这些一维比较结果来约束高维分布
这样做的好处是:把高维分布匹配问题转化为一维分布匹配问题,计算上更可行。
总结
| 术语 | 含义 |
|---|---|
| 单位范数方向 | 长度为1的向量,只表示方向 |
| S d − 1 \mathbb{S}^{d-1} Sd−1 | d d d 维空间中的单位球面( d − 1 d-1 d−1 维流形) |
| 均匀采样 | 球面上按面积比例均匀随机取点 |
| 目的 | 用 Cramér-Wold 定理将高维分布匹配转化为多个一维投影分布的匹配 |
这是分布对齐/正则化中常见的技巧,避免了直接在高维空间操作带来的计算困难。
通俗解释
一句话总结
SIGReg 想比较两个高维数据分布是否一样,但直接比较高维太麻烦,所以改成"从很多个随机角度分别看过去,比较一维投影分布"。
类比理解
1. 单位范数方向 = 手电筒的光束方向
想象你有一个手电筒,站在原点向黑暗中照射:
- 单位范数 = 手电筒的光束长度固定为 1(不关心照多远,只关心朝向哪个方向)
- 方向 = 手电筒指向哪里(东、西、上、下…)
你随机转动手电筒,指向各个方向。
2. d − 1 d-1 d−1 维单位超球面 = 所有可能方向的集合
| 维度 | 你能指向的范围 | 就像… |
|---|---|---|
| 2维 | 平面上的360°一圈 | 站在地上,只能水平转圈 |
| 3维 | 空间中的任意方向 | 站在空中,可以上下左右任意指 |
| 512维 | 512维空间里的任意方向 | 抽象的高维空间,数学上一样定义 |
为什么叫 d − 1 d-1 d−1 维?
虽然你在 d d d 维空间里,但"方向"只有 d − 1 d-1 d−1 个自由度——就像地球表面(2维)嵌在3维空间里,但你在地球表面只需要经度和纬度两个数就能定位。
3. 均匀采样 = 公平地随机选方向
不是只挑几个喜欢的方向看,而是完全随机、不偏不倚地选 M M M 个方向。
就像:
- ❌ 不公平:只往东边看
- ✅ 均匀:闭眼随机转,指哪算哪
为什么要这样做?(Cramér-Wold 定理)
核心思想:降维打击
问题:直接比较两个高维分布(比如512维图像特征)很难,计算量大,还容易维度灾难。
Cramér-Wold 定理说:
如果两个高维分布在所有方向上的投影都一样,那这两个高维分布本身就一样。
SIGReg 的做法:
- 随机选 M M M 个方向(手电筒指向)
- 把高维数据"压扁"到这些方向上(用手电筒照,看墙上的影子)
- 比较这些一维影子的分布
- 如果所有方向的影子都差不多,那原分布也差不多
具体场景
假设你有两个图像编码器(比如教师模型和学生模型),想让学生学到的特征分布和老师一样:
| 步骤 | 做什么 |
|---|---|
| 1. 随机选方向 | 生成 M M M 个随机单位向量 u ( m ) \mathbf{u}^{(m)} u(m) |
| 2. 投影 | 教师特征 z t \mathbf{z}_t zt 投影: u ⊤ z t \mathbf{u}^\top \mathbf{z}_t u⊤zt(得到一个数) |
| 3. 投影 | 学生特征 z s \mathbf{z}_s zs 投影: u ⊤ z s \mathbf{u}^\top \mathbf{z}_s u⊤zs(得到一个数) |
| 4. 比较 | 这两个一维数列的分布是否相同? |
| 5. 重复 | 换另一个方向,再比较… |
最终目标:让所有方向上的投影分布都匹配。
最直观的类比
想象你要比较两座形状复杂的高山(两个高维分布)是否一样:
- 直接比:把两座山3D扫描,逐点比较——太麻烦
- Cramér-Wold 方法:从山脚下随机选 M M M 个位置,用手电筒往山上照,看墙上的轮廓影子。如果所有角度的影子都一样,那两座山就是一样的。
总结
| 术语 | 通俗说法 |
|---|---|
| 单位范数方向 | 长度为1的方向向量(手电筒指向) |
| S d − 1 \mathbb{S}^{d-1} Sd−1 | 所有可能方向的集合(方向球) |
| 均匀采样 | 公平随机地选方向 |
| Cramér-Wold 定理 | “所有角度的影子都一样 → 原物体一样” |
| SIGReg 的目的 | 用很多个一维比较,代替一个高维比较,省事又有效 |
回到原文
Additional details on SIGReg and the definition of the Epps–Pulley statistical test are provided in appendix A.
SIGReg 的附加细节以及 Epps–Pulley 统计检验的定义在附录 A 中提供。
The complete LeWM training objective is defined as:
完整的LeWM训练目标定义为:
L LeWM ≜ L pred + λ SIGReg ( Z ) . (3) \mathcal{L}_{\text{LeWM}} \triangleq \mathcal{L}_{\text{pred}} + \lambda \text{SIGReg}(\mathbf{Z}). \tag{3} LLeWM≜Lpred+λSIGReg(Z).(3)
【公式含义】:
该公式定义了 LeWorldModel (LeWM) 的完整训练目标函数,它旨在通过结合未来嵌入预测损失和潜在空间正则化损失来学习环境的动态模型,同时有效地防止表示坍塌。
【符号解释】:
- L LeWM \mathcal{L}_{\text{LeWM}} LLeWM: 表示 LeWorldModel 的总训练目标函数(Total Training Objective)。模型训练的目标是最小化此函数。
- ≜ \triangleq ≜: 定义符号,表示左侧的项被定义为右侧的表达式。
- L pred \mathcal{L}_{\text{pred}} Lpred: 表示预测损失(Prediction Loss)。根据论文中的公式 (1) 定义,它计算的是预测的下一时间步潜在嵌入 z ^ t + 1 \hat{\mathbf{z}}_{t+1} z^t+1 与真实的下一时间步潜在嵌入 z t + 1 \mathbf{z}_{t+1} zt+1 之间的均方误差(MSE)。
L pred ≜ ∣ ∣ z ^ t + 1 − z t + 1 ∣ ∣ 2 , 其中 z ^ t + 1 = pred ϕ ( z t , a t ) . \mathcal{L}_{\text{pred}} \triangleq ||\hat{\mathbf{z}}_{t+1} - \mathbf{z}_{t+1}||^2, \quad \text{其中 } \hat{\mathbf{z}}_{t+1} = \text{pred}_{\phi}(\mathbf{z}_t, \mathbf{a}_t). Lpred≜∣∣z^t+1−zt+1∣∣2,其中 z^t+1=predϕ(zt,at).- z ^ t + 1 \hat{\mathbf{z}}_{t+1} z^t+1: 预测器根据当前潜在状态 z t \mathbf{z}_t zt 和动作 a t \mathbf{a}_t at 预测的下一时间步潜在嵌入。
- z t + 1 \mathbf{z}_{t+1} zt+1: 编码器对实际的下一时间步观测数据编码得到的真实潜在嵌入。
- ∣ ∣ ⋅ ∣ ∣ 2 ||\cdot||^2 ∣∣⋅∣∣2: 表示L2范数(欧几里得距离)的平方,通常用于计算均方误差。
- λ \lambda λ: 表示正则化权重(Regularization Weight),是一个可调节的超参数,用于平衡预测损失 L pred \mathcal{L}_{\text{pred}} Lpred 和正则化损失 SIGReg ( Z ) \text{SIGReg}(\mathbf{Z}) SIGReg(Z) 在总损失中的相对重要性。论文中默认设置为 λ = 0.1 \lambda = 0.1 λ=0.1。
- SIGReg ( Z ) \text{SIGReg}(\mathbf{Z}) SIGReg(Z): 表示 Sketched-Isotropic-Gaussian Regularizer 损失函数。它作用于潜在嵌入张量 Z \mathbf{Z} Z,其核心目的是强制潜在嵌入的分布符合各向同性高斯(isotropic Gaussian)目标分布,从而防止表示坍塌并促进特征多样性。该损失的详细定义在论文的公式 (2) 和附录 A 中给出。
SIGReg ( Z ) ≜ 1 M ∑ m = 1 M T ( h ( m ) ) , 其中 h ( m ) = Z u ( m ) . \text{SIGReg}(\mathbf{Z}) \triangleq \frac{1}{M} \sum_{m=1}^{M} T(\mathbf{h}^{(m)}), \quad \text{其中 } \mathbf{h}^{(m)} = \mathbf{Z}\mathbf{u}^{(m)}. SIGReg(Z)≜M1m=1∑MT(h(m)),其中 h(m)=Zu(m).- Z ∈ R N × B × d \mathbf{Z} \in \mathbb{R}^{N \times B \times d} Z∈RN×B×d: 潜在嵌入张量,包含批次中所有样本在一段时间步内的潜在嵌入。
- M M M: 随机单位范数方向的数量。
- T ( h ( m ) ) T(\mathbf{h}^{(m)}) T(h(m)): Epps-Pulley 统计检验函数,用于评估一维投影数据 h ( m ) \mathbf{h}^{(m)} h(m) 与标准高斯分布的匹配程度。
- h ( m ) \mathbf{h}^{(m)} h(m): 潜在嵌入 Z \mathbf{Z} Z 沿着第 m m m 个随机单位范数方向 u ( m ) \mathbf{u}^{(m)} u(m) 投影后得到的一维数据。
【公式解释】:
这个公式不是一个推导公式,而是 LeWorldModel 训练过程中所使用的总损失函数定义。它结合了两个核心组件来优化模型:
1. 背景知识:
- 联合嵌入预测架构 (JEPA):JEPA 旨在学习一个世界模型,通过预测未来观测的潜在表示来捕捉环境动态。这种方法避免了对像素空间进行显式生成,从而可能更高效。
- 表示坍塌 (Representation Collapse):JEPA 方法面临的一个主要挑战是表示坍塌。在这种失败模式中,模型会将所有输入映射到几乎相同的潜在表示,从而 trivially 满足预测目标,但学到的表示却无法用于下游任务。
- 自监督学习中的正则化:为了避免表示坍塌,自监督学习方法通常需要引入各种正则化技术,如对比损失、不变性损失、方差损失或停止梯度等启发式方法。然而,这些方法可能引入额外的超参数、训练不稳定性和复杂性。
2. 公式的工作原理:
公式 (3) 将 LeWM 的训练目标分解为两个关键部分,并通过一个超参数 λ \lambda λ 进行加权:
-
核心预测任务( L pred \mathcal{L}_{\text{pred}} Lpred):
- 这是 LeWM 的主要学习信号,它鼓励编码器学习能够使预测器准确预测未来潜在状态的表示。
- 具体而言,编码器将当前观测 o t o_t ot 映射到潜在嵌入 z t \mathbf{z}_t zt,预测器则根据 z t \mathbf{z}_t zt 和当前动作 a t a_t at 预测下一个潜在嵌入 z ^ t + 1 \hat{\mathbf{z}}_{t+1} z^t+1。
- L pred \mathcal{L}_{\text{pred}} Lpred 的最小化意味着模型能够有效地捕捉环境的动力学,使得模型对未来状态的预测尽可能准确。
- 设计动机:通过预测未来嵌入而非原始像素,模型被鼓励关注环境中最相关的、对预测有用的特征,而忽略不必要的细节。
-
防止表示坍塌的正则化( λ SIGReg ( Z ) \lambda \text{SIGReg}(\mathbf{Z}) λSIGReg(Z)):
- 仅仅依靠 L pred \mathcal{L}_{\text{pred}} Lpred 很容易导致表示坍塌,即编码器将所有输入都映射到常量表示。为了解决这个问题,LeWM 引入了 Sketched-Isotropic-Gaussian Regularizer (SIGReg)。
- SIGReg 的核心思想:它强制潜在嵌入 Z \mathbf{Z} Z 的分布匹配一个各向同性高斯目标分布。各向同性高斯分布的特点是在所有维度上均值均为零,方差相同,且各维度之间不相关。强制潜在嵌入符合这种分布,可以促进特征的多样性,确保潜在空间中的信息是丰富且去相关的。
- 实现机制:SIGReg 利用 Cramér-Wold 定理,通过将高维潜在嵌入投影到多个随机的一维方向上,然后对每个一维投影应用 Epps-Pulley 统计检验来评估其与标准高斯分布的匹配程度。最小化这些单维检验统计量的平均值,等价于鼓励整个高维潜在分布接近各向同性高斯。
- 设计动机:与现有 JEPA 方法通常依赖多项复杂的损失函数、启发式正则化或预训练编码器来防止坍塌不同,SIGReg 提供了一种简单、可扩展且稳定的单项正则化方法,避免了引入额外的训练不稳定性和超参数。
3. LeWM 的整体训练流程:
- 数据输入:模型接收原始像素观测序列 O 1 : T O_{1:T} O1:T 和动作序列 a 1 : T a_{1:T} a1:T。
- 编码:编码器将每个观测 o t o_t ot 映射到低维潜在表示 z t \mathbf{z}_t zt。
- 预测:预测器根据当前的潜在状态 z t \mathbf{z}_t zt 和动作 a t a_t at 预测下一个潜在状态 z ^ t + 1 \hat{\mathbf{z}}_{t+1} z^t+1。
- 计算损失:
- 计算预测损失 L pred \mathcal{L}_{\text{pred}} Lpred,衡量 z ^ t + 1 \hat{\mathbf{z}}_{t+1} z^t+1 与真实 z t + 1 \mathbf{z}_{t+1} zt+1 之间的差异。
- 计算 SIGReg 损失 SIGReg ( Z ) \text{SIGReg}(\mathbf{Z}) SIGReg(Z),衡量潜在嵌入 Z \mathbf{Z} Z 的分布与各向同性高斯分布的匹配程度。
- 总损失优化:将这两个损失项按权重 λ \lambda λ 相加,得到总损失 L LeWM \mathcal{L}_{\text{LeWM}} LLeWM,并通过反向传播优化编码器和预测器的所有参数。
这个两项损失设计是 LeWM 稳定训练和学习有效世界模型的关键,它简化了训练过程,减少了需要调整的超参数,并提供了形式化的反坍塌保证。
【总结】:
公式 L LeWM ≜ L pred + λ SIGReg ( Z ) \mathcal{L}_{\text{LeWM}} \triangleq \mathcal{L}_{\text{pred}} + \lambda \text{SIGReg}(\mathbf{Z}) LLeWM≜Lpred+λSIGReg(Z) 简洁而有效地定义了 LeWorldModel 的训练目标。它将准确预测未来潜在状态(通过 L pred \mathcal{L}_{\text{pred}} Lpred)作为核心任务,同时通过强制潜在嵌入符合各向同性高斯分布(通过 λ SIGReg ( Z ) \lambda \text{SIGReg}(\mathbf{Z}) λSIGReg(Z))来解决表示坍塌的根本问题。这种设计使得 LeWM 能够从原始像素端到端稳定训练,学习到具有物理意义的紧凑潜在表示,而无需复杂的启发式方法或多项损失项。
回到原文
Algorithm 1. Pseudo-code for the training procedure of LeWorldModel. Pixel observations are encoded into latent embeddings, and a predictor estimates the dynamics by predicting the next-step embedding conditioned on actions. The model is optimized end-to-end using a next-embedding prediction loss together with a step-wise SIGReg regularization term to prevent representation collapse.
算法 1 LeWorldModel 训练过程的伪代码。像素观测被编码为潜在嵌入,并且预测器通过在给定动作的情况下预测下一步嵌入来估计动力学。该模型通过下一步嵌入预测损失以及分步 SIGReg 正则化项进行端到端优化,以防止表示坍塌。
def LeWorldModel(obs, actions, lambd=0.1):
"""
obs: (B, T, C, H, W) raw pixels sequence
actions: (B, T, A) action sequence
lambd: (float) SIGReg loss weight
"""
emb = encoder(obs) # (B, T, D)
next_emb = predictor(emb, actions) #(B, T, D)
# — LeWorldModel training loss
# next-embedding prediction loss
pred_loss = F.mse_loss(emb[:, 1:] - next_emb[:, :-1])
# step-wise sigreg (anti-collapse)
sigreg_loss = mean(SIGReg(emb.transpose(0, 1)))
return pred_loss + lambd * sigreg_loss
The method introduces only two training hyperparameters: the number of random projections M M M used in SIGReg and the regularization weight λ \lambda λ . Unless otherwise specified, we use M = 1024 M = 1024 M=1024 projections and λ = 0.1 \lambda = 0.1 λ=0.1 . In practice, we observe that the number of projections has negligible impact on downstream performance (see Sec. 4 and App. G), making λ \lambda λ the only effective hyperparameter to tune.
该方法仅引入两个训练超参数:SIGReg中使用的随机投影数量 M M M 和正则化权重 λ \lambda λ 。除非另有说明,我们使用 M = 1024 M = 1024 M=1024 个投影和 λ = 0.1 \lambda = 0.1 λ=0.1 。实践中,我们观察到投影数量对下游性能影响可忽略不计(见第4节和附录G),使得 λ \lambda λ 成为唯一有效的待调超参数。
This greatly simplifies hyperparameter selection, as λ \lambda λ can be efficiently optimized using a simple bisection search with logarithmic complexity. We do not employ stop-gradient, exponential moving averages, or additional stabilization heuristics. Gradients are propagated through all components of the loss, and all parameters are optimized jointly in an end-to-end manner, resulting in a streamlined and easy-to-implement training procedure.
这大大简化了超参数的选择,因为可以使用具有对数复杂度的简单二分搜索来有效地优化 λ \lambda λ。我们不采用停止梯度、指数移动平均或额外的稳定启发式方法。梯度会传播到损失的所有组成部分,并且所有参数都以端到端的方式联合优化,从而实现了简化且易于实现的训练过程。
The training logic is summarized in Alg. 1.
训练逻辑总结于算法 1。
代码解释
这段代码是 LeWorldModel (LeWM) 的训练流程,我来逐行解释:
整体架构
原始视频帧 → [Encoder] → 嵌入向量 → [Predictor] → 预测下一帧嵌入
↑___________________________________________↓
比较预测 vs 实际,计算损失
逐行解释
输入
obs: (B, T, C, H, W) # B=批次, T=时间步, C=通道(3=RGB), H=高, W=宽
actions: (B, T, A) # A=动作维度
obs 是一段视频(比如连续 T 帧),actions 是对应的动作指令。
1. 编码器(Encoder)
emb = encoder(obs) # (B, T, D)
把原始像素视频压缩成嵌入向量序列:
- 输入:视频帧
(B, T, 3, H, W) - 输出:特征
(B, T, D),每帧变成一个 D 维向量
类似把每帧图片"翻译"成神经网络能理解的数字指纹。
2. 预测器(Predictor)
next_emb = predictor(emb, actions) # (B, T, D)
根据当前嵌入 + 动作,预测下一时刻的嵌入:
- 输入:当前特征序列 + 动作序列
- 输出:预测的未来特征序列
想象成"看了视频前几帧 + 知道动作,猜下一帧长什么样"。
3. 预测损失(主任务)
pred_loss = F.mse_loss(emb[:, 1:] - next_emb[:, :-1])
| 符号 | 含义 |
|---|---|
emb[:, 1:] |
真实的第 1,2,3…T 帧嵌入(去掉第0帧) |
next_emb[:, :-1] |
预测的第 1,2,3…T 帧嵌入(基于 0,1,2…T-1 预测) |
emb[:, 1:] - next_emb[:, :-1] |
真实 vs 预测的差距 |
MSE = 均方误差,让预测尽量接近真实。
就像让你看视频前10帧,猜第11帧,然后看实际第11帧,算猜错了多少。
4. SIGReg 正则化(防止坍缩)
sigreg_loss = mean(SIGReg(emb.transpose(0, 1)))
这是关键!防止模型偷懒作弊:
如果不加这个,模型可能"坍缩"(Collapse):
- 作弊方式:把所有帧都编码成同一个向量(比如全0)
- 结果:
emb[:, t]都一样,预测器啥也不用学,直接输出同一个向量 - 损失
pred_loss也能很低,但模型什么都没学到!
SIGReg 的作用:
- 强制不同时间步的嵌入分布保持"丰富"(高方差、非坍缩)
- 用之前讲的 Cramér-Wold 定理,从多个随机方向检查分布是否退化
emb.transpose(0, 1) # 从 (B, T, D) → (T, B, D)
# 按时间步分组,每个时间步有 B 个样本
对每个时间步的嵌入集合做 SIGReg,确保它们不是都挤在一起。
5. 总损失
return pred_loss + lambd * sigreg_loss
| 损失 | 作用 | 权重 |
|---|---|---|
pred_loss |
学预测下一帧 | 1.0 |
sigreg_loss |
防止嵌入坍缩 | lambd (默认0.1) |
为什么需要 SIGReg?(直观例子)
假设你在学"预测小球运动":
| 情况 | 嵌入空间 | 结果 |
|---|---|---|
| ✅ 正常 | 不同位置的小球 → 不同向量 | 能预测运动 |
| ❌ 坍缩 | 所有位置 → 同一个向量 | 预测器无法区分状态,学不到物理 |
SIGReg 就像老师盯着学生说:“不许偷懒!每个状态必须给我不同的表示!”
总结流程图
视频帧 obs ──┐
├──→ [Encoder] ──→ emb (B,T,D)
动作 actions ─┘ │
↓
[Predictor] ──→ next_emb (预测的未来)
│
┌────────────────────────┘
↓
真实 emb[:, 1:] vs 预测 next_emb[:, :-1]
│
├──→ pred_loss (MSE,让预测准)
│
└──→ sigreg_loss (SIGReg,让表示丰富不坍缩)
↓
总损失 = pred_loss + 0.1 * sigreg_loss
一句话总结
LeWorldModel 学"看视频猜未来",但光靠猜不准不够,还要用 SIGReg 逼着模型把每个时刻的状态编码得不一样,防止它偷懒把所有帧都当成同一个东西处理。
问题
next_emb[:, :-1]) 这个是什么意思啊?:-1是不包括最后一帧吗?
是的,你的理解完全正确!
[:, :-1] 的含义
这是 Python 切片语法:
| 写法 | 意思 |
|---|---|
a[start:end] |
取从 start 到 end-1(左闭右开) |
a[:-1] |
从开头取到倒数第二个(不包括最后一个) |
a[1:] |
从第二个取到末尾(不包括第一个) |
代码中的对应关系
emb[:, 1:] # 真实帧:第 1, 2, 3, ..., T-1 帧(去掉第0帧)
next_emb[:, :-1] # 预测帧:第 1, 2, 3, ..., T-1 帧(去掉最后一帧)
为什么要错位?
时间轴: t=0 t=1 t=2 ... t=T-1
──── ──── ──── ────
emb: 帧0 帧1 帧2 ... 帧T-1
↓ ↓ ↓ ↓
next_emb: 预测1 预测2 预测3 ... 预测T (基于当前帧+动作,预测下一帧)
对齐方式:
| 真实目标 | 预测输出 | 基于什么预测 |
|---|---|---|
emb[:, 1] (帧1) |
next_emb[:, 0] (预测1) |
帧0 + 动作0 |
emb[:, 2] (帧2) |
next_emb[:, 1] (预测2) |
帧1 + 动作1 |
| … | … | … |
所以:
emb[:, 1:]→ 真实的目标帧(从第1帧开始)next_emb[:, :-1]→ 对应的预测帧(去掉最后一个无目标的预测)
为什么去掉最后一个预测?
next_emb[:, T-1] 是基于最后一帧 emb[:, T-1] 预测的"第 T 帧",但视频里没有第 T 帧(只有 0 到 T-1),所以没有真实值可以比较,直接丢弃。
一句话总结
:-1就是不要最后一个,因为最后一个预测没有对应的"下一帧真实值"来比较。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐

所有评论(0)