一、前言

仅供参考,未经实验验证。

二、LeWorldModel论文

论文标题:LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels(LeWorldModel: 从像素到像素的稳定端到端联合嵌入预测架构)
作者:Lucas Maes*, Quentin Le Lidec*, Damien Scieur, Yann LeCun, Randall Balestriero
论文地址: https://arxiv.org/pdf/2603.19312
GitHub 地址: https://github.com/lucas-maes/le-wm

3 方法:LeWorldModel

In this section, we introduce LeWorldModel (LeWM). We first describe the streamlined training procedure used to learn the latent world model from offline data, including the dataset, model architecture, and training objective. We then explain how the learned model can be leveraged for decision making through latent planning using model predictive control (MPC).
在本节中,我们介绍LeWorldModel(LeWM)。我们首先描述用于从离线数据学习潜在世界模型的优化训练过程,包括数据集、模型架构和训练目标。然后,我们解释如何通过使用模型预测控制(MPC)的潜在规划来利用学习到的模型进行决策。

3.1 学习潜在世界模型

Offline Dataset. We consider a fully offline and reward-free setting. LeWorldModel is trained solely from unannotated trajectories of observations and actions, without access to reward signals or task specifications. This setup aligns with the JEPA line of work [18, 14], which aims to learn generic, task-agnostic world models from observational data.
离线数据集。我们考虑一个完全离线且无奖励的设置。LeWorldModel 仅从无标注的观测和动作轨迹中进行训练,无法访问奖励信号或任务规范。这种设置与 JEPA 系列工作 [18, 14] 一致,该系列工作旨在从观测数据中学习通用的、与任务无关的世界模型。

Our objective is not to optimize behavior for a specific task, but to learn representations that capture environment dynamics and can later be controlled or adapted to a diverse set of tasks.
我们的目标不是优化特定任务的行为,而是学习能够捕捉环境动力学并随后可用于控制或适应各种任务的表示。

The training data consists of trajectories of length T T T composed of raw pixel observations o 1 : T o_{1:T} o1:T and associated actions a 1 : T a_{1:T} a1:T . Trajectories are collected offline from behavior policies with no optimality requirements; they may be pseudo-expert or exploratory, as long as they sufficiently cover the environment dynamics.
训练数据由长度为 T T T的轨迹组成,这些轨迹包含原始像素观测值 o 1 : T o_{1:T} o1:T和相关的动作 a 1 : T a_{1:T} a1:T。轨迹是从行为策略离线收集的,不要求最优性;只要它们充分覆盖环境动力学,就可以是伪专家或探索性的。

Additional implementation details (batch size, resolution, and sub-trajectory construction) are provided in App. D.
其他实现细节(批大小、分辨率和子轨迹构建)在附录 D 中提供。

Model Architecture. LeWM is built upon two components: an encoder and a predictor. The encoder maps a given frame observation o t \mathbf{o}_t ot into a compact, low-dimensional latent representation z t \mathbf{z}_t zt . The predictor models the environment dynamics in latent space by predicting the embedding of the next frame observation z ^ t + 1 \hat{\mathbf{z}}_{t+1} z^t+1 given the latent embedding z t \mathbf{z}_t zt and an action a t \mathbf{a}_t at .
模型架构。LeWM 由两个组件构成:一个编码器和一个预测器。编码器将给定的帧观测 o t \mathbf{o}_t ot 映射到一个紧凑的、低维的潜在表示 z t \mathbf{z}_t zt。预测器通过在给定潜在嵌入 z t \mathbf{z}_t zt和动作 a t \mathbf{a}_t at 的情况下预测下一帧观测的嵌入 z ^ t + 1 \hat{\mathbf{z}}_{t+1} z^t+1 来模拟潜在空间中的环境动力学。
Encoder:  z t = enc θ ( o t ) Predictor:  z ^ t + 1 = pred ϕ ( z t , a t ) (LeWM) \begin{aligned} \text{Encoder: } \mathbf{z}_t &= \text{enc}_\theta(\mathbf{o}_t) \\ \text{Predictor: } \hat{\mathbf{z}}_{t+1} &= \text{pred}_\phi(\mathbf{z}_t, \mathbf{a}_t) \end{aligned} \tag{LeWM} Encoder: ztPredictor: z^t+1=encθ(ot)=predϕ(zt,at)(LeWM)

【公式含义】:

这两个公式共同定义了 LeWorldModel (LeWM) 的核心架构,它是一个联合嵌入预测架构 (Joint Embedding Predictive Architecture, JEPA)。它们描述了 LeWM 如何将原始像素观测转换为紧凑的潜在表示,并在此潜在空间中预测环境的动态,是构建世界模型的基础。

【符号解释】:

  • z t \mathbf{z}_t zt: 表示在时间步 t t t 时的潜在状态(latent representation)潜在嵌入(latent embedding)。它是一个低维的向量,捕捉了原始高维视觉观测 o t \mathbf{o}_t ot 的关键信息。
  • enc θ ( ⋅ ) \text{enc}_\theta(\cdot) encθ(): 表示**编码器(Encoder)**函数,其参数为 θ \theta θ。编码器的作用是将原始像素观测 o t \mathbf{o}_t ot 映射到低维的潜在空间,生成潜在状态 z t \mathbf{z}_t zt
  • o t \mathbf{o}_t ot: 表示在时间步 t t t 时的原始像素观测(raw pixel observation),例如图像帧。
  • z ^ t + 1 \hat{\mathbf{z}}_{t+1} z^t+1: 表示在时间步 t + 1 t+1 t+1预测的潜在状态。这是由预测器根据当前潜在状态和动作估计出的未来潜在状态。
  • pred ϕ ( ⋅ , ⋅ ) \text{pred}_\phi(\cdot, \cdot) predϕ(,): 表示**预测器(Predictor)**函数,其参数为 ϕ \phi ϕ。预测器的作用是模拟环境动态,根据当前潜在状态 z t \mathbf{z}_t zt 和在时间步 t t t 采取的动作 a t \mathbf{a}_t at,预测下一个时间步 t + 1 t+1 t+1 的潜在状态 z ^ t + 1 \hat{\mathbf{z}}_{t+1} z^t+1
  • a t \mathbf{a}_t at: 表示在时间步 t t t采取的动作(action)
  • θ \theta θ: 编码器 enc \text{enc} enc 的可学习参数集合。
  • ϕ \phi ϕ: 预测器 pred \text{pred} pred 的可学习参数集合。
  • (LeWM): 表示 LeWorldModel,即论文中提出的稳定端到端联合嵌入预测架构。

【公式解释】

这两个公式是 LeWorldModel 整个训练和规划过程的基石,它们定义了模型的两个核心组成部分及其功能。

背景知识:

  • 世界模型 (World Models, WMs):旨在学习环境的动态预测模型,使智能体能够在“想象空间”中规划和改进自身行为。
  • 联合嵌入预测架构 (Joint Embedding Predictive Architectures, JEPAs):一种学习世界模型的方法,它不是试图建模环境的每一个方面,而是聚焦于捕获预测未来状态所需的最相关特征。JEPA 通过将观测编码到紧凑、低维的潜在空间中,并在此空间中建模时间动态来预测未来的潜在表示。
  • 表示坍塌 (Representation Collapse):JEPA 方法训练中常见的一个问题,指模型将所有输入都映射到几乎相同的表示,导致潜在空间失去信息多样性,变得无用。

公式拆解与解释:

  1. 编码器公式: z t = enc θ ( o t ) \mathbf{z}_t = \text{enc}_\theta(\mathbf{o}_t) zt=encθ(ot)

    • 功能: 这个公式描述了编码器的作用。它接收一个原始的像素观测 o t \mathbf{o}_t ot(例如,来自摄像头的图像),然后通过一个参数为 θ \theta θ 的神经网络(即 enc θ \text{enc}_\theta encθ)将其转换为一个紧凑的、低维的潜在表示 z t \mathbf{z}_t zt
    • 设计动机:
      • 降维与特征提取: 原始像素数据通常维度非常高,直接处理计算成本大。编码器将高维像素数据压缩到低维潜在空间,同时提取出环境中与动态预测最相关的语义特征物理结构(如物体位置、形状等)。这使得后续的预测和规划任务能够在更高效、更抽象的层次上进行。
      • 去除冗余信息: 潜在表示旨在去除原始视觉观测中的冗余细节,只保留对理解环境状态和预测未来至关重要的信息。
    • 在 LeWM 中的实现: 根据论文第5页“Model Architecture”部分的描述,编码器被实现为一个Vision Transformer (ViT) 模型(具体为 ViT-Tiny 配置),它将输入帧编码为低维潜在表示。
  2. 预测器公式: z ^ t + 1 = pred ϕ ( z t , a t ) \hat{\mathbf{z}}_{t+1} = \text{pred}_\phi(\mathbf{z}_t, \mathbf{a}_t) z^t+1=predϕ(zt,at)

    • 功能: 这个公式描述了预测器的作用。它接收当前的潜在状态 z t \mathbf{z}_t zt 和在当前时间步采取的动作 a t \mathbf{a}_t at,然后通过一个参数为 ϕ \phi ϕ 的神经网络(即 pred ϕ \text{pred}_\phi predϕ)来预测下一个时间步 t + 1 t+1 t+1潜在状态 z ^ t + 1 \hat{\mathbf{z}}_{t+1} z^t+1
    • 设计动机:
      • 建模环境动态: 预测器是世界模型的“核心”,它学习了环境中状态如何随时间和动作演变。通过预测未来的潜在状态,模型能够理解“如果我执行这个动作,环境会变成什么样”。
      • 在潜在空间进行预测: 在潜在空间中进行预测比在原始像素空间中进行预测更高效且更稳定,因为潜在空间已经抽象出了高层次的语义信息,减少了需要建模的复杂性。
      • 支持规划: 预测器能够“滚动”生成一系列未来的潜在状态(如图4所示),这使得代理可以在不与真实环境交互的情况下,在潜在空间中“想象”和规划一系列动作,以达到某个目标。
    • 在 LeWM 中的实现: 根据论文第5页的描述,预测器被实现为一个Transformer 模型,它将动作通过自适应层归一化 (AdaLN) 融入到预测过程中,并使用时间因果掩码 (temporal causal masking) 自回归地预测下一个帧表示。

LeWM 整体架构联系:
这两个公式共同构成了 LeWM 的端到端训练框架。编码器负责将原始视觉信息转换为可预测的潜在表示,而预测器则利用这些潜在表示来建模环境的动态。它们通过一个预测损失(prediction loss)(例如,均方误差 L pred = ∣ ∣ z ^ t + 1 − z t + 1 ∣ ∣ 2 2 \mathcal{L}_{\text{pred}} = ||\hat{\mathbf{z}}_{t+1} - \mathbf{z}_{t+1}||_2^2 Lpred=∣∣z^t+1zt+122)和正则化损失(regularization loss)(LeWM 中使用 SIGReg)进行联合优化。预测损失促使编码器学习可预测的表示,而正则化损失则通过强制潜在嵌入遵循各向同性高斯分布来防止表示坍塌,从而确保潜在空间的特征多样性稳定性

【总结】:

这两个公式是 LeWorldModel 的核心骨架,它们定义了编码器如何将原始像素观测压缩为有意义的潜在状态,以及预测器如何基于这些潜在状态和动作预测未来的环境动态。通过这种方式,LeWM 能够在紧凑的潜在空间中学习一个稳定、高效且能够进行规划的世界模型,同时通过精心设计的损失函数有效避免了表示坍塌问题。

The encoder is implemented as a Vision Transformer (ViT) [34]. Unless otherwise specified, we use the tiny configuration ( ∼ 5 M \sim 5\text{M} 5M parameters) with a patch size of 14, 12 layers, 3 attention heads, and hidden dimensions of 192. The observation embedding z t \mathbf{z}_t zt is constructed from the [CLS] token embedding of the last layer, followed by a projection step.
编码器实现为一个视觉 Transformer (ViT) [34]。除非另有说明,我们使用 tiny 配置( ∼ 5 M \sim 5\text{M} 5M 参数),补丁大小为 14,12 层,3 个注意力头,隐藏维度为 192。观测嵌入 z t \mathbf{z}_t zt 由最后一层的 [CLS] 标记嵌入构成,然后进行投影。

The projection step maps the [CLS] token embedding into a new representation space using a 1-layer MLP with Batch Normalization [35]. This step is necessary because the final ViT layer applies a Layer Normalization [36], which prevents our anti-collapse objective from being optimized effectively.
该投影步骤使用一个单层 MLP 和批量归一化 [35] 将 [CLS] 标记嵌入映射到一个新的表示空间。此步骤是必需的,因为最终的 ViT 层应用了层归一化 [36],这会阻碍我们的反崩溃目标被有效优化。

问题1:单层 MLP 和批量归一化

【总结】:

这段文本阐述了 LeWorldModel (LeWM) 编码器架构中,在 Vision Transformer (ViT) 最终层之后,通过一个包含单层多层感知机 (MLP) 和批量归一化 (Batch Normalization) 的投影步骤,将 [CLS] 标记嵌入映射到新的表示空间,以解决 ViT 最终层所使用的层归一化 (Layer Normalization) 对模型反崩溃目标优化造成的阻碍。

【解释】:

在 LeWorldModel 的编码器架构中,为了从 Vision Transformer (ViT) 输出的 [CLS] 标记嵌入中提取用于后续任务的稳定且可优化的潜在表示,引入了一个关键的投影步骤。该步骤的具体实现是利用一个单层多层感知机 (MLP) 对 [CLS] 标记嵌入进行变换,并结合批量归一化 (Batch Normalization)。此设计是必要的,因为 ViT 模型的最后一个层通常会应用层归一化 (Layer Normalization)。层归一化是一种在单个样本的特征维度上进行归一化的技术,它倾向于将激活值的均值和方差规范化。然而,这种规范化行为可能与 LeWM 所采用的反崩溃目标(如 SIGReg,它旨在强制潜在嵌入在整个批次上遵循各向同性高斯分布以促进特征多样性)产生冲突,从而阻碍反崩溃目标被有效优化。通过在 ViT 最终层之后引入一个额外的投影层(包含 MLP 和 Batch Normalization),模型能够将 [CLS] 标记嵌入转换到一个新的表示空间,该空间更适合于反崩溃目标的优化,从而确保潜在表示的稳定性和有效性。

【关键词】:

  • 投影步骤 (Projection step):解释编码器中将 [CLS] 标记嵌入映射到新表示空间的过程。
  • 单层 MLP (Single-layer MLP):构成投影步骤的神经网络组件,用于特征变换。
  • 批量归一化 (Batch Normalization):投影步骤中使用的归一化技术,有助于稳定训练和加速收敛。
  • [CLS] 标记嵌入 ([CLS] token embedding):Vision Transformer 模型输出的代表整个序列信息的特殊标记的向量表示。
  • 表示空间 (Representation space):数据被编码或映射到的一个抽象的、低维的向量空间。
  • ViT 层 (ViT layer):Vision Transformer 模型中的一个变换器层。
  • 层归一化 (Layer Normalization):一种在神经网络层中对单个样本的特征维度进行归一化的技术。
  • 反崩溃目标 (Anti-collapse objective):模型训练中旨在防止潜在表示坍塌,确保特征多样性的损失函数或正则化项。

我们用一个更简单的比喻来解释一下:

想象一下,你有一台非常厉害的相机 (Vision Transformer),它能把看到的图片(比如一辆车、一只猫)拍下来,然后不仅仅是保存图片,还会生成一个很精炼的“总结” (CLS 标记嵌入)。这个“总结”就像是图片的“核心信息”,用来代表这张图片。

这台相机在给出“总结”之前,会有一个内部的“调整”过程 (Layer Normalization),这个调整是为了让它自己内部的工作更顺畅,更稳定。

问题来了:
我们现在想用这些“总结”来做一件特殊的事情:我们希望所有的“总结”都足够多样化,不能都长得一样。如果所有图片的“总结”都差不多,那我们就分不清哪张图是车,哪张图是猫了,这就是所谓的“表示坍塌”——信息丢失了,大家看起来都一样。我们有一个**“反坍塌”的规则 (SIGReg)**,专门来检查和强制这些“总结”保持多样性。

但是,相机的那个内部“调整”过程 (Layer Normalization),虽然对相机自己好,却**“干扰”了我们的“反坍塌”规则**。它把“总结”调整得太“平整”了,使得我们的“反坍塌”规则很难有效地去“塑造”这些总结,让它们变得多样化。

所以,我们的解决方案是:
在相机给出“总结”之后,但在我们应用“反坍塌”规则之前,我们加了一个小小的“修饰”步骤 (投影步骤)
这个“修饰”步骤就像是一个**“翻译官” (单层 MLP),它把相机给出的“总结”稍微“重新组织”一下。同时,我们还给这个“翻译官”配备了一个更适合我们“反坍塌”规则的“校准器” (Batch Normalization)**。

这样做的好处是:
经过这个“翻译官”和“校准器”的“修饰”之后,这些“总结”不仅依然包含着图片的核心信息,而且它们变得更容易被我们的“反坍塌”规则所“塑造”,从而保证了我们最终得到的图片“总结”是稳定、有用且多样化的,不会出现信息坍塌的问题。

简而言之:ViT相机内部的“整理”方式(Layer Norm)跟我们想让“总结”保持多样性的“外部规则”(反坍塌目标)有点冲突。所以,我们加了一个中间的“小改造”环节,用一个“小翻译官”(MLP)和另一个“更合适的校准器”(Batch Norm)来重新整理“总结”,让它既保留信息,又能顺利地被我们的“多样化规则”所管理。

问题2:为什么 Layer Normalization 会干扰反坍塌目标(SIGReg)的有效性

首先,我们简要回顾一下 Layer Normalization 和 SIGReg 的核心作用:

  1. Layer Normalization (LN) 的作用:

    • 目标: 稳定神经网络的训练,加速收敛,减少内部协变量漂移。
    • 工作原理:每个样本独立地,沿着特征维度(而不是批次维度)计算均值和方差,然后将该样本的特征向量标准化,使其均值为 0,方差为 1。
    • 关键特性: LN 使得每个样本的特征向量在标准化后,都具有相似的统计特性(均值接近 0,方差接近 1)。它强制性地移除了每个样本特征向量的尺度(scale)和偏移(shift)信息
  2. SIGReg (Sketched-Isotropic-Gaussian Regularizer) 的作用:

    • 目标: 强制潜在嵌入(latent embeddings)服从各向同性高斯分布 (Isotropic Gaussian distribution),从而促进特征多样性,防止表示坍塌。表示坍塌是指模型将所有输入都映射到相似的潜在表示,导致无法区分不同输入。
    • 工作原理(简化): SIGReg 不直接在原始高维空间评估高斯性。相反,它将潜在嵌入 Z 投影到 M 个随机的一维方向上,然后对每个一维投影的分布应用一个正态性检验(Epps-Pulley 检验)。根据 Cramér-Wold 定理,如果所有一维投影都服从高斯分布,那么原始高维分布也服从高斯分布。
    • 关键特性: SIGReg 关心的是潜在嵌入的整体统计分布(跨越批次中的所有样本),而不是单个样本的统计特性。它需要潜在空间有足够的“自由度”来形成一个高斯形状的分布,这意味着不同样本的嵌入需要有足够的差异性,并且这些差异性需要以**特定的统计方式(高斯)**进行分布。

Layer Normalization 如何干扰 SIGReg?

Layer Normalization 的“每个样本独立标准化”的特性与 SIGReg 的“整体分布高斯化”目标之间存在根本性冲突。具体来说:

  1. 去除尺度信息,阻碍多样性:

    • LN 的影响: LN 将每个潜在嵌入向量强制缩放到单位方差。这意味着在 LN 之后,所有样本的嵌入向量在各自的特征维度上都失去了其原始的尺度信息,它们都被“压”到了一个近似单位球面上。
    • 对 SIGReg 的干扰: SIGReg 的目标是让整个批次的潜在嵌入在各个方向上都呈现高斯分布。高斯分布的特征是其方差,即数据点在均值周围的散布程度。如果每个样本都被 LN 强制为单位方差,那么整个批次在各个特征维度上的总体方差就会受到限制,或者说其形成高斯分布所需的自然尺度和散布特性被削弱了。SIGReg 难以在这种被 LN 强加了“单位尺度”约束的空间中有效地“塑造”出各向同性高斯分布,因为它需要更大的自由度来控制整个分布的协方差结构和散布范围
  2. 强制零均值,限制分布中心:

    • LN 的影响: LN 还将每个潜在嵌入向量的均值强制设为 0。
    • 对 SIGReg 的干扰: 虽然各向同性高斯分布的均值通常也是 0,但 LN 强制每个独立样本的均值为 0,而不是让整个数据集的均值自然地趋向于 0。这听起来可能微不足道,但当模型试图通过学习来形成一个高斯分布时,它需要潜在嵌入的值域和中心能够自由地调整,以适应数据本身的结构。LN 的强制性均值归零,可能限制了潜在嵌入在学习过程中探索和形成多样化分布的能力。
  3. 弱化样本间的相对差异:

    • LN 的影响: 由于每个样本都独立地被标准化,LN 会削弱样本之间原始的、绝对的数值差异。它更强调每个样本内部的特征模式,而不是样本之间的比较。
    • 对 SIGReg 的干扰: SIGReg 依赖于批次中所有样本的统计信息来评估和优化分布。如果样本之间的相对尺度和差异被 LN 抹平,那么 SIGReg 就更难捕捉到足以形成高斯分布所需的全局统计特征。换句话说,SIGReg 需要看到数据点之间更“原始”的散布和关系,以便将其推向高斯形状。LN 就像给每个数据点穿上了统一的“制服”,使得 SIGReg 难以分辨它们在整体分布中的“个性”和“位置”。

论文中的解决方案及其原因:

论文中提到:

“The projection step maps the [CLS] token embedding into a new representation space using a 1-layer MLP with Batch Normalization [35]. This step is necessary because the final ViT layer applies a Layer Normalization [36], which prevents our anti-collapse objective from being optimized effectively.”

这意味着:

  1. ViT 输出层有 Layer Normalization: ViT 模型的最后一层通常会包含 Layer Normalization。因此,直接从 ViT 的 CLS 令牌嵌入中获取的特征,已经经过了 LN 处理。
  2. 添加投影层(MLP + Batch Normalization): 在 ViT 输出的 CLS 令牌嵌入之后,添加了一个由单层 MLP 和 Batch Normalization 组成的投影层。
  3. Batch Normalization 的作用: Batch Normalization(BN)与 LN 不同,它是在批次维度上进行标准化。这意味着它会计算整个批次在每个特征维度上的均值和方差,然后进行标准化。BN 在一定程度上保留了样本间的相对差异,并且它自身也具有正则化作用,有助于防止坍塌。
  4. 解决冲突: 通过将经过 ViT 的 LN 处理的嵌入再次通过一个带有 Batch Normalization 的 MLP 投影层,模型实际上是在为 SIGReg 重新提供一个更适合其作用机制的输入。BN 的批次维度标准化特性,允许 SIGReg 在一个新的表示空间中,更好地捕捉和塑造整个批次数据的全局统计分布,使其趋向于各向同性高斯分布,从而有效地实现反坍塌。

总结来说: Layer Normalization 强制每个样本的特征向量具有相似的统计属性(均值 0,方差 1),这抹去了样本间形成高斯分布所需的关键尺度和散布信息。SIGReg 需要在整个批次的维度上操作,以塑造一个全局的高斯分布。ViT 输出层自带的 LN 使得其输出的嵌入对于 SIGReg 来说“过于规范化”且失去了关键的全局统计自由度。通过引入一个带有 Batch Normalization 的投影层,模型有效地将嵌入转换为一个更“灵活”的表示空间,使得 SIGReg 能够在此空间中高效地施加其反坍塌约束,从而达到稳定且多样化的潜在表示学习。

问题3:这个“投影步骤”中的单层 MLP 和 Batch Normalization 具体是如何帮助“反坍塌”规则的?它们在数学上是如何实现的?

背景回顾:ViT 的 Layer Normalization (LN) 问题

在 LeWorldModel 的架构中,编码器(Encoder)是一个 Vision Transformer (ViT)。ViT 的最后一层通常会应用 Layer Normalization (LN)。如前所述,LN 的核心作用是对每个样本独立地,沿着特征维度进行标准化,使其均值接近 0,方差接近 1。这导致了一个问题:

  • 问题所在: LN 消除了每个样本自身特征向量的尺度和偏移信息。这使得批次中不同样本之间的原始统计差异被削弱,也限制了潜在嵌入在学习过程中形成多样化分布的自由度。而 SIGReg 需要在整个批次的潜在嵌入上施加高斯分布约束,这要求潜在空间有足够的“灵活性”来形成这种全局统计特性。经过 LN 处理的嵌入,其“局部”统计特性过于统一,导致 SIGReg 难以有效优化“全局”高斯分布目标。

投影步骤 (Projection Step) 的作用与数学实现

为了解决上述问题,LeWM 在 ViT 的 CLS 令牌嵌入之后,引入了一个投影步骤,包含一个单层 MLPBatch Normalization (BN)

假设 ViT 最后一层输出的 CLS 令牌嵌入为 z C L S ∈ R D z_{CLS} \in \mathbb{R}^D zCLSRD,其中 D D D 是特征维度。这个 z C L S z_{CLS} zCLS 已经经过了 ViT 内部的 LN 处理。

  1. 单层 MLP (Multi-Layer Perceptron):

    • 作用: MLP 的主要作用是将 ViT 提供的嵌入 z C L S z_{CLS} zCLS 映射到一个新的表示空间。这个映射可以是线性的(如果 MLP 没有激活函数)或非线性的(如果 MLP 包含激活函数)。通过这个映射,MLP 能够学习转换特征,改变特征之间的关系,并可能调整嵌入的维度。
    • 数学实现: 一个单层 MLP 可以表示为:
      h = W ⋅ z C L S + b h = W \cdot z_{CLS} + b h=WzCLS+b
      其中, W ∈ R D ′ × D W \in \mathbb{R}^{D' \times D} WRD×D 是权重矩阵, b ∈ R D ′ b \in \mathbb{R}^{D'} bRD 是偏置向量, D ′ D' D 是新的特征维度(通常与 D D D 相同或不同)。如果 MLP 包含激活函数(如 ReLU),则为:
      h = ReLU ( W ⋅ z C L S + b ) h = \text{ReLU}(W \cdot z_{CLS} + b) h=ReLU(WzCLS+b)
    • 如何帮助反坍塌: MLP 本身并不直接防止坍塌,但它提供了一个可学习的转换,能够将 ViT 的 LN 输出(其统计特性被限制)映射到一个新的空间。这个新空间为接下来的 Batch Normalization 和 SIGReg 提供了更大的灵活性,使得模型可以学习如何“解开”LN 的限制,从而更好地适应 SIGReg 的目标。它允许模型重新学习特征的尺度和相对重要性,为 BN 创造条件。
  2. Batch Normalization (BN):

    • 作用: BN 对**整个批次(Batch)**的输入进行标准化。它计算每个特征维度在整个批次上的均值和方差,然后进行标准化。与 LN 不同,BN 保留了不同样本之间的相对尺度信息,因为它对整个批次采用相同的统计量进行标准化。
    • 数学实现: 假设我们有一个批次 B B B 个样本,经过 MLP 后的输出为 { h 1 , h 2 , … , h B } \{h_1, h_2, \dots, h_B\} {h1,h2,,hB},其中 h i ∈ R D ′ h_i \in \mathbb{R}^{D'} hiRD。对于每个特征维度 j ∈ { 1 , … , D ′ } j \in \{1, \dots, D'\} j{1,,D}
      • 批次均值: μ j = 1 B ∑ i = 1 B h i , j \mu_j = \frac{1}{B} \sum_{i=1}^B h_{i,j} μj=B1i=1Bhi,j
      • 批次方差: σ j 2 = 1 B ∑ i = 1 B ( h i , j − μ j ) 2 \sigma_j^2 = \frac{1}{B} \sum_{i=1}^B (h_{i,j} - \mu_j)^2 σj2=B1i=1B(hi,jμj)2
      • 标准化: h ^ i , j = h i , j − μ j σ j 2 + ϵ \hat{h}_{i,j} = \frac{h_{i,j} - \mu_j}{\sqrt{\sigma_j^2 + \epsilon}} h^i,j=σj2+ϵ hi,jμj ϵ \epsilon ϵ 是为数值稳定性添加的小常数)
      • 缩放和偏移: y i , j = γ j h ^ i , j + β j y_{i,j} = \gamma_j \hat{h}_{i,j} + \beta_j yi,j=γjh^i,j+βj
        其中, γ j \gamma_j γj β j \beta_j βj 是可学习的缩放和偏移参数。最终输出的 y i y_i yi 就是用于 SIGReg 的潜在嵌入 z t z_t zt
    • 如何帮助反坍塌 (SIGReg):
      • 提供全局统计信息: BN 的核心在于它基于整个批次的统计量进行标准化。这意味着它不再像 LN 那样独立地对待每个样本,而是考虑了样本间的整体分布关系。这种全局的标准化为 SIGReg 提供了它所需的工作基础,因为它可以在一个批次一致性更强的空间中,更有效地将整个批次的潜在嵌入推向各向同性高斯分布。
      • 恢复尺度和方差的灵活性: 尽管 BN 也进行标准化,但它通过可学习的 γ \gamma γ β \beta β 参数,允许网络重新学习和调整每个特征维度的尺度和偏移。这使得潜在嵌入在标准化后,依然可以维持足够的差异性和可塑性,以满足 SIGReg 所需的全局高斯分布特征(包括其方差和协方差结构)。LN 强制的单位方差限制被 BN 的批次统计量和可学习参数所“软化”和“取代”。
      • 平滑梯度,稳定训练: BN 众所周知能够平滑损失函数,减少内部协变量漂移,从而稳定训练并允许使用更大的学习率。这对于像 SIGReg 这样旨在塑造复杂高维分布的正则化目标尤其重要,因为稳定的梯度有助于其有效优化。

总结:协同作用

这个投影步骤中的单层 MLP 和 Batch Normalization 共同作用,克服了 ViT 最后一层 LN 对 SIGReg 的干扰:

  1. MLP 提供了特征转换的灵活性,将 ViT 的原始 LN 嵌入映射到一个新的、可能更适合分布塑造的空间。
  2. Batch Normalization 是关键,它通过批次级别的标准化,重新引入了全局统计信息和可学习的尺度/偏移参数。这使得潜在嵌入具备了足够的“自由度”和“可塑性”,让 SIGReg 能够有效地在整个批次上施加其各向同性高斯分布的约束,从而实现有效的反坍塌,促进特征多样性,并最终学习到有意义的潜在表示。

简而言之,BN 提供了一个统计上更“友好”的环境给 SIGReg,让 SIGReg 能够通过操纵批次级的均值和方差(通过学习 γ \gamma γ β \beta β),将潜在嵌入的整体分布有效地引导到各向同性高斯分布,从而防止模型将所有输入映射到相似的表示,避免表示坍塌。


回到原文
The predictor is a transformer with 6 layers, 16 attention heads, and 10% dropout ( ∼ 10 M \sim 10\text{M} 10M parameters). Actions are incorporated into the predictor through Adaptive Layer Normalization (AdaLN) [37] applied at each layer. The AdaLN parameters are initialized to zero to stabilize training and ensure that action conditioning impacts the predictor training progressively.
该预测器是一个具有6层、16个注意力头和10% dropout(约10M参数)的Transformer。动作通过自适应层归一化(AdaLN)[37]被整合到预测器中,该归一化应用于每一层。AdaLN参数被初始化为零,以稳定训练并确保动作条件对预测器训练产生渐进式影响。

The predictor takes as input a history of N N N frame representations and predicts the next frame representation auto-regressively with temporal causal masking to avoid looking at future embeddings. The predictor is also followed by a projector network with the same implementation as the one used for the encoder.
预测器接收 N N N 帧表示的历史作为输入,并通过时间因果掩码自回归地预测下一帧表示,以避免查看未来的嵌入。预测器之后还跟着一个投影器网络,其实现与编码器使用的相同。

All components of our world model are learned jointly using the loss described in the following paragraph.
我们世界模型的所有组成部分都使用下一段中描述的损失函数进行联合学习。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐