项目地址:https://github.com/lucas-maes/le-wm

复现过程

环境配置

uv venv --python=3.10
source .venv/bin/activate
uv pip install stable-worldmodel[train,env]

uv是新一代 Python 环境工具 ,配环境更快

安装指令↓。安装完毕之后需要重启终端

curl -LsSf https://astral.sh/uv/install.sh | sh

或者如果想单纯通过pip安装。如果太慢就使用清华源

pip install "stable-worldmodel[train,env]"
pip install stable-worldmodel[train,env] -i https://pypi.tuna.tsinghua.edu.cn/simple

使用pip安装,中途可能会出现gym包版本不兼容报错的问题,这时候使用gymnasium替代即可

pip install gymnasium

安装完gymnasium,继续运行其他的安装

数据下载

数据下载地址:https://huggingface.co/collections/quentinll/lewm

一共有四个数据集,选一个自己喜欢的。各个数据集的介绍贴在文末了。

点进想下载的数据,然后点击 files and versions就可以看到下载的按钮。下载.h5.zst文件就够了

把下载完的文件丢进~/.stable-wm/里

解压

tar --zstd -xvf archive.tar.zst(文件名)

训练

如果想用单GPU训练的话,可以在指令里加上。

原项目的batch_size是128,感觉过大的话可以自行调整,在train/lewm.yaml里对batch_size进行修改即可

python train.py data=pusht
python train.py data=pusht trainer.devices=1

训练的时候可能会遇见WandB报错。

wandb (Weights & Biases) 是机器学习 / 强化学习领域最主流的 实验跟踪、可视化与协作平台,自动记录超参数、损失曲线、准确率、GPU / 内存占用、代码版本等所有训练细节。如果需要的话,需要在lewm.yaml 里填你的 wandb 账号信息

wandb:
  config:
    entity: your_entity
    project: your_project

当然,WandB需要科学上网,如果服务器的条件不允许的话,也可以去train.py代码里把这部分的相关内容注释掉

    #if cfg.wandb.enabled:
    #    logger = WandbLogger(**cfg.wandb.config)
    #    logger.log_hyperparams(OmegaConf.to_container(cfg))

然后等着训练就行了。如果机子比较烂的话可能训练的时间会比较久

规划评估

如果不改代码的话,训练好的权重文件大概在~/.stable-wm/里,不用改位置。

python eval.py --config-name=pusht.yaml policy=lewm_epoch_99

生成评估的时间也会稍微比较久

可以看到成功率success_rate

{'success_rate': 54.0, 
 'episode_successes': array([ True, True, False, True, False, ... ]),
 ...
}

找生成的视频看看

代码解读

lejepa_forward   前向 + 损失计算

def lejepa_forward(self, batch, stage, cfg):
    ...
    output = self.model.encode(batch)          # 编码器输出
    emb = output["emb"]                        # (B, T, D)  latent embeddings
    ...
    ctx_emb = emb[:, :ctx_len]                 # 上下文(历史帧)
    tgt_emb = emb[:, n_preds:]                 # 真实下一帧标签
    pred_emb = self.model.predict(ctx_emb, ctx_act)  # 预测器预测

    # LeWM loss(论文公式 (1)+(3))
    output["pred_loss"] = (pred_emb - tgt_emb).pow(2).mean()   # L_pred = MSE
    output["sigreg_loss"] = self.sigreg(emb.transpose(0, 1))    # SIGReg
    output["loss"] = output["pred_loss"] + lambd * output["sigreg_loss"]

 数据加载部分(对应论文 3.1 Offline Dataset)

dataset = swm.data.HDF5Dataset(...)          # 离线 HDF5 轨迹
# 图像预处理 + 动作/状态归一化
train_set, val_set = ...                     # 随机划分
train = DataLoader(..., shuffle=True, drop_last=True)

    模型构建部分

    encoder = spt.backbone.utils.vit_hf(...)        # ViT-Tiny(论文默认 ~5M 参数)
    predictor = ARPredictor(...)                    # 6层 Transformer + AdaLN
    action_encoder = Embedder(...)                  # 动作条件化
    projector = MLP(..., norm_fn=BatchNorm1d)       # [CLS] 投影头(关键!)
    predictor_proj = MLP(..., norm_fn=BatchNorm1d)  # 预测器输出投影头

    优化器与训练流程

    world_model = JEPA(...)                     # 把 encoder/predictor/action_encoder 封装
    world_model = spt.Module(
        model=world_model,
        sigreg=SIGReg(**cfg.loss.sigreg.kwargs),   # SIGReg 实例
        forward=partial(lejepa_forward, cfg=cfg),
        optim=optimizers,                          # LinearWarmupCosineAnnealingLR
    )

    数据集解释

    自论文 Figure 5 和 Appendix E 的描述

    环境名称 维度 任务描述 环境画面 / 视觉风格 目标 / 成功条件
    TwoRoom 2D 从一个房间穿过门到达另一个房间的随机目标位置 俯视2D简单场景,两个矩形房间被一堵竖墙隔开,中间有一个小门,背景纯黑或浅色 到达目标位置(另一个房间的随机点)
    PushT 2D 把T形块推到指定的目标位置和角度 俯视2D桌面场景,平整的浅色桌面,T形块在桌面上 T形块的位置和角度与目标完全匹配
    OGBench-Cube 3D 机械臂抓住立方块并移动到指定目标位置 3D真实感机器人场景,有机械臂、桌面、一个立方块,相机固定视角 立方块被放置到目标位置
    Reacher 2D 控制双关节机械臂让末端精确到达目标点 2D平面黑色背景,机械臂在平面上运动,目标点用小圆点标记 机械臂末端与目标点完全对齐(perfect alignment)

      Logo

      AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

      更多推荐