本文介绍基于DROID采集数据,来微调π0.5模型,整个过程分为三个环节:

  • 数据采集 —— 参照DROID硬件平台搭建机器人系统,通过遥操作采集多视角视频、机器人状态与语言指令,为训练打下基础。
  • 格式转换 —— 将DROID格式数据转换为LeRobot标准格式,让数据顺利进入训练流程。
  • 模型微调 —— 基于π0.5-DROID预训练权重进行微调,重点优化策略头部分,使模型更好地适应具体任务。

整个过程涉及数据采集、格式转换和模型训练三个关键环节,为机器人操作任务提供了端到端的参考。

数据示例:

1、采集数据

我们可以参考DROID,自定义采集机器人的操作数据

采购硬件设备、组装设备参考:https://droid-dataset.github.io/droid/docs/hardware-setup

组件 规格
机械臂 Franka Emika Panda 7DoF
夹爪 Robotiq 2F-85
外部相机 2× ZED 2 立体相机(可调三脚架)
腕部相机 ZED-Mini 立体相机
遥操作设备 Meta Quest 2 头显+手柄
控制器 Polymetis(15Hz 控制频率)
计算设备 NUC(Polymetis 服务器)+ Alienware 笔记本(数据收集 GUI)

硬件平台,如下图所示:

然后推荐在docker安装开发环境:https://droid-dataset.github.io/droid/software-setup/docker.html

遥操采集数据:https://droid-dataset.github.io/droid/example-workflows/teleoperation.html

在主机或通过 Docker 设置好DROID软件,使用 Oculus Quest 2 远程操控,来采集机械臂的操作数据。

数据采集流程

  1. 场景设置:移动机器人到新场景,调整相机视角

  2. 相机标定:使用棋盘格和 OpenCV 进行外参标定

  3. 任务输入:在 GUI 中输入场景内所有可能的任务(支持选择或自由输入)

  4. 自动采样:GUI 随机抽取任务指令,确保任务覆盖度

  5. 场景增强:定期提示执行场景变化(移动底座、调整相机、改变光照、增减物品)

  6. 数据采集:以 15Hz 频率记录以下数据 :

采集到的数据内容

观测数据

  • 3 路立体 RGB 相机流(1280×720)

  • 机器人关节位置和速度(7D)

  • 末端执行器位姿和速度(6D)

  • 夹爪位置和速度(1D)

动作数据(多种动作空间):

  • 关节空间:关节位置/速度指令

  • 笛卡尔空间:末端位姿/速度指令 + 夹爪控制

元数据

  • 1-3 条自然语言指令(众包标注)

  • 相机外参矩阵、建筑名称、采集者 ID

  • 场景类型(GPT-4V 分类)

  • 成功/失败标记

采集的原始数据以 HDF5 格式存储 ,每个episode 包含:

episode/
├── metadata_*.json      # 场景、采集者等元数据
├── trajectory.h5         # 低维数据(动作、本体感知)
└── recordings/
    ├── MP4/             # 高清视频(左目/立体)
    └── SVO/             # ZED 原始 SVO 文件

2、下载示例数据

如果上面没有采购到硬件设备,没关系,先用示例数据进行模型微调,走通流程

openpi的环境搭建,参考我上一篇博客:《VLA 系列》复现 π0.5、π0-FAST、π0 | 环境搭建 | 模型推理

进行入openpi代码目录中,新建一个droid_examples_1_0_1文件夹,用于存放数据(1.6 GiB左右)

使用gsutil 下载示例数据:

gsutil -m cp -r gs://gresearch/robotics/droid_raw/1.0.1/IRIS/success/2023-12-04  droid_examples_1_0_1

运行信息:

等待下载完成:

Copying gs://gresearch/robotics/droid_raw/1.0.1/IRIS/success/2023-12-04/Mon_Dec__4_16:18:46_2023/recordings/MP4/19824535.mp4...
Copying gs://gresearch/robotics/droid_raw/1.0.1/IRIS/success/2023-12-04/Mon_Dec__4_16:18:46_2023/recordings/MP4/23404442.mp4...
Copying gs://gresearch/robotics/droid_raw/1.0.1/IRIS/success/2023-12-04/Mon_Dec__4_16:18:46_2023/recordings/MP4/29838012.mp4...
Copying gs://gresearch/robotics/droid_raw/1.0.1/IRIS/success/2023-12-04/Mon_Dec__4_16:18:46_2023/recordings/SVO/19824535.svo...
Copying gs://gresearch/robotics/droid_raw/1.0.1/IRIS/success/2023-12-04/Mon_Dec__4_16:18:46_2023/recordings/SVO/23404442.svo...
Copying gs://gresearch/robotics/droid_raw/1.0.1/IRIS/success/2023-12-04/Mon_Dec__4_16:18:46_2023/recordings/SVO/29838012.svo...
Copying gs://gresearch/robotics/droid_raw/1.0.1/IRIS/success/2023-12-04/Mon_Dec__4_16:18:46_2023/trajectory.h5...
| [240/240 files][  1.6 GiB/  1.6 GiB] 100% Done  29.8 KiB/s ETA 00:00:00       
Operation completed over 240 objects/1.6 GiB.  

然后下载上面视频对应的“操作任务-语言指令”

gsutil -m cp -r gs://gresearch/robotics/droid_raw/1.0.1/aggregated-annotations-030724.json  droid_examples_1_0_1

打印信息:

Google recommends using Gcloud storage CLI (https://docs.cloud.google.com/storage/docs/discover-object-storage-gcloud) instead of gsutil. Please refer to migration guide (https://docs.cloud.google.com/storage/docs/gsutil-transition-to-gcloud) for assistance.
Copying gs://gresearch/robotics/droid_raw/1.0.1/aggregated-annotations-030724.json...
\ [1/1 files][ 11.5 MiB/ 11.5 MiB] 100% Done                                    
Operation completed over 1 objects/11.5 MiB. 

能看到目录结构:

数据目录内容:

droid_examples_1_0_1/
├── aggregated-annotations-030724.json    # 操作任务-语言指令 标注文件
└── <dataset_name>/
    ├── recordings/
    │   └── MP4/
    │       ├── <camera_id>.mp4                 # 相机视频文件(三个相机)
    │       └── ...

    │   └── SVO/                                         # ZED 原始 SVO 文件
    ├── trajectory.h5                                   # 轨迹数据(HDF5格式)
    └── metadata_<episode_id>.json        # 元数据文件

mp4示例数据:

重点看一下trajectory.h5 的轨迹数据,可以在VScode安装一个H5Web插件,很好可视化.h5的数据

比如,查看夹爪的情况,能可视化看的:

或者查看关节电机的具体数值:

3、数据格式转换(转为LeRobot 格式)

原始数据,包含的关键内容

数据类型 来源 说明
视频帧 recordings/MP4/*.mp4 多视角相机图像(手腕相机、外部相机)
机器人状态 trajectory.h5 关节位置 joint_positions (7维)
夹爪位置 gripper_position (1维)
动作 trajectory.h5 关节速度 joint_velocity (7维) + 夹爪位置 (1维)
语言指令 aggregated-annotations-030724.json 文本任务描述
时间戳 trajectory.h5 各相机帧的时间戳信息

转换后的数据格式(LeRobot格式),是 使用 LeRobotDataset.create() 定义的特征结构:

{
    # 图像数据(3个视角,分辨率 180×320)
    "exterior_image_1_left":  image  (180, 320, 3)  # 外部相机1
    "exterior_image_2_left":  image  (180, 320, 3)  # 外部相机2  
    "wrist_image_left":       image  (180, 320, 3)  # 手腕相机
    
    # 状态数据
    "joint_position":         float32 (7,)          # 7维关节位置
    "gripper_position":       float32 (1,)          # 1维夹爪位置
    
    # 动作数据(用于训练)
    "actions":                float32 (8,)          # 7维关节速度 + 1维夹爪
}

关键转换处理:

处理步骤 说明
BGR → RGB [..., ::-1] 翻转颜色通道
图像缩放 使用 Image.BICUBIC 缩放到 320×180
动作拼接 joint_velocity (7D) + gripper_position (1D) = 8D
帧率 15 FPS

思路流程:

转换后的数据路径:~/.cache/huggingface/lerobot/your_hf_username/my_droid_dataset

 以上面的示例数据为例,运行指令:

uv run examples/droid/convert_droid_data_to_lerobot.py --data_dir droid_examples_1_0_1

运行信息:

参考代码:

"""
将DROID平台采集的数据,转换为LeRobot格式的最小示例

"""

from collections import defaultdict
import copy
import glob
import json
from pathlib import Path
import shutil

import cv2
import h5py
from lerobot.common.datasets.lerobot_dataset import HF_LEROBOT_HOME
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
import numpy as np
from PIL import Image
from tqdm import tqdm
import tyro

# 输出数据集的名称,也用于Hugging Face Hub
REPO_NAME = "your_hf_username/my_droid_dataset"


def resize_image(image, size):
    """调整图像大小"""
    image = Image.fromarray(image)
    return np.array(image.resize(size, resample=Image.BICUBIC))


def main(data_dir: str, *, push_to_hub: bool = False):
    # 清理输出目录中任何现有的数据集
    output_path = HF_LEROBOT_HOME / REPO_NAME
    if output_path.exists():
        shutil.rmtree(output_path)
    data_dir = Path(data_dir)

    # 创建LeRobot数据集,定义要存储的特征
    # 这里我们将遵循DROID数据的命名约定
    # LeRobot假设图像数据的dtype为`image`
    dataset = LeRobotDataset.create(
        repo_id=REPO_NAME,
        robot_type="panda",
        fps=15,  # DROID数据通常以15fps录制
        features={
            # 我们称之为"left",因为只使用左立体相机(遵循DROID RLDS约定)
            "exterior_image_1_left": {
                "dtype": "image",
                "shape": (180, 320, 3),  # 这是DROID RLDS数据集使用的分辨率
                "names": ["height", "width", "channel"],
            },
            "exterior_image_2_left": {
                "dtype": "image",
                "shape": (180, 320, 3),
                "names": ["height", "width", "channel"],
            },
            "wrist_image_left": {
                "dtype": "image",
                "shape": (180, 320, 3),
                "names": ["height", "width", "channel"],
            },
            "joint_position": {
                "dtype": "float32",
                "shape": (7,),
                "names": ["joint_position"],
            },
            "gripper_position": {
                "dtype": "float32",
                "shape": (1,),
                "names": ["gripper_position"],
            },
            "actions": {
                "dtype": "float32",
                "shape": (8,),  # 这里使用关节*速度*动作(7维)+ 夹爪位置(1维)
                "names": ["actions"],
            },
        },
        image_writer_threads=10,
        image_writer_processes=5,
    )

    # 加载语言标注
    # 注意:本示例加载DROID语言标注,但你可以为自己的数据手动定义
    with (data_dir / "aggregated-annotations-030724.json").open() as f:
        language_annotations = json.load(f)

    # 遍历原始DROID微调数据集,将片段写入LeRobot数据集
    # 我们假设以下目录结构:
    # RAW_DROID_PATH/
    #   - <...>/
    #     - recordings/
    #        - MP4/
    #          - <camera_id>.mp4  # 左立体相机对的单视角视频
    #     - trajectory.hdf5
    #   - <...>/
    episode_paths = list(data_dir.glob("**/trajectory.h5"))
    print(f"找到 {len(episode_paths)} 个片段用于转换")

    # 遍历每个数据集名称,将片段写入LeRobot数据集
    for episode_path in tqdm(episode_paths, desc="转换片段中"):
        # 加载原始数据
        recording_folderpath = episode_path.parent / "recordings" / "MP4"
        trajectory = load_trajectory(str(episode_path), recording_folderpath=str(recording_folderpath))

        # 为了加载语言指令,需要从元数据文件中解析出episode_id
        # 同样,你可以为自己的数据修改此步骤,加载你自己的语言指令
        metadata_filepath = next(iter(episode_path.parent.glob("metadata_*.json")))
        episode_id = metadata_filepath.name.split(".")[0].split("_")[-1]
        language_instruction = language_annotations.get(episode_id, {"language_instruction1": "Do something"})[
            "language_instruction1"
        ]
        print(f"正在转换带有语言指令的片段: {language_instruction}")

        # 写入LeRobot数据集
        for step in trajectory:
            camera_type_dict = step["observation"]["camera_type"]
            wrist_ids = [k for k, v in camera_type_dict.items() if v == 0]
            exterior_ids = [k for k, v in camera_type_dict.items() if v != 0]
            dataset.add_frame(
                {
                    # 注意:加载的图像需要从BGR翻转为RGB
                    "exterior_image_1_left": resize_image(
                        step["observation"]["image"][exterior_ids[0]][..., ::-1], (320, 180)
                    ),
                    "exterior_image_2_left": resize_image(
                        step["observation"]["image"][exterior_ids[1]][..., ::-1], (320, 180)
                    ),
                    "wrist_image_left": resize_image(step["observation"]["image"][wrist_ids[0]][..., ::-1], (320, 180)),
                    "joint_position": np.asarray(
                        step["observation"]["robot_state"]["joint_positions"], dtype=np.float32
                    ),
                    "gripper_position": np.asarray(
                        step["observation"]["robot_state"]["gripper_position"][None], dtype=np.float32
                    ),
                    # 重要:这里使用关节速度动作,因为pi05-droid是在关节速度动作上预训练的
                    "actions": np.concatenate(
                        [step["action"]["joint_velocity"], step["action"]["gripper_position"][None]], dtype=np.float32
                    ),
                    "task": language_instruction,
                }
            )
        dataset.save_episode()

    # 可选:推送到 Hugging Face Hub
    if push_to_hub:
        dataset.push_to_hub(
            tags=["libero", "panda", "rlds"],
            private=False,
            push_videos=True,
            license="apache-2.0",
        )


##########################################################################################################
################ 本文件其余部分是解析原始DROID数据的函数 #########################
################ 不需要理解这部分 #########################
################ 复制自: https://github.com/JonathanYang0127/r2d2_rlds_dataset_builder/blob/parallel_convert/r2_d2/r2_d2.py
##########################################################################################################


camera_type_dict = {
    "hand_camera_id": 0,
    "varied_camera_1_id": 1,
    "varied_camera_2_id": 1,
}

camera_type_to_string_dict = {
    0: "hand_camera",
    1: "varied_camera",
    2: "fixed_camera",
}


def get_camera_type(cam_id):
    if cam_id not in camera_type_dict:
        return None
    type_int = camera_type_dict[cam_id]
    return camera_type_to_string_dict[type_int]


class MP4Reader:
    def __init__(self, filepath, serial_number):
        # 保存参数 #
        self.serial_number = serial_number
        self._index = 0

        # 打开视频读取器 #
        self._mp4_reader = cv2.VideoCapture(filepath)
        if not self._mp4_reader.isOpened():
            raise RuntimeError("损坏的MP4文件")

    def set_reading_parameters(
        self,
        image=True,  # noqa: FBT002
        concatenate_images=False,  # noqa: FBT002
        resolution=(0, 0),
        resize_func=None,
    ):
        # 保存参数 #
        self.image = image
        self.concatenate_images = concatenate_images
        self.resolution = resolution
        self.resize_func = cv2.resize
        self.skip_reading = not image
        if self.skip_reading:
            return

    def get_frame_resolution(self):
        width = self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_WIDTH)
        height = self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_HEIGHT)
        return (width, height)

    def get_frame_count(self):
        if self.skip_reading:
            return 0
        return int(self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_COUNT))

    def set_frame_index(self, index):
        if self.skip_reading:
            return

        if index < self._index:
            self._mp4_reader.set(cv2.CAP_PROP_POS_FRAMES, index - 1)
            self._index = index

        while self._index < index:
            self.read_camera(ignore_data=True)

    def _process_frame(self, frame):
        frame = copy.deepcopy(frame)
        if self.resolution == (0, 0):
            return frame
        return self.resize_func(frame, self.resolution)

    def read_camera(self, ignore_data=False, correct_timestamp=None):  # noqa: FBT002
        # 如果不需要读取则跳过 #
        if self.skip_reading:
            return {}

        # 读取相机 #
        success, frame = self._mp4_reader.read()

        self._index += 1
        if not success:
            return None
        if ignore_data:
            return None

        # 返回数据 #
        data_dict = {}

        if self.concatenate_images or "stereo" not in self.serial_number:
            data_dict["image"] = {self.serial_number: self._process_frame(frame)}
        else:
            single_width = frame.shape[1] // 2
            data_dict["image"] = {
                self.serial_number + "_left": self._process_frame(frame[:, :single_width, :]),
                self.serial_number + "_right": self._process_frame(frame[:, single_width:, :]),
            }

        return data_dict

    def disable_camera(self):
        if hasattr(self, "_mp4_reader"):
            self._mp4_reader.release()


class RecordedMultiCameraWrapper:
    def __init__(self, recording_folderpath, camera_kwargs={}):  # noqa: B006
        # 保存相机信息 #
        self.camera_kwargs = camera_kwargs

        # 打开相机读取器 #
        mp4_filepaths = glob.glob(recording_folderpath + "/*.mp4")
        all_filepaths = mp4_filepaths

        self.camera_dict = {}
        for f in all_filepaths:
            serial_number = f.split("/")[-1][:-4]
            cam_type = get_camera_type(serial_number)
            camera_kwargs.get(cam_type, {})

            if f.endswith(".mp4"):
                Reader = MP4Reader  # noqa: N806
            else:
                raise ValueError

            self.camera_dict[serial_number] = Reader(f, serial_number)

    def read_cameras(self, index=None, camera_type_dict={}, timestamp_dict={}):  # noqa: B006
        full_obs_dict = defaultdict(dict)

        # 随机顺序读取相机 #
        all_cam_ids = list(self.camera_dict.keys())
        # random.shuffle(all_cam_ids)

        for cam_id in all_cam_ids:
            if "stereo" in cam_id:
                continue
            try:
                cam_type = camera_type_dict[cam_id]
            except KeyError:
                print(f"{self.camera_dict} -- {camera_type_dict}")
                raise ValueError(f"在camera_type_dict中未找到相机类型 {cam_id}")  # noqa: B904
            curr_cam_kwargs = self.camera_kwargs.get(cam_type, {})
            self.camera_dict[cam_id].set_reading_parameters(**curr_cam_kwargs)

            timestamp = timestamp_dict.get(cam_id + "_frame_received", None)
            if index is not None:
                self.camera_dict[cam_id].set_frame_index(index)

            data_dict = self.camera_dict[cam_id].read_camera(correct_timestamp=timestamp)

            # 处理返回的数据 #
            if data_dict is None:
                return None
            for key in data_dict:
                full_obs_dict[key].update(data_dict[key])

        return full_obs_dict


def get_hdf5_length(hdf5_file, keys_to_ignore=[]):  # noqa: B006
    length = None

    for key in hdf5_file:
        if key in keys_to_ignore:
            continue

        curr_data = hdf5_file[key]
        if isinstance(curr_data, h5py.Group):
            curr_length = get_hdf5_length(curr_data, keys_to_ignore=keys_to_ignore)
        elif isinstance(curr_data, h5py.Dataset):
            curr_length = len(curr_data)
        else:
            raise ValueError

        if length is None:
            length = curr_length
        assert curr_length == length

    return length


def load_hdf5_to_dict(hdf5_file, index, keys_to_ignore=[]):  # noqa: B006
    data_dict = {}

    for key in hdf5_file:
        if key in keys_to_ignore:
            continue

        curr_data = hdf5_file[key]
        if isinstance(curr_data, h5py.Group):
            data_dict[key] = load_hdf5_to_dict(curr_data, index, keys_to_ignore=keys_to_ignore)
        elif isinstance(curr_data, h5py.Dataset):
            data_dict[key] = curr_data[index]
        else:
            raise ValueError

    return data_dict


class TrajectoryReader:
    def __init__(self, filepath, read_images=True):  # noqa: FBT002
        self._hdf5_file = h5py.File(filepath, "r")
        is_video_folder = "observations/videos" in self._hdf5_file
        self._read_images = read_images and is_video_folder
        self._length = get_hdf5_length(self._hdf5_file)
        self._video_readers = {}
        self._index = 0

    def length(self):
        return self._length

    def read_timestep(self, index=None, keys_to_ignore=[]):  # noqa: B006
        # 确保在范围内读取 #
        if index is None:
            index = self._index
        else:
            assert not self._read_images
            self._index = index
        assert index < self._length

        # 加载低维数据 #
        keys_to_ignore = [*keys_to_ignore.copy(), "videos"]
        timestep = load_hdf5_to_dict(self._hdf5_file, self._index, keys_to_ignore=keys_to_ignore)

        # 递增读取索引 #
        self._index += 1

        # 返回时间步 #
        return timestep

    def close(self):
        self._hdf5_file.close()


def load_trajectory(
    filepath=None,
    read_cameras=True,  # noqa: FBT002
    recording_folderpath=None,
    camera_kwargs={},  # noqa: B006
    remove_skipped_steps=False,  # noqa: FBT002
    num_samples_per_traj=None,
    num_samples_per_traj_coeff=1.5,
):
    read_recording_folderpath = read_cameras and (recording_folderpath is not None)

    traj_reader = TrajectoryReader(filepath)
    if read_recording_folderpath:
        camera_reader = RecordedMultiCameraWrapper(recording_folderpath, camera_kwargs)

    horizon = traj_reader.length()
    timestep_list = []

    # 选择要保存的时间步 #
    if num_samples_per_traj:
        num_to_save = num_samples_per_traj
        if remove_skipped_steps:
            num_to_save = int(num_to_save * num_samples_per_traj_coeff)
        max_size = min(num_to_save, horizon)
        indices_to_save = np.sort(np.random.choice(horizon, size=max_size, replace=False))
    else:
        indices_to_save = np.arange(horizon)

    # 遍历轨迹 #
    for i in indices_to_save:
        # 获取HDF5数据 #
        timestep = traj_reader.read_timestep(index=i)

        # 如果适用,获取录制的数据 #
        if read_recording_folderpath:
            timestamp_dict = timestep["observation"]["timestamp"]["cameras"]
            camera_type_dict = {
                k: camera_type_to_string_dict[v] for k, v in timestep["observation"]["camera_type"].items()
            }
            camera_obs = camera_reader.read_cameras(
                index=i, camera_type_dict=camera_type_dict, timestamp_dict=timestamp_dict
            )
            camera_failed = camera_obs is None

            # 如果成功,将数据添加到时间步 #
            if camera_failed:
                break
            timestep["observation"].update(camera_obs)

        # 过滤步骤 #
        step_skipped = not timestep["observation"]["controller_info"].get("movement_enabled", True)
        delete_skipped_step = step_skipped and remove_skipped_steps

        # 保存过滤后的时间步 #
        if delete_skipped_step:
            del timestep
        else:
            timestep_list.append(timestep)

    # 移除多余的过渡帧 #
    timestep_list = np.array(timestep_list)
    if (num_samples_per_traj is not None) and (len(timestep_list) > num_samples_per_traj):
        ind_to_keep = np.random.choice(len(timestep_list), size=num_samples_per_traj, replace=False)
        timestep_list = timestep_list[ind_to_keep]

    # 关闭读取器 #
    traj_reader.close()

    # 返回数据 #
    return timestep_list


if __name__ == "__main__":
    tyro.cli(main)

4、模型微调

这里默认使用 pi05_droid (π0.5-DROID)权重进行微调,也可以选择其他模型

 模型 用例 描述 检查点路径
π0-FAST-DROID 推理 π0-基于DROID数据集微调的FAST模型:能够在DROID机器人平台上执行各种简单的桌面操作任务,无需在新场景中进行任何测试。 gs://openpi-assets/checkpoints/pi0_fast_droid
π0-DROID 微调 π0在DROID 数据集上微调的模型:推理速度比以往更快π0-FAST-DROID,但可能无法很好地遵循语言命令。 gs://openpi-assets/checkpoints/pi0_droid
π0-ALOHA-towel 推理 π0基于ALOHA内部数据微调的模型:可在ALOHA机器人平台上零次折叠各种毛巾 gs://openpi-assets/checkpoints/pi0_aloha_towel
π0-ALOHA-tupperware 推理 π0基于ALOHA内部数据微调的模型:可以从特百惠容器中取出食物 gs://openpi-assets/checkpoints/pi0_aloha_tupperware
π0-ALOHA-pen-uncap 推理 π0基于公开的ALOHA数据微调的模型:可以打开笔帽 gs://openpi-assets/checkpoints/pi0_aloha_pen_uncap
π0.5-LIBERO 推理 π0.5针对LIBERO基准测试进行了微调的模型:获得了最先进的性能(参见LIBERO README gs://openpi-assets/checkpoints/pi05_libero
π0.5-DROID 推理/微调 π0.5在DROID数据集上进行微调并实现知识隔离的模型:推理速度快,语言跟随性能好 gs://openpi-assets/checkpoints/pi05_droid

微调命令:

uv run scripts/train.py pi05_droid_finetune --exp-name=my_experiment --overwrite

关键参数:

配置项 推测值 说明
model PI0.5架构 扩散策略或流匹配VLA模型
weight_loader 预训练PI0.5权重 从基础模型加载
freeze_filter 可能冻结视觉编码器 只微调策略头/LoRA
trainable_filter 策略相关参数 指定可训练部分
optimizer AdamW 带权重衰减
lr_schedule 余弦退火/常数 微调学习率
ema_decay 0.9999 模型参数平滑
batch_size 需整除GPU数 FSDP数据并行
data_loader DROID格式 LeRobot格式加载

打印信息:

等待微调完成:

repack_transforms=Group(inputs=[RepackTransform(structure={'observation/exterior_image_1_left': 'exterior_image_1_left', 'observation/exterior_image_2_left': 'exterior_image_2_left', 'observation/wrist_image_left': 'wrist_image_left', 'observation/joint_position': 'joint_position', 'observation/gripper_position': 'gripper_position', 'actions': 'actions', 'prompt': 'prompt'})], outputs=()), data_transforms=Group(inputs=[DroidInputs(model_type=<ModelType.PI05: 'pi05'>)], outputs=[DroidOutputs()]), model_transforms=Group(inputs=[InjectDefaultPrompt(prompt=None), ResizeImages(height=224, width=224), TokenizePrompt(tokenizer=<openpi.models.tokenizer.PaligemmaTokenizer object at 0x7f8fd525c210>, discrete_state_input=True), PadStatesAndActions(model_action_dim=32)], outputs=()), use_quantile_norm=True, action_sequence_keys=('actions',), prompt_from_task=True, rlds_data_dir=None, action_space=None, datasets=()) (14605:data_loader.py:243)
Resolving data files: 100%|█████████████████████████████████████████████████████████████████| 30/30 [00:00<00:00, 391991.03it/s]
Downloading data: 100%|██████████████████████████████████████████████████████████████████| 30/30 [00:00<00:00, 466033.78files/s]
Generating train split: 7726 examples [00:01, 5662.68 examples/s]
11:09:27.192 [I] local_batch_size: 32                                                             (14605:data_loader.py:324)
2026-04-04 11:09:29.646418: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
........

分享完成~

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐