作者:SkyXZ
CSDN:SkyXZ~-CSDN博客
博客园:SkyXZ - 博客园
机械臂:LeRobot-SO101 数采机:MacBook-Pro Python3.10
开发机:Ubuntu 22.04, Cuda12.4,8 × NVIDIA A100-SXM4-40GB
开发板:RDK OS 4.0.2 Based on Ubuntu 22.04, Python 3.10.12, OpenExplore 3.2.0
相关资料:
- LeRobot Doc:https://huggingface.co/docs/lerobot/main/en/index
- RDT 170M&1B:https://github.com/thu-ml/RoboticsDiffusionTransformer
- RDT on RDKS100:RDT on Double RDK S100P 全流程文档
一、环境安装&机械臂配置
- 所有代码已上传至GitHub:GitHub - xiongqi123123/LeRobot-VLA: Classic VLA for LeRobot
环境安装
我们首先完成LeRobot环境的安装,我们默认使用conda作为环境管理,先运行以下命令创建一个Python3.10的虚拟环境
conda create -y -n lerobot python=3.10
接着便可以在环境中运行以下命令来配置lerobot所需要的依赖(使用的lerobot源码为我修改之后的,仅添加了本地serverclient,其他部分与官方源码一致)
# step:0 安装编译依赖
sudo apt-get install cmake build-essential python-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev pkg-config
# step:1 激活环境
conda activate lerobot
# step:2 安装ffmpeg
conda install ffmpeg -c conda-forge
# 以下两种方式任选其一:
# step:3 从源码安装lerobot
git clone https://github.com/xiongqi123123/LeRobot-VLA.git
cd LeRobot-VLA/lerobot
pip install -e .
# step:3 从PyPI 安装
pip install lerobot
# 要安装附加功能,请使用以下之一
pip install 'lerobot[all]' # All available features
pip install 'lerobot[aloha,pusht]' # Specific features (Aloha & Pusht)
pip install 'lerobot[feetech]' # Feetech motor support
机械臂安装
- 官方安装教程:SO-101 - Hugging Face 机器学习平台
由于SO-ARM101机械臂默认不提供机械臂上的相机安装位置,因此我们在原始的机械臂夹爪部分自行设计添加了一个相机固定的位置(安装孔位与夹爪是对齐的),我们使用的相机是亚博智能的1080P高清免驱摄像头,打印文件已保存进仓库中:


拿到刚拆封的机械臂配件我们首先将Follower和Leader臂的物料进行区分,要注意的是Follower机械臂使用的是12个ST-3215-C001(7.4V)1:345
齿轮比的电机,而Leader臂不同关机使用的电机型号有所不同,不同关节的电机型号区分如下图及上表:


在完成机械臂的安装后我们便可以开始对机械臂两个臂的电机进行配置设置其对应的ID了,新版的LeRobot提供了CLI命令可以直接运行对应的任务,我们首先将两个机械臂的串口钱全部接上电脑并运行如下命令,接着按照提示拔出其中一个串口线按下回车即可知道拔出的串口号是多少(实际就是记录插上的所有串口,然后再和拔出后的进行对比就可以知道哪个串口少了...),示例输出如下图:
lerobot-find-port
由于LeRobot使用的是总线舵机,每个电机都通过总线上的唯一 ID 进行识别,而全新电机通常带有一个默认 ID 为 1
,所以为了让电机和控制器之间正常通信,我们需要为每个电机设置一个唯一的、不同的 ID,在确定了自己的两个机械臂分别对应的串口后我们便可以一个臂一个臂一个电机一个电机的进行配置啦,在新版的LeRobot库不需要重复多次运行电机配置命令,我们只需要运行以下命令并依次将不同的舵机线插到控制板按回车即可,具体可以参考以下视频:
lerobot-setup-motors \--robot.type=so101_follower \--robot.port=/dev/tty.usbmodem585A0076841 # <- paste here the port found at previous step
在完成机械臂电机的配置之后我们便可以对机械臂进行校准了,这部分主要是为了确保主动机臂和从动机臂在处于相同物理位置时具有相同的位置值,也就是我们可以操控Leader来控制Follower臂进行数据采集,旧版的LeRobot需要在运行命令之后将所有臂摆成三个不同的姿势来校准,这就导致两个机械臂如果校准的姿势有细微的区别就会影响两者之间的映射,所以新版的LeRobot采用了新的控制方式,我们只需要在所有舵机处于机械臂运动的中值时开启校准然后依次扭动舵机使系统自动记录舵机当前可运动的最大最小值即可自动完成校准,具体命令及示例视频如下:
lerobot-calibrate \--robot.type=so101_follower \--robot.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot--robot.id=my_awesome_follower_arm # <- Give the robot a unique name
LeRobot 提供多种视频捕捉选项,包括手机摄像头、内置笔记本电脑摄像头、外部网络摄像头等,也就是我们可以随意选择使用的摄像头以及摄像头的放置位置,我们首先测试我们的摄像头:
lerobot-find-cameras opencv


启用命令后,理论上会显示出三个摄像头设备(免驱外部摄像头、Macbook摄像头、IPhone连续互通摄像头),实际也如下,拍摄的图片会按照摄像头的ID保存在outputs
文件夹中,但是要注意不要将两个同款的摄像头连接在一个USB-HUB上!这样大概率在采集数据的时候会出现摄像头识别问题!!!


机械臂测试
在完成了机械臂的校准配置之后我们可以使用以下命令来对机械臂进行测试,在命令行输入以下命令之后我们便可以操作Leader臂来测试Follower臂是否正常的跟随运动了:
lerobot-teleoperate \--robot.type=so101_follower \--robot.port=/dev/tty.usbmodem5AB90671801 \--robot.id=my_awesome_follower_arm \--teleop.type=so101_leader \--teleop.port=/dev/tty.usbmodem5AB90671501 \--teleop.id=my_awesome_leader_arm \
如果你已经配置好了摄像头,那么便可以使用以下命令来同步测试带摄像头的遥操是否正常:
lerobot-teleoperate \--robot.type=so101_follower \--robot.port=/dev/tty.usbmodem5AB90671801 \--robot.id=my_awesome_follower_arm \--robot.cameras="{ arm: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}, front: {type: opencv, index_or_path: 2, width: 1920, height: 1080, fps: 30}}" \--teleop.type=so101_leader \--teleop.port=/dev/tty.usbmodem5AB90671501 \--teleop.id=my_awesome_leader_arm \--display_data=true
二、数据集采集
完成了上述配置之后我们便可以开始采集数据集了,由于RDT是双臂的模型,因此默认使用的是三个视角(前、左臂、右臂)的图像,但由于我们目前使用的是单臂,因此我们仅采集两个视角的图像,即上方和夹爪视角:


然后是任务方面,我们围绕笔和笔筒定义了生活中两个常见的任务分别是将桌面上的杂乱摆放的笔放到笔筒中以及将笔筒中的笔拿出放到桌面中间,所使用的物品如下:


我们运行以下命令即可开始数据采集,其中的num_episodes
用于设置一次采集的数据条数,single_task
则用于设置当前任务使用的语言instructions,在将push_to_hub
参数设置为False
后便会将数据集保存至本地而不会将采集的数据上传至huggingface,episode_time_s
参数用于设置每段数据的视频采集时长,reset_time_s
则是两段采集之间的复原时间,这些参数按需设置即可:
lerobot-record \--robot.type=so101_follower \--robot.port=/dev/tty.usbmodem5AB90671801 \--robot.id=my_awesome_follower_arm \--robot.cameras="{ arm: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}, front: {type: opencv, index_or_path: 2, width: 1920, height: 1080, fps: 30}}" \--teleop.type=so101_leader \--teleop.port=/dev/tty.usbmodem5AB90671501 \--teleop.id=my_awesome_leader_arm \--display_data=true \--dataset.repo_id=skyxz/blackmarker_scence1 \--dataset.num_episodes=5 \--dataset.single_task="Grab the black marker and put it in the bin" \--dataset.push_to_hub=False \--dataset.episode_time_s=15 \--dataset.reset_time_s=5
采集过程中要注意慢慢平滑的操作,不可过于快速的完成一个任务!采集之后的数据将保存在~/.cache/huggingface/lerobot/{repo_id}
下面,我们可以打开这个目录来查看每条数据的采集效果,我们在这里采集红黑两种颜色笔的数据,每个笔采集50条,其中每个笔有10条数据位置大致一定,20条数据笔的摆放位置随机,20条数据桌面有少量其他物品干扰
三、RDT模型训练
训练数据转换
我们使用的是LeRobot的库采集的数据,其数据集结构如下,data
目录保存核心序列数据,采用 Parquet 格式按 episode 切分,里面的每个文件都记录了该条演示的观测(observation)、动作(action)的完整数据,meta
目录则保存了数据集的全局描述和统计信息,如当前任务的描述、数据的步长等信息,videos
则如所示一样保存着不同摄像头视角下与观测数据同步的可视化视频
而RDT的训练数据使用的是HDF5
格式,其将同一条演示的所有观测、动作以及当前机械臂的状态信息均集中保存在一个文件中,RDT的数据格式如下:
(RoboTwin) qi.xiong@A100-Test:~/Data_Qi/LeRobot/skyxz/blackmarker_scence1$ python3 -c "import h5py; f=h5py.File('/home/qi.xiong/Data_Qi/RDT/processed_data/place_dual_shoes-demo_clean-300/episode_3/episode_3.hdf5','r'); f.visit(print); f.close()"
action
observations
observations/images
observations/images/cam_high
observations/images/cam_left_wrist
observations/images/cam_right_wrist
observations/left_arm_dim
observations/qpos
observations/right_arm_dim
所以我们在完成了数据的采集之后我们需要将LeRobot采集的数据对齐到RDT所需的HDF5格式的文件,同时需要在数据转换的时候对语言指令进行预编码得到处理之后的pt格式文件,使用的代码如下:
#!/usr/bin/env python3
"""
LeRobot到RDT数据转换脚本LeRobot机器人结构:
- 5个关节 (shoulder_pan, shoulder_lift, elbow_flex, wrist_flex, wrist_roll)
- 1个夹爪 (gripper)
- 总计:6个自由度 (6DOF)维度映射(匹配RDT训练代码):
- left_arm_dim = 0 (单臂机器人,左臂不存在)
- right_arm_dim = 6 (5关节 + 1夹爪,映射到RDT的right_arm部分)
- 状态向量:6维 [joint1, joint2, joint3, joint4, joint5, gripper]
- RDT索引映射:right_arm_joint_0_pos到right_arm_joint_5_pos (索引0-5)
"""import sys
import os
import h5py
import numpy as np
import cv2
import argparse
import yaml
import json
import subprocess
from pathlib import Path
import pandas as pd
import torchcurrent_dir = os.path.dirname(__file__)
sys.path.append(os.path.join(current_dir, ".."))
from models.multimodal_encoder.t5_encoder import T5Embedderdef extract_frames_from_video(video_path, output_dir, episode_idx):if not os.path.exists(video_path):print(f" No video file: {video_path}")return []temp_dir = os.path.join(output_dir, f"temp_frames_{episode_idx}")if not os.path.exists(temp_dir):os.makedirs(temp_dir)output_pattern = os.path.join(temp_dir, "frame_%04d.jpg")try:cmd = ['ffmpeg', '-i', video_path,'-vf', 'fps=30','-q:v', '2',output_pattern,'-y']result = subprocess.run(cmd, capture_output=True, text=True)if result.returncode != 0:print(f" Failed to extract frames with ffmpeg: {result.stderr}")return []frames = []frame_files = sorted([f for f in os.listdir(temp_dir) if f.endswith('.jpg')])for frame_file in frame_files:frame_path = os.path.join(temp_dir, frame_file)frame = cv2.imread(frame_path)if frame is not None:frame_resized = cv2.resize(frame, (640, 480))frames.append(frame_resized)print(f" Successfully extracted {len(frames)} frames")for frame_file in frame_files:os.remove(os.path.join(temp_dir, frame_file))os.rmdir(temp_dir)return framesexcept Exception as e:print(f" Error extracting frames: {e}")return []def load_lerobot_episode(data_dir, episode_idx, output_dir):"""加载LeRobot的单个episode数据LeRobot数据结构:- action: 6维 [shoulder_pan, shoulder_lift, elbow_flex, wrist_flex, wrist_roll, gripper]- observation.state: 6维 [shoulder_pan, shoulder_lift, elbow_flex, wrist_flex, wrist_roll, gripper]- 图像: 高位相机 + 手臂相机"""parquet_path = os.path.join(data_dir, "data/chunk-000", f"episode_{episode_idx:06d}.parquet")if not os.path.exists(parquet_path):print(f"Episode {episode_idx} parquet file does not exist: {parquet_path}")return Nonedf = pd.read_parquet(parquet_path)actions = []qpos = []for i in range(len(df)):action = df['action'].iloc[i]state = df['observation.state'].iloc[i]if isinstance(action, np.ndarray):actions.append(action.astype(np.float32))else:actions.append(np.array(action, dtype=np.float32))if isinstance(state, np.ndarray):qpos.append(state.astype(np.float32))else:qpos.append(np.array(state, dtype=np.float32))high_cam_path = os.path.join(data_dir, "videos/chunk-000/observation.images.high", f"episode_{episode_idx:06d}.mp4")arm_cam_path = os.path.join(data_dir, "videos/chunk-000/observation.images.arm", f"episode_{episode_idx:06d}.mp4")print(f" Extracting high camera frames...")high_images = extract_frames_from_video(high_cam_path, output_dir, episode_idx)print(f" Extracting arm camera frames...")arm_images = extract_frames_from_video(arm_cam_path, output_dir, episode_idx)target_frames = len(df)if len(high_images) > target_frames:high_images = high_images[:target_frames]if len(arm_images) > target_frames:arm_images = arm_images[:target_frames]while len(high_images) < target_frames and high_images:high_images.append(high_images[-1])while len(arm_images) < target_frames and arm_images:arm_images.append(arm_images[-1])return {'actions': np.array(actions),'qpos': np.array(qpos),'high_images': high_images,'arm_images': arm_images,'episode_length': len(df)}def images_encoding(imgs):if not imgs:return [], 0encode_data = []padded_data = []max_len = 0for i in range(len(imgs)):success, encoded_image = cv2.imencode(".jpg", imgs[i])if success:jpeg_data = encoded_image.tobytes()encode_data.append(jpeg_data)max_len = max(max_len, len(jpeg_data))else:print(f" Image encoding failed: {i}")empty_data = b""encode_data.append(empty_data)for i in range(len(imgs)):padded_data.append(encode_data[i].ljust(max_len, b"\0"))return encode_data, max_lendef load_task_instructions(data_dir):tasks_file = os.path.join(data_dir, "meta/tasks.jsonl")if not os.path.exists(tasks_file):print(f"Warning: tasks file not found: {tasks_file}")return Noneinstructions = []with open(tasks_file, 'r') as f:for line in f:if line.strip():task_data = json.loads(line.strip())instructions.append(task_data["task"])print(f" 加载了 {len(instructions)} 个任务指令")return instructionsdef encode_language_instruction(instruction_text, t5_embedder, device):try:text_embeds, attn_mask = t5_embedder.get_text_embeddings([instruction_text])valid_embeds = text_embeds[0][attn_mask[0]].float()return valid_embeds.cpu().numpy()except Exception as e:print(f" Language encoding failed: {e}")return np.zeros((1, 4096))def convert_lerobot_to_rdt(data_dir, output_dir, episode_num, gpu=0, no_language=False):if not os.path.exists(output_dir):os.makedirs(output_dir)print(f"Start converting LeRobot data to RDT format...")print(f"Data source: {data_dir}")print(f"Output directory: {output_dir}")print(f"Processing episode number: {episode_num}")print(f"GPU device: {gpu}")scene_name = os.path.basename(data_dir)instructions = Noneif not no_language:instructions = load_task_instructions(data_dir)t5_embedder = Noneif not no_language and instructions:try:print(f" Initializing T5 encoder...")t5_model_path = "/home/qi.xiong/Data_Qi/t5-v1_1-xxl"if not os.path.exists(t5_model_path):print(f" Warning: T5 model path does not exist: {t5_model_path}")print(f" Will skip language processing")no_language = Trueelse:t5_embedder = T5Embedder(from_pretrained=t5_model_path,device=f"cuda:{gpu}" if torch.cuda.is_available() else "cpu",model_max_length=120,use_offload_folder=None,)print(f" T5 encoder initialized successfully")except Exception as e:print(f" T5 encoder initialization failed: {e}")print(f" Will skip language processing")no_language = Truefor i in range(episode_num):print(f"Processing episode {i}...")episode_data = load_lerobot_episode(data_dir, i, output_dir)if episode_data is None:print(f"Skipping episode {i}")continueepisode_output_dir = os.path.join(output_dir, f"episode_{i}")if not os.path.exists(episode_output_dir):os.makedirs(episode_output_dir)hdf5_path = os.path.join(episode_output_dir, f"episode_{i}.hdf5")with h5py.File(hdf5_path, "w") as f:f.create_dataset("action", data=episode_data['actions'])obs = f.create_group("observations")obs.create_dataset("qpos", data=episode_data['qpos'])image = obs.create_group("images")if episode_data['high_images']:print(f" Encoding high camera images...")high_enc, len_high = images_encoding(episode_data['high_images'])if high_enc and len_high > 0:image.create_dataset("cam_high", data=high_enc, dtype=f"S{len_high}")print(f" Saved high camera images: {len(episode_data['high_images'])} frames")else:print(f" Warning: High camera images encoding failed")if episode_data['arm_images']:print(f" Encoding arm camera images...")arm_enc, len_arm = images_encoding(episode_data['arm_images'])if arm_enc and len_arm > 0:image.create_dataset("cam_right_wrist", data=arm_enc, dtype=f"S{len_arm}")print(f" Saved arm camera images: {len(episode_data['arm_images'])} frames")else:print(f" Warning: Arm camera images encoding failed")# 添加机器人维度信息(LeRobot: 5个关节 + 1个夹爪)# 根据process_data.py的逻辑,每个时间步都需要记录维度信息# LeRobot是单臂机器人,只有右臂:5个关节 + 1个夹爪 = 6维# 左臂:0维(单臂机器人)# 为每个时间步记录维度信息left_arm_dim = [0] * len(episode_data['actions']) # 左臂0维(单臂机器人)right_arm_dim = [6] * len(episode_data['actions']) # 右臂6维(5关节+1夹爪)obs.create_dataset("left_arm_dim", data=np.array(left_arm_dim))obs.create_dataset("right_arm_dim", data=np.array(right_arm_dim))print(f" Episode {i} converted successfully: {hdf5_path}")print(f" Data length: {episode_data['episode_length']}")print(f" Action shape: {episode_data['actions'].shape}")print(f" Qpos shape: {episode_data['qpos'].shape}")print(f" High camera frames: {len(episode_data['high_images'])}")print(f" Arm camera frames: {len(episode_data['arm_images'])}")if not no_language and t5_embedder and instructions:print(f" Processing language instructions...")try:instruction = instructions[0]language_features = encode_language_instruction(instruction, t5_embedder, f"cuda:{gpu}")instructions_dir = os.path.join(episode_output_dir, "instructions")if not os.path.exists(instructions_dir):os.makedirs(instructions_dir)lang_embed_path = os.path.join(instructions_dir, "lang_embed_0.pt")torch.save(torch.from_numpy(language_features), lang_embed_path)print(f" Language instruction encoded successfully: {instruction}")print(f" Language features saved to: {lang_embed_path}")print(f" Language features shape: {language_features.shape}, data type: {language_features.dtype}")except Exception as e:print(f" Language instruction processing failed: {e}")print(f"\nConversion completed! Processed {episode_num} episodes")print(f"Output directory: {output_dir}")def main():parser = argparse.ArgumentParser(description="Convert LeRobot data to RDT format")parser.add_argument("--data_dir", type=str, required=True, help="LeRobot data directory path")parser.add_argument("--output_dir", type=str, required=True,help="Output directory path")parser.add_argument("--episode_num", type=int, default=10,help="Number of episodes to process")parser.add_argument("--gpu", type=int, default=0,help="GPU device ID")parser.add_argument("--no_language", action="store_true",help="Skip language processing")args = parser.parse_args()if not os.path.exists(args.data_dir):print(f"Error: Data directory does not exist: {args.data_dir}")returnmeta_file = os.path.join(args.data_dir, "meta/info.json")if not os.path.exists(meta_file):print(f"Error: Meta information file not found: {meta_file}")returntry:subprocess.run(['ffmpeg', '-version'], capture_output=True, check=True)print("ffmpeg is available, will use ffmpeg to extract video frames")except (subprocess.CalledProcessError, FileNotFoundError):print("Warning: ffmpeg is not available, image data may not be extracted correctly")print("Please install ffmpeg: conda install -c conda-forge ffmpeg=6.1")returnwith open(meta_file, 'r') as f:meta_info = yaml.safe_load(f)total_episodes = meta_info.get('total_episodes', 10)if args.episode_num > total_episodes:print(f"Warning: Requested episode number ({args.episode_num}) exceeds available number ({total_episodes})")args.episode_num = total_episodesconvert_lerobot_to_rdt(args.data_dir, args.output_dir, args.episode_num,args.gpu,args.no_language)if __name__ == "__main__":main()
接下来仅需运行这个脚本或者是如下命令即可完成数据的转换:
# 法一 bash process_data_rdt.sh data_dir=${1} output_dir=${2} episode_num=${3} gpu_id=${4}
bash process_data_rdt.sh /home/qi.xiong/Data_Qi/LeRobot/skyxz/redmarker_scence4 /home/qi.xiong/DualArm/RoboTwin/policy/RDT-LeRobot/processed_data/redmarker_scence4 5 0
# 法二 python scripts/process_data_lerobot.py --data_dir --output_dir --episode_num --gpu
python3 scripts/process_data_lerobot.py --data_dir /home/qi.xiong/Data_Qi/LeRobot/skyxz/redmarker_scence4 --output_dir /home/qi.xiong/DualArm/RoboTwin/policy/RDT-LeRobot/processed_data/redmarker_scence4
我们使用RoboTwin修改后的RDT版本,这个版本使用比较简单快速(更多关于RDT的信息请见:RDT on Double RDK S100P 全流程文档),由于RDT默认是双臂任务,而我们采集的LeRobot是单臂数据,且我们只采集了两个摄像头的画面跟RDT默认的三个摄像头不匹配,因此如果直接训练的话肯定会报索引不匹配的错误,因此我们还需要对数据集加载的部分进行修改,首先是修改action的归一化部分,我们直接在加载的时候对数据除以[[180, 180, 180, 180, 180, 180]]
,然后便是将LeRobot的单臂映射到RDT的右臂部分的动作维度并将左臂整个给剔除同时把RDT默认加载的右臂图像用Ground图像进行替代,请使用以下完成了修改的代码替换原本的代码中的RDT/data/hdf5_vla_dataset.py
import os
import fnmatch
import jsonimport h5py
import yaml
import cv2
import numpy as npfrom configs.state_vec import STATE_VEC_IDX_MAPPINGclass HDF5VLADataset:"""This class is used to sample episodes from the embododiment datasetstored in HDF5."""def __init__(self, model_config_path) -> None:# [Modify] The path to the HDF5 dataset directory# Each HDF5 file contains one episodewith open(model_config_path, "r") as f:model_config = yaml.safe_load(f)HDF5_DIR = model_config["data_path"]self.DATASET_NAME = "agilex"self.file_paths = []for root, _, files in os.walk(HDF5_DIR):for filename in fnmatch.filter(files, "*.hdf5"):file_path = os.path.join(root, filename)self.file_paths.append(file_path)# Load the configwith open("configs/base.yaml", "r") as file:config = yaml.safe_load(file)self.CHUNK_SIZE = config["common"]["action_chunk_size"]self.IMG_HISORY_SIZE = config["common"]["img_history_size"]self.STATE_DIM = config["common"]["state_dim"]# Get each episode's len (use original length, not standardized length)episode_lens = []for file_path in self.file_paths:try:with h5py.File(file_path, "r") as f:qpos = f["observations"]["qpos"][:]num_steps = qpos.shape[0]episode_lens.append(num_steps)except Exception as e:print(f"Warning: Could not read {file_path}: {e}")episode_lens.append(0)self.episode_sample_weights = np.array(episode_lens) / np.sum(episode_lens)def __len__(self):return len(self.file_paths)def get_dataset_name(self):return self.DATASET_NAMEdef get_item(self, index: int = None, state_only=False):"""Get a training sample at a random timestep.Args:index (int, optional): the index of the episode.If not provided, a random episode will be selected.state_only (bool, optional): Whether to return only the state.In this way, the sample will contain a complete trajectory ratherthan a single timestep. Defaults to False.Returns:sample (dict): a dictionary containing the training sample."""while True:if index is None:file_path = np.random.choice(self.file_paths, p=self.episode_sample_weights)else:file_path = self.file_paths[index]valid, sample = (self.parse_hdf5_file(file_path)if not state_only else self.parse_hdf5_file_state_only(file_path))if valid:return sampleelse:index = np.random.randint(0, len(self.file_paths))def parse_hdf5_file(self, file_path):"""[Modify] Parse a hdf5 file to generate a training sample ata random timestep.Args:file_path (str): the path to the hdf5 fileReturns:valid (bool): whether the episode is valid, which is useful for filtering.If False, this episode will be dropped.dict: a dictionary containing the training sample,{"meta": {"dataset_name": str, # the name of your dataset."#steps": int, # the number of steps in the episode,# also the total timesteps."instruction": str # the language instruction for this episode.},"step_id": int, # the index of the sampled step,# also the timestep t."state": ndarray, # state[t], (1, STATE_DIM)."state_std": ndarray, # std(state[:]), (STATE_DIM,)."state_mean": ndarray, # mean(state[:]), (STATE_DIM,)."state_norm": ndarray, # norm(state[:]), (STATE_DIM,)."actions": ndarray, # action[t:t+CHUNK_SIZE], (CHUNK_SIZE, STATE_DIM)."state_indicator", ndarray, # indicates the validness of each dim, (STATE_DIM,)."cam_high": ndarray, # external camera image, (IMG_HISORY_SIZE, H, W, 3)# or (IMG_HISORY_SIZE, 0, 0, 0) if unavailable."cam_high_mask": ndarray, # indicates the validness of each timestep, (IMG_HISORY_SIZE,) boolean array.# For the first IMAGE_HISTORY_SIZE-1 timesteps, the mask should be False."cam_left_wrist": ndarray, # left wrist camera image, (IMG_HISORY_SIZE, H, W, 3).# or (IMG_HISORY_SIZE, 0, 0, 0) if unavailable."cam_left_wrist_mask": ndarray,"cam_right_wrist": ndarray, # right wrist camera image, (IMG_HISORY_SIZE, H, W, 3).# or (IMG_HISORY_SIZE, 0, 0, 0) if unavailable.# If only one wrist, make it right wrist, plz."cam_right_wrist_mask": ndarray} or None if the episode is invalid."""with h5py.File(file_path, "r") as f:qpos = f["observations"]["qpos"][:]left_arm_dim = f["observations"]["left_arm_dim"][:]right_arm_dim = f["observations"]["right_arm_dim"][:]num_steps = qpos.shape[0]action_dim = qpos# [Optional] We drop too-short episode# if num_steps < 128:# return False, None# [Optional] We skip the first few still stepsEPS = 1e-2# Get the idx of the first qpos whose delta exceeds the thresholdqpos_delta = np.abs(qpos - qpos[0:1])indices = np.where(np.any(qpos_delta > EPS, axis=1))[0]if len(indices) > 0:first_idx = indices[0]else:raise ValueError("Found no qpos that exceeds the threshold.")# We randomly sample a timestepstep_id = np.random.randint(first_idx - 1, num_steps)# Load the instructiondir_path = os.path.dirname(file_path)# with open(os.path.join(dir_path, 'instruction.json'), 'r') as f_instr:# instruction_dict = json.load(f_instr)# # We have 1/3 prob to use original instruction,# # 1/3 to use simplified instruction,# # and 1/3 to use expanded instruction.# instruction_type = np.random.choice([# 'instruction', 'expanded_instruction'])# instruction = instruction_dict[instruction_type]# if isinstance(instruction, list):# instruction = np.random.choice(instruction)# You can also use precomputed language embeddings (recommended)# instruction = "path/to/lang_embed.pt"instructions_path = os.path.join(dir_path, "instructions")instructions_names = []for filename in os.listdir(instructions_path):# 检查文件名是否以.pt结尾if filename.endswith(".pt"):instructions_names.append(os.path.join(instructions_path, filename))instruction = np.random.choice(instructions_names)# print(f"choose {instruction} file as instruction.")# Assemble the metameta = {"dataset_name": self.DATASET_NAME,"#steps": num_steps,"step_id": step_id,"instruction": instruction,}# Rescale gripper to [0, 1]# qpos = qpos / np.array([[1 for i in range(left_arm_dim[0] + 1 + right_arm_dim[0] + 1)]])# target_qpos = f["action"][step_id:step_id + self.CHUNK_SIZE] / np.array(# [[1 for i in range(left_arm_dim[0] + 1 + right_arm_dim[0] + 1)]])qpos = qpos / np.array(# [[1, 1, 1, 1, 1, 1, 4.7908, 1, 1, 1, 1, 1, 1, 4.7888]] [[180, 180, 180, 180, 180, 180]])target_qpos = f['action'][step_id:step_id + self.CHUNK_SIZE] / np.array(# [[1, 1, 1, 1, 1, 1, 11.8997, 1, 1, 1, 1, 1, 1, 13.9231]] [[180, 180, 180, 180, 180, 180]])# Parse the state and actionstate = qpos[step_id:step_id + 1]state_std = np.std(qpos, axis=0)state_mean = np.mean(qpos, axis=0)state_norm = np.sqrt(np.mean(qpos**2, axis=0))actions = target_qposif actions.shape[0] < self.CHUNK_SIZE:# Pad the actions using the last actionactions = np.concatenate([actions,np.tile(actions[-1:], (self.CHUNK_SIZE - actions.shape[0], 1)),],axis=0,)# Fill the state/action into the unified vectordef fill_in_state(values):# Target indices corresponding to your state space# In this example: 6 joints + 1 gripper for each armUNI_STATE_INDICES = [STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(6)# ] + [# STATE_VEC_IDX_MAPPING["right_gripper_open"]]uni_vec = np.zeros(values.shape[:-1] + (self.STATE_DIM, ))uni_vec[..., UNI_STATE_INDICES] = valuesreturn uni_vecstate = fill_in_state(state)state_indicator = fill_in_state(np.ones_like(state_std))state_std = fill_in_state(state_std)state_mean = fill_in_state(state_mean)state_norm = fill_in_state(state_norm)# If action's format is different from state's,# you may implement fill_in_action()actions = fill_in_state(actions)# Parse the imagesdef parse_img(key):imgs = []for i in range(max(step_id - self.IMG_HISORY_SIZE + 1, 0), step_id + 1):img_bits = f["observations"]["images"][key][i]img = cv2.imdecode(np.frombuffer(img_bits, np.uint8), cv2.IMREAD_COLOR)imgs.append(img)imgs = np.stack(imgs)if imgs.shape[0] < self.IMG_HISORY_SIZE:# Pad the images using the first imageimgs = np.concatenate([np.tile(imgs[:1],(self.IMG_HISORY_SIZE - imgs.shape[0], 1, 1, 1),),imgs,],axis=0,)return imgs# `cam_high` is the external camera imagecam_high = parse_img("cam_high")# For step_id = first_idx - 1, the valid_len should be onevalid_len = min(step_id - (first_idx - 1) + 1, self.IMG_HISORY_SIZE)cam_high_mask = np.array([False] * (self.IMG_HISORY_SIZE - valid_len) + [True] * valid_len)# cam_left_wrist = parse_img("cam_left_wrist")# cam_left_wrist_mask = cam_high_mask.copy()cam_left_wrist = np.zeros((self.IMG_HISORY_SIZE, 0, 0, 0))#parse_img('cam_right_wrist')cam_left_wrist_mask = np.array([False] * self.IMG_HISORY_SIZE)#cam_high_mask.copy()cam_right_wrist = parse_img("cam_right_wrist")cam_right_wrist_mask = cam_high_mask.copy() # 使用相同的掩码逻辑# Return the resulting sample# For unavailable images, return zero-shape arrays, i.e., (IMG_HISORY_SIZE, 0, 0, 0)# E.g., return np.zeros((self.IMG_HISORY_SIZE, 0, 0, 0)) for the key "cam_left_wrist",# if the left-wrist camera is unavailable on your robotreturn True, {"meta": meta,"state": state,"state_std": state_std,"state_mean": state_mean,"state_norm": state_norm,"actions": actions,"state_indicator": state_indicator,"cam_high": cam_high,"cam_high_mask": cam_high_mask,"cam_left_wrist": cam_left_wrist,"cam_left_wrist_mask": cam_left_wrist_mask,"cam_right_wrist": cam_right_wrist,"cam_right_wrist_mask": cam_right_wrist_mask,}def parse_hdf5_file_state_only(self, file_path):"""[Modify] Parse a hdf5 file to generate a state trajectory.Args:file_path (str): the path to the hdf5 fileReturns:valid (bool): whether the episode is valid, which is useful for filtering.If False, this episode will be dropped.dict: a dictionary containing the training sample,{"state": ndarray, # state[:], (T, STATE_DIM)."action": ndarray, # action[:], (T, STATE_DIM).} or None if the episode is invalid."""with h5py.File(file_path, "r") as f:qpos = f["observations"]["qpos"][:]left_arm_dim = f["observations"]["left_arm_dim"][:]right_arm_dim = f["observations"]["right_arm_dim"][:]num_steps = qpos.shape[0]# [Optional] We drop too-short episode# if num_steps < 128:# return False, None# [Optional] We skip the first few still stepsEPS = 1e-2# Get the idx of the first qpos whose delta exceeds the thresholdqpos_delta = np.abs(qpos - qpos[0:1])indices = np.where(np.any(qpos_delta > EPS, axis=1))[0]if len(indices) > 0:first_idx = indices[0]else:raise ValueError("Found no qpos that exceeds the threshold.")# Rescale gripper to [0, 1]# qpos = qpos / np.array([[1 for i in range(left_arm_dim[0] + right_arm_dim[0] + 2)]])# target_qpos = f["action"][:] / np.array([[1 for i in range(left_arm_dim[0] + right_arm_dim[0] + 2)]])qpos = qpos / np.array(# [[1, 1, 1, 1, 1, 1, 4.7908, 1, 1, 1, 1, 1, 1, 4.7888]] [[180, 180, 180, 180, 180, 180]])target_qpos = f['action'][first_idx - 1:] / np.array(# [[1, 1, 1, 1, 1, 1, 11.8997, 1, 1, 1, 1, 1, 1, 13.9231]] [[180, 180, 180, 180, 180, 180]])# Parse the state and actionstate = qpos[first_idx - 1:]action = target_qpos[first_idx - 1:]# Standardize trajectory length to avoid batch size mismatch# Use a fixed length (e.g., 128) or pad/truncate to matchtarget_length = 128 # You can adjust this valueif state.shape[0] > target_length:# Truncate to target lengthstate = state[:target_length]action = action[:target_length]elif state.shape[0] < target_length:# Pad with the last state/actionpad_length = target_length - state.shape[0]state = np.concatenate([state, np.tile(state[-1:], (pad_length, 1))], axis=0)action = np.concatenate([action, np.tile(action[-1:], (pad_length, 1))], axis=0)# Fill the state/action into the unified vectordef fill_in_state(values):# Target indices corresponding to your state space# In this example: 6 joints + 1 gripper for each armUNI_STATE_INDICES = [STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(6)# ] + [# STATE_VEC_IDX_MAPPING["right_gripper_open"]]uni_vec = np.zeros(values.shape[:-1] + (self.STATE_DIM, ))uni_vec[..., UNI_STATE_INDICES] = valuesreturn uni_vecstate = fill_in_state(state)action = fill_in_state(action)# Return the resulting samplereturn True, {"state": state, "action": action}if __name__ == "__main__":ds = HDF5VLADataset()for i in range(len(ds)):print(f"Processing episode {i}/{len(ds)}...")ds.get_item(i)
开始训练
训练环境配置
我们首先来安装RDT训练所需的环境,此我们进入RDT
目录下依次安装如下包即可:
# step1:安装torch、torchvision
pip install torch==2.1.0 torchvision==0.16.0 --index-url https://download.pytorch.org/whl/cu121
# step2:安装packaging
pip install packaging==24.0
# step3:安装其他依赖
pip install -r requirements.txt
除了上述依赖之外我们还需要安装flash_attn
用来加速,为了避免网络连接问题我们手动下载预编译的wheel
文件,下载连接为:https://github.com/Dao-AILab/flash-attention/releases,我们需要根据我们实际安装的torch
及cuda
版本来选择对应的版本,然后我们还需要根据我们下载的PyTorch
是如何编译的来选择对应的cxx11abi
是 TRUE 还是 FALSE。
$ python3 -c "import torch ; print(torch.__config__.show())"
接着我们便可以根据输出判断我们要下载的是哪个版本啦,如下图所示我们当前的PyTorch
是CXX11_ABI = 0,因此我们要下载的是cxx11abiFALSE
的 .whl
文件
PyTorch built with:- GCC 9.3- C++ Version: 201703- Intel(R) oneAPI Math Kernel Library Version 2022.2-Product Build 20220804 for Intel(R) 64 architecture applications- Intel(R) MKL-DNN v3.4.2 (Git Hash 1137e04ec0b5251ca2b4400a4fd3c667ce843d67)- OpenMP 201511 (a.k.a. OpenMP 4.5)- LAPACK is enabled (usually provided by MKL)- NNPACK is enabled- CPU capability usage: AVX2- CUDA Runtime 12.1- NVCC architecture flags: -gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_90,code=sm_90- CuDNN 90.1 (built against CUDA 12.4)- Magma 2.6.1- Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=12.1, CUDNN_VERSION=9.1.0, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wsuggest-override -Wno-psabi -Wno-error=pedantic -Wno-error=old-style-cast -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=2.4.1, USE_CUDA=ON, USE_CUDNN=ON, USE_CUSPARSELT=1, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_GLOO=ON, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=1, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, USE_ROCM_KERNEL_ASSERT=OFF,
如果上述步骤跟我的一样的话,那么大家需要下载并安装的应该是如下的版本,下载后安装即可:
pip3 install flash_attn-2.7.2.post1+cu12torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
预训练模型下载
RDT分为了1B版本(SigLip+DIT+Adaptor)以及单独的170M(DiT)两个版本,其中的区别仅在于最后的DiT的hidden_size
和depth
维度区别,170M相比于1B的版本直接减半了,如果需要在RDKS100上部署的话请参考接下来的RDT170M模型版本以保证可行的性能,在训练之前还需按照如下的步骤下载预训练模型及完成训练环境的安装;
# 法一: 直接运行我仓库中写好的脚本
cd weights/RDT
bash _download.sh
# 法二: 手动下载
export HF_ENDPOINT=https://hf-mirror.com # 国内镜像,加速下载
huggingface-cli download google/t5-v1_1-xxl --local-dir t5-v1_1-xxl
huggingface-cli download google/siglip-so400m-patch14-384 --local-dir siglip-so400m-patch14-384
huggingface-cli download robotics-diffusion-transformer/rdt-1b --local-dir rdt-1b
huggingface-cli download robotics-diffusion-transformer/rdt-170m --local-dir rdt-170m
RDT1B微调
要训练1B的版本我们需要修改RDT/configs/base.yaml
文件中model
类下的rdt
参数
rdt:
# 1B: num_head 32 hidden_size 2048 depth: 28
# 170M: num_head 32 hidden_size 1024 depth: 14hidden_size: 2048depth: 28num_heads: 32cond_pos_embed_type: multimodal
接着生成训练的参数yml文件,并将pretrained_model_name_or_path
指向我们先前下载的1B模型
bash generate.sh RDT1B_LeRobot
# Generated on 2025-08-28 17:14:20
model: RDT1B_LeRobot
data_path: training_data/RDT1B_LeRobot
checkpoint_path: checkpoints/RDT1B_LeRobot
pretrained_model_name_or_path: ../weights/RDT/rdt-1b
cuda_visible_device: '0,1,2,3'
train_batch_size: 16
sample_batch_size: 32
max_train_steps: 10000
checkpointing_period: 2500
sample_period: 100
checkpoints_total_limit: 40
learning_rate: 0.0001
dataloader_num_workers: 8
state_noise_snr: 40
gradient_accumulation_steps: 1
接着直接开始训练:
bash finetune.sh RDT1B_LeRobot
RDT170M微调
要训练170M的版本我们需要修改RDT/configs/base.yaml
文件中model
类下的rdt
参数
rdt:
# 1B: num_head 32 hidden_size 2048 depth: 28
# 170M: num_head 32 hidden_size 1024 depth: 14hidden_size: 1024depth: 14num_heads: 32cond_pos_embed_type: multimodal
接着生成训练的参数yml文件,并将pretrained_model_name_or_path
指向我们先前下载的170M模型
bash generate.sh RDT170M_LeRobot
# Generated on 2025-08-28 17:14:20
model: RDT170M_LeRobot
data_path: training_data/RDT170M_LeRobot
checkpoint_path: checkpoints/RDT170M_LeRobot
pretrained_model_name_or_path: ../weights/RDT/rdt-170m
cuda_visible_device: '0,1'
train_batch_size: 16
sample_batch_size: 32
max_train_steps: 10000
checkpointing_period: 2500
sample_period: 100
checkpoints_total_limit: 40
learning_rate: 0.0001
dataloader_num_workers: 8
state_noise_snr: 40
gradient_accumulation_steps: 1
接着直接开始训练:
bash finetune.sh RDT170M_LeRobot
在双卡A100-40GB上BS16即单卡8BS的显存占用及训练速度参考如下,按照RDT论文中的说法关注overall_avg_sample_mse
指标,RDT170M和1B的版本在数据量仅有100条的时候均能在7000步左右实现拟合指标下降到0.0001
量级
四、实际评测
在训练完成之后我们便可以开始实机评测啦,由于我们目前需要在Mac上连接LeRobot机械臂进行控制,因此我们在实际使用A100或者是RDKS100部署推理的时候还需要完成两端之间的通信代码,我们在这里就用最简单的Socket来实现ServerClient
,还有其他更优的ZMQ等方式就不在这里呈现了,具体实现的代码如下,这份代码需要同时放到本地及推理服务器端进行调用
import socket
import numpy as np
import zlib
import json
import base64
import time
from typing import Any
import torchclass NumpyEncoder(json.JSONEncoder):"""Enhanced json encoder for numpy types and PyTorch tensors with array reconstruction info"""def default(self, obj):if isinstance(obj, np.ndarray):return {'__numpy_array__': True,'data': base64.b64encode(obj.tobytes()).decode('ascii'),'dtype': str(obj.dtype),'shape': obj.shape}elif torch is not None and isinstance(obj, torch.Tensor):# 将 PyTorch Tensor 转换为 numpy 数组numpy_array = obj.cpu().detach().numpy()return {'__numpy_array__': True,'data': base64.b64encode(numpy_array.tobytes()).decode('ascii'),'dtype': str(numpy_array.dtype),'shape': numpy_array.shape}elif isinstance(obj, (np.integer, np.floating, np.bool_)):return obj.item()return super().default(obj)def numpy_to_json(data: Any) -> str:return json.dumps(data, cls=NumpyEncoder)def json_to_numpy(json_str: str) -> Any:def hook(dct):if '__numpy_array__' in dct:data = base64.b64decode(dct['data'])return np.frombuffer(data, dtype=dct['dtype']).reshape(dct['shape'])return dctreturn json.loads(json_str, object_hook=hook)class CommonUtils:@staticmethoddef serialize(data: Any) -> bytes:return zlib.compress(numpy_to_json(data).encode('utf-8'))@staticmethoddef deserialize(data: bytes) -> Any:return json_to_numpy(zlib.decompress(data).decode('utf-8'))def send_all(sock, payload):sock.sendall(len(payload).to_bytes(8, 'big') + payload)def recv_all(sock) :length_bytes = sock.recv(8)if not length_bytes:return Nonelength = int.from_bytes(length_bytes, 'big')buf = b''while len(buf) < length:chunk = sock.recv(length - len(buf))if not chunk:return Nonebuf += chunkreturn bufclass ServerClient:def __init__(self, host='localhost', port=5000, is_server=True):self.host, self.port, self.is_server = host, port, is_serverself.utils = CommonUtils()self._connect()def _connect(self):if self.is_server:self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)self.sock.bind((self.host, self.port))self.sock.listen(1)print(f"[ServerClient] Listening on {self.host}:{self.port}")self.conn, addr = self.sock.accept()print(f"[ServerClient] Connected by {addr}")else:while True:try:self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)self.sock.connect((self.host, self.port))self.conn = self.sockprint(f"[ServerClient] Connected to {self.host}:{self.port}")breakexcept (ConnectionRefusedError, OSError):print("[ServerClient] Waiting for server...")time.sleep(2)def send(self, data):payload = self.utils.serialize(data)try:send_all(self.conn, payload)except (BrokenPipeError, ConnectionResetError, OSError):print("[ServerClient] Connection lost. Reconnecting...")self._connect()send_all(self.conn, payload)def receive(self):try:buf = recv_all(self.conn)return self.utils.deserialize(buf) if buf else Noneexcept (BrokenPipeError, ConnectionResetError, OSError):print("[ServerClient] Connection lost. Reconnecting...")self._connect()return Nonedef close(self):self.conn.close()self.sock.close()class Client:def __init__(self, host='127.0.0.1', port=5000):self.host, self.port = host, portself.utils = CommonUtils()self.connect()def connect(self):while True:try:self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)self.sock.connect((self.host, self.port))print(f"[Client] Connected to {self.host}:{self.port}")breakexcept (ConnectionRefusedError, OSError):print("[Client] Waiting for server...")time.sleep(2)def send(self, data):payload = self.utils.serialize(data)try:send_all(self.sock, payload)except (BrokenPipeError, ConnectionResetError, OSError):print("[Client] Connection lost. Reconnecting...")self.connect()send_all(self.sock, payload)def receive(self):try:buf = recv_all(self.sock)return self.utils.deserialize(buf) if buf else Noneexcept (BrokenPipeError, ConnectionResetError, OSError):print("[Client] Connection lost. Reconnecting...")self.connect()return Nonedef close(self):self.sock.close()print("[Client] Closed.")class Server:def __init__(self, host='0.0.0.0', port=5000):self.host, self.port = host, portself.utils = CommonUtils()self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)self.sock.bind((self.host, self.port))self.sock.listen(1)print(f"[Server] Listening on {self.host}:{self.port}")self._wait_client()def _wait_client(self):print("[Server] Waiting for client...")self.conn, addr = self.sock.accept()print(f"[Server] Connected by {addr}")def send(self, data: Any):payload = self.utils.serialize(data)try:send_all(self.conn, payload)except (BrokenPipeError, ConnectionResetError, OSError):print("[Server] Client disconnected. Waiting new client...")self._wait_client()send_all(self.conn, payload)def receive(self):try:buf = recv_all(self.conn)return self.utils.deserialize(buf) if buf else Noneexcept (BrokenPipeError, ConnectionResetError, OSError):print("[Server] Client disconnected. Waiting new client...")self._wait_client()return Nonedef close(self):self.conn.close()self.sock.close()print("[Server] Closed.")
服务器单卡部署
接着我们便可以完成我们服务器上的推理代码了,我们参考RDT中的RDT/scripts/agilex_model.py
来完成我们的lerobot_rdt_server
,我们在这个代码中完成对RoboticDiffusionTransformerModel
类的修改同时使用类中的step
执行推理,并集成我们的ServerClient来接收本地电脑发来的机械臂observation
数据,具体的代码实现如下:
import os, sysimport numpy as np
import torch
from PIL import Image
from torchvision import transforms
import yamlfrom pathlib import Path# get current workspace
current_file = Path(__file__)
sys.path.append(os.path.join(current_file.parent.parent, "models"))
sys.path.append(os.path.join(current_file.parent.parent, "models"))
sys.path.append(os.path.join(current_file.parent.parent))
from configs.state_vec import STATE_VEC_IDX_MAPPING
from multimodal_encoder.siglip_encoder import SiglipVisionTower
from multimodal_encoder.t5_encoder import T5Embedder
from rdt_runner import RDTRunner
from server_client import ServerClient# The indices that the raw vector should be mapped to in the unified action vector
AGILEX_STATE_INDICES = [STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(6)
]# Create the RDT model
def create_model(args, **kwargs):model = RoboticDiffusionTransformerModel(args, **kwargs)pretrained = kwargs.get("pretrained", None)if pretrained is not None and os.path.isfile(pretrained):model.load_pretrained_weights(pretrained)return modelclass RoboticDiffusionTransformerModel(object):"""A wrapper for the RDT model, which handles1. Model initialization2. Encodings of instructions3. Model inference"""def __init__(self,args,device="cuda",dtype=torch.bfloat16,image_size=None,control_frequency=25,pretrained=None,pretrained_vision_encoder_name_or_path=None,):self.args = argsself.dtype = dtypeself.image_size = image_sizeself.device = deviceself.control_frequency = control_frequency# We do not use the text encoder due to limited GPU memory# self.text_tokenizer, self.text_model = self.get_text_encoder(pretrained_text_encoder_name_or_path)self.image_processor, self.vision_model = self.get_vision_encoder(pretrained_vision_encoder_name_or_path)self.policy = self.get_policy(pretrained)self.reset()def get_policy(self, pretrained):"""Initialize the model."""# Initialize model with argumentsif pretrained is None or os.path.isfile(pretrained):img_cond_len = (self.args["common"]["img_history_size"] * self.args["common"]["num_cameras"] *self.vision_model.num_patches)_model = RDTRunner(action_dim=self.args["common"]["state_dim"],pred_horizon=self.args["common"]["action_chunk_size"],config=self.args["model"],lang_token_dim=self.args["model"]["lang_token_dim"],img_token_dim=self.args["model"]["img_token_dim"],state_token_dim=self.args["model"]["state_token_dim"],max_lang_cond_len=self.args["dataset"]["tokenizer_max_length"],img_cond_len=img_cond_len,img_pos_embed_config=[# No initial pos embed in the last grid size# since we've already done in ViT("image",(self.args["common"]["img_history_size"],self.args["common"]["num_cameras"],-self.vision_model.num_patches,),),],lang_pos_embed_config=[# Similarly, no initial pos embed for language("lang", -self.args["dataset"]["tokenizer_max_length"]),],dtype=self.dtype,)else:_model = RDTRunner.from_pretrained(pretrained)return _modeldef get_text_encoder(self, pretrained_text_encoder_name_or_path):text_embedder = T5Embedder(from_pretrained=pretrained_text_encoder_name_or_path,model_max_length=self.args["dataset"]["tokenizer_max_length"],device=self.device,)tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.modelreturn tokenizer, text_encoderdef get_vision_encoder(self, pretrained_vision_encoder_name_or_path):vision_encoder = SiglipVisionTower(vision_tower=pretrained_vision_encoder_name_or_path, args=None)image_processor = vision_encoder.image_processorreturn image_processor, vision_encoderdef reset(self):"""Set model to evaluation mode."""device = self.deviceweight_dtype = self.dtypeself.policy.eval()# self.text_model.eval()self.vision_model.eval()self.policy = self.policy.to(device, dtype=weight_dtype)# self.text_model = self.text_model.to(device, dtype=weight_dtype)self.vision_model = self.vision_model.to(device, dtype=weight_dtype)def load_pretrained_weights(self, pretrained=None):if pretrained is None:returnprint(f"Loading weights from {pretrained}")filename = os.path.basename(pretrained)if filename.endswith(".pt"):checkpoint = torch.load(pretrained)self.policy.load_state_dict(checkpoint["module"])elif filename.endswith(".safetensors"):from safetensors.torch import load_modelload_model(self.policy, pretrained)else:raise NotImplementedError(f"Unknown checkpoint format: {pretrained}")def encode_instruction(self, instruction, device="cuda"):"""Encode string instruction to latent embeddings.Args:instruction: a string of instructiondevice: a string of deviceReturns:pred: a tensor of latent embeddings of shape (text_max_length, 512)"""tokens = self.text_tokenizer(instruction, return_tensors="pt", padding="longest",truncation=True)["input_ids"].to(device)tokens = tokens.view(1, -1)with torch.no_grad():pred = self.text_model(tokens).last_hidden_state.detach()return preddef _format_joint_to_state(self, joints):"""Format the joint proprioception into the unified action vector.Args:joints (torch.Tensor): The joint proprioception to be formatted.qpos ([B, N, 14]).Returns:state (torch.Tensor): The formatted vector for RDT ([B, N, 128])."""# Rescale the gripper to the range of [0, 1]joints = joints / torch.tensor([[[180, 180, 180, 180, 180, 180]]],device=joints.device,dtype=joints.dtype,)B, N, _ = joints.shapestate = torch.zeros((B, N, self.args["model"]["state_token_dim"]),device=joints.device,dtype=joints.dtype,)# Fill into the unified state vectorstate[:, :, AGILEX_STATE_INDICES] = joints# Assemble the mask indicating each dimension's availabilitystate_elem_mask = torch.zeros((B, self.args["model"]["state_token_dim"]),device=joints.device,dtype=joints.dtype,)state_elem_mask[:, AGILEX_STATE_INDICES] = 1return state, state_elem_maskdef _unformat_action_to_joint(self, action):"""Unformat the unified action vector into the joint action to be executed.Args:action (torch.Tensor): The unified action vector to be unformatted.([B, N, 128])Returns:joints (torch.Tensor): The unformatted robot joint action.qpos ([B, N, 14])."""action_indices = AGILEX_STATE_INDICESjoints = action[:, :, action_indices]# Rescale the gripper back to the action range# Note that the action range and proprioception range are different# for Mobile ALOHA robotjoints = joints * torch.tensor([[[180, 180, 180, 180, 180, 180]]],device=joints.device,dtype=joints.dtype,)return joints@torch.no_grad()def step(self, proprio, images, text_embeds):"""Predict the next action chunk given theproprioceptive states, images, and instruction embeddings.Args:proprio: proprioceptive statesimages: RGB images, the order should be[ext_{t-1}, right_wrist_{t-1}, left_wrist_{t-1},ext_{t}, right_wrist_{t}, left_wrist_{t}]text_embeds: instruction embeddingsReturns:action: predicted action"""device = self.devicedtype = self.dtype# The background image used for paddingbackground_color = np.array([int(x * 255) for x in self.image_processor.image_mean],dtype=np.uint8).reshape(1, 1, 3)background_image = (np.ones((self.image_processor.size["height"],self.image_processor.size["width"],3,),dtype=np.uint8,) * background_color)# Preprocess the images by order and encode themimage_tensor_list = []for image in images:if image is None:# Replace it with the background imageimage = Image.fromarray(background_image)else:# Convert numpy array to PIL Image if neededif isinstance(image, np.ndarray):image = Image.fromarray(image)if self.image_size is not None:image = transforms.Resize(self.image_size)(image)if self.args["dataset"].get("auto_adjust_image_brightness", False):pixel_values = list(image.getdata())average_brightness = sum(sum(pixel) for pixel in pixel_values) / (len(pixel_values) * 255.0 * 3)if average_brightness <= 0.15:image = transforms.ColorJitter(brightness=(1.75, 1.75))(image)if self.args["dataset"].get("image_aspect_ratio", "pad") == "pad":def expand2square(pil_img, background_color):width, height = pil_img.sizeif width == height:return pil_imgelif width > height:result = Image.new(pil_img.mode, (width, width), background_color)result.paste(pil_img, (0, (width - height) // 2))return resultelse:result = Image.new(pil_img.mode, (height, height), background_color)result.paste(pil_img, ((height - width) // 2, 0))return resultimage = expand2square(image, tuple(int(x * 255) for x in self.image_processor.image_mean))image = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]image_tensor_list.append(image)image_tensor = torch.stack(image_tensor_list, dim=0).to(device, dtype=dtype)image_embeds = self.vision_model(image_tensor).detach()image_embeds = image_embeds.reshape(-1, self.vision_model.hidden_size).unsqueeze(0)# Prepare the proprioception states and the control frequency# Convert numpy array to tensor if neededif isinstance(proprio, np.ndarray):# Copy the array to make it writableproprio = torch.from_numpy(proprio.copy())joints = proprio.to(device).unsqueeze(0) # (1, 1, 14)states, state_elem_mask = self._format_joint_to_state(joints) # (1, 1, 128), (1, 128)states, state_elem_mask = states.to(device, dtype=dtype), state_elem_mask.to(device, dtype=dtype)states = states[:, -1:, :] # (1, 1, 128)ctrl_freqs = torch.tensor([self.control_frequency]).to(device)text_embeds = text_embeds.to(device, dtype=dtype)# Predict the next action chunk given the inputstrajectory = self.policy.predict_action(lang_tokens=text_embeds,lang_attn_mask=torch.ones(text_embeds.shape[:2], dtype=torch.bool, device=text_embeds.device),img_tokens=image_embeds,state_tokens=states,action_mask=state_elem_mask.unsqueeze(1),ctrl_freqs=ctrl_freqs,)trajectory = self._unformat_action_to_joint(trajectory).to(torch.float32)return trajectoryclass LERobotRDTServer:def __init__(self, pretrained_vision_encoder_name_or_path, pretrained, args, lang_model):self.policy = create_model(args=args,dtype=torch.bfloat16,pretrained=pretrained,pretrained_vision_encoder_name_or_path=pretrained_vision_encoder_name_or_path,control_frequency=30,)self.server = ServerClient(host="0.0.0.0", port=5002, is_server=True)# Load and debug language embeddingsself.lang_embeddings = torch.load(lang_model)print(f"Loaded language embeddings shape: {self.lang_embeddings.shape}")print(f"Model expects tokenizer_max_length: {self.policy.args['dataset']['tokenizer_max_length']}")print(f"Model lang_token_dim: {self.policy.args['model']['lang_token_dim']}")# Check if dimensions matchexpected_seq_len = self.policy.args["dataset"]["tokenizer_max_length"]expected_hidden_dim = self.policy.args["model"]["lang_token_dim"]# Handle different embedding formatsif len(self.lang_embeddings.shape) == 2:# Format: [seq_len, hidden_dim]actual_seq_len, actual_hidden_dim = self.lang_embeddings.shapeif actual_seq_len != expected_seq_len:print(f"WARNING: Sequence length mismatch! Expected {expected_seq_len}, got {actual_seq_len}")if actual_hidden_dim != expected_hidden_dim:print(f"WARNING: Hidden dimension mismatch! Expected {expected_hidden_dim}, got {actual_hidden_dim}")elif len(self.lang_embeddings.shape) == 3:# Format: [batch_size, seq_len, hidden_dim]actual_batch, actual_seq_len, actual_hidden_dim = self.lang_embeddings.shapeif actual_seq_len != expected_seq_len:print(f"WARNING: Sequence length mismatch! Expected {expected_seq_len}, got {actual_seq_len}")if actual_hidden_dim != expected_hidden_dim:print(f"WARNING: Hidden dimension mismatch! Expected {expected_hidden_dim}, got {actual_hidden_dim}")else:print(f"WARNING: Unexpected embedding shape: {self.lang_embeddings.shape}")def run(self):print("LERobot RDT Server started, waiting for messages...")try:while True:print("Waiting for RDT data...")rdt_data = self.server.receive()print(f"Received RDT data, message_id: {rdt_data['message_id']}")# Perform inference# Ensure language embeddings have correct shapeif len(self.lang_embeddings.shape) == 2:# [seq_len, hidden_dim] -> [1, seq_len, hidden_dim]text_embeds = self.lang_embeddings.unsqueeze(0)else:# Already [batch_size, seq_len, hidden_dim]text_embeds = self.lang_embeddingsaction = self.policy.step(proprio=rdt_data["proprio"],images=rdt_data["images"],text_embeds=text_embeds,)# Prepare response - use 'actions' key to match client expectationmessage_id = rdt_data["message_id"]action_data = {"message_id": message_id,"actions": action, # Changed from 'action' to 'actions'}# Send responseprint(f"send action data, action_data: {action_data}")self.server.send(action_data)print(f"Sent action data for message_id: {message_id}")except KeyboardInterrupt:print("\nServer stopped by user")self.server.close()except Exception as e:print(f"Error in server loop: {e}")self.server.close()raiseif __name__ == "__main__":path_to_rdt_model_wights = "/home/qi.xiong/DualArm/RoboTwin/policy/RDT/checkpoints/RDT_LeRobot/checkpoint-7500/pytorch_model/mp_rank_00_model_states.pt"path_to_vision_encoder_model = "/home/qi.xiong/DualArm/RoboTwin/policy/weights/RDT/siglip-so400m-patch14-384"lang_model = "/home/qi.xiong/DualArm/RoboTwin/policy/RDT/scripts/lerobot_rdt_data/greenmarker_scene1/episode_4/instructions/lang_embed_0.pt"with open("/home/qi.xiong/DualArm/RoboTwin/policy/RDT/configs/base.yaml", "r") as fp:config = yaml.safe_load(fp)rdt_server = LERobotRDTServer(path_to_vision_encoder_model, path_to_rdt_model_wights, config, lang_model)rdt_server.run()
新版的LeRobot使用函数来替换了老的observation获取函数,同时其中的数据结构也进行了更改,因此我们在record的基础上直接复制一份使用最小的状态获取及机械臂控制实例完成本地的数据传输及通信控制代码,具体代码如下,文件位置为lerobot/src/lerobot/record_rdt.py
import logging
import time
from dataclasses import asdict, dataclass
from pathlib import Path
from pprint import pformat
from collections import deque
import torch
from PIL import Image
import numpy as np
from lerobot.cameras import ( # noqa: F401CameraConfig, # noqa: F401
)
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
from lerobot.datasets.image_writer import safe_stop_image_writer
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
from lerobot.datasets.video_utils import VideoEncodingManager
from lerobot.policies.factory import make_policy
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.robots import ( # noqa: F401Robot,RobotConfig,bi_so100_follower,hope_jr,koch_follower,make_robot_from_config,so100_follower,so101_follower,
)
from lerobot.robots.so101_follower.so101_follower import SO101Follower
from lerobot.robots.so101_follower.config_so101_follower import SO101FollowerConfig
from lerobot.teleoperators import ( # noqa: F401Teleoperator,TeleoperatorConfig,bi_so100_leader,homunculus,koch_leader,make_teleoperator_from_config,so100_leader,so101_leader,
)
from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop
from lerobot.utils.control_utils import (init_keyboard_listener,is_headless,predict_action,sanity_check_dataset_name,sanity_check_dataset_robot_compatibility,
)
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.visualization_utils import _init_rerun, log_rerun_data
from server_client import *debug_save_img = True _action_queue = deque([], maxlen=64)
_message_id = 0
_last_cam_high = None
_last_cam_right_wrist = None @safe_stop_image_writer
def record_loop(robot: Robot,client: Client,
):global _action_queue, _message_id, _last_cam_high, _last_cam_right_wristobservation = robot.get_observation()cam_high = observation['high']cam_right_wrist = observation['arm']image_arrs = [_last_cam_high,_last_cam_right_wrist,None,cam_high,cam_right_wrist,None]images = [arr if arr is not None else Nonefor arr in image_arrs]joint_positions = [observation[key] for key in observation.keys() if key.endswith('.pos')]proprio = torch.tensor(joint_positions, dtype=torch.float32).unsqueeze(0)###################Debug图像########################if debug_save_img:imgs_to_show = [cam_high, cam_right_wrist, _last_cam_high, _last_cam_right_wrist]if all(img is not None for img in imgs_to_show):pil_imgs = []for img in imgs_to_show:if img.dtype != np.uint8:img = np.clip(img, 0, 1)img = (img * 255).astype(np.uint8)if img.ndim == 2:img = np.stack([img]*3, axis=-1) elif img.shape[-1] == 1:img = np.repeat(img, 3, axis=-1)pil_imgs.append(Image.fromarray(img))w, h = pil_imgs[0].sizefor i in range(4):if pil_imgs[i].size != (w, h):pil_imgs[i] = pil_imgs[i].resize((w, h))new_img = Image.new('RGB', (w*2, h*2))new_img.paste(pil_imgs[0], (0, 0)) # 左上:新highnew_img.paste(pil_imgs[1], (w, 0)) # 右上:新wristnew_img.paste(pil_imgs[2], (0, h)) # 左下:老highnew_img.paste(pil_imgs[3], (w, h)) # 右下:老wristdebug_save_path = "debug_2x2.png"new_img.save(debug_save_path)print(f"Have been saved at: {debug_save_path}")# new_img.show()
###################Debug图像########################rdt_data = {'message_id': _message_id,'proprio': proprio,'images': images,'text_embeds': ""}client.send(rdt_data)_message_id += 1 print(f"send new rdt data done, message_id: {_message_id-1}")action_data = client.receive()if action_data is None:print("ERROR: Server returned None. Is the RDT server running?")print("Please start the RDT server first!")raise ConnectionError("Failed to receive response from RDT server")actions = action_data['actions']action_message_id = action_data["message_id"]print(f"receive actions done, message_id: {action_message_id}")# print(f"receive actions contents: {actions}")actions_array = np.array(actions)if len(actions_array.shape) == 3: action_sequence = actions_array[0, :, :] # 取第一个batch的所有时间步else:print(f"action shape should be 3 dim, but get {actions_array.shape} ")joint_names = ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"]for step_idx in range(0, len(action_sequence), 4): # 64个动作隔4个执行一次动作action_values = action_sequence[step_idx]action_dict = {f"{joint}.pos": float(action_values[i]) for i, joint in enumerate(joint_names)}sent_action = robot.send_action(action_dict)time.sleep(0.1) _last_cam_high = cam_high_last_cam_right_wrist = cam_right_wristdef main():robot = SO101Follower(SO101FollowerConfig(port="/dev/tty.usbmodem5AB90671801",id="my_awesome_follower_arm",cameras={"arm": OpenCVCameraConfig(index_or_path=0, width=1920, height=1080, fps=30),"high": OpenCVCameraConfig(index_or_path=2, width=1920, height=1080, fps=30)}))robot.connect()client = Client(host="localhost", port=5002)try:while True:record_loop(robot,client)time.sleep(0.1)except KeyboardInterrupt:passrobot.disconnect()if __name__ == "__main__":main()
接着按照你的实际摄像头串口配置修改main()
中的对应配置即可,接着在依次运行服务端即本地客户端即可看到机械臂开始运动完成任务:
# 服务端
python3 RDT/scripts/lerobot_rdt_server.py
# 客户端
python3 lerobot/src/lerobot/record_rdt.py
RDKS100部署
目前RDKS100上仅支持RDT170M的部署,接下来我们参考文档:RDT on Double RDK S100P 全流程文档,一步步的完成LeRobot的RDT的上板流程,我们首先使用如下的脚本RDT/export_all.py
来完成RDT中所有ONNX模型的导出:
# Copyright (c) 2025, Cauchy WuChao D-Robotics.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.import os
import re
import yaml
import logging
import argparse
from time import time
from collections import OrderedDictimport cv2
import numpy as np
import torch
import h5py
from PIL import Image as PImageimport torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transformsfrom diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepSchedulerfrom scripts.agilex_model import create_model
from configs.state_vec import STATE_VEC_IDX_MAPPING
from models.hub_mixin import CompatiblePyTorchModelHubMixin
from models.rdt.blocks import (FinalLayer, RDTBlock, TimestepEmbedder, get_1d_sincos_pos_embed_from_grid, get_multimodal_cond_pos_embed)
from models.multimodal_encoder.siglip_encoder import SiglipVisionTower
from models.multimodal_encoder.t5_encoder import T5Embedder# taskset -c 0-7 docker run [--gpus all] -it -v <ws>:/open_explorer ai_toolchain_ubuntu_22_s100_gpu:v3.2.0logging.basicConfig(level = logging.DEBUG,format = '[%(name)s] [%(asctime)s.%(msecs)03d] [%(levelname)s] %(message)s',datefmt='%H:%M:%S')
logger = logging.getLogger("RDK_RDT")def main():parser = argparse.ArgumentParser()parser.add_argument('--export_path', type=str, default="rdt_export_ws_test", help="")parser.add_argument('--config_path', type=str, default="configs/base.yaml", help="")parser.add_argument('--pretrained_vision_encoder', type=str, default="../weights/RDT/siglip-so400m-patch14-384", help="")parser.add_argument('--pretrained_model', type=str, default="checkpoints/RDT170M-LeRobot/checkpoint-9000/pytorch_model/mp_rank_00_model_states.pt", help="")parser.add_argument('--train_data', type=str, default="training_data/RDT170M-LeRobot", help="")# --train_data exampe:# .# ├── adjust_bottle-demo_clean-300# │ ├── episode_0# │ │ ├── episode_0.hdf5# │ │ └── instructions# │ ├── episode_1# │ │ ├── episode_1.hdf5# │ │ └── instructions# ...# ├── place_dual_shoes-demo_clean-300# │ ├── episode_0# │ │ ├── episode_0.hdf5# │ │ └── instructions# ...# └── place_empty_cup-demo_clean-300# ├── episode_0# │ ├── episode_0.hdf5# │ └── instructions# ...parser.add_argument('--num_samples', type=int, default=50, help="")parser.add_argument('--jobs', type=int, default=8, help="")parser.add_argument('--optimized_level', type=str, default="O2", help="")parser.add_argument('--ctrl_freq', type=int, default=25, help="")parser.add_argument('--left_arm_dim', type=int, default=6, help="")parser.add_argument('--right_arm_dim', type=int, default=6, help="")parser.add_argument('--cal_data_device', type=str, default='cuda:6', help="")parser.add_argument('--instructions_per_episode', type=int, default=1, help="")opt = parser.parse_args()logger.info(opt)# Create WorkSpaceos.makedirs(opt.export_path, exist_ok=True)## BPU_RDT_Policybpu_rdt_name = "BPU_RDT_Policy"bpu_rdt_path = os.path.join(opt.export_path, bpu_rdt_name)os.makedirs(bpu_rdt_path, exist_ok=True)os.system(f"cp {opt.config_path} {bpu_rdt_path}")bash_build_all_name = "build_all.sh"## test datastest_data_name = "test_data"test_data_path = os.path.join(opt.export_path, test_data_name)os.makedirs(test_data_path, exist_ok=True)## instructioninstruction_ws_name = "instructions"instruction_ws_path = os.path.join(opt.export_path, instruction_ws_name)os.makedirs(instruction_ws_path, exist_ok=True)for name in os.listdir(opt.train_data):os.makedirs(os.path.join(instruction_ws_path, name), exist_ok=True)## image adaptorimg_adaptor_ws_name = "img_adaptor_WorkSpace"img_adaptor_cal_name = "rdt_image_adaptor_calibration"img_adaptor_name = "rdt_image_adaptor.onnx"img_adaptor_config_name = "config.yaml"img_adaptor_bash_name = "build.sh"img_adaptor_path = os.path.join(opt.export_path, img_adaptor_ws_name, img_adaptor_name)img_adaptor_ws = os.path.join(opt.export_path, img_adaptor_ws_name)os.makedirs(img_adaptor_ws, exist_ok=True)global img_adaptor_cal_wsimg_adaptor_cal_ws = os.path.join(img_adaptor_ws, img_adaptor_cal_name)os.makedirs(img_adaptor_cal_ws, exist_ok=True)## action adaptorstate_adaptor_name1 = "rdt_state_adaptor_1x1x256.onnx"state_adaptor_path1 = os.path.join(opt.export_path, bpu_rdt_name, state_adaptor_name1)state_adaptor_name2 = "rdt_state_adaptor_1x64x256.onnx"state_adaptor_path2 = os.path.join(opt.export_path, bpu_rdt_name, state_adaptor_name2)## lang adaptor lang_adaptor_name = "rdt_lang_adaptor.onnx"lang_adaptor_path = os.path.join(opt.export_path, bpu_rdt_name, lang_adaptor_name)## DiT Policydit_ws_name = "DiT_WorkSpace"dit_cal_name = "rdt_dit_calibration"dit_name = "rdt_dit.onnx"dit_config_name = "config.yaml"dit_json_name = "quant_config.json"dit_bash_name = "build.sh"dit_path = os.path.join(opt.export_path, dit_ws_name, dit_name)dit_ws = os.path.join(opt.export_path, dit_ws_name)os.makedirs(dit_ws, exist_ok=True)dit_cal_path = os.path.join(opt.export_path, dit_ws_name, dit_cal_name)os.makedirs(dit_cal_path, exist_ok=True)global dit_cal_path_x, dit_cal_path_freq, dit_cal_path_t, dit_cal_path_lang_c, dit_cal_path_img_c, dit_cal_path_lang_maskdit_cal_path_x = os.path.join(opt.export_path, dit_ws_name, dit_cal_name, "x")os.makedirs(dit_cal_path_x, exist_ok=True)dit_cal_path_freq = os.path.join(opt.export_path, dit_ws_name, dit_cal_name, "freq")os.makedirs(dit_cal_path_freq, exist_ok=True)dit_cal_path_t = os.path.join(opt.export_path, dit_ws_name, dit_cal_name, "t")os.makedirs(dit_cal_path_t, exist_ok=True)dit_cal_path_lang_c = os.path.join(opt.export_path, dit_ws_name, dit_cal_name, "lang_c")os.makedirs(dit_cal_path_lang_c, exist_ok=True)dit_cal_path_img_c = os.path.join(opt.export_path, dit_ws_name, dit_cal_name, "img_c")os.makedirs(dit_cal_path_img_c, exist_ok=True)dit_cal_path_lang_mask = os.path.join(opt.export_path, dit_ws_name, dit_cal_name, "lang_mask")os.makedirs(dit_cal_path_lang_mask, exist_ok=True)# Prepare Calibrate Datawith open(opt.config_path, "r") as fp:config_base_yaml = yaml.safe_load(fp)config_base_yaml["arm_dim"] = {"left_arm_dim": opt.left_arm_dim, "right_arm_dim": opt.right_arm_dim}dump_model = create_dump_model(args=config_base_yaml,dtype=torch.float32,pretrained=opt.pretrained_model,pretrained_vision_encoder_name_or_path=opt.pretrained_vision_encoder,control_frequency=opt.ctrl_freq,device=opt.cal_data_device)# Prepare Calbriation Data# load training data global dump_cnt, dump_dataset_nametest_data_cnt = 0for dump_dataset_name in os.listdir(opt.train_data):dump_dataset_path = os.path.join(opt.train_data, dump_dataset_name)training_samples = get_training_samples(dump_dataset_path, num_samples=opt.num_samples, instructions_per_episode=opt.instructions_per_episode)for dump_cnt in range(min(opt.num_samples, len(training_samples))):sample = training_samples[dump_cnt]instruction_emb = {"lang_cond": sample['lang_embed'].float().cpu(),"lang_str": sample['lang_str']}ins_str_name = sample['lang_str'].replace(" ", "_")+"__"torch.save(instruction_emb, os.path.join(instruction_ws_path, dump_dataset_name, f"{ins_str_name}.pt"))# 兼容缺失相机:按键取值,缺失则用 [None, None]cam_high_imgs = sample['multi_cam_images'].get('cam_high', [None, None])cam_right_wrist_imgs = sample['multi_cam_images'].get('cam_right_wrist', [None, None])cam_left_wrist_imgs = sample['multi_cam_images'].get('cam_left_wrist', [None, None])image_arrs = [cam_high_imgs[0],cam_right_wrist_imgs[0],cam_left_wrist_imgs[0],cam_high_imgs[1],cam_right_wrist_imgs[1],cam_left_wrist_imgs[1],]test_data_cnt += 1# 仅对存在的相机保存if cam_high_imgs[0] is not None:np.save(os.path.join(test_data_path, f"{test_data_cnt}_cam_high_0.npy"), cam_high_imgs[0])if cam_right_wrist_imgs[0] is not None:np.save(os.path.join(test_data_path, f"{test_data_cnt}_cam_right_wrist_0.npy"), cam_right_wrist_imgs[0])if cam_left_wrist_imgs[0] is not None:np.save(os.path.join(test_data_path, f"{test_data_cnt}_cam_left_wrist_0.npy"), cam_left_wrist_imgs[0])if cam_high_imgs[1] is not None:np.save(os.path.join(test_data_path, f"{test_data_cnt}_cam_high_1.npy"), cam_high_imgs[1])if cam_right_wrist_imgs[1] is not None:np.save(os.path.join(test_data_path, f"{test_data_cnt}_cam_right_wrist_1.npy"), cam_right_wrist_imgs[1])if cam_left_wrist_imgs[1] is not None:np.save(os.path.join(test_data_path, f"{test_data_cnt}_cam_left_wrist_1.npy"), cam_left_wrist_imgs[1])images = [PImage.fromarray(arr) if arr is not None else None for arr in image_arrs]proprio = torch.from_numpy(sample['joints']).float().unsqueeze(0).to(opt.cal_data_device) np.save(os.path.join(test_data_path, f"{test_data_cnt}_joints.npy"), sample['joints'])lang_embeddings = sample['lang_embed'].float().unsqueeze(0).to(opt.cal_data_device) torch.save(lang_embeddings, os.path.join(test_data_path, f"{test_data_cnt}_lang_embeddings.pt"))dump_model.reset()begin_time = time()actions = dump_model.step(proprio=proprio, images=images, text_embeds=lang_embeddings).squeeze(0).cpu().numpy()np.save(os.path.join(test_data_path, f"{test_data_cnt}_actions.npy"), actions)logger.debug(f"Dump: Cost {(1000*(time() - begin_time)):.1f} ms, cnt: {dump_cnt}, name: {dump_dataset_name}")logger.info("End Generate Calibration Data.")del dump_model# Load RDT Policy: CPU Model For ONNX Exportwith open(opt.config_path, "r") as fp:config_base_yaml = yaml.safe_load(fp)config_base_yaml["arm_dim"] = {"left_arm_dim": opt.left_arm_dim, "right_arm_dim": opt.right_arm_dim}model = create_model(args=config_base_yaml,dtype=torch.float32,pretrained=opt.pretrained_model,pretrained_vision_encoder_name_or_path=opt.pretrained_vision_encoder,control_frequency=opt.ctrl_freq,device="cpu")bash_build_all = ""# image adaptor: ONNX Modelm = model.policy.img_adaptorm.eval()input_data = torch.randn(1, 4374, 1152) # 假设批量大小为1output = m(input_data)torch.onnx.export(m, # 要转换的模型input_data, # 模型的输入img_adaptor_path, # 输出文件名export_params=True, # 存储训练后的参数opset_version=17, # ONNX版本(降低以兼容 ReduceMean axes 属性)do_constant_folding=True, # 是否执行常量折叠优化input_names=['img_tokens'], # 输入节点名称output_names=['adpated_img'], # 输出节点名称dynamic_axes=None,verbose=False)logger.info("Export RDT [img_adaptor] Model Success.")yaml_str = f'''
model_parameters:onnx_model: '{img_adaptor_name}'march: nash-mlayer_out_dump: Falseworking_dir: bpu_outputoutput_model_file_prefix: rdt_img_adaptorenable_vpu: True
input_parameters:input_name: ''input_type_rt: 'featuremap;'input_layout_rt: 'NCHW;'input_type_train: 'featuremap;'input_layout_train: 'NCHW;'norm_type: 'no_preprocess;'
calibration_parameters:cal_data_dir: '{img_adaptor_cal_name}'cal_data_type: 'float32'calibration_type: 'default'# quant_config: {{"op_config": {{"softmax": {{"qtype": "int8"}}}}}}quant_config: {{"model_config": {{"all_node_type": "int16","model_output_type": "int16",}}}}
compiler_parameters:extra_params: {{'input_no_padding': True, 'output_no_padding': True}}jobs: {opt.jobs}compile_mode: 'latency'debug: Trueadvice: 1optimize_level: 'O2'core_num: 2'''with open(os.path.join(opt.export_path, img_adaptor_ws_name, img_adaptor_config_name), "w", encoding="utf-8") as f:f.write(yaml_str)bash_str = f'''hb_compile --config {img_adaptor_config_name}cp bpu_output/*.hbm ../{bpu_rdt_name}/'''with open(os.path.join(opt.export_path, img_adaptor_ws_name, img_adaptor_bash_name), "w", encoding="utf-8") as f:f.write(bash_str)bash_build_all += f"cd {img_adaptor_ws_name}" + "\n"bash_build_all += f"bash {img_adaptor_bash_name}" + "\n"bash_build_all += f"cd .." + "\n"# DiTm = model.policy.modelm = m.eval().cpu()x = torch.randn(1, 65, 1024)freq = torch.tensor([1], dtype=torch.int32)t = torch.tensor([10], dtype=torch.int32)lang_c = torch.randn(1, 64, 1024)img_c = torch.randn(1, 4374, 1024)lang_mask = torch.ones(1, 64, dtype=torch.float32)dummy_inputs = (x, freq, t, lang_c, img_c, lang_mask)outputs = m(x, freq, t, lang_c, img_c, lang_mask)torch.onnx.export(m, # 要导出的模型dummy_inputs, # 模型的输入dit_path, # 保存路径# export_params=True, # 是否导出训练参数opset_version=17, # ONNX 的版本,降低以兼容 ReduceMean axes 属性do_constant_folding=True, # 是否执行常量折叠优化input_names=["x", "freq", "t", "lang_c", "img_c", "lang_mask"], # 输入名称output_names=["actions"], # 输出名称verbose=False,)logger.info("Export RDT [dit] Model Success.")yaml_str = f'''
calibration_parameters:cal_data_dir: '{dit_cal_name}/x/;{dit_cal_name}/freq/;{dit_cal_name}/t/;{dit_cal_name}/lang_c/;{dit_cal_name}/img_c/;{dit_cal_name}/lang_mask/;'quant_config: {dit_json_name}run_on_cpu: '/t_embedder/Cos;/t_embedder/Sin;/freq_embedder/Cos;/freq_embedder/Sin'
compiler_parameters:compile_mode: latencycore_num: 1debug: truejobs: {opt.jobs}max_time_per_fc: 0optimize_level: O2advice: 1
input_parameters:input_layout_rt: NCHW;NCHW;NCHW;NCHW;NCHW;NCHWinput_layout_train: NCHW;NCHW;NCHW;NCHW;NCHW;NCHWinput_name: x;freq;t;lang_c;img_c;lang_mask;input_shape: 1x65x1024;1;1;1x64x1024;1x4374x1024;1x64input_space_and_range: ''input_type_rt: featuremap;featuremap;featuremap;featuremap;featuremap;featuremapinput_type_train: featuremap;featuremap;featuremap;featuremap;featuremap;featuremapnorm_type: no_preprocess;no_preprocess;no_preprocess;no_preprocess;no_preprocess;no_preprocess
model_parameters:layer_out_dump: falsedebug_mode: "dump_calibration_data"enable_vpu: Truemarch: nash-monnx_model: {dit_name}output_model_file_prefix: rdt_ditworking_dir: bpu_output'''with open(os.path.join(opt.export_path, dit_ws_name, dit_config_name), "w", encoding="utf-8") as f:f.write(yaml_str)json_str = '''{"model_config": {"all_node_type": "int16","model_output_type": "float32","activation": {"calibration_type": ["max"],"num_bin": [1024, 2048, 4096],"max_num_bin": 16384,"max_percentile": 1.0,"per_channel": true,"asymmetric": [true]},"weight": {"bias_correction": {"metric": "mae"}},"modelwise_search": {"metric": "mae"}},"op_config": {"ReduceMean": {"qtype": "int16"},"Sub": {"qtype": "int16"},"Softmax": {"qtype": "int16"}},"node_config": {"/t_embedder/Mul": {"qtype": "float32"},"/t_embedder/Cos": {"qtype": "float32"},"/t_embedder/Sin": {"qtype": "float32"},"/t_embedder/Concat": {"qtype": "float32"},"/freq_embedder/Mul": {"qtype": "float32"},"/freq_embedder/Cos": {"qtype": "float32"},"/freq_embedder/Sin": {"qtype": "float32"},"/freq_embedder/Concat": {"qtype": "float32"},"/blocks.0/attn/MatMul": {"InputType0": "int16", "InputType1": "int16"},"/blocks.0/attn/MatMul_1": {"InputType0": "int16", "InputType1": "int16"},"/blocks.0/cross_attn/MatMul": {"InputType0": "int16", "InputType1": "int16"},"/blocks.0/cross_attn/MatMul_1": {"InputType0": "int16", "InputType1": "int16"},"/blocks.0/cross_attn/k_norm/Mul_1": {"qtype": "int16"},"/blocks.0/ffn/fc1/MatMul": {"qtype": "int16"},"/blocks.0/ffn/act/Mul": {"qtype": "int16"},"/blocks.0/ffn/act/Mul_1": {"qtype": "int16"},"/blocks.0/ffn/act/Mul_2": {"qtype": "int16"},"/blocks.0/ffn/act/Add": {"qtype": "int16"},"/blocks.0/ffn/act/Mul_3": {"qtype": "int16"},"/blocks.0/ffn/act/Tanh": {"qtype": "int16"},"/blocks.0/norm1/Mul_2": {"qtype": "int16"},"/blocks.0/cross_attn/k_norm/Div_1_reciprocal": {"qtype": "int16"},"/blocks.0/Add": {"qtype": "int16"},"/blocks.1/attn/MatMul": {"InputType0": "int16", "InputType1": "int16"},"/blocks.1/attn/MatMul_1": {"InputType0": "int16", "InputType1": "int16"},"/blocks.1/cross_attn/MatMul": {"InputType0": "int16", "InputType1": "int16"},"/blocks.1/cross_attn/MatMul_1": {"InputType0": "int16", "InputType1": "int16"},"/blocks.1/cross_attn/k_norm/Mul_1": {"qtype": "int16"},"/blocks.1/ffn/fc1/MatMul": {"qtype": "int16"},"/blocks.1/ffn/act/Mul": {"qtype": "int16"},"/blocks.1/ffn/act/Mul_1": {"qtype": "int16"},"/blocks.1/ffn/act/Mul_2": {"qtype": "int16"},"/blocks.1/ffn/act/Add": {"qtype": "int16"},"/blocks.1/ffn/act/Mul_3": {"qtype": "int16"},"/blocks.1/ffn/act/Tanh": {"qtype": "int16"},"/blocks.1/norm1/Mul_2": {"qtype": "int16"},"/blocks.1/cross_attn/k_norm/Div_1_reciprocal": {"qtype": "int16"},"/blocks.1/Add": {"qtype": "int16"},"/blocks.2/attn/MatMul": {"InputType0": "int16", "InputType1": "int16"},"/blocks.2/attn/MatMul_1": {"InputType0": "int16", "InputType1": "int16"},"/blocks.2/cross_attn/MatMul": {"InputType0": "int16", "InputType1": "int16"},"/blocks.2/cross_attn/MatMul_1": {"InputType0": "int16", "InputType1": "int16"},"/blocks.2/cross_attn/k_norm/Mul_1": {"qtype": "int16"},"/blocks.2/ffn/fc1/MatMul": {"qtype": "int16"},"/blocks.2/ffn/act/Mul": {"qtype": "int16"},"/blocks.2/ffn/act/Mul_1": {"qtype": "int16"},"/blocks.2/ffn/act/Mul_2": {"qtype": "int16"},"/blocks.2/ffn/act/Add": {"qtype": "int16"},"/blocks.2/ffn/act/Mul_3": {"qtype": "int16"},"/blocks.2/ffn/act/Tanh": {"qtype": "int16"},"/blocks.2/norm1/Mul_2": {"qtype": "int16"},"/blocks.2/cross_attn/k_norm/Div_1_reciprocal": {"qtype": "int16"},"/blocks.2/Add": {"qtype": "int16"},"/blocks.3/attn/MatMul": {"InputType0": "int16", "InputType1": "int16"},"/blocks.3/attn/MatMul_1": {"InputType0": "int16", "InputType1": "int16"},"/blocks.3/cross_attn/MatMul": {"InputType0": "int16", "InputType1": "int16"},"/blocks.3/cross_attn/MatMul_1": {"InputType0": "int16", "InputType1": "int16"},"/blocks.3/cross_attn/k_norm/Mul_1": {"qtype": "int16"},"/blocks.3/ffn/fc1/MatMul": {"qtype": "int16"},"/blocks.3/ffn/act/Mul": {"qtype": "int16"},"/blocks.3/ffn/act/Mul_1": {"qtype": "int16"},"/blocks.3/ffn/act/Mul_2": {"qtype": "int16"},"/blocks.3/ffn/act/Add": {"qtype": "int16"},"/blocks.3/ffn/act/Mul_3": {"qtype": "int16"},"/blocks.3/ffn/act/Tanh": {"qtype": "int16"},"/blocks.3/norm1/Mul_2": {"qtype": "int16"},"/blocks.3/cross_attn/k_norm/Div_1_reciprocal": {"qtype": "int16"},"/blocks.3/Add": {"qtype": "int16"},"/blocks.4/attn/MatMul": {"InputType0": "int16", "InputType1": "int16"},"/blocks.4/attn/MatMul_1": {"InputType0": "int16", "InputType1": "int16"},"/blocks.4/cross_attn/MatMul": {"InputType0": "int16", "InputType1": "int16"},"/blocks.4/cross_attn/MatMul_1": {"InputType0": "int16", "InputType1": "int16"},"/blocks.4/cross_attn/k_norm/Mul_1": {"qtype": "int16"},"/blocks.4/ffn/fc1/MatMul": {"qtype": "int16"},"/blocks.4/ffn/act/Mul": {"qtype": "int16"},"/blocks.4/ffn/act/Mul_1": {"qtype": "int16"},"/blocks.4/ffn/act/Mul_2": {"qtype": "int16"},"/blocks.4/ffn/act/Add": {"qtype": "int16"},"/blocks.4/ffn/act/Mul_3": {"qtype": "int16"},"/blocks.4/ffn/act/Tanh": {"qtype": "int16"},"/blocks.4/norm1/Mul_2": {"qtype": "int16"},"/blocks.4/cross_attn/k_norm/Div_1_reciprocal": {"qtype": "int16"},"/blocks.4/Add": {"qtype": "int16"},"/blocks.5/attn/MatMul": {"InputType0": "int16", "InputType1": "int16"},"/blocks.5/attn/MatMul_1": {"InputType0": "int16", "InputType1": "int16"},"/blocks.5/cross_attn/MatMul": {"InputType0": "int16", "InputType1": "int16"},"/blocks.5/cross_attn/MatMul_1": {"InputType0": "int16", "InputType1": "int16"},"/blocks.5/cross_attn/k_norm/Mul_1": {"qtype": "int16"},"/blocks.5/ffn/fc1/MatMul": {"qtype": "int16"},"/blocks.5/ffn/act/Mul": {"qtype": "int16"},"/blocks.5/ffn/act/Mul_1": {"qtype": "int16"},"/blocks.5/ffn/act/Mul_2": {"qtype": "int16"},"/blocks.5/ffn/act/Add": {"qtype": "int16"},"/blocks.5/ffn/act/Mul_3": {"qtype": "int16"},"/blocks.5/ffn/act/Tanh": {"qtype": "int16"},"/blocks.5/norm1/Mul_2": {"qtype": "int16"},"/blocks.5/cross_attn/k_norm/Div_1_reciprocal": {"qtype": "int16"},"/blocks.5/Add": {"qtype": "int16"},"/blocks.6/attn/MatMul": {"InputType0": "int16", "InputType1": "int16"},"/blocks.6/attn/MatMul_1": {"InputType0": "int16", "InputType1": "int16"},"/blocks.6/cross_attn/MatMul": {"InputType0": "int16", "InputType1": "int16"},"/blocks.6/cross_attn/MatMul_1": {"InputType0": "int16", "InputType1": "int16"},"/blocks.6/cross_attn/k_norm/Mul_1": {"qtype": "int16"},"/blocks.6/ffn/fc1/MatMul": {"qtype": "int16"},"/blocks.6/ffn/act/Mul": {"qtype": "int16"},"/blocks.6/ffn/act/Mul_1": {"qtype": "int16"},"/blocks.6/ffn/act/Mul_2": {"qtype": "int16"},"/blocks.6/ffn/act/Add": {"qtype": "int16"},"/blocks.6/ffn/act/Mul_3": {"qtype": "int16"},"/blocks.6/ffn/act/Tanh": {"qtype": "int16"},"/blocks.6/norm1/Mul_2": {"qtype": "int16"},"/blocks.6/cross_attn/k_norm/Div_1_reciprocal": {"qtype": "int16"},"/blocks.6/Add": {"qtype": "int16"},"/blocks.7/attn/MatMul": {"InputType0": "int16", "InputType1": "int16"},"/blocks.7/attn/MatMul_1": {"InputType0": "int16", "InputType1": "int16"},"/blocks.7/cross_attn/MatMul": {"InputType0": "int16", "InputType1": "int16"},"/blocks.7/cross_attn/MatMul_1": {"InputType0": "int16", "InputType1": "int16"},"/blocks.7/cross_attn/k_norm/Mul_1": {"qtype": "int16"},"/blocks.7/ffn/fc1/MatMul": {"qtype": "int16"},"/blocks.7/ffn/act/Mul": {"qtype": "int16"},"/blocks.7/ffn/act/Mul_1": {"qtype": "int16"},"/blocks.7/ffn/act/Mul_2": {"qtype": "int16"},"/blocks.7/ffn/act/Add": {"qtype": "int16"},"/blocks.7/ffn/act/Mul_3": {"qtype": "int16"},"/blocks.7/ffn/act/Tanh": {"qtype": "int16"},"/blocks.7/norm1/Mul_2": {"qtype": "int16"},"/blocks.7/cross_attn/k_norm/Div_1_reciprocal": {"qtype": "int16"},"/blocks.7/Add": {"qtype": "int16"},"/blocks.8/attn/MatMul": {"InputType0": "int16", "InputType1": "int16"},"/blocks.8/attn/MatMul_1": {"InputType0": "int16", "InputType1": "int16"},"/blocks.8/cross_attn/MatMul": {"InputType0": "int16", "InputType1": "int16"},"/blocks.8/cross_attn/MatMul_1": {"InputType0": "int16", "InputType1": "int16"},"/blocks.8/cross_attn/k_norm/Mul_1": {"qtype": "int16"},"/blocks.8/ffn/fc1/MatMul": {"qtype": "int16"},"/blocks.8/ffn/act/Mul": {"qtype": "int16"},"/blocks.8/ffn/act/Mul_1": {"qtype": "int16"},"/blocks.8/ffn/act/Mul_2": {"qtype": "int16"},"/blocks.8/ffn/act/Add": {"qtype": "int16"},"/blocks.8/ffn/act/Mul_3": {"qtype": "int16"},"/blocks.8/ffn/act/Tanh": {"qtype": "int16"},"/blocks.8/norm1/Mul_2": {"qtype": "int16"},"/blocks.8/cross_attn/k_norm/Div_1_reciprocal": {"qtype": "int16"},"/blocks.8/Add": {"qtype": "int16"},"/blocks.9/attn/MatMul": {"InputType0": "int16", "InputType1": "int16"},"/blocks.9/attn/MatMul_1": {"InputType0": "int16", "InputType1": "int16"},"/blocks.9/cross_attn/MatMul": {"InputType0": "int16", "InputType1": "int16"},"/blocks.9/cross_attn/MatMul_1": {"InputType0": "int16", "InputType1": "int16"},"/blocks.9/cross_attn/k_norm/Mul_1": {"qtype": "int16"},"/blocks.9/ffn/fc1/MatMul": {"qtype": "int16"},"/blocks.9/ffn/act/Mul": {"qtype": "int16"},"/blocks.9/ffn/act/Mul_1": {"qtype": "int16"},"/blocks.9/ffn/act/Mul_2": {"qtype": "int16"},"/blocks.9/ffn/act/Add": {"qtype": "int16"},"/blocks.9/ffn/act/Mul_3": {"qtype": "int16"},"/blocks.9/ffn/act/Tanh": {"qtype": "int16"},"/blocks.9/norm1/Mul_2": {"qtype": "int16"},"/blocks.9/cross_attn/k_norm/Div_1_reciprocal": {"qtype": "int16"},"/blocks.9/Add": {"qtype": "int16"},"/blocks.10/attn/MatMul": {"InputType0": "int16", "InputType1": "int16"},"/blocks.10/attn/MatMul_1": {"InputType0": "int16", "InputType1": "int16"},"/blocks.10/cross_attn/MatMul": {"InputType0": "int16", "InputType1": "int16"},"/blocks.10/cross_attn/MatMul_1": {"InputType0": "int16", "InputType1": "int16"},"/blocks.10/cross_attn/k_norm/Mul_1": {"qtype": "int16"},"/blocks.10/ffn/fc1/MatMul": {"qtype": "int16"},"/blocks.10/ffn/act/Mul": {"qtype": "int16"},"/blocks.10/ffn/act/Mul_1": {"qtype": "int16"},"/blocks.10/ffn/act/Mul_2": {"qtype": "int16"},"/blocks.10/ffn/act/Add": {"qtype": "int16"},"/blocks.10/ffn/act/Mul_3": {"qtype": "int16"},"/blocks.10/ffn/act/Tanh": {"qtype": "int16"},"/blocks.10/norm1/Mul_2": {"qtype": "int16"},"/blocks.10/cross_attn/k_norm/Div_1_reciprocal": {"qtype": "int16"},"/blocks.10/Add": {"qtype": "int16"},"/blocks.11/attn/MatMul": {"InputType0": "int16", "InputType1": "int16"},"/blocks.11/attn/MatMul_1": {"InputType0": "int16", "InputType1": "int16"},"/blocks.11/cross_attn/MatMul": {"InputType0": "int16", "InputType1": "int16"},"/blocks.11/cross_attn/MatMul_1": {"InputType0": "int16", "InputType1": "int16"},"/blocks.11/cross_attn/k_norm/Mul_1": {"qtype": "int16"},"/blocks.11/ffn/fc1/MatMul": {"qtype": "int16"},"/blocks.11/ffn/act/Mul": {"qtype": "int16"},"/blocks.11/ffn/act/Mul_1": {"qtype": "int16"},"/blocks.11/ffn/act/Mul_2": {"qtype": "int16"},"/blocks.11/ffn/act/Add": {"qtype": "int16"},"/blocks.11/ffn/act/Mul_3": {"qtype": "int16"},"/blocks.11/ffn/act/Tanh": {"qtype": "int16"},"/blocks.11/norm1/Mul_2": {"qtype": "int16"},"/blocks.11/cross_attn/k_norm/Div_1_reciprocal": {"qtype": "int16"},"/blocks.11/Add": {"qtype": "int16"},"/blocks.12/attn/MatMul": {"InputType0": "int16", "InputType1": "int16"},"/blocks.12/attn/MatMul_1": {"InputType0": "int16", "InputType1": "int16"},"/blocks.12/cross_attn/MatMul": {"InputType0": "int16", "InputType1": "int16"},"/blocks.12/cross_attn/MatMul_1": {"InputType0": "int16", "InputType1": "int16"},"/blocks.12/cross_attn/k_norm/Mul_1": {"qtype": "int16"},"/blocks.12/ffn/fc1/MatMul": {"qtype": "int16"},"/blocks.12/ffn/act/Mul": {"qtype": "int16"},"/blocks.12/ffn/act/Mul_1": {"qtype": "int16"},"/blocks.12/ffn/act/Mul_2": {"qtype": "int16"},"/blocks.12/ffn/act/Add": {"qtype": "int16"},"/blocks.12/ffn/act/Mul_3": {"qtype": "int16"},"/blocks.12/ffn/act/Tanh": {"qtype": "int16"},"/blocks.12/norm1/Mul_2": {"qtype": "int16"},"/blocks.12/cross_attn/k_norm/Div_1_reciprocal": {"qtype": "int16"},"/blocks.12/Add": {"qtype": "int16"},"/blocks.13/attn/MatMul": {"InputType0": "int16", "InputType1": "int16"},"/blocks.13/attn/MatMul_1": {"InputType0": "int16", "InputType1": "int16"},"/blocks.13/cross_attn/MatMul": {"InputType0": "int16", "InputType1": "int16"},"/blocks.13/cross_attn/MatMul_1": {"InputType0": "int16", "InputType1": "int16"},"/blocks.13/cross_attn/k_norm/Mul_1": {"qtype": "int16"},"/blocks.13/ffn/fc1/MatMul": {"qtype": "int16"},"/blocks.13/ffn/act/Mul": {"qtype": "int16"},"/blocks.13/ffn/act/Mul_1": {"qtype": "int16"},"/blocks.13/ffn/act/Mul_2": {"qtype": "int16"},"/blocks.13/ffn/act/Add": {"qtype": "int16"},"/blocks.13/ffn/act/Mul_3": {"qtype": "int16"},"/blocks.13/ffn/act/Tanh": {"qtype": "int16"},"/blocks.13/norm1/Mul_2": {"qtype": "int16"},"/blocks.13/cross_attn/k_norm/Div_1_reciprocal": {"qtype": "int16"},"/blocks.13/Add": {"qtype": "int16"},"/blocks.13/norm3/Div_1_reciprocal": {"qtype": "int16"},"/final_layer/ffn_final/act/Mul_1": {"qtype": "int16"},"/final_layer/ffn_final/act/Mul_2 ": {"qtype": "int16"},"/final_layer/norm_final/Div_1_reciprocal": {"qtype": "float32"}}}'''with open(os.path.join(opt.export_path, dit_ws_name, dit_json_name), "w", encoding="utf-8") as f:f.write(json_str)bash_str = f'''hb_compile --config {dit_config_name}cp bpu_output/*.hbm ../{bpu_rdt_name}/'''with open(os.path.join(opt.export_path, dit_ws_name, dit_bash_name), "w", encoding="utf-8") as f:f.write(bash_str)bash_build_all += f"cd {dit_ws_name}" + "\n"bash_build_all += f"bash {dit_bash_name}" + "\n"bash_build_all += f"cd .." + "\n"with open(os.path.join(opt.export_path, bash_build_all_name), "w", encoding="utf-8") as f:f.write(bash_build_all)# state adaptorm = model.policy.state_adaptorm.eval()input_data = torch.randn(1, 1, 256) # 假设批量大小为1output = m(input_data)torch.onnx.export(m, # 要转换的模型input_data, # 模型的输入state_adaptor_path1, # 输出文件名export_params=True, # 存储训练后的参数opset_version=17, # ONNX版本(降低以兼容 ReduceMean axes 属性)do_constant_folding=True, # 是否执行常量折叠优化input_names=['state_tokens'], # 输入节点名称output_names=['state_traj'], # 输出节点名称dynamic_axes=None,verbose=False)logger.info("Export RDT [state 1x1x256] Model Success.")input_data = torch.randn(1, 64, 256) # 假设批量大小为1output = m(input_data)torch.onnx.export(m, # 要转换的模型input_data, # 模型的输入state_adaptor_path2, # 输出文件名export_params=True, # 存储训练后的参数opset_version=17, # ONNX版本(降低以兼容 ReduceMean axes 属性)do_constant_folding=True, # 是否执行常量折叠优化input_names=['state_tokens'], # 输入节点名称output_names=['state_traj'], # 输出节点名称dynamic_axes=None,verbose=False)logger.info("Export RDT [state 1x64x256] Model Success.")# lang adaptorm = model.policy.lang_adaptorm.eval()input_data = torch.randn(1, 14, 4096) # 假设批量大小为1output = m(input_data)torch.onnx.export(m, # 要转换的模型input_data, # 模型的输入lang_adaptor_path, # 输出文件名export_params=True, # 存储训练后的参数opset_version=17, # ONNX版本(降低以兼容 ReduceMean axes 属性)do_constant_folding=True, # 是否执行常量折叠优化input_names=['text_embeds'], # 输入节点名称output_names=['lang_cond'], # 输出节点名称dynamic_axes={"text_embeds": {1: "N"},"lang_cond": {1: "N"}},verbose=False)logger.info("Export RDT [lang adaptor] Model Success.")######## Prepare Calbibration Datadef create_dump_model(args, **kwargs):left_arm_dim, right_arm_dim = (args["arm_dim"]["left_arm_dim"], args["arm_dim"]["right_arm_dim"],)# 仅右臂6关节映射到 [0, 6) 位置AGILEX_STATE_INDICES = [STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(6)]model = RoboticDiffusionTransformerModel_Dump(args, **kwargs)pretrained = kwargs.get("pretrained", None)if pretrained is not None and os.path.isfile(pretrained):model.load_pretrained_weights(pretrained)return modelclass RDT_Dump(nn.Module):def __init__(self,output_dim=128,horizon=32,hidden_size=1152,depth=28,num_heads=16,max_lang_cond_len=1024,img_cond_len=4096,lang_pos_embed_config=None,img_pos_embed_config=None,dtype=torch.bfloat16):super().__init__()self.horizon = horizonself.hidden_size = hidden_sizeself.max_lang_cond_len = max_lang_cond_lenself.img_cond_len = img_cond_lenself.dtype = dtypeself.lang_pos_embed_config = lang_pos_embed_configself.img_pos_embed_config = img_pos_embed_configself.t_embedder = TimestepEmbedder(hidden_size, dtype=dtype)self.freq_embedder = TimestepEmbedder(hidden_size, dtype=dtype)# We will use trainable sin-cos embeddings# [timestep; state; action]self.x_pos_embed = nn.Parameter(torch.zeros(1, horizon + 3, hidden_size))# Language conditionsself.lang_cond_pos_embed = nn.Parameter(torch.zeros(1, max_lang_cond_len, hidden_size))# Image conditionsself.img_cond_pos_embed = nn.Parameter(torch.zeros(1, img_cond_len, hidden_size))self.blocks = nn.ModuleList([RDTBlock(hidden_size, num_heads) for _ in range(depth)])self.final_layer = FinalLayer(hidden_size, output_dim)self.initialize_weights()def initialize_weights(self):# Initialize transformer layers:def _basic_init(module):if isinstance(module, nn.Linear):torch.nn.init.xavier_uniform_(module.weight)if module.bias is not None:nn.init.constant_(module.bias, 0)self.apply(_basic_init)# Initialize pos_embed by sin-cos embeddingx_pos_embed = get_multimodal_cond_pos_embed(embed_dim=self.hidden_size,mm_cond_lens=OrderedDict([('timestep', 1),('ctrl_freq', 1),('state', 1),('action', self.horizon),]))self.x_pos_embed.data.copy_(torch.from_numpy(x_pos_embed).float().unsqueeze(0))if self.lang_pos_embed_config is None:lang_cond_pos_embed = get_1d_sincos_pos_embed_from_grid(self.hidden_size, torch.arange(self.max_lang_cond_len))else:lang_cond_pos_embed = get_multimodal_cond_pos_embed(embed_dim=self.hidden_size, mm_cond_lens=OrderedDict(self.lang_pos_embed_config), embed_modality=False)self.lang_cond_pos_embed.data.copy_(torch.from_numpy(lang_cond_pos_embed).float().unsqueeze(0))if self.img_pos_embed_config is None:img_cond_pos_embed = get_1d_sincos_pos_embed_from_grid(self.hidden_size, torch.arange(self.img_cond_len))else:img_cond_pos_embed = get_multimodal_cond_pos_embed(embed_dim=self.hidden_size, mm_cond_lens=OrderedDict(self.img_pos_embed_config), embed_modality=False)self.img_cond_pos_embed.data.copy_(torch.from_numpy(img_cond_pos_embed).float().unsqueeze(0))# Initialize timestep and control freq embedding MLPnn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)nn.init.normal_(self.freq_embedder.mlp[0].weight, std=0.02)nn.init.normal_(self.freq_embedder.mlp[2].weight, std=0.02)# Initialize the final layer: zero-out the final linear layernn.init.constant_(self.final_layer.ffn_final.fc2.weight, 0)nn.init.constant_(self.final_layer.ffn_final.fc2.bias, 0)# Move all the params to given data type:self.to(self.dtype)def forward(self, x, freq, t, lang_c, img_c, lang_mask=None, img_mask=None):t = self.t_embedder(t).unsqueeze(1) # (B, 1, D) or (1, 1, D)freq = self.freq_embedder(freq).unsqueeze(1) # (B, 1, D)# Append timestep to the input tokensif t.shape[0] == 1:t = t.expand(x.shape[0], -1, -1)x = torch.cat([t, freq, x], dim=1) # (B, T+1, D)# Add multimodal position embeddingsx = x + self.x_pos_embed# Note the lang is of variable lengthlang_c = lang_c + self.lang_cond_pos_embed[:, :lang_c.shape[1]]img_c = img_c + self.img_cond_pos_embed# Forward passconds = [lang_c, img_c]masks = [lang_mask, img_mask]for i, block in enumerate(self.blocks):c, mask = conds[i % 2], masks[i % 2]x = block(x, c, mask) # (B, T+1, D)# Inject the language condition at the final layerx = self.final_layer(x) # (B, T+1, out_channels)# Only preserve the action tokensx = x[:, -self.horizon:]return xdef dump_dit(state_action_traj, ctrl_freqs, t, lang_cond, img_cond, lang_attn_mask):t_str = str(t)x = state_action_traj.float().contiguous().cpu().detach().numpy()freq = ctrl_freqs.float().contiguous().cpu().detach().numpy().astype(np.int32).copy()t_ = t.float().contiguous().cpu().detach().numpy()t_ = np.expand_dims(t_.astype(np.int32), axis=0).copy()lang_c = lang_cond.float().contiguous().cpu().detach().numpy()img_c = img_cond.float().contiguous().cpu().detach().numpy()lang_mask = lang_attn_mask.float().contiguous().cpu().detach().numpy()pad_rows = 64 - lang_mask.shape[1]padded = np.pad(lang_mask, ((0,0), (0,pad_rows)), mode="constant")mask_float = np.where(padded, 0.0, -512.0).astype(np.float32)lang_cond_padded = np.pad(lang_c, pad_width=((0, 0), (0, pad_rows), (0,0)), mode="constant", constant_values=0)global dit_cal_path_x, dit_cal_path_freq, dit_cal_path_t, dit_cal_path_lang_c, dit_cal_path_img_c, dit_cal_path_lang_maskglobal dump_cnt, dump_dataset_namenp.save(os.path.join(dit_cal_path_x, f"x_{t_str}_{dump_dataset_name}_{dump_cnt}.npy"), x)np.save(os.path.join(dit_cal_path_freq, f"freq_{t_str}_{dump_dataset_name}_{dump_cnt}.npy"), freq)np.save(os.path.join(dit_cal_path_t, f"t_{t_str}_{dump_dataset_name}_{dump_cnt}.npy"), t_)np.save(os.path.join(dit_cal_path_lang_c, f"lang_c_{t_str}_{dump_dataset_name}_{dump_cnt}.npy"), lang_cond_padded)np.save(os.path.join(dit_cal_path_img_c, f"img_{t_str}_{dump_dataset_name}_{dump_cnt}.npy"), img_c)np.save(os.path.join(dit_cal_path_lang_mask, f"lang_mask_{t_str}_{dump_dataset_name}_{dump_cnt}.npy"), mask_float)def dump_img_adaptor(img_tokens):global img_adaptor_cal_wsglobal dump_cnt, dump_dataset_namenp.save(os.path.join(img_adaptor_cal_ws, f"img_adaptor_{dump_dataset_name}_{dump_cnt}.npy"), img_tokens.float().contiguous().cpu().detach().numpy())class RDTRunner_Dump(nn.Module,CompatiblePyTorchModelHubMixin,repo_url="https://huggingface.co/robotics-diffusion-transformer/rdt-1b"):def __init__(self,*,action_dim,pred_horizon,config,lang_token_dim,img_token_dim,state_token_dim,max_lang_cond_len,img_cond_len,lang_pos_embed_config=None,img_pos_embed_config=None,dtype=torch.bfloat16):super(RDTRunner_Dump, self).__init__()# Create diffusion modelhidden_size = config['rdt']['hidden_size']self.model = RDT_Dump(output_dim=action_dim,horizon=pred_horizon,hidden_size=hidden_size,depth=config['rdt']['depth'],num_heads=config['rdt']['num_heads'],max_lang_cond_len=max_lang_cond_len,img_cond_len=img_cond_len,lang_pos_embed_config=lang_pos_embed_config,img_pos_embed_config=img_pos_embed_config,dtype=dtype,)# Create adpators for various conditional inputsself.lang_adaptor = self.build_condition_adapter(config['lang_adaptor'], in_features=lang_token_dim, out_features=hidden_size)self.img_adaptor = self.build_condition_adapter(config['img_adaptor'], in_features=img_token_dim, out_features=hidden_size)# A `state` refers to an action or a proprioception vectorself.state_adaptor = self.build_condition_adapter(config['state_adaptor'],in_features=state_token_dim * 2, # state + state mask (indicator)out_features=hidden_size)# Create the noise schedulernoise_scheduler_config = config['noise_scheduler']self.noise_scheduler = DDPMScheduler(num_train_timesteps=noise_scheduler_config['num_train_timesteps'],beta_schedule=noise_scheduler_config['beta_schedule'],prediction_type=noise_scheduler_config['prediction_type'],clip_sample=noise_scheduler_config['clip_sample'],)self.noise_scheduler_sample = DPMSolverMultistepScheduler(num_train_timesteps=noise_scheduler_config['num_train_timesteps'],beta_schedule=noise_scheduler_config['beta_schedule'],prediction_type=noise_scheduler_config['prediction_type'],)self.num_train_timesteps = noise_scheduler_config['num_train_timesteps']self.num_inference_timesteps = noise_scheduler_config['num_inference_timesteps']self.prediction_type = noise_scheduler_config['prediction_type']self.pred_horizon = pred_horizonself.action_dim = action_dimprint("Diffusion params: %e" %sum([p.numel() for p in self.model.parameters()] + [p.numel() for p in self.lang_adaptor.parameters()] +[p.numel()for p in self.img_adaptor.parameters()] + [p.numel() for p in self.state_adaptor.parameters()]))def build_condition_adapter(self, projector_type, in_features, out_features):projector = Noneif projector_type == 'linear':projector = nn.Linear(in_features, out_features)else:mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)if mlp_gelu_match:mlp_depth = int(mlp_gelu_match.group(1))modules = [nn.Linear(in_features, out_features)]for _ in range(1, mlp_depth):modules.append(nn.GELU(approximate="tanh"))modules.append(nn.Linear(out_features, out_features))projector = nn.Sequential(*modules)if projector is None:raise ValueError(f'Unknown projector type: {projector_type}')return projectordef adapt_conditions(self, lang_tokens, img_tokens, state_tokens):adpated_lang = self.lang_adaptor(lang_tokens)dump_img_adaptor(img_tokens)adpated_img = self.img_adaptor(img_tokens)adpated_state = self.state_adaptor(state_tokens)return adpated_lang, adpated_img, adpated_statedef conditional_sample(self, lang_cond, lang_attn_mask, img_cond, state_traj, action_mask, ctrl_freqs):device = state_traj.devicedtype = state_traj.dtypenoisy_action = torch.randn(size=(state_traj.shape[0], self.pred_horizon, self.action_dim), dtype=dtype, device=device)action_mask = action_mask.expand(-1, self.pred_horizon, -1)# Set step valuesself.noise_scheduler_sample.set_timesteps(self.num_inference_timesteps)for t in self.noise_scheduler_sample.timesteps:# Prepare state-action trajectoryaction_traj = torch.cat([noisy_action, action_mask], dim=2)action_traj = self.state_adaptor(action_traj)state_action_traj = torch.cat([state_traj, action_traj], dim=1)# dumpdump_dit(state_action_traj, ctrl_freqs, t, lang_cond, img_cond, lang_attn_mask)# Predict the model outputmodel_output = self.model(state_action_traj,ctrl_freqs,t.unsqueeze(-1).to(device),lang_cond,img_cond,lang_mask=lang_attn_mask)# Compute previous actions: x_t -> x_t-1noisy_action = self.noise_scheduler_sample.step(model_output, t, noisy_action).prev_samplenoisy_action = noisy_action.to(state_traj.dtype)# Finally apply the action mask to mask invalid action dimensionsnoisy_action = noisy_action * action_maskreturn noisy_actiondef compute_loss(self, lang_tokens, lang_attn_mask, img_tokens, state_tokens, action_gt, action_mask,ctrl_freqs) -> torch.Tensor:'''lang_tokens: (batch_size, lang_len, lang_token_dim)lang_attn_mask: (batch_size, lang_len), a mask for valid language tokens,which should be True-False bool tensor.img_tokens: (batch_size, img_len, img_token_dim)state_tokens: (batch_size, 1, state_token_dim)action_gt: (batch_size, horizon, state_token_dim), ground-truth actions for supervisionaction_mask: (batch_size, 1, state_token_dim), a 0-1 **float** tensor.ctrl_freqs: (batch_size,), control frequency for each sample.return: loss_value, a scalar tensor'''batch_size = lang_tokens.shape[0]device = lang_tokens.device# Sample noise that we'll add to the actionsnoise = torch.randn(action_gt.shape, dtype=action_gt.dtype, device=device)# Sample random diffusion timestepstimesteps = torch.randint(0, self.num_train_timesteps, (batch_size, ), device=device).long()# Add noise to the clean actions according to the noise magnitude at each timestep# (this is the forward diffusion process)noisy_action = self.noise_scheduler.add_noise(action_gt, noise, timesteps)# Concatenate the state and action tokens to form the input sequencestate_action_traj = torch.cat([state_tokens, noisy_action], dim=1)# Append the action mask to the input sequenceaction_mask = action_mask.expand(-1, state_action_traj.shape[1], -1)state_action_traj = torch.cat([state_action_traj, action_mask], dim=2)# Align the dimension with the hidden sizelang_cond, img_cond, state_action_traj = self.adapt_conditions(lang_tokens, img_tokens, state_action_traj)# Predict the denoised resultpred = self.model(state_action_traj, ctrl_freqs, timesteps, lang_cond, img_cond, lang_mask=lang_attn_mask)pred_type = self.prediction_typeif pred_type == 'epsilon':target = noiseelif pred_type == 'sample':target = action_gtelse:raise ValueError(f"Unsupported prediction type {pred_type}")loss = F.mse_loss(pred, target)return lossdef predict_action(self, lang_tokens, lang_attn_mask, img_tokens, state_tokens, action_mask, ctrl_freqs):'''lang_tokens: (batch_size, lang_len, lang_token_dim)lang_attn_mask: (batch_size, lang_len), a mask for valid language tokens,which should be True-False bool tensor.img_tokens: (batch_size, img_len, img_token_dim)state_tokens: (batch_size, 1, state_token_dim)action_mask: (batch_size, 1, action_dim),which should be a 0-1 **float** tensor.ctrl_freqs: (batch_size,), control frequency for each sample.return: (batch_size, horizon, action_dim), predicted action sequence'''# Prepare the state and conditionsstate_tokens = torch.cat([state_tokens, action_mask], dim=2)lang_cond, img_cond, state_traj = self.adapt_conditions(lang_tokens, img_tokens, state_tokens)# Run samplingaction_pred = self.conditional_sample(lang_cond,lang_attn_mask,img_cond,state_traj,action_mask,ctrl_freqs,)return action_preddef forward(self, *args, **kwargs) -> torch.Tensor:return self.compute_loss(*args, **kwargs)class RoboticDiffusionTransformerModel_Dump(object):def __init__(self,args,device="cuda",dtype=torch.bfloat16,image_size=None,control_frequency=25,pretrained=None,pretrained_vision_encoder_name_or_path=None,):self.args = argsself.dtype = dtypeself.image_size = image_sizeself.device = deviceself.control_frequency = control_frequency# We do not use the text encoder due to limited GPU memory# self.text_tokenizer, self.text_model = self.get_text_encoder(pretrained_text_encoder_name_or_path)self.image_processor, self.vision_model = self.get_vision_encoder(pretrained_vision_encoder_name_or_path)self.policy = self.get_policy(pretrained)self.left_arm_dim, self.right_arm_dim = (args["arm_dim"]["left_arm_dim"],args["arm_dim"]["right_arm_dim"],)self.reset()def get_policy(self, pretrained):# Initialize model with argumentsif pretrained is None or os.path.isfile(pretrained):img_cond_len = (self.args["common"]["img_history_size"] * self.args["common"]["num_cameras"] * self.vision_model.num_patches)_model = RDTRunner_Dump(action_dim=self.args["common"]["state_dim"],pred_horizon=self.args["common"]["action_chunk_size"],config=self.args["model"],lang_token_dim=self.args["model"]["lang_token_dim"],img_token_dim=self.args["model"]["img_token_dim"],state_token_dim=self.args["model"]["state_token_dim"],max_lang_cond_len=self.args["dataset"]["tokenizer_max_length"],img_cond_len=img_cond_len,img_pos_embed_config=[# No initial pos embed in the last grid size# since we've already done in ViT("image",(self.args["common"]["img_history_size"],self.args["common"]["num_cameras"],-self.vision_model.num_patches,),),],lang_pos_embed_config=[# Similarly, no initial pos embed for language("lang", -self.args["dataset"]["tokenizer_max_length"]),],dtype=self.dtype,)else:_model = RDTRunner_Dump.from_pretrained(pretrained)return _modeldef get_text_encoder(self, pretrained_text_encoder_name_or_path):text_embedder = T5Embedder(from_pretrained=pretrained_text_encoder_name_or_path,model_max_length=self.args["dataset"]["tokenizer_max_length"],device=self.device,)tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.modelreturn tokenizer, text_encoderdef get_vision_encoder(self, pretrained_vision_encoder_name_or_path):vision_encoder = SiglipVisionTower(vision_tower=pretrained_vision_encoder_name_or_path, args=None)image_processor = vision_encoder.image_processorreturn image_processor, vision_encoderdef reset(self):device = self.deviceweight_dtype = self.dtypeself.policy.eval()# self.text_model.eval()self.vision_model.eval()self.policy = self.policy.to(device, dtype=weight_dtype)# self.text_model = self.text_model.to(device, dtype=weight_dtype)self.vision_model = self.vision_model.to(device, dtype=weight_dtype)def load_pretrained_weights(self, pretrained=None):if pretrained is None:returnprint(f"Loading weights from {pretrained}")filename = os.path.basename(pretrained)if filename.endswith(".pt"):checkpoint = torch.load(pretrained)self.policy.load_state_dict(checkpoint["module"])elif filename.endswith(".safetensors"):from safetensors.torch import load_modelload_model(self.policy, pretrained)else:raise NotImplementedError(f"Unknown checkpoint format: {pretrained}")def encode_instruction(self, instruction, device="cuda"):tokens = self.text_tokenizer(instruction, return_tensors="pt", padding="longest", truncation=True)["input_ids"].to(device)tokens = tokens.view(1, -1)with torch.no_grad():pred = self.text_model(tokens).last_hidden_state.detach()return preddef _format_joint_to_state(self, joints):AGILEX_STATE_INDICES = [STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(6)]# Rescale the gripper to the range of [0, 1]joints = joints / torch.tensor([[[180, 180, 180, 180, 180, 180]]],device=joints.device,dtype=joints.dtype,)B, N, _ = joints.shapestate = torch.zeros((B, N, self.args["model"]["state_token_dim"]),device=joints.device,dtype=joints.dtype,)# Fill into the unified state vectorstate[:, :, AGILEX_STATE_INDICES] = joints# Assemble the mask indicating each dimension's availabilitystate_elem_mask = torch.zeros((B, self.args["model"]["state_token_dim"]),device=joints.device,dtype=joints.dtype,)state_elem_mask[:, AGILEX_STATE_INDICES] = 1return state, state_elem_maskdef _unformat_action_to_joint(self, action):AGILEX_STATE_INDICES = [STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(6)]action_indices = AGILEX_STATE_INDICESjoints = action[:, :, action_indices]# Rescale the gripper back to the action range# Note that the action range and proprioception range are different# for Mobile ALOHA robotjoints = joints * torch.tensor([[[180, 180, 180, 180, 180, 180]]],device=joints.device,dtype=joints.dtype,)return joints@torch.no_grad()def step(self, proprio, images, text_embeds):device = self.devicedtype = self.dtype# The background image used for paddingbackground_color = np.array([int(x * 255) for x in self.image_processor.image_mean], dtype=np.uint8).reshape(1, 1, 3)background_image = (np.ones((self.image_processor.size["height"],self.image_processor.size["width"],3,),dtype=np.uint8,) * background_color)# Preprocess the images by order and encode themimage_tensor_list = []for image in images:if image is None:# Replace it with the background imageimage = PImage.fromarray(background_image)if self.image_size is not None:image = transforms.Resize(self.data_args.image_size)(image)if self.args["dataset"].get("auto_adjust_image_brightness", False):pixel_values = list(image.getdata())average_brightness = sum(sum(pixel) for pixel in pixel_values) / (len(pixel_values) * 255.0 * 3)if average_brightness <= 0.15:image = transforms.ColorJitter(brightness=(1.75, 1.75))(image)if self.args["dataset"].get("image_aspect_ratio", "pad") == "pad":def expand2square(pil_img, background_color):width, height = pil_img.sizeif width == height:return pil_imgelif width > height:result = PImage.new(pil_img.mode, (width, width), background_color)result.paste(pil_img, (0, (width - height) // 2))return resultelse:result = PImage.new(pil_img.mode, (height, height), background_color)result.paste(pil_img, ((height - width) // 2, 0))return resultimage = expand2square(image, tuple(int(x * 255) for x in self.image_processor.image_mean))image = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]image_tensor_list.append(image)image_tensor = torch.stack(image_tensor_list, dim=0).to(device, dtype=dtype)image_embeds = self.vision_model(image_tensor).detach()image_embeds = image_embeds.reshape(-1, self.vision_model.hidden_size).unsqueeze(0)# Prepare the proprioception states and the control frequencyjoints = proprio.to(device).unsqueeze(0) # (1, 1, 14)states, state_elem_mask = self._format_joint_to_state(joints) # (1, 1, 128), (1, 128)states, state_elem_mask = states.to(device, dtype=dtype), state_elem_mask.to(device, dtype=dtype)states = states[:, -1:, :] # (1, 1, 128)ctrl_freqs = torch.tensor([self.control_frequency]).to(device)text_embeds = text_embeds.to(device, dtype=dtype)# Predict the next action chunk given the inputstrajectory = self.policy.predict_action(lang_tokens=text_embeds,lang_attn_mask=torch.ones(text_embeds.shape[:2], dtype=torch.bool, device=text_embeds.device),img_tokens=image_embeds,state_tokens=states,action_mask=state_elem_mask.unsqueeze(1),ctrl_freqs=ctrl_freqs,)trajectory = self._unformat_action_to_joint(trajectory).to(torch.float32)return trajectory######## Prepare Training Data
def fill_in_state(values, left_arm_dim, right_arm_dim):# RDT的Action及状态需要按照要求映射导一个128维的统一向量uni_vec = np.zeros(values.shape[:-1] + (128,))# 仅右臂6关节映射到 [0, 6) 位置for i in range(min(6, right_arm_dim[0])):uni_vec[..., i] = values[..., i]return uni_vecdef get_training_samples(data_dir, num_samples=100, instructions_per_episode=1):training_samples = []logger.info(f"Get Training Data From: {data_dir}.")for root, dirs, files in os.walk(data_dir):for file in files:if file.endswith('.hdf5') and len(training_samples) < num_samples:file_path = os.path.join(root, file)try:with h5py.File(file_path, 'r') as f:observations = f['observations']actions = f['action'][:]images = observations['images']# left_arm_dim = observations['left_arm_dim'][:]# right_arm_dim = observations['right_arm_dim'][:]qpos = observations['qpos'][:]episode_dir = os.path.dirname(file_path)instructions_dir = os.path.join(episode_dir, 'instructions')num_steps = len(qpos)if num_steps > 1:# 收集该 episode 可用的 instruction 索引step_indices = []if os.path.isdir(instructions_dir):for name in os.listdir(instructions_dir):if name.startswith('lang_embed_') and name.endswith('.pt'):try:idx = int(name[len('lang_embed_'):-3])if 0 <= idx < num_steps:step_indices.append(idx)except Exception:continuestep_indices = sorted(set(step_indices))# 按每个 episode 上限采样/截断if len(step_indices) == 0:# 回退策略:如果没有任何配套 instruction,就尝试用一个随机 step,但很可能会在后续被过滤step_indices = []for step_idx in step_indices[:max(1, instructions_per_episode)]:if len(training_samples) >= num_samples:break# 获取多摄像头多历史帧图像multi_cam_images = {}for cam_name in ['cam_high', 'cam_left_wrist', 'cam_right_wrist']:if cam_name in images:cam_images = []# 取 step_idx 的前一帧与当前帧,共两帧for i in range(max(step_idx - 1, 0), step_idx + 1):img_bits = images[cam_name][i]img = cv2.imdecode(np.frombuffer(img_bits, np.uint8), cv2.IMREAD_COLOR)if img is not None:cam_images.append(img)if len(cam_images) == 1:cam_images = [cam_images[0], cam_images[0]]if len(cam_images) >= 2:multi_cam_images[cam_name] = cam_images[:2]if len(multi_cam_images) == 0:logger.warning(f"Skip sample (missing images): {file_path}, step {step_idx}")continue# 语言嵌入与文本lang_embed = Nonelang_str = ""lang_embed_path = os.path.join(instructions_dir, f'lang_embed_{step_idx}.pt')if os.path.exists(lang_embed_path):try:lang_embed = torch.load(lang_embed_path)except Exception as e:logger.error(f"Error reading text file {lang_embed_path}: {e}")lang_embed = Nonelang_str_path = os.path.join(instructions_dir, f'txt_lang_embed_{step_idx}.txt')if os.path.exists(lang_str_path):try:with open(lang_str_path, 'r', encoding='utf-8') as f:lang_str = f.read().strip()except Exception as e:logger.error(f"Error reading text file {lang_str_path}: {e}")lang_str = ""if lang_embed is None:logger.warning(f"Skip sample (missing lang): {file_path}, step {step_idx}")continuetraining_samples.append({'multi_cam_images': multi_cam_images,'joints': actions[step_idx],'lang_embed': lang_embed,'lang_str': lang_str,'source': file_path,'step': step_idx})logger.debug(f"TimeStep: {step_idx}, Sample: {file_path}")except Exception as e:logger.error(f"Faild: {file_path} : {e}")continuelogger.info(f"Total Num: {len(training_samples)}.")return training_samplesif __name__ == "__main__":main()
这段脚本有如下的参数可以进行配置,大家仅需按照自己的实际情况配置即可
parser.add_argument('--export_path', type=str, default="rdt_export_ws", help="")
parser.add_argument('--config_path', type=str, default="configs/base.yaml", help="")
parser.add_argument('--pretrained_vision_encoder', type=str, default="../weights/RDT/siglip-so400m-patch14-384", help="")
parser.add_argument('--pretrained_model', type=str, default="checkpoints/RDT170M_LeRobot/checkpoint-9000/pytorch_model/mp_rank_00_model_states.pt", help="")
parser.add_argument('--train_data', type=str, default="training_data/RDT170M_LeRobot", help="")
parser.add_argument('--num_samples', type=int, default=100, help="")
parser.add_argument('--jobs', type=int, default=8, help="")
parser.add_argument('--optimized_level', type=str, default="O2", help="")
parser.add_argument('--ctrl_freq', type=int, default=25, help="")
parser.add_argument('--left_arm_dim', type=int, default=6, help="")
parser.add_argument('--right_arm_dim', type=int, default=6, help="")
parser.add_argument('--cal_data_device', type=str, default='cuda', help="")
parser.add_argument('--instructions_per_episode', type=int, default=1, help="")
接下来我们执行这段脚本之后在RDT的目录下会生成脚本导出的ONNX模型以及dump出的校准数据和文件
(RoboTwin) qi.xiong@A100-Test:~/DualArm/LeRobot-VLA/RDT/rdt_export_ws_test$ tree -L 1
.
├── BPU_RDT_Policy
├── build_all.sh
├── DiT_WorkSpace
├── img_adaptor_WorkSpace
├── instructions
└── test_data5 directories, 1 file
接着我们使用RDKS100算法工具链标准交付的docker环境,来完成BPU模型的量化和编译,docker的安装挂载和使用命令参考以下:
# 下载并挂载RDKS100官方Docker工具链交付物
#CPU Docker镜像
wget -c ftp://oeftp@sdk.d-robotics.cc/oe_v3.2.0/ai_toolchain_ubuntu_22_s100_cpu_v3.2.0.tar --ftp-password=Oeftp~123$%
#GPU Docker镜像
wget -c ftp://oeftp@sdk.d-robotics.cc/oe_v3.2.0/ai_toolchain_ubuntu_22_s100_gpu_v3.2.0.tar --ftp-password=Oeftp~123$%
# 导入镜像
docker load < ai_toolchain_ubuntu_22_s100_gpu_v3.2.0.tar# 挂载镜像执行自动编译
[sudo] docker run [--gpus all] -it -v <BPU_Work_Space>:/open_explorer REPOSITORY:TAG
# Example:
docker run -it --rm --gpus device=6 --shm-size=32g \-v /home/qi.xiong/DualArm/LeRobot-VLA/RDT/rdt_export_ws_test:/open_explorer \ai_toolchain_ubuntu_22_s100_gpu:v3.2.0
在Docker内输入bash build_all.sh
后便会自动开启编译并导出可以在RDKS100板端运行的HBM模型
以下是DiT、img_adaptor、lang_adaptor以及state_adaptor模型量化过程中的部分指标参考:
####################### rdt_img_adaptor #######################
2025-09-17 18:47:06,602 INFO +-------------+-------------------+------------------+
2025-09-17 18:47:06,603 INFO | TensorName | Calibrated Cosine | Quantized Cosine |
2025-09-17 18:47:06,603 INFO +-------------+-------------------+------------------+
2025-09-17 18:47:06,603 INFO | adpated_img | 0.999790 | 0.999789 |
2025-09-17 18:47:06,603 INFO +-------------+-------------------+------------------+####################### rdt_dit #######################
2025-09-17 18:58:29,082 INFO ================Bias_Correction_Info(mae)================
Node Before-Bias-Correlation After-Bias-Correlation
---------------------------------------------------------
output 0.000461 0.000070
2025-09-17 19:08:38,073 INFO +------------+-------------------+------------------+
2025-09-17 19:08:38,073 INFO | TensorName | Calibrated Cosine | Quantized Cosine |
2025-09-17 19:08:38,073 INFO +------------+-------------------+------------------+
2025-09-17 19:08:38,073 INFO | actions | 0.999881 | 0.999972 |
2025-09-17 19:08:38,073 INFO +------------+-------------------+------------------+
接着我们运行以下命令下载已经量化编译好的SigLip到编译产物结果文件夹BPU_RDT_Policy
中并将这个文件夹和test_data
文件夹以及LeRobot-VLA仓库中RDKS_ModelRun/RDT
路径下的所有文件复制到我们的RDKS100板端如图所示:
wget https://archive.d-robotics.cc/downloads/rdk_model_zoo/rdk_s100/RoboticsDiffusionTransformers/bpu_siglip_so400m_patch14_nashm_384x384_featuremaps.hbm
接着我们直接运行以下命令即可首先使用校准测试数据验证模型推理是否成功,若无报错即证明模型可正常推理无异常,若运行过程中缺少某个依赖,直接安装即可:
python3 BPU_RDT_policy.py
(RDT) root@ubuntu:~/WorkSpace/RDT# python3 BPU_RDT_policy.py
[UCP]: log level = 3
[UCP]: UCP version = 3.7.4
[VP]: log level = 3
[DNN]: log level = 3
[HPL]: log level = 3
[UCPT]: log level = 6
[RDK_RDT] [22:02:07.183] [INFO] Namespace(bpu_rdt_path='./BPU_RDT_Policy/', test_data_path='./test_data/', ctrl_freq=25, left_arm_dim=0, right_arm_dim=6)
[RDK_RDT] [22:02:07.193] [INFO] Using Single RDK S100(P) mode.
[RDK_RDT] [22:02:07.199] [INFO] Loading dit ...
[BPU][[BPU_MONITOR]][281473781271296][INFO]BPULib verison(2, 1, 2)[0d3f195]!
[DNN] HBTL_EXT_DNN log level:6
[DNN]: 3.7.4_(4.3.2 HBRT)
================== Model Summarys ==================
Model File: ./BPU_RDT_Policy/rdt_dit.hbm
Model Names:
0: rdt_dit [*Select]
Task N: 1
Inputs/Outputs AlignedByteSize: 17.6215MB.
Inputs Info:
[0][x]: float32, (1, 65, 1024, )
[1][freq]: int32, (1, )
[2][t]: int32, (1, )
[3][lang_c]: float32, (1, 64, 1024, )
[4][img_c]: float32, (1, 4374, 1024, )
[5][lang_mask]: float32, (1, 64, )
Outputs Info:
[0][actions]: float32, (1, 64, 128, )
====================================================
[RDK_RDT] [22:02:07.666] [INFO] Loading img adaptor ...
================== Model Summarys ==================
Model File: ./BPU_RDT_Policy/rdt_img_adaptor.hbm
Model Names:
0: rdt_img_adaptor [*Select]
Task N: 1
Inputs/Outputs AlignedByteSize: 36.3076MB.
Inputs Info:
[0][img_tokens]: float32, (1, 4374, 1152, )
Outputs Info:
[0][adpated_img]: float32, (1, 4374, 1024, )
====================================================
[RDK_RDT] [22:02:07.671] [INFO] loading lang adaptor ...
[RDK_RDT] [22:02:07.708] [INFO] NodeArg(name='text_embeds', type='tensor(float)', shape=[1, 'N', 4096])
[RDK_RDT] [22:02:07.708] [INFO] NodeArg(name='lang_cond', type='tensor(float)', shape=[1, 'N', 1024])
[RDK_RDT] [22:02:07.709] [INFO] Loading state sdaptor 1x1x256 ...
[RDK_RDT] [22:02:07.727] [INFO] NodeArg(name='state_tokens', type='tensor(float)', shape=[1, 1, 256])
[RDK_RDT] [22:02:07.727] [INFO] NodeArg(name='state_traj', type='tensor(float)', shape=[1, 1, 1024])
[RDK_RDT] [22:02:07.727] [INFO] Loading state sdaptor 1x64x256 ...
[RDK_RDT] [22:02:07.749] [INFO] NodeArg(name='state_tokens', type='tensor(float)', shape=[1, 64, 256])
[RDK_RDT] [22:02:07.749] [INFO] NodeArg(name='state_traj', type='tensor(float)', shape=[1, 64, 1024])
[RDK_RDT] [22:02:07.749] [INFO] Loading bpu_siglip_so400m_patch14_nashm_384x384_featuremaps.hbm ... (Please wait for 20 seconds.)
================== Model Summarys ==================
Model File: ./BPU_RDT_Policy/bpu_siglip_so400m_patch14_nashm_384x384_featuremaps.hbm
Model Names:
0: SiglipVisionModel [*Select]
Task N: 4
Inputs/Outputs AlignedByteSize: 19.5645MB.
Inputs Info:
[0][_input_0]: float32, (1, 3, 384, 384, )
Outputs Info:
[0][_output_0]: float32, (1, 729, 1152, )
====================================================
[RDK_RDT] [22:02:28.359] [INFO] === Compare Actions: 1_actions.npy ===
[RDK_RDT] [22:02:28.363] [DEBUG] proprio.shape = (6,)
[RDK_RDT] [22:02:28.367] [DEBUG] lang_embeddings.shape = (1, 11, 4096)
[RDK_RDT] [22:02:28.370] [DEBUG] set_lang_condition time = 2.80 ms
[RDK_RDT] [22:02:28.371] [DEBUG] Language Condition Shape: torch.Size([1, 11, 1024])
[RDK_RDT] [22:02:29.521] [DEBUG] SigLIP BPU Forward time = 1134.01 ms
[RDK_RDT] [22:02:29.538] [DEBUG] image adaptor time = 10.09 ms
[RDK_RDT] [22:02:29.539] [DEBUG] state adaptor time = 0.64 ms
[RDK_RDT] [22:02:29.556] [DEBUG] DiT conditional_sample time = 16.19 ms
[RDK_RDT] [22:02:29.572] [DEBUG] DiT PreProcess time = 15.98 ms
[RDK_RDT] [22:02:29.775] [DEBUG] DiT BPU Forward time = 202.91 ms
[RDK_RDT] [22:02:29.778] [DEBUG] DiT BPU PostProcess time = 2.06 ms
[RDK_RDT] [22:02:29.788] [DEBUG] DiT PreProcess time = 10.55 ms
[RDK_RDT] [22:02:29.983] [DEBUG] DiT BPU Forward time = 194.35 ms
[RDK_RDT] [22:02:29.984] [DEBUG] DiT BPU PostProcess time = 0.80 ms
[RDK_RDT] [22:02:29.994] [DEBUG] DiT PreProcess time = 9.84 ms
[RDK_RDT] [22:02:30.188] [DEBUG] DiT BPU Forward time = 194.06 ms
[RDK_RDT] [22:02:30.189] [DEBUG] DiT BPU PostProcess time = 0.78 ms
[RDK_RDT] [22:02:30.201] [DEBUG] DiT PreProcess time = 11.46 ms
[RDK_RDT] [22:02:30.395] [DEBUG] DiT BPU Forward time = 193.90 ms
[RDK_RDT] [22:02:30.396] [DEBUG] DiT BPU PostProcess time = 0.79 ms
[RDK_RDT] [22:02:30.404] [DEBUG] DiT PreProcess time = 7.87 ms
[RDK_RDT] [22:02:30.602] [DEBUG] DiT BPU Forward time = 198.14 ms
[RDK_RDT] [22:02:30.603] [DEBUG] DiT BPU PostProcess time = 0.59 ms
[RDK_RDT] [22:02:30.605] [INFO] BPU RDT time = 2233.59 ms
COS: 0.98739, A: -0.52879 ~ 0.54263, B: -94.66862 ~ 93.62534, Error: -93.09312 ~ 94.17428
[RDK_RDT] [22:02:30.606] [INFO] === Compare Actions: 2_actions.npy ===
[RDK_RDT] [22:02:30.609] [DEBUG] proprio.shape = (6,)
[RDK_RDT] [22:02:30.610] [DEBUG] lang_embeddings.shape = (1, 11, 4096)
[RDK_RDT] [22:02:30.625] [DEBUG] set_lang_condition time = 14.65 ms
[RDK_RDT] [22:02:30.625] [DEBUG] Language Condition Shape: torch.Size([1, 11, 1024])
测试正常后我们便可以仿照上面服务器推理一样先启动板端推理代码,将其作为一个板端推理Server,接着运行我们本地的机械臂控制代码即可:
# 板端运行
python3 rdks100_server.py
# 本地运行
python3 lerobot/src/lerobot/record_rdt.py
RDKS100板端性能占用参考如下(多多支持我的Dtop:Jetson有Jtop,Linux有Htop,RDK也有Dtop! - SkyXZ - 博客园哈哈哈哈哈哈哈哈哈
RDKS100可直接连接LeRobot控制机械臂,但由于写文档的时候手上没有多余的摄像头,头部的摄像头只能使用Apple的连续互通实现,若使用RDK直接连接LeRobot的话将无法当问头部摄像头,因此只能绕个弯用本地Mac来控制机械臂了...