LeWorldModel 复现过程+代码解读
项目地址:https://github.com/lucas-maes/le-wm
复现过程
环境配置
uv venv --python=3.10
source .venv/bin/activate
uv pip install stable-worldmodel[train,env]
uv是新一代 Python 环境工具 ,配环境更快
安装指令↓。安装完毕之后需要重启终端
curl -LsSf https://astral.sh/uv/install.sh | sh
或者如果想单纯通过pip安装。如果太慢就使用清华源
pip install "stable-worldmodel[train,env]"
pip install stable-worldmodel[train,env] -i https://pypi.tuna.tsinghua.edu.cn/simple
使用pip安装,中途可能会出现gym包版本不兼容报错的问题,这时候使用gymnasium替代即可
pip install gymnasium
安装完gymnasium,继续运行其他的安装
数据下载
数据下载地址:https://huggingface.co/collections/quentinll/lewm

一共有四个数据集,选一个自己喜欢的。各个数据集的介绍贴在文末了。
点进想下载的数据,然后点击 files and versions就可以看到下载的按钮。下载.h5.zst文件就够了
把下载完的文件丢进~/.stable-wm/里
解压
tar --zstd -xvf archive.tar.zst(文件名)
训练
如果想用单GPU训练的话,可以在指令里加上。
原项目的batch_size是128,感觉过大的话可以自行调整,在train/lewm.yaml里对batch_size进行修改即可
python train.py data=pusht
python train.py data=pusht trainer.devices=1
训练的时候可能会遇见WandB报错。
wandb (Weights & Biases) 是机器学习 / 强化学习领域最主流的 实验跟踪、可视化与协作平台,自动记录超参数、损失曲线、准确率、GPU / 内存占用、代码版本等所有训练细节。如果需要的话,需要在lewm.yaml 里填你的 wandb 账号信息
wandb:
config:
entity: your_entity
project: your_project
当然,WandB需要科学上网,如果服务器的条件不允许的话,也可以去train.py代码里把这部分的相关内容注释掉
#if cfg.wandb.enabled:
# logger = WandbLogger(**cfg.wandb.config)
# logger.log_hyperparams(OmegaConf.to_container(cfg))
然后等着训练就行了。如果机子比较烂的话可能训练的时间会比较久
规划评估
如果不改代码的话,训练好的权重文件大概在~/.stable-wm/里,不用改位置。
python eval.py --config-name=pusht.yaml policy=lewm_epoch_99
生成评估的时间也会稍微比较久
可以看到成功率success_rate
{'success_rate': 54.0,
'episode_successes': array([ True, True, False, True, False, ... ]),
...
}
找生成的视频看看

代码解读
lejepa_forward 前向 + 损失计算
def lejepa_forward(self, batch, stage, cfg):
...
output = self.model.encode(batch) # 编码器输出
emb = output["emb"] # (B, T, D) latent embeddings
...
ctx_emb = emb[:, :ctx_len] # 上下文(历史帧)
tgt_emb = emb[:, n_preds:] # 真实下一帧标签
pred_emb = self.model.predict(ctx_emb, ctx_act) # 预测器预测
# LeWM loss(论文公式 (1)+(3))
output["pred_loss"] = (pred_emb - tgt_emb).pow(2).mean() # L_pred = MSE
output["sigreg_loss"] = self.sigreg(emb.transpose(0, 1)) # SIGReg
output["loss"] = output["pred_loss"] + lambd * output["sigreg_loss"]
数据加载部分(对应论文 3.1 Offline Dataset)
dataset = swm.data.HDF5Dataset(...) # 离线 HDF5 轨迹
# 图像预处理 + 动作/状态归一化
train_set, val_set = ... # 随机划分
train = DataLoader(..., shuffle=True, drop_last=True)
模型构建部分
encoder = spt.backbone.utils.vit_hf(...) # ViT-Tiny(论文默认 ~5M 参数)
predictor = ARPredictor(...) # 6层 Transformer + AdaLN
action_encoder = Embedder(...) # 动作条件化
projector = MLP(..., norm_fn=BatchNorm1d) # [CLS] 投影头(关键!)
predictor_proj = MLP(..., norm_fn=BatchNorm1d) # 预测器输出投影头
优化器与训练流程
world_model = JEPA(...) # 把 encoder/predictor/action_encoder 封装
world_model = spt.Module(
model=world_model,
sigreg=SIGReg(**cfg.loss.sigreg.kwargs), # SIGReg 实例
forward=partial(lejepa_forward, cfg=cfg),
optim=optimizers, # LinearWarmupCosineAnnealingLR
)
数据集解释
自论文 Figure 5 和 Appendix E 的描述
| 环境名称 | 维度 | 任务描述 | 环境画面 / 视觉风格 | 目标 / 成功条件 |
|---|---|---|---|---|
| TwoRoom | 2D | 从一个房间穿过门到达另一个房间的随机目标位置 | 俯视2D简单场景,两个矩形房间被一堵竖墙隔开,中间有一个小门,背景纯黑或浅色 | 到达目标位置(另一个房间的随机点) |
| PushT | 2D | 把T形块推到指定的目标位置和角度 | 俯视2D桌面场景,平整的浅色桌面,T形块在桌面上 | T形块的位置和角度与目标完全匹配 |
| OGBench-Cube | 3D | 机械臂抓住立方块并移动到指定目标位置 | 3D真实感机器人场景,有机械臂、桌面、一个立方块,相机固定视角 | 立方块被放置到目标位置 |
| Reacher | 2D | 控制双关节机械臂让末端精确到达目标点 | 2D平面黑色背景,机械臂在平面上运动,目标点用小圆点标记 | 机械臂末端与目标点完全对齐(perfect alignment) |
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐
所有评论(0)