From dda0c15c99eee5e7efb7a4b0cc9f94cae9469f96 Mon Sep 17 00:00:00 2001 From: sriramprasadkothapalli Date: Tue, 16 Jul 2024 22:25:32 -0400 Subject: [PATCH 1/4] first try jai shree ram git add . --- aloha_scripts/auto_record.sh | 15 + aloha_scripts/constants.py | 69 ++++ aloha_scripts/real_env.py | 121 +++++++ aloha_scripts/robot_utils.py | 181 ++++++++++ aloha_scripts/teleop.py | 208 ++++++++++++ demo_real_robot.py | 177 ++++++---- .../config/task/real_pusht_image.yaml | 29 +- ...n_diffusion_unet_real_image_workspace.yaml | 20 +- diffusion_policy/real_world/real_env.py | 308 ++++++++++-------- .../rtde_interpolation_controller.py | 4 +- .../real_world/single_realsense.py | 218 ++++++++++--- .../real_world/spacemouse_shared_memory.py | 32 +- eval_real_robot.py | 126 +++---- visualization/ft.py | 135 ++++++++ visualization/visualize_episode_length.py | 22 ++ visualization/visualize_ft_data.py | 64 ++++ visualization/visualize_robot_calibration.py | 79 +++++ visualization/visualize_robot_trajectory.py | 61 ++++ visualization/visualize_robot_tree.py | 89 +++++ visualization/viz_constants.py | 1 + 20 files changed, 1606 insertions(+), 353 deletions(-) create mode 100644 aloha_scripts/auto_record.sh create mode 100644 aloha_scripts/constants.py create mode 100644 aloha_scripts/real_env.py create mode 100644 aloha_scripts/robot_utils.py create mode 100644 aloha_scripts/teleop.py create mode 100644 visualization/ft.py create mode 100644 visualization/visualize_episode_length.py create mode 100644 visualization/visualize_ft_data.py create mode 100644 visualization/visualize_robot_calibration.py create mode 100644 visualization/visualize_robot_trajectory.py create mode 100644 visualization/visualize_robot_tree.py create mode 100644 visualization/viz_constants.py diff --git a/aloha_scripts/auto_record.sh b/aloha_scripts/auto_record.sh new file mode 100644 index 00000000..cc8748b7 --- /dev/null +++ b/aloha_scripts/auto_record.sh @@ -0,0 +1,15 @@ +if [ "$2" -lt 0 ]; then + echo "# of episodes not valid" + exit +fi + +echo "Task: $1" +for (( i=0; i<$2; i++ )) +do + echo "Starting episode $i" + python3 record_episodes.py --task "$1" + if [ $? -ne 0 ]; then + echo "Failed to execute command. Returning" + exit + fi +done \ No newline at end of file diff --git a/aloha_scripts/constants.py b/aloha_scripts/constants.py new file mode 100644 index 00000000..61e91be5 --- /dev/null +++ b/aloha_scripts/constants.py @@ -0,0 +1,69 @@ +### Task parameters +import numpy as np + +DATA_DIR = '/home/bmv/aloha/src/dataset' +TASK = 'Rice_scoop_master_150_ft' +TASK_CONFIGS = { + TASK:{ + 'dataset_dir': DATA_DIR + '/' + TASK, + 'num_episodes': 150, + 'episode_len': 700, + 'camera_names': ['cam_low', 'cam_high', 'cam_wrist'], + 'json_dir' : '/home/bmv/act_dec2023/segmentation' + '/' + TASK + '_json' + }, +} + +### ALOHA fixed constants +MASTER_IP = "192.168.137.12" +FOLLOWER_IP = "192.168.137.1" + +FREQUENCY = 50 #Hz +DT = 1/FREQUENCY + +CAM_WIDTH = 640 +CAM_HEIGHT = 480 + +FRONT_CAM_ID = '936322071211' +WRIST_CAM_ID = '128422271715' +WRIST_CAM_MASTER_ID = '127122270237' + +VIZ_DIR = "base_press_ft" + +JOINT_NAMES = ["base", "shoulder", "elbow", "wrist_3", "wrist_2", "wrist_1"] +HOME_POSE = np.deg2rad([-93,-54,-109,-76,90,0]) + +# Left finger position limits (qpos[7]), right_finger = -1 * left_finger +MASTER_GRIPPER_POSITION_OPEN = 0.02417 +MASTER_GRIPPER_POSITION_CLOSE = 0.01244 +PUPPET_GRIPPER_POSITION_OPEN = 0.05800 +PUPPET_GRIPPER_POSITION_CLOSE = 0.01844 + +# Gripper joint limits (qpos[6]) +MASTER_GRIPPER_JOINT_OPEN = 0.3083 +MASTER_GRIPPER_JOINT_CLOSE = -0.6842 +PUPPET_GRIPPER_JOINT_OPEN = 1.4910 +PUPPET_GRIPPER_JOINT_CLOSE = -0.6213 + +############################ Helper functions ############################ + +MASTER_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_POSITION_CLOSE) / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) +PUPPET_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_POSITION_CLOSE) / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) +MASTER_GRIPPER_POSITION_UNNORMALIZE_FN = lambda x: x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE +PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN = lambda x: x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE +MASTER2PUPPET_POSITION_FN = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(MASTER_GRIPPER_POSITION_NORMALIZE_FN(x)) + +MASTER_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) +PUPPET_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) +MASTER_GRIPPER_JOINT_UNNORMALIZE_FN = lambda x: x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE +PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN = lambda x: x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE +MASTER2PUPPET_JOINT_FN = lambda x: PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(MASTER_GRIPPER_JOINT_NORMALIZE_FN(x)) + +MASTER_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) +PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + +MASTER_POS2JOINT = lambda x: MASTER_GRIPPER_POSITION_NORMALIZE_FN(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE +MASTER_JOINT2POS = lambda x: MASTER_GRIPPER_POSITION_UNNORMALIZE_FN((x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)) +PUPPET_POS2JOINT = lambda x: PUPPET_GRIPPER_POSITION_NORMALIZE_FN(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE +PUPPET_JOINT2POS = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN((x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)) + +MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE)/2 \ No newline at end of file diff --git a/aloha_scripts/real_env.py b/aloha_scripts/real_env.py new file mode 100644 index 00000000..96bcb7a1 --- /dev/null +++ b/aloha_scripts/real_env.py @@ -0,0 +1,121 @@ +import sys +directory_to_add = "~/act_dec2023" +sys.path.append(directory_to_add) + + +import time +import numpy as np +import collections +import matplotlib.pyplot as plt +import dm_env + +from aloha_scripts.constants import DT, HOME_POSE, MASTER_IP, FOLLOWER_IP +from aloha_scripts.robot_utils import Recorder, ImageRecorder +from aloha_scripts.teleop import Follower, Master, generateTrajectory +# from interbotix_xs_modules.arm import InterbotixManipulatorXS +# from interbotix_xs_msgs.msg import JointSingleCommand + +# import IPython +# e = IPython.embed + + +class RealEnv: + """ + Environment for real robot bi-manual manipulation + Action space: [left_arm_qpos (6), # absolute joint position + left_gripper_positions (1), # normalized gripper position (0: close, 1: open) + right_arm_qpos (6), # absolute joint position + right_gripper_positions (1),] # normalized gripper position (0: close, 1: open) + + Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position + left_gripper_position (1), # normalized gripper position (0: close, 1: open) + right_arm_qpos (6), # absolute joint position + right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open) + "qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad) + left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing) + right_arm_qvel (6), # absolute joint velocity (rad) + right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing) + "images": {"cam_high": (480x640x3), # h, w, c, dtype='uint8' + "cam_low": (480x640x3), # h, w, c, dtype='uint8' + "cam_left_wrist": (480x640x3), # h, w, c, dtype='uint8' + "cam_right_wrist": (480x640x3)} # h, w, c, dtype='uint8' + """ + + def __init__(self): + self.puppet = Follower(FOLLOWER_IP) + self.n_obs_steps = 2 + # self.recorder = Recorder('left', init_node=False) + self.image_recorder = ImageRecorder(init_node=True) + + + def get_qpos(self): + qpos = self.puppet.getJointAngles() + return np.array(qpos) + + def get_qvel(self): + qvel = self.puppet.getJointVelocity() + return np.array(qvel) + + def get_effort(self): + qeff = self.puppet.getJointEffort() + return np.array(qeff) + + def get_images(self): + return self.image_recorder.get_images() + + def get_timestamp(self): + return time.time() + + def get_TCP(self): + tcp = self.puppet.getTCPPosition() + return np.array(tcp) + + def get_observation(self): + obs = dict(self.get_images()) + obs['robot_joint'] = self.get_qpos() + obs['robot_eef_pose'] = self.get_TCP() + obs['timestamp'] = self.get_timestamp() + return obs + + def get_reward(self): + return 0 + + def reset(self, fake=False): + if not fake: + initial_observation = {} + for step in range(self.n_obs_steps): + obs_data = self.get_observation() + for key, value in obs_data.items(): + # Check if the key is already in the dictionary + if key not in initial_observation: + initial_observation[key] = [value] + else: + initial_observation[key].append(value) + + return dm_env.TimeStep( + step_type=dm_env.StepType.FIRST, + reward=self.get_reward(), + discount=None, + observation=initial_observation) + + def step(self, action): + self.puppet.operate(action) + # time.sleep(DT) + return dm_env.TimeStep( + step_type=dm_env.StepType.MID, + reward=self.get_reward(), + discount=None, + observation=self.get_observation()) + + +def get_action(master): + action = np.zeros(6) # 6 joint + 1 gripper, for two arms + # Arm actions + action = master.getJointAngles() + + return action + + +def make_real_env(): + env = RealEnv() + return env diff --git a/aloha_scripts/robot_utils.py b/aloha_scripts/robot_utils.py new file mode 100644 index 00000000..06b52227 --- /dev/null +++ b/aloha_scripts/robot_utils.py @@ -0,0 +1,181 @@ +import numpy as np +import time +from .constants import DT +from collections import deque +import rospy +from cv_bridge import CvBridge +from sensor_msgs.msg import Image +class ImageRecorder: + def __init__(self, init_node=True, is_debug=False): + + self.is_debug = is_debug + self.bridge = CvBridge() + self.camera_names = ['cam_high', 'cam_low', 'cam_wrist'] + if init_node: + rospy.init_node('image_recorder', anonymous=True) + for cam_name in self.camera_names: + setattr(self, f'{cam_name}_image', None) + setattr(self, f'{cam_name}_secs', None) + setattr(self, f'{cam_name}_nsecs', None) + if cam_name == 'cam_high': + callback_func = self.image_cb_cam_high + elif cam_name == 'cam_low': + callback_func = self.image_cb_cam_low + elif cam_name == 'cam_wrist': + callback_func = self.image_cb_cam_wrist + else: + raise NotImplementedError + rospy.Subscriber(f"/usb_{cam_name}/image_raw", Image, callback_func) + if self.is_debug: + setattr(self, f'{cam_name}_timestamps', deque(maxlen=50)) + time.sleep(0.5) + + def image_cb(self, cam_name, data): + setattr(self, f'{cam_name}_image', self.bridge.imgmsg_to_cv2(data, desired_encoding='bgr8')) + setattr(self, f'{cam_name}_secs', data.header.stamp.secs) + setattr(self, f'{cam_name}_nsecs', data.header.stamp.nsecs) + + + if self.is_debug: + getattr(self, f'{cam_name}_timestamps').append(data.header.stamp.secs + data.header.stamp.secs * 1e-9) + + def image_cb_cam_high(self, data): + cam_name = 'cam_high' + return self.image_cb(cam_name, data) + + def image_cb_cam_low(self, data): + cam_name = 'cam_low' + return self.image_cb(cam_name, data) + + def image_cb_cam_wrist(self, data): + cam_name = 'cam_wrist' + return self.image_cb(cam_name, data) + + def get_images(self): + image_dict = dict() + new_img_dict = dict() + for cam_name in self.camera_names: + image_dict[cam_name] = getattr(self, f'{cam_name}_image') + for i,v in enumerate(image_dict): + new_img_dict[f'camera_{i}'] = image_dict[v] + return new_img_dict + + def print_diagnostics(self): + def dt_helper(l): + l = np.array(l) + diff = l[1:] - l[:-1] + return np.mean(diff) + for cam_name in self.camera_names: + image_freq = 1 / dt_helper(getattr(self, f'{cam_name}_timestamps')) + print(f'{cam_name} {image_freq=:.2f}') + print() + +class Recorder: + def __init__(self, side, init_node=True, is_debug=False): + from collections import deque + import rospy + from sensor_msgs.msg import JointState + from teleop import Follower, Master, generateTrajectory + + self.secs = None + self.nsecs = None + self.qpos = None + self.effort = None + self.arm_command = None + # self.gripper_command = None + self.is_debug = is_debug + + if init_node: + rospy.init_node('recorder', anonymous=True) + # rospy.Subscriber(f"/puppet_{side}/joint_states", JointState, self.puppet_state_cb) + # rospy.Subscriber(f"/puppet_{side}/commands/joint_group", JointGroupCommand, self.puppet_arm_commands_cb) + # rospy.Subscriber(f"/puppet_{side}/commands/joint_single", JointSingleCommand, self.puppet_gripper_commands_cb) + if self.is_debug: + self.joint_timestamps = deque(maxlen=50) + self.arm_command_timestamps = deque(maxlen=50) + # self.gripper_command_timestamps = deque(maxlen=50) + time.sleep(0.1) + + def puppet_state_cb(self, data): + self.qpos = data.position + self.qvel = data.velocity + self.effort = data.effort + self.data = data + if self.is_debug: + self.joint_timestamps.append(time.time()) + + def puppet_arm_commands_cb(self, data): + self.arm_command = data.cmd + if self.is_debug: + self.arm_command_timestamps.append(time.time()) + + # def puppet_gripper_commands_cb(self, data): + # self.gripper_command = data.cmd + # if self.is_debug: + # self.gripper_command_timestamps.append(time.time()) + + def print_diagnostics(self): + def dt_helper(l): + l = np.array(l) + diff = l[1:] - l[:-1] + return np.mean(diff) + + joint_freq = 1 / dt_helper(self.joint_timestamps) + arm_command_freq = 1 / dt_helper(self.arm_command_timestamps) + # gripper_command_freq = 1 / dt_helper(self.gripper_command_timestamps) + + print(f'{joint_freq=:.2f}\n{arm_command_freq=:.2f}\n') + +def get_arm_joint_positions(bot): + return bot.arm.core.joint_states.position[:6] + +# def get_arm_gripper_positions(bot): +# joint_position = bot.gripper.core.joint_states.position[6] +# return joint_position + +def move_arms(bot_list, target_pose_list, move_time=1): + num_steps = int(move_time / DT) + curr_pose_list = [get_arm_joint_positions(bot) for bot in bot_list] + traj_list = [np.linspace(curr_pose, target_pose, num_steps) for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)] + for t in range(num_steps): + for bot_id, bot in enumerate(bot_list): + bot.arm.set_joint_positions(traj_list[bot_id][t], blocking=False) + time.sleep(DT) + +# def move_grippers(bot_list, target_pose_list, move_time): +# gripper_command = JointSingleCommand(name="gripper") +# num_steps = int(move_time / DT) +# curr_pose_list = [get_arm_gripper_positions(bot) for bot in bot_list] +# traj_list = [np.linspace(curr_pose, target_pose, num_steps) for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)] +# for t in range(num_steps): +# for bot_id, bot in enumerate(bot_list): +# gripper_command.cmd = traj_list[bot_id][t] +# bot.gripper.core.pub_single.publish(gripper_command) +# time.sleep(DT) + +# def setup_puppet_bot(bot): +# bot.dxl.robot_reboot_motors("single", "gripper", True) +# bot.dxl.robot_set_operating_modes("group", "arm", "position") +# bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position") +# torque_on(bot) + +# def setup_master_bot(bot): +# bot.dxl.robot_set_operating_modes("group", "arm", "pwm") +# bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position") +# torque_off(bot) + +# def set_standard_pid_gains(bot): +# bot.dxl.robot_set_motor_registers("group", "arm", 'Position_P_Gain', 800) +# bot.dxl.robot_set_motor_registers("group", "arm", 'Position_I_Gain', 0) + +# def set_low_pid_gains(bot): +# bot.dxl.robot_set_motor_registers("group", "arm", 'Position_P_Gain', 100) +# bot.dxl.robot_set_motor_registers("group", "arm", 'Position_I_Gain', 0) + +# def torque_off(bot): +# bot.dxl.robot_torque_enable("group", "arm", False) +# bot.dxl.robot_torque_enable("single", "gripper", False) + +# def torque_on(bot): +# bot.dxl.robot_torque_enable("group", "arm", True) +# bot.dxl.robot_torque_enable("single", "gripper", True) diff --git a/aloha_scripts/teleop.py b/aloha_scripts/teleop.py new file mode 100644 index 00000000..c8203dcf --- /dev/null +++ b/aloha_scripts/teleop.py @@ -0,0 +1,208 @@ +import socket +import time +import rtde_receive +import rtde_control +import numpy as np +import struct +from datetime import datetime +import csv +# import keyboard as key +from visual_kinematics.RobotSerial import * +from math import pi +import h5py +from aloha_scripts.constants import DT, HOME_POSE, MASTER_IP, FOLLOWER_IP + + +""" +Class to control the actual robot (UR5e) +""" +class Follower(): + def __init__(self,ip): + self.ip = ip + self.receive = rtde_receive.RTDEReceiveInterface(self.ip) + self.control = rtde_control.RTDEControlInterface(self.ip) + + def move2Home(self): + print("Moving Follower to Home Pose") + self.control.moveJ(HOME_POSE, 1.500, 0.05, False) + + def move2Pose(self,trajectoryList): + for trajectory in trajectoryList: + self.control.moveJ(trajectory, 3.14, 0.5, False) + + def getJointAngles(self): + angles = self.receive.getActualQ() + return angles + + def getTCPPosition(self): + pose = self.receive.getActualTCPPose() + return pose + + def getJointVelocity(self): + vels = self.receive.getActualQd() + return vels + + def getJointEffort(self): + # efforts = self.receive.getJointTorques() + return np.zeros(6) + + def operate(self,masterJoints): + velocity = 0.5 + acceleration = 0.5 + dt = 1.0/50 # 2ms can use 1/50 + lookahead_time = 0.2 # can use 0.09 + gain = 300 + # t_start = self.control.initPeriod() + self.control.servoJ(masterJoints, velocity, acceleration, dt, lookahead_time, gain) + # self.control.waitPeriod(t_start) + + def disconnect(self): + # print("Disconnecting robot") + self.control.stopScript() + self.control.disconnect() + self.receive.disconnect() + + + +""" +Class to control the replica robot +""" +class Master(): + def __init__(self,ip): + + self.ip = ip + self.receiver_address = (ip, 5000) + self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.server_socket.bind(self.receiver_address) + self.dh_params = np.array( + [ + [0.1625, 0.0, 0.5 * pi, 0.0], + [0.0, -0.425, 0, 0.0], + [0.0, -0.3922, 0, 0.0], + [0.1333, 0.0, 0.5 * pi, 0.0], + [0.0997, 0.0, -0.5 * pi, 0.0], + [0.0996, 0.0, 0.0, 0.0], + ] + ) + self.replica = RobotSerial(self.dh_params) + + def connect(self): + self.server_socket.listen(1) + print("Waiting for a connection...") + self.client_socket, self.client_address = self.server_socket.accept() + print("Connected by", self.client_address) + + def disconnect(self): + self.client_socket.close() + self.server_socket.close() + + def getTCPPosition(self,masterJointAngles): + forward = self.replica.forward(masterJointAngles) + xyz = forward.t_3_1.reshape([3,]) + rxryrz = forward.r_3 + replicaPose = np.concatenate((xyz,rxryrz)) + return replicaPose + + def getJointAngles(self): + data = self.client_socket.recv(24) + encoder2Angles = list(struct.unpack('6f', data)) + return encoder2Angles + + +""" +Class to record and save observations +""" +class Data(): + def __init__(self,fileName): + self.fileName = fileName + + def write2CSV(self,dataArray): + current_date = datetime.now().strftime('%m-%d_%H-%M') + filename = f"{self.fileName}_{current_date}.csv" + heading = ['Sr.no.', 'J1', ' J2', 'J3', 'J4', 'J5', 'J6', 'Timestamp'] + with open(filename, 'w', newline='') as csvfile: + csv_writer = csv.writer(csvfile) + csv_writer.writerow(heading) + for row in dataArray: + csv_writer.writerow(row) + + ## TODO: + def write2h5py(self): + return + +## TODO: +""" +Class to check for collisions +""" +class Collision(): + def __init__(self, check,xlim,ylim,zlim,axis): + self.check = check + self.xlim = xlim + self.ylim = ylim + self.zlim = zlim + self.axis = axis + + def detect(self,followerTCP,masterTCP): + if self.check: + if self.axis == "pos": + pass + if self.axis == "neg": + pass + else: + return + + +""" +Function to generate trajectory from follower to master position +""" +def generateTrajectory(masterPose, followerPose, moveTime = 1): + steps = int(moveTime/0.02) + waypoints = np.linspace(followerPose,masterPose,steps) + return waypoints + + + +def main(): + # initialize and connect master + master = Master(MASTER_IP) + master.connect() + + # iniialize and connect follower + follower = Follower(FOLLOWER_IP) + + # move follower to home pose + follower.move2Home() + + # move follower arm to same position as master + masterJoints = master.getJointAngles() + followerJoints = follower.getJointAngles() + safeTrajectory = generateTrajectory(masterJoints,followerJoints) + for joint in safeTrajectory: + follower.operate(joint) + + # teleoperation + while True: + try: + joints2Follow = master.getJointAngles() + follower.operate(joints2Follow) + qPos = follower.getJointAngles() + qVel = follower.getJointVelocity() + + # if key.is_pressed('c'): + # pass + + except (KeyboardInterrupt,BrokenPipeError,ConnectionResetError): + print("bye bye ") + follower.stop() + master.disconnect() + break + + # except follower.receive.RTDEException as e: + # print("RTDE error: {}".format(e)) + # follower.stop() + # master.disconnect() + # break + +if __name__ == "__main__": + + main() \ No newline at end of file diff --git a/demo_real_robot.py b/demo_real_robot.py index 846badef..32decc7b 100644 --- a/demo_real_robot.py +++ b/demo_real_robot.py @@ -28,45 +28,75 @@ from diffusion_policy.real_world.keystroke_counter import ( KeystrokeCounter, Key, KeyCode ) +from aloha_scripts.teleop import * +from aloha_scripts.constants import * @click.command() -@click.option('--output', '-o', required=True, help="Directory to save demonstration dataset.") -@click.option('--robot_ip', '-ri', required=True, help="UR5's IP address e.g. 192.168.0.204") -@click.option('--vis_camera_idx', default=0, type=int, help="Which RealSense camera to visualize.") -@click.option('--init_joints', '-j', is_flag=True, default=False, help="Whether to initialize robot joint configuration in the beginning.") -@click.option('--frequency', '-f', default=10, type=float, help="Control frequency in Hz.") -@click.option('--command_latency', '-cl', default=0.01, type=float, help="Latency between receiving SapceMouse command to executing on Robot in Sec.") +@click.option( + "--output", "-o", required=True, help="Directory to save demonstration dataset." +) +@click.option( + "--robot_ip", "-ri", required=True, help="UR5's IP address e.g. 192.168.0.204" +) +@click.option( + "--vis_camera_idx", default=2, type=int, help="Which RealSense camera to visualize." +) +@click.option( + "--init_joints", + "-j", + is_flag=True, + default=False, + help="Whether to initialize robot joint configuration in the beginning.", +) +@click.option( + "--frequency", "-f", default=10, type=float, help="Control frequency in Hz." +) +@click.option( + "--command_latency", + "-cl", + default=0.01, + type=float, + help="Latency between receiving SpaceMouse command to executing on Robot in Sec.", +) def main(output, robot_ip, vis_camera_idx, init_joints, frequency, command_latency): - dt = 1/frequency + # TODO: look into rtde_interpolation_controller.py and pose_trajectory_interpolator.py + dt = 1 / frequency + with SharedMemoryManager() as shm_manager: - with KeystrokeCounter() as key_counter, \ - Spacemouse(shm_manager=shm_manager) as sm, \ - RealEnv( - output_dir=output, - robot_ip=robot_ip, - # recording resolution - obs_image_resolution=(1280,720), - frequency=frequency, - init_joints=init_joints, - enable_multi_cam_vis=True, - record_raw_video=True, - # number of threads per camera view for video recording (H.264) - thread_per_video=3, - # video recording quality, lower is better (but slower). - video_crf=21, - shm_manager=shm_manager - ) as env: + with KeystrokeCounter() as key_counter, Spacemouse( + shm_manager=shm_manager + ) as sm, RealEnv( + output_dir=output, + robot_ip=robot_ip, + # recording resolution + obs_image_resolution=(640, 480), + frequency=frequency, + init_joints=init_joints, + enable_multi_cam_vis=True, + record_raw_video=True, + # number of threads per camera view for video recording (H.264) + thread_per_video=3, + # video recording quality, lower is better (but slower). + video_crf=21, + shm_manager=shm_manager, + camera_serial_numbers=["cam_high","cam_low", "cam_wrist"], + video_capture_fps=30 + ) as env: cv2.setNumThreads(1) + # connect to replica + replica = Master(MASTER_IP) + replica.connect() # realsense exposure - env.realsense.set_exposure(exposure=120, gain=0) - # realsense white balance - env.realsense.set_white_balance(white_balance=5900) + # env.realsense.set_exposure(exposure=120, gain=0) + # # realsense white balance + # env.realsense.set_white_balance(white_balance=5900) time.sleep(1.0) - print('Ready!') + print("Ready!") state = env.get_robot_state() - target_pose = state['TargetTCPPose'] + # print("state:", state) + target_pose = state["ActualTCPPose"] t_start = time.monotonic() iter_idx = 0 stop = False @@ -79,82 +109,95 @@ def main(output, robot_ip, vis_camera_idx, init_joints, frequency, command_laten # pump obs obs = env.get_obs() - + # print(obs.keys()) # handle key presses press_events = key_counter.get_press_events() for key_stroke in press_events: - if key_stroke == KeyCode(char='q'): + if key_stroke == KeyCode(char="r"): # Exit program stop = True - elif key_stroke == KeyCode(char='c'): + elif key_stroke == KeyCode(char="c"): # Start recording - env.start_episode(t_start + (iter_idx + 2) * dt - time.monotonic() + time.time()) + env.start_episode( + t_start + + (iter_idx + 2) * dt + - time.monotonic() + + time.time() + ) key_counter.clear() is_recording = True - print('Recording!') - elif key_stroke == KeyCode(char='s'): + print("Recording!") + elif key_stroke == KeyCode(char="b"): # Stop recording env.end_episode() key_counter.clear() is_recording = False - print('Stopped.') + print("Stopped.") elif key_stroke == Key.backspace: # Delete the most recent recorded episode - if click.confirm('Are you sure to drop an episode?'): + if click.confirm("Are you sure to drop an episode?"): env.drop_episode() key_counter.clear() is_recording = False # delete stage = key_counter[Key.space] - + # print("Stage:", stage) # visualize - vis_img = obs[f'camera_{vis_camera_idx}'][-1,:,:,::-1].copy() + vis_img = obs[f"camera_{vis_camera_idx}"][-1, :, :, ::-1].copy() episode_id = env.replay_buffer.n_episodes - text = f'Episode: {episode_id}, Stage: {stage}' + text = f"Episode: {episode_id}, Stage: {stage}" if is_recording: - text += ', Recording!' + text += ", Recording!" cv2.putText( vis_img, text, - (10,30), + (10, 30), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, thickness=2, - color=(255,255,255) + color=(0, 0, 0), ) - - cv2.imshow('default', vis_img) + cv2.imshow("default", vis_img) cv2.pollKey() precise_wait(t_sample) + # get teleop command - sm_state = sm.get_motion_state_transformed() - # print(sm_state) - dpos = sm_state[:3] * (env.max_pos_speed / frequency) - drot_xyz = sm_state[3:] * (env.max_rot_speed / frequency) - - if not sm.is_button_pressed(0): - # translation mode - drot_xyz[:] = 0 - else: - dpos[:] = 0 - if not sm.is_button_pressed(1): - # 2D translation mode - dpos[2] = 0 + # sm_state = sm.get_motion_state_transformed() + # # print(sm_state) + # dpos = sm_state[:3] * (env.max_pos_speed / frequency) + # drot_xyz = sm_state[3:] * (env.max_rot_speed / frequency) - drot = st.Rotation.from_euler('xyz', drot_xyz) - target_pose[:3] += dpos - target_pose[3:] = (drot * st.Rotation.from_rotvec( - target_pose[3:])).as_rotvec() + # if not sm.is_button_pressed(0): + # # translation mode + # drot_xyz[:] = 0 + # else: + # dpos[:] = 0 + # if not sm.is_button_pressed(1): + # # 2D translation mode + # dpos[2] = 0 + + # drot = st.Rotation.from_euler('xyz', drot_xyz) + # target_pose[:3] += dpos + # target_pose[3:] = (drot * st.Rotation.from_rotvec( + # target_pose[3:])).as_rotvec() + + # use replica to get TCP actions + replica_joint = replica.getJointAngles() + replica_joint[-1] -= 0.587999344 + target_pose = replica.getTCPPosition(replica_joint) # execute teleop command env.exec_actions( - actions=[target_pose], - timestamps=[t_command_target-time.monotonic()+time.time()], - stages=[stage]) + actions=[target_pose], + timestamps=[t_command_target - time.monotonic() + time.time()], + replica_joint=[replica_joint], + stages=[stage], + ) precise_wait(t_cycle_end) iter_idx += 1 + replica.disconnect() # %% -if __name__ == '__main__': - main() +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/diffusion_policy/config/task/real_pusht_image.yaml b/diffusion_policy/config/task/real_pusht_image.yaml index a3f7c3f0..f7254a24 100644 --- a/diffusion_policy/config/task/real_pusht_image.yaml +++ b/diffusion_policy/config/task/real_pusht_image.yaml @@ -1,31 +1,30 @@ name: real_image +data_name: rice_scoop_teleop_org_repo -image_shape: [3, 240, 320] -dataset_path: data/pusht_real/real_pusht_20230105 +task_type: all_cameras +image_shape: [3,480,640] +dataset_path: /home/bmv/diffusion_policy_new/data/rice_scoop_teleop_org_repo shape_meta: &shape_meta # acceptable types: rgb, low_dim obs: - # camera_0: - # shape: ${task.image_shape} - # type: rgb + camera_0: + shape: ${task.image_shape} + type: rgb camera_1: shape: ${task.image_shape} type: rgb - # camera_2: - # shape: ${task.image_shape} - # type: rgb - camera_3: + camera_2: shape: ${task.image_shape} type: rgb - # camera_4: - # shape: ${task.image_shape} - # type: rgb robot_eef_pose: - shape: [2] + shape: [6,] type: low_dim + # ft_data: + # shape: [6,] + # type: low_dim action: - shape: [2] + shape: [6,] env_runner: _target_: diffusion_policy.env_runner.real_pusht_image_runner.RealPushTImageRunner @@ -41,7 +40,7 @@ dataset: n_latency_steps: ${n_latency_steps} use_cache: True seed: 42 - val_ratio: 0.00 + val_ratio: 0.1 max_train_episodes: null delta_action: False diff --git a/diffusion_policy/config/train_diffusion_unet_real_image_workspace.yaml b/diffusion_policy/config/train_diffusion_unet_real_image_workspace.yaml index f6187ef1..99f157a0 100644 --- a/diffusion_policy/config/train_diffusion_unet_real_image_workspace.yaml +++ b/diffusion_policy/config/train_diffusion_unet_real_image_workspace.yaml @@ -6,6 +6,8 @@ name: train_diffusion_unet_image _target_: diffusion_policy.workspace.train_diffusion_unet_image_workspace.TrainDiffusionUnetImageWorkspace task_name: ${task.name} +data_name: ${task.data_name} +task_type: ${task.task_type} shape_meta: ${task.shape_meta} exp_name: "default" @@ -43,8 +45,8 @@ policy: _target_: diffusion_policy.model.vision.model_getter.get_resnet name: resnet18 weights: null - resize_shape: [240, 320] - crop_shape: [216, 288] # ch, cw 240x320 90% + resize_shape: [480, 640] + crop_shape: [432, 576] # ch, cw 240x320 90% random_crop: True use_group_norm: True share_rgb_model: False @@ -74,14 +76,14 @@ ema: max_value: 0.9999 dataloader: - batch_size: 64 + batch_size: 16 num_workers: 8 shuffle: True pin_memory: True persistent_workers: True val_dataloader: - batch_size: 64 + batch_size: 16 num_workers: 8 shuffle: False pin_memory: True @@ -102,7 +104,7 @@ training: # optimization lr_scheduler: cosine lr_warmup_steps: 500 - num_epochs: 600 + num_epochs: 300 gradient_accumulate_every: 1 # EMA destroys performance when used with BatchNorm # replace BatchNorm with GroupNorm. @@ -139,14 +141,14 @@ checkpoint: save_last_snapshot: False multi_run: - run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} - wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name} + run_dir: data/outputs/${now:%Y.%m.%d}_${data_name}/${now:%H.%M.%S}_${name}_${task_type} + wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${data_name} hydra: job: override_dirname: ${name} run: - dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} + dir: data/outputs/${now:%Y.%m.%d}_${data_name}/${now:%H.%M.%S}_${name}_${task_type} sweep: - dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} + dir: data/outputs/${now:%Y.%m.%d}_${data_name}/${now:%H.%M.%S}_${name}_${task_type} subdir: ${hydra.job.num} diff --git a/diffusion_policy/real_world/real_env.py b/diffusion_policy/real_world/real_env.py index b731205f..d999e35e 100644 --- a/diffusion_policy/real_world/real_env.py +++ b/diffusion_policy/real_world/real_env.py @@ -5,72 +5,74 @@ import shutil import math from multiprocessing.managers import SharedMemoryManager -from diffusion_policy.real_world.rtde_interpolation_controller import RTDEInterpolationController +from diffusion_policy.real_world.rtde_interpolation_controller import ( + RTDEInterpolationController, +) from diffusion_policy.real_world.multi_realsense import MultiRealsense, SingleRealsense from diffusion_policy.real_world.video_recorder import VideoRecorder from diffusion_policy.common.timestamp_accumulator import ( - TimestampObsAccumulator, + TimestampObsAccumulator, TimestampActionAccumulator, - align_timestamps + align_timestamps, ) from diffusion_policy.real_world.multi_camera_visualizer import MultiCameraVisualizer from diffusion_policy.common.replay_buffer import ReplayBuffer -from diffusion_policy.common.cv2_util import ( - get_image_transform, optimal_row_cols) - +from diffusion_policy.common.cv2_util import get_image_transform, optimal_row_cols DEFAULT_OBS_KEY_MAP = { # robot - 'ActualTCPPose': 'robot_eef_pose', - 'ActualTCPSpeed': 'robot_eef_pose_vel', - 'ActualQ': 'robot_joint', - 'ActualQd': 'robot_joint_vel', + "ActualTCPPose": "robot_eef_pose", + "ActualTCPSpeed": "robot_eef_pose_vel", + "ActualQ": "robot_joint", + "ActualQd": "robot_joint_vel", # timestamps - 'step_idx': 'step_idx', - 'timestamp': 'timestamp' + "step_idx": "step_idx", + "timestamp": "timestamp", } + class RealEnv: - def __init__(self, - # required params - output_dir, - robot_ip, - # env params - frequency=10, - n_obs_steps=2, - # obs - obs_image_resolution=(640,480), - max_obs_buffer_size=30, - camera_serial_numbers=None, - obs_key_map=DEFAULT_OBS_KEY_MAP, - obs_float32=False, - # action - max_pos_speed=0.25, - max_rot_speed=0.6, - # robot - tcp_offset=0.13, - init_joints=False, - # video capture params - video_capture_fps=30, - video_capture_resolution=(1280,720), - # saving params - record_raw_video=True, - thread_per_video=2, - video_crf=21, - # vis params - enable_multi_cam_vis=True, - multi_cam_vis_resolution=(1280,720), - # shared memory - shm_manager=None - ): + def __init__( + self, + # required params + output_dir, + robot_ip, + # env params + frequency=50, + n_obs_steps=2, + # obs + obs_image_resolution=(640, 480), + max_obs_buffer_size=30, + camera_serial_numbers=None, + obs_key_map=DEFAULT_OBS_KEY_MAP, + obs_float32=False, + # action + max_pos_speed=0.25, + max_rot_speed=0.6, + # robot + tcp_offset=0.0, + init_joints=False, + # video capture params + video_capture_fps=30, + video_capture_resolution=(640, 480), + # saving params + record_raw_video=True, + thread_per_video=2, + video_crf=21, + # vis params + enable_multi_cam_vis=True, + multi_cam_vis_resolution=(640, 480), + # shared memory + shm_manager=None, + # ft = True, + ): assert frequency <= video_capture_fps output_dir = pathlib.Path(output_dir) assert output_dir.parent.is_dir() - video_dir = output_dir.joinpath('videos') + video_dir = output_dir.joinpath("videos") video_dir.mkdir(parents=True, exist_ok=True) - zarr_path = str(output_dir.joinpath('replay_buffer.zarr').absolute()) - replay_buffer = ReplayBuffer.create_from_path( - zarr_path=zarr_path, mode='a') - + zarr_path = str(output_dir.joinpath("replay_buffer.zarr").absolute()) + replay_buffer = ReplayBuffer.create_from_path(zarr_path=zarr_path, mode="a") + # print("level 1") if shm_manager is None: shm_manager = SharedMemoryManager() shm_manager.start() @@ -79,46 +81,47 @@ def __init__(self, color_tf = get_image_transform( input_res=video_capture_resolution, - output_res=obs_image_resolution, + output_res=obs_image_resolution, # obs output rgb - bgr_to_rgb=True) + bgr_to_rgb=True, + ) color_transform = color_tf if obs_float32: color_transform = lambda x: color_tf(x).astype(np.float32) / 255 def transform(data): - data['color'] = color_transform(data['color']) + data["color"] = color_transform(data["color"]) return data - + rw, rh, col, row = optimal_row_cols( n_cameras=len(camera_serial_numbers), - in_wh_ratio=obs_image_resolution[0]/obs_image_resolution[1], - max_resolution=multi_cam_vis_resolution + in_wh_ratio=obs_image_resolution[0] / obs_image_resolution[1], + max_resolution=multi_cam_vis_resolution, ) vis_color_transform = get_image_transform( - input_res=video_capture_resolution, - output_res=(rw,rh), - bgr_to_rgb=False + input_res=video_capture_resolution, output_res=(rw, rh), bgr_to_rgb=False ) + def vis_transform(data): - data['color'] = vis_color_transform(data['color']) + data["color"] = vis_color_transform(data["color"]) return data recording_transfrom = None recording_fps = video_capture_fps - recording_pix_fmt = 'bgr24' + recording_pix_fmt = "bgr24" if not record_raw_video: recording_transfrom = transform recording_fps = frequency - recording_pix_fmt = 'rgb24' - + recording_pix_fmt = "rgb24" + # print("level 2") video_recorder = VideoRecorder.create_h264( - fps=recording_fps, - codec='h264', - input_pix_fmt=recording_pix_fmt, + fps=recording_fps, + codec="h264", + input_pix_fmt=recording_pix_fmt, crf=video_crf, - thread_type='FRAME', - thread_count=thread_per_video) + thread_type="FRAME", + thread_count=thread_per_video, + ) realsense = MultiRealsense( serial_numbers=camera_serial_numbers, @@ -138,33 +141,29 @@ def vis_transform(data): vis_transform=vis_transform, recording_transform=recording_transfrom, video_recorder=video_recorder, - verbose=False - ) - + verbose=False, + ) + # print("level 3") multi_cam_vis = None if enable_multi_cam_vis: multi_cam_vis = MultiCameraVisualizer( - realsense=realsense, - row=row, - col=col, - rgb_to_bgr=False + realsense=realsense, row=row, col=col, rgb_to_bgr=False ) - - cube_diag = np.linalg.norm([1,1,1]) - j_init = np.array([0,-90,-90,-90,90,0]) / 180 * np.pi + # print("level 4") + cube_diag = np.linalg.norm([1, 1, 1]) + j_init = np.array([0, -90, -90, -90, 90, 0]) / 180 * np.pi if not init_joints: j_init = None - robot = RTDEInterpolationController( shm_manager=shm_manager, robot_ip=robot_ip, - frequency=125, # UR5 CB3 RTDE + frequency=500, # UR5e RTDE lookahead_time=0.1, gain=300, - max_pos_speed=max_pos_speed*cube_diag, - max_rot_speed=max_rot_speed*cube_diag, + max_pos_speed=max_pos_speed * cube_diag, + max_rot_speed=max_rot_speed * cube_diag, launch_timeout=3, - tcp_offset_pose=[0,0,tcp_offset,0,0,0], + tcp_offset_pose=[0, 0, tcp_offset, 0, 0, 0], payload_mass=None, payload_cog=None, joints_init=j_init, @@ -172,8 +171,8 @@ def vis_transform(data): soft_real_time=False, verbose=False, receive_keys=None, - get_max_k=max_obs_buffer_size - ) + get_max_k=max_obs_buffer_size, + ) self.realsense = realsense self.robot = robot self.multi_cam_vis = multi_cam_vis @@ -194,14 +193,19 @@ def vis_transform(data): self.obs_accumulator = None self.action_accumulator = None self.stage_accumulator = None - + # self.replica_joint_accumulator = None self.start_time = None - + # print("level 7") + # if ft: + # ft_sensor = FTSensor(shm_manager=shm_manager, + # get_max_k=max_obs_buffer_size) + # self.ft_sensor = ft_sensor + # self.ft = ft # ======== start-stop API ============= @property def is_ready(self): return self.realsense.is_ready and self.robot.is_ready - + def start(self, wait=True): self.realsense.start(wait=False) self.robot.start(wait=False) @@ -224,7 +228,7 @@ def start_wait(self): self.robot.start_wait() if self.multi_cam_vis is not None: self.multi_cam_vis.start_wait() - + def stop_wait(self): self.robot.stop_wait() self.realsense.stop_wait() @@ -235,7 +239,7 @@ def stop_wait(self): def __enter__(self): self.start() return self - + def __exit__(self, exc_type, exc_val, exc_tb): self.stop() @@ -247,22 +251,23 @@ def get_obs(self) -> dict: # get data # 30 Hz, camera_receive_timestamp k = math.ceil(self.n_obs_steps * (self.video_capture_fps / self.frequency)) - self.last_realsense_data = self.realsense.get( - k=k, - out=self.last_realsense_data) - + # print("before realsense data:",self.last_realsense_data) + self.last_realsense_data = self.realsense.get(k=k, out=self.last_realsense_data) + # print("after realsense data:",self.last_realsense_data) # 125 hz, robot_receive_timestamp last_robot_data = self.robot.get_all_state() # both have more than n_obs_steps data # align camera obs timestamps dt = 1 / self.frequency - last_timestamp = np.max([x['timestamp'][-1] for x in self.last_realsense_data.values()]) + last_timestamp = np.max( + [x["timestamp"][-1] for x in self.last_realsense_data.values()] + ) obs_align_timestamps = last_timestamp - (np.arange(self.n_obs_steps)[::-1] * dt) camera_obs = dict() for camera_idx, value in self.last_realsense_data.items(): - this_timestamps = value['timestamp'] + this_timestamps = value["timestamp"] this_idxs = list() for t in obs_align_timestamps: is_before_idxs = np.nonzero(this_timestamps < t)[0] @@ -271,10 +276,10 @@ def get_obs(self) -> dict: this_idx = is_before_idxs[-1] this_idxs.append(this_idx) # remap key - camera_obs[f'camera_{camera_idx}'] = value['color'][this_idxs] + camera_obs[f"camera_{camera_idx}"] = value["color"][this_idxs] # align robot obs - robot_timestamps = last_robot_data['robot_receive_timestamp'] + robot_timestamps = last_robot_data["robot_receive_timestamp"] this_timestamps = robot_timestamps this_idxs = list() for t in obs_align_timestamps: @@ -288,33 +293,42 @@ def get_obs(self) -> dict: for k, v in last_robot_data.items(): if k in self.obs_key_map: robot_obs_raw[self.obs_key_map[k]] = v - + robot_obs = dict() for k, v in robot_obs_raw.items(): robot_obs[k] = v[this_idxs] - # accumulate obs + # print("last robot data keys:", last_robot_data.keys()) + # print("robot_obs_raw keys:", robot_obs_raw.keys()) + # print("robot_obs keys:", robot_obs.keys()) + + # accumulate robot obs if self.obs_accumulator is not None: - self.obs_accumulator.put( - robot_obs_raw, - robot_timestamps - ) + self.obs_accumulator.put(robot_obs_raw, robot_timestamps) # return obs obs_data = dict(camera_obs) obs_data.update(robot_obs) - obs_data['timestamp'] = obs_align_timestamps + obs_data["timestamp"] = obs_align_timestamps return obs_data - - def exec_actions(self, - actions: np.ndarray, - timestamps: np.ndarray, - stages: Optional[np.ndarray]=None): + + def exec_actions( + self, + actions: np.ndarray, + timestamps: np.ndarray, + replica_joint: Optional[np.ndarray] = None, # added to capture replica joints 4/2/24 Abhi + stages: Optional[np.ndarray] = None, + ): assert self.is_ready if not isinstance(actions, np.ndarray): actions = np.array(actions) if not isinstance(timestamps, np.ndarray): timestamps = np.array(timestamps) + + # # added to capture replica joints 4/2/24 Abhi + # if not isinstance(replica_joint, np.ndarray): + # replica_joint = np.array(replica_joint) + if stages is None: stages = np.zeros_like(timestamps, dtype=np.int64) elif not isinstance(stages, np.ndarray): @@ -324,28 +338,26 @@ def exec_actions(self, receive_time = time.time() is_new = timestamps > receive_time new_actions = actions[is_new] + # new_replica_joint = replica_joint[is_new] # added to capture replica joints 4/2/24 Abhi new_timestamps = timestamps[is_new] new_stages = stages[is_new] # schedule waypoints for i in range(len(new_actions)): self.robot.schedule_waypoint( - pose=new_actions[i], - target_time=new_timestamps[i] + pose=new_actions[i], target_time=new_timestamps[i] ) - + # record actions if self.action_accumulator is not None: - self.action_accumulator.put( - new_actions, - new_timestamps - ) + self.action_accumulator.put(new_actions, new_timestamps) if self.stage_accumulator is not None: - self.stage_accumulator.put( - new_stages, - new_timestamps - ) - + self.stage_accumulator.put(new_stages, new_timestamps) + + # # added to capture replica joints 4/2/24 Abhi + # if self.replica_joint_accumulator is not None: + # self.replica_joint_accumulator.put(new_replica_joint, new_timestamps) + def get_robot_state(self): return self.robot.get_state() @@ -365,32 +377,34 @@ def start_episode(self, start_time=None): n_cameras = self.realsense.n_cameras video_paths = list() for i in range(n_cameras): - video_paths.append( - str(this_video_dir.joinpath(f'{i}.mp4').absolute())) - + video_paths.append(str(this_video_dir.joinpath(f"{i}.mp4").absolute())) + # start recording on realsense self.realsense.restart_put(start_time=start_time) self.realsense.start_recording(video_path=video_paths, start_time=start_time) # create accumulators self.obs_accumulator = TimestampObsAccumulator( - start_time=start_time, - dt=1/self.frequency + start_time=start_time, dt=1 / self.frequency ) self.action_accumulator = TimestampActionAccumulator( - start_time=start_time, - dt=1/self.frequency + start_time=start_time, dt=1 / self.frequency ) self.stage_accumulator = TimestampActionAccumulator( - start_time=start_time, - dt=1/self.frequency + start_time=start_time, dt=1 / self.frequency ) - print(f'Episode {episode_id} started!') - + + # # added to capture replica joints 4/2/24 Abhi + # self.replica_joint_accumulator = TimestampActionAccumulator( + # start_time=start_time, dt=1 / self.frequency + # ) + + print(f"Episode {episode_id} started!") + def end_episode(self): "Stop recording" assert self.is_ready - + # stop video recorder self.realsense.stop_recording() @@ -399,6 +413,9 @@ def end_episode(self): assert self.action_accumulator is not None assert self.stage_accumulator is not None + # added to capture replica joints 4/2/24 Abhi + # assert self.replica_joint_accumulator is not None + # Since the only way to accumulate obs and action is by calling # get_obs and exec_actions, which will be in the same thread. # We don't need to worry new data come in here. @@ -408,21 +425,31 @@ def end_episode(self): actions = self.action_accumulator.actions action_timestamps = self.action_accumulator.timestamps stages = self.stage_accumulator.actions + + # added to capture replica joints 4/2/24 Abhi + # replica_joint = self.replica_joint_accumulator.actions + n_steps = min(len(obs_timestamps), len(action_timestamps)) if n_steps > 0: episode = dict() - episode['timestamp'] = obs_timestamps[:n_steps] - episode['action'] = actions[:n_steps] - episode['stage'] = stages[:n_steps] + episode["timestamp"] = obs_timestamps[:n_steps] + episode["action"] = actions[:n_steps] + episode["stage"] = stages[:n_steps] + + # added to capture replica joints + # episode["replica_joint"] = replica_joint[:n_steps] + for key, value in obs_data.items(): + # print(key) episode[key] = value[:n_steps] - self.replay_buffer.add_episode(episode, compressors='disk') + self.replay_buffer.add_episode(episode, compressors="disk") episode_id = self.replay_buffer.n_episodes - 1 - print(f'Episode {episode_id} saved!') - + print(f"Episode {episode_id} saved!") + self.obs_accumulator = None self.action_accumulator = None self.stage_accumulator = None + # self.replica_joint_accumulator = None def drop_episode(self): self.end_episode() @@ -431,5 +458,4 @@ def drop_episode(self): this_video_dir = self.video_dir.joinpath(str(episode_id)) if this_video_dir.exists(): shutil.rmtree(str(this_video_dir)) - print(f'Episode {episode_id} dropped!') - + print(f"Episode {episode_id} dropped!") diff --git a/diffusion_policy/real_world/rtde_interpolation_controller.py b/diffusion_policy/real_world/rtde_interpolation_controller.py index af27b2ed..12574a2e 100644 --- a/diffusion_policy/real_world/rtde_interpolation_controller.py +++ b/diffusion_policy/real_world/rtde_interpolation_controller.py @@ -29,8 +29,8 @@ class RTDEInterpolationController(mp.Process): def __init__(self, shm_manager: SharedMemoryManager, robot_ip, - frequency=125, - lookahead_time=0.1, + frequency=500, + lookahead_time=0.2, gain=300, max_pos_speed=0.25, # 5% of max speed max_rot_speed=0.16, # 5% of max speed diff --git a/diffusion_policy/real_world/single_realsense.py b/diffusion_policy/real_world/single_realsense.py index 7a8443b9..0c80c263 100644 --- a/diffusion_policy/real_world/single_realsense.py +++ b/diffusion_policy/real_world/single_realsense.py @@ -2,11 +2,13 @@ import os import enum import time -import json +# import torch import numpy as np import pyrealsense2 as rs +import pyzed.sl as sl import multiprocessing as mp import cv2 +import threading as th from threadpoolctl import threadpool_limits from multiprocessing.managers import SharedMemoryManager from diffusion_policy.common.timestamp_accumulator import get_accumulate_timestamp_idxs @@ -14,6 +16,7 @@ from diffusion_policy.shared_memory.shared_memory_ring_buffer import SharedMemoryRingBuffer from diffusion_policy.shared_memory.shared_memory_queue import SharedMemoryQueue, Full, Empty from diffusion_policy.real_world.video_recorder import VideoRecorder +from aloha_scripts.constants import * class Command(enum.Enum): SET_COLOR_OPTION = 0 @@ -144,6 +147,7 @@ def __init__( self.video_recorder = video_recorder self.verbose = verbose self.put_start_time = None + # self.image_recorder = ImageRecorder(camera_names=self.serial_number,init_node=True) # shared variables self.stop_event = mp.Event() @@ -164,6 +168,7 @@ def get_connected_devices_serial(): # only works with D400 series serials.append(serial) serials = sorted(serials) + #print('serials :', serials) return serials # ========= context manager =========== @@ -180,6 +185,7 @@ def start(self, wait=True, put_start_time=None): super().start() if wait: self.start_wait() + # print("camera process is alive:", self.is_ready) def stop(self, wait=True): self.stop_event.set() @@ -187,7 +193,9 @@ def stop(self, wait=True): self.end_wait() def start_wait(self): + # print("in camera start_wait") self.ready_event.wait() + # print("camera process after start_wait:", self.is_ready) def end_wait(self): self.join() @@ -280,36 +288,88 @@ def run(self): # limit threads threadpool_limits(1) cv2.setNumThreads(1) - - w, h = self.resolution - fps = self.capture_fps - align = rs.align(rs.stream.color) - # Enable the streams from all the intel realsense devices - rs_config = rs.config() - if self.enable_color: - rs_config.enable_stream(rs.stream.color, - w, h, rs.format.bgr8, fps) - if self.enable_depth: - rs_config.enable_stream(rs.stream.depth, - w, h, rs.format.z16, fps) - if self.enable_infrared: - rs_config.enable_stream(rs.stream.infrared, - w, h, rs.format.y8, fps) - + # print("Inside run") + ## try: - rs_config.enable_device(self.serial_number) - - # start pipeline - pipeline = rs.pipeline() - pipeline_profile = pipeline.start(rs_config) + + """ + Initialize wrist camera + Put: config.enable_device(WRIST_CAM_ID) if using UR rbot + Put: config.enable_device(WRIST_CAM_MASTER_ID) if using replica + """ + if self.serial_number == 'cam_wrist': + # Initialize the RealSense pipeline + config = rs.config() + config.enable_stream(rs.stream.color, 640, 480, rs.format.bgr8, self.capture_fps) + config.enable_device(WRIST_CAM_ID) + pipeline = rs.pipeline() + pipeline.start(config) + sensor = pipeline.get_active_profile().get_device().query_sensors()[0] + sensor.set_option(rs.option.enable_auto_exposure, False) + sensor.set_option(rs.option.exposure, 10000) ## experiment with exposure value + + """ + Initialize Zed camera + """ + if self.serial_number == 'cam_low': + # print('Inside ZED init') + # Create a ZED camera object + zed = sl.Camera() + # Set configuration parameters + init_params = sl.InitParameters() + init_params.camera_resolution = sl.RESOLUTION.HD1080 # Change the resolution as needed + init_params.camera_fps = self.capture_fps # Use the specified FPS + init_params.depth_mode = sl.DEPTH_MODE.NONE # Disable depth calculation + init_params.camera_disable_self_calib = True + # Initialize the camera + err = zed.open(init_params) + # print(err) + if err != sl.ERROR_CODE.SUCCESS: + print(f"Error initializing ZED camera: {err}") + zed.close() + return + # Create a runtime parameters object + runtime_params = sl.RuntimeParameters() + + + """ + Initialize Logitec camera + Put: os.path.realpath("/dev/CAM_HIGH") for UR robot setup + Put: os.path.realpath("/dev/CAM_HIGH_MASTER") for replica robot setup + """ + # Logitech camera setup + if self.serial_number == 'cam_high': + cam_path = os.path.realpath("/dev/CAM_HIGH") + cam_idx = int(cam_path.split("/dev/video")[-1]) + cap = cv2.VideoCapture(cam_idx) + cap.set(cv2.CAP_PROP_EXPOSURE, 0) + cap.set(cv2.CAP_PROP_FPS, self.capture_fps) + cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640) + cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480) + cap.set(cv2.CAP_PROP_AUTOFOCUS, 0) + + if self.serial_number == 'cam_front': + # Initialize the RealSense pipeline + config = rs.config() + config.enable_stream(rs.stream.color, 640, 480, rs.format.bgr8, self.capture_fps) + config.enable_device(FRONT_CAM_ID) + pipeline = rs.pipeline() + pipeline.start(config) + # sensor = pipeline.get_active_profile().get_device().query_sensors()[0] + # sensor.set_option(rs.option.enable_auto_exposure, False) + # sensor.set_option(rs.option.exposure, 10000) # report global time # https://github.com/IntelRealSense/librealsense/pull/3909 + + """ + Can omit this step to incorporate D405 data ## abhi 3/7/24 + d = pipeline_profile.get_device().first_color_sensor() d.set_option(rs.option.global_time_enabled, 1) - + """ # setup advanced mode - if self.advanced_mode_config is not None: + """if self.advanced_mode_config is not None: json_text = json.dumps(self.advanced_mode_config) device = pipeline_profile.get_device() advanced_mode = rs.rs400_advanced_mode(device) @@ -326,6 +386,7 @@ def run(self): depth_sensor = pipeline_profile.get_device().first_depth_sensor() depth_scale = depth_sensor.get_depth_scale() self.intrinsics_array.get()[-1] = depth_scale + """ # one-time setup (intrinsics etc, ignore for now) if self.verbose: @@ -336,25 +397,87 @@ def run(self): put_start_time = self.put_start_time if put_start_time is None: put_start_time = time.time() - iter_idx = 0 t_start = time.time() + + # frame grabbing part + # print("entering data aquisition") while not self.stop_event.is_set(): + black_img = np.zeros((480,640,3), dtype= np.uint8) # wait for frames to come in - frameset = pipeline.wait_for_frames() + if self.serial_number == 'cam_wrist': + # Realsense image acquisition + rs_frames = pipeline.wait_for_frames() + rs_color_frame = rs_frames.get_color_frame() + if rs_color_frame: + frameset = np.asanyarray(rs_color_frame.get_data()) + # frameset = cv2.cvtColor(rs_color_image, cv2.COLOR_BGR2RGB) + + if self.serial_number == 'cam_low': + # ZED image acquisition + if zed.grab(runtime_params) == sl.ERROR_CODE.SUCCESS: + zed_image_frame = sl.Mat() + zed.retrieve_image(zed_image_frame, sl.VIEW.RIGHT) + frameset = cv2.resize(zed_image_frame.get_data()[:, :, :3], (640,480)) + """ + Below is cropping logic when performing domain gap + comment/ uncomment these lines as needed + """ + # black_img[y1:y2, x1:x2] = frameset[y1:y2, x1:x2] + # frameset = black_img.copy() + """ + Cropping logic ends here + """ + # alpha_channel = frameset[:, :, 3] + # print(np.min(alpha_channel), np.max(alpha_channel)) + # frameset = cv2.cvtColor(frameset, cv2.COLOR_BGR2RGB) + + if self.serial_number == 'cam_high': + _, frameset = cap.read() + frameset = cv2.resize(frameset, (640,480)) + """ + Below is cropping logic when performing domain gap + comment/ uncomment these lines as needed + """ + # black_img[y1:y2, x1:x2] = frameset[y1:y2, x1:x2] + # frameset = black_img.copy() + + """ + Cropping logic ends here + """ + # frameset = cv2.cvtColor(frameset, cv2.COLOR_BGR2RGB) + + if self.serial_number == 'cam_front': + # Realsense image acquisition + rs_frames = pipeline.wait_for_frames() + rs_color_frame = rs_frames.get_color_frame() + if rs_color_frame: + frameset = np.asanyarray(rs_color_frame.get_data()) + black_img[30:240, 200:550] = frameset[30:240, 200:550] + frameset = black_img.copy() + + # frameset = pipeline.wait_for_frames() receive_time = time.time() # align frames to color - frameset = align.process(frameset) + # frameset = align.process(frameset) # grab data data = dict() data['camera_receive_timestamp'] = receive_time # realsense report in ms - data['camera_capture_timestamp'] = frameset.get_timestamp() / 1000 + # data['camera_capture_timestamp'] = frameset[f'{self.serial_number}_timestamps'] + data['camera_capture_timestamp'] = time.time() + # print("camera is:", self.serial_number) if self.enable_color: - color_frame = frameset.get_color_frame() - data['color'] = np.asarray(color_frame.get_data()) - t = color_frame.get_timestamp() / 1000 + color_frame = frameset + # cv2.imwrite(f'diffusion_policy/LHW_images/{self.serial_number}.png', color_frame) + # print("camera data from ros:",color_frame.shape, type(color_frame)) + # data['color'] = np.asarray(color_frame.get_data()) + data['color'] = color_frame + # t = color_frame.get_timestamp() / 1000 + # t = frameset[f'{self.serial_number}_timestamps'] + t= time.time() + # print("timestamp from ros:",t, type(t)) data['camera_capture_timestamp'] = t # print('device', time.time() - t) # print(color_frame.get_frame_timestamp_domain()) @@ -364,13 +487,17 @@ def run(self): if self.enable_infrared: data['infrared'] = np.asarray( frameset.get_infrared_frame().get_data()) - + + # apply transform put_data = data + # print(self.transform) if self.transform is not None: put_data = self.transform(dict(data)) - if self.put_downsample: + # by default is false; value from real_env.py # abhi 3/7/24 + if self.put_downsample: + # print("inside if stmt:", self.put_downsample) # put frequency regulation local_idxs, global_idxs, put_idx \ = get_accumulate_timestamp_idxs( @@ -390,13 +517,14 @@ def run(self): # put_data['timestamp'] = put_start_time + step_idx / self.put_fps put_data['timestamp'] = receive_time # print(step_idx, data['timestamp']) - self.ring_buffer.put(put_data, wait=False) + self.ring_buffer.put(put_data, wait=True) ## observations from camera put in ring buffer else: step_idx = int((receive_time - put_start_time) * self.put_fps) put_data['step_idx'] = step_idx put_data['timestamp'] = receive_time - self.ring_buffer.put(put_data, wait=False) - + self.ring_buffer.put(put_data) + # print("************* put data color shapes") + # print("put data dict:", put_data["color"].shape) # signal ready if iter_idx == 0: self.ready_event.set() @@ -407,7 +535,7 @@ def run(self): vis_data = put_data elif self.vis_transform is not None: vis_data = self.vis_transform(dict(data)) - self.vis_ring_buffer.put(vis_data, wait=False) + self.vis_ring_buffer.put(vis_data, wait=True) # record frame rec_data = data @@ -441,7 +569,7 @@ def run(self): for key, value in commands.items(): command[key] = value[i] cmd = command['cmd'] - if cmd == Command.SET_COLOR_OPTION.value: + """if cmd == Command.SET_COLOR_OPTION.value: sensor = pipeline_profile.get_device().first_color_sensor() option = rs.option(command['option_enum']) value = float(command['option_value']) @@ -453,8 +581,8 @@ def run(self): sensor = pipeline_profile.get_device().first_depth_sensor() option = rs.option(command['option_enum']) value = float(command['option_value']) - sensor.set_option(option, value) - elif cmd == Command.START_RECORDING.value: + sensor.set_option(option, value)""" + if cmd == Command.START_RECORDING.value: video_path = str(command['video_path']) start_time = command['recording_start_time'] if start_time < 0: @@ -473,8 +601,14 @@ def run(self): iter_idx += 1 finally: self.video_recorder.stop() - rs_config.disable_all_streams() + if self.serial_number == 'cam_wrist': + pipeline.stop() + if self.serial_number == 'cam_low': + zed.close() + if self.serial_number == 'cam_high': + cap.release() + # rs_config.disable_all_streams() self.ready_event.set() if self.verbose: - print(f'[SingleRealsense {self.serial_number}] Exiting worker process.') + print(f'[SingleRealsense {self.serial_number}] Exiting worker process.') \ No newline at end of file diff --git a/diffusion_policy/real_world/spacemouse_shared_memory.py b/diffusion_policy/real_world/spacemouse_shared_memory.py index 06102fdb..0a6596a4 100644 --- a/diffusion_policy/real_world/spacemouse_shared_memory.py +++ b/diffusion_policy/real_world/spacemouse_shared_memory.py @@ -1,7 +1,6 @@ import multiprocessing as mp import numpy as np import time -from spnav import spnav_open, spnav_poll_event, spnav_close, SpnavMotionEvent, SpnavButtonEvent from diffusion_policy.shared_memory.shared_memory_ring_buffer import SharedMemoryRingBuffer class Spacemouse(mp.Process): @@ -53,7 +52,6 @@ def __init__(self, example = { # 3 translation, 3 rotation, 1 period 'motion_event': np.zeros((7,), dtype=np.int64), - # left and right button 'button_state': np.zeros((n_buttons,), dtype=bool), 'receive_timestamp': time.time() } @@ -126,7 +124,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): # ========= main loop ========== def run(self): - spnav_open() + self._spnav_open() try: motion_event = np.zeros((7,), dtype=np.int64) button_state = np.zeros((self.n_buttons,), dtype=bool) @@ -139,7 +137,7 @@ def run(self): self.ready_event.set() while not self.stop_event.is_set(): - event = spnav_poll_event() + event = self._spnav_poll_event() receive_timestamp = time.time() if isinstance(event, SpnavMotionEvent): motion_event[:3] = event.translation @@ -157,4 +155,28 @@ def run(self): }) time.sleep(1/self.frequency) finally: - spnav_close() + self._spnav_close() + + def _spnav_open(self): + # Placeholder for spnav_open + pass + + def _spnav_poll_event(self): + # Placeholder for spnav_poll_event + # This should return an event-like object for testing purposes + return None + + def _spnav_close(self): + # Placeholder for spnav_close + pass + +class SpnavMotionEvent: + def __init__(self, translation, rotation, period): + self.translation = translation + self.rotation = rotation + self.period = period + +class SpnavButtonEvent: + def __init__(self, bnum, press): + self.bnum = bnum + self.press = press diff --git a/eval_real_robot.py b/eval_real_robot.py index cb3cdd17..e9efb0c8 100644 --- a/eval_real_robot.py +++ b/eval_real_robot.py @@ -33,7 +33,8 @@ import skvideo.io from omegaconf import OmegaConf import scipy.spatial.transform as st -from diffusion_policy.real_world.real_env import RealEnv +# from diffusion_policy.real_world.real_env import RealEnv +from diffusion_policy.real_world.real_env_with_ft import RealEnv from diffusion_policy.real_world.spacemouse_shared_memory import Spacemouse from diffusion_policy.common.precise_sleep import precise_wait from diffusion_policy.real_world.real_inference_util import ( @@ -53,7 +54,7 @@ @click.option('--robot_ip', '-ri', required=True, help="UR5's IP address e.g. 192.168.0.204") @click.option('--match_dataset', '-m', default=None, help='Dataset used to overlay and adjust initial condition') @click.option('--match_episode', '-me', default=None, type=int, help='Match specific episode from the match dataset') -@click.option('--vis_camera_idx', default=0, type=int, help="Which RealSense camera to visualize.") +@click.option('--vis_camera_idx', default=1, type=int, help="Which RealSense camera to visualize.") @click.option('--init_joints', '-j', is_flag=True, default=False, help="Whether to initialize robot joint configuration in the beginning.") @click.option('--steps_per_inference', '-si', default=6, type=int, help="Action horizon for inference.") @click.option('--max_duration', '-md', default=60, help='Max duration for each epoch in seconds.') @@ -78,75 +79,26 @@ def main(input, output, robot_ip, match_dataset, match_episode, episode_first_frame_map[episode_idx] = frames[0] print(f"Loaded initial frame for {len(episode_first_frame_map)} episodes") - # load checkpoint - ckpt_path = input - payload = torch.load(open(ckpt_path, 'rb'), pickle_module=dill) - cfg = payload['cfg'] - cls = hydra.utils.get_class(cfg._target_) - workspace = cls(cfg) - workspace: BaseWorkspace - workspace.load_payload(payload, exclude_keys=None, include_keys=None) - - # hacks for method-specific setup. - action_offset = 0 - delta_action = False - if 'diffusion' in cfg.name: - # diffusion model - policy: BaseImagePolicy - policy = workspace.model - if cfg.training.use_ema: - policy = workspace.ema_model - - device = torch.device('cuda') - policy.eval().to(device) - - # set inference params - policy.num_inference_steps = 16 # DDIM inference iterations - policy.n_action_steps = policy.horizon - policy.n_obs_steps + 1 - - elif 'robomimic' in cfg.name: - # BCRNN model - policy: BaseImagePolicy - policy = workspace.model - - device = torch.device('cuda') - policy.eval().to(device) - - # BCRNN always has action horizon of 1 - steps_per_inference = 1 - action_offset = cfg.n_latency_steps - delta_action = cfg.task.dataset.get('delta_action', False) - - elif 'ibc' in cfg.name: - policy: BaseImagePolicy - policy = workspace.model - policy.pred_n_iter = 5 - policy.pred_n_samples = 4096 - - device = torch.device('cuda') - policy.eval().to(device) - steps_per_inference = 1 - action_offset = 1 - delta_action = cfg.task.dataset.get('delta_action', False) - else: - raise RuntimeError("Unsupported policy type: ", cfg.name) + # setup experiment dt = 1/frequency - obs_res = get_real_obs_resolution(cfg.task.shape_meta) - n_obs_steps = cfg.n_obs_steps - print("n_obs_steps: ", n_obs_steps) - print("steps_per_inference:", steps_per_inference) - print("action_offset:", action_offset) + # obs_res = get_real_obs_resolution(cfg.task.shape_meta) + # n_obs_steps = cfg.n_obs_steps + # print("n_obs_steps: ", n_obs_steps) + # print("steps_per_inference:", steps_per_inference) + # print("action_offset:", action_offset) with SharedMemoryManager() as shm_manager: with Spacemouse(shm_manager=shm_manager) as sm, RealEnv( output_dir=output, robot_ip=robot_ip, frequency=frequency, - n_obs_steps=n_obs_steps, - obs_image_resolution=obs_res, + + ### change n_obs_steps same as training + n_obs_steps= 1, + obs_image_resolution= (640,480), obs_float32=True, init_joints=init_joints, enable_multi_cam_vis=True, @@ -155,14 +107,42 @@ def main(input, output, robot_ip, match_dataset, match_episode, thread_per_video=3, # video recording quality, lower is better (but slower). video_crf=21, - shm_manager=shm_manager) as env: + shm_manager=shm_manager, + camera_serial_numbers= ["cam_high", "cam_low"]) as env: cv2.setNumThreads(1) + # load checkpoint + ckpt_path = input + payload = torch.load(open(ckpt_path, 'rb'), pickle_module=dill) + cfg = payload['cfg'] + cls = hydra.utils.get_class(cfg._target_) + workspace = cls(cfg) + workspace: BaseWorkspace + workspace.load_payload(payload, exclude_keys=None, include_keys=None) + + # hacks for method-specific setup. + action_offset = 0 + delta_action = False + + if 'diffusion' in cfg.name: + # diffusion model + policy: BaseImagePolicy + policy = workspace.model + if cfg.training.use_ema: + policy = workspace.ema_model + + device = torch.device('cuda') + policy.eval().to(device) + + # set inference params + policy.num_inference_steps = 15 # DDIM inference iterations + policy.n_action_steps = policy.horizon - policy.n_obs_steps + 1 + # Should be the same as demo # realsense exposure - env.realsense.set_exposure(exposure=120, gain=0) - # realsense white balance - env.realsense.set_white_balance(white_balance=5900) + # env.realsense.set_exposure(exposure=150, gain=0) + # # realsense white balance + # env.realsense.set_white_balance(white_balance=3700) print("Waiting for realsense") time.sleep(1.0) @@ -177,7 +157,7 @@ def main(input, output, robot_ip, match_dataset, match_episode, lambda x: torch.from_numpy(x).unsqueeze(0).to(device)) result = policy.predict_action(obs_dict) action = result['action'][0].detach().to('cpu').numpy() - assert action.shape[-1] == 2 + assert action.shape[-1] == 6 del result print('Ready!') @@ -256,11 +236,11 @@ def main(input, output, robot_ip, match_dataset, match_episode, target_pose[3:] = (drot * st.Rotation.from_rotvec( target_pose[3:])).as_rotvec() # clip target pose - target_pose[:2] = np.clip(target_pose[:2], [0.25, -0.45], [0.77, 0.40]) + # target_pose[:2] = np.clip(target_pose[:2], [0.25, -0.45], [0.77, 0.40]) ## commented this to make teleop working # execute teleop command env.exec_actions( - actions=[target_pose], + actions=[target_pose], timestamps=[t_command_target-time.monotonic()+time.time()]) precise_wait(t_cycle_end) iter_idx += 1 @@ -268,6 +248,7 @@ def main(input, output, robot_ip, match_dataset, match_episode, # ========== policy control loop ============== try: # start episode + env.ft_sensor.calibrate_sensor() policy.reset() start_delay = 1.0 eval_t_start = time.time() + start_delay @@ -286,7 +267,8 @@ def main(input, output, robot_ip, match_dataset, match_episode, t_cycle_end = t_start + (iter_idx + steps_per_inference) * dt # get obs - print('get_obs') + # print('get_obs') + # env.ft_sensor.calibrate_sensor() obs = env.get_obs() obs_timestamps = obs['timestamp'] print(f'Obs latency {time.time() - obs_timestamps[-1]}') @@ -315,7 +297,7 @@ def main(input, output, robot_ip, match_dataset, match_episode, else: this_target_poses = np.zeros((len(action), len(target_pose)), dtype=np.float64) this_target_poses[:] = target_pose - this_target_poses[:,[0,1]] = action + this_target_poses[:,[0,1,2,3,4,5]] = action # deal with timing # the same step actions are always the target for @@ -337,8 +319,8 @@ def main(input, output, robot_ip, match_dataset, match_episode, action_timestamps = action_timestamps[is_new] # clip actions - this_target_poses[:,:2] = np.clip( - this_target_poses[:,:2], [0.25, -0.45], [0.77, 0.40]) + # this_target_poses[:,:2] = np.clip( + # this_target_poses[:,:2], [0.25, -0.45], [0.77, 0.40]) # execute actions env.exec_actions( diff --git a/visualization/ft.py b/visualization/ft.py new file mode 100644 index 00000000..d4aa46cf --- /dev/null +++ b/visualization/ft.py @@ -0,0 +1,135 @@ +import zarr +import numpy as np +import matplotlib.pyplot as plt +from viz_constants import VIZ_DIR + +dataset_dir = VIZ_DIR +# Load the Zarr dataset +z = zarr.open(f'/home/bmv/diffusion_policy_mod_apr24/data/{dataset_dir}/replay_buffer.zarr', mode='r') + +ep_len = z['meta/episode_ends'][:] +ft = z['data/ft_data'][:] +poses = z['data/replica_eef_pose'][:] +#print("ft data", ft) + +def plot_per_episode_ft(ax, episode_data): + # Plot ft data for a specific episode on a given axis (subplot) + ax.plot(episode_data[:,1]) + ax.grid(True) + +def visualize_all_episodes_ft(ft_data, ep_len, num_cols=5): + num_episodes = len(ep_len) + num_rows = (num_episodes + num_cols - 1) // num_cols # Calculate number of rows based on number of episodes and columns + print(f"Total number of episodes: {num_episodes}") + + fig, axs = plt.subplots(num_rows, num_cols, figsize=(15, 5 * num_rows)) + + for episode_index in range(num_episodes): + row = episode_index // num_cols + col = episode_index % num_cols + ax = axs[row, col] if num_rows > 1 else axs[col] + + start_index = 0 if episode_index == 0 else ep_len[episode_index - 1] + end_index = ep_len[episode_index] + episode_data = ft_data[start_index:end_index] + + plot_per_episode_ft(ax, episode_data) + + # Adjust layout and spacing + fig.tight_layout(pad=0.5) + # plt.show() + +# Function to plot ft data for a particular episode +def plot_episode_ft(ft_data, ep_len, episode_index): + # Plot ft data for a specific episode + start_index = 0 if episode_index == 0 else ep_len[episode_index - 1] + end_index = ep_len[episode_index] + episode_ft_data = ft_data[start_index:end_index] + reshaped_ft_data = episode_ft_data.reshape(-1, 6) + print("reshaped:",len(reshaped_ft_data)) + # print("epusde_ft_data:", episode_ft_data) + # print("epusdie len:", len(episode_ft_data)) + plt.figure(figsize=(10, 6)) + for i in range(reshaped_ft_data.shape[1]): + plt.plot(reshaped_ft_data[:, i], label=f'Component {i+1}') + plt.title(f'Force/Torque Data for Episode {episode_index}') + plt.xlabel('Time Step') + plt.ylabel('Force/Torque') + plt.legend(['FX', 'FY', 'FZ', 'TX', 'TY', 'TZ']) + plt.grid(True) + plt.show() + +def plot_episode_ft_scatter(ft_data, ep_len, episode_index, set_index=0): + # Determine the start and end indices for the specific episode + start_index = 0 if episode_index == 0 else ep_len[episode_index - 1] + end_index = ep_len[episode_index] + + # Extract only the 3rd column (FZ) data for the specific episode and specified set + episode_ft_data_fz = ft_data[start_index:end_index, set_index, 2] # 2 is the index for the 3rd column + x = poses[:, 0][start_index:end_index] + y = poses[:, 1][start_index:end_index] + z = poses[:, 2][start_index:end_index] + + print("z values", z) + + for j in range(len(episode_ft_data_fz)): + print(f"Episode FZ {j}: {episode_ft_data_fz[j]}") + print(f"Episode X {j}: {x[j]}, Y {j}: {y[j]}, Z {j}: {z[j]}") + + + # Create scatter plot for the FZ data + plt.figure(figsize=(10, 6)) + plt.scatter(range(len(episode_ft_data_fz)), episode_ft_data_fz, label='FZ', color='b') + + # Set plot titles and labels + plt.title(f'Scatter Plot of FZ Data for Episode {episode_index}, Set {set_index}') + plt.xlabel('Time Step') + plt.ylabel('FZ') + plt.legend() + plt.grid(True) + plt.show() + +def plot_episode_ft_scatter_frequency(ft_data, ep_len, episode_index, set_index=0): + # Determine the start and end indices for the specific episode + start_index = 0 if episode_index == 0 else ep_len[episode_index - 1] + end_index = ep_len[episode_index] + + # Extract the specific episode data + episode_ft_data = ft_data[start_index:end_index] + reshaped_ft_data = episode_ft_data.reshape(-1, 6) + + # Create scatter plot for the reshaped FZ data + plt.figure(figsize=(10, 6)) + for i in range(reshaped_ft_data.shape[1]): + plt.scatter(range(reshaped_ft_data.shape[0]), reshaped_ft_data[:, i], label=f'Component {i+1}') + plt.title(f'Scatter Plot of Force/Torque Data for Episode {episode_index}') + plt.xlabel('Time Step') + plt.ylabel('Force/Torque Value') + plt.legend(['FX', 'FY', 'FZ', 'TX', 'TY', 'TZ']) + plt.grid(True) + plt.show() + + +# Function to plot all FX data in one plot +# def plot_all_fx_in_one_plot(ft_data, ep_len): +# plt.figure(figsize=(15, 6)) + +# for episode_index in range(len(ep_len)): +# start_index = 0 if episode_index == 0 else ep_len[episode_index - 1] +# end_index = ep_len[episode_index] +# episode_data = ft_data[start_index:end_index] # Only FX values +# plt.plot(episode_data, alpha=0.5) # Plot each episode's FX with some transparency + +# plt.title('Force in X direction (FX) for All Episodes') +# plt.xlabel('Time Step') +# plt.ylabel('Force (N)') +# plt.grid(True) +# plt.show() +print(f"Visualizing Datset: {dataset_dir}") +episode_index = int(input("Enter episode index to visualize || -1 for all data: ")) +if episode_index>=0: + plot_episode_ft_scatter_frequency(ft, ep_len, episode_index=episode_index) +elif episode_index==-1: + plot_all_fx_in_one_plot(ft, ep_len) +else: + raise NotImplementedError \ No newline at end of file diff --git a/visualization/visualize_episode_length.py b/visualization/visualize_episode_length.py new file mode 100644 index 00000000..9e57d1f7 --- /dev/null +++ b/visualization/visualize_episode_length.py @@ -0,0 +1,22 @@ +import zarr +import numpy as np +import matplotlib.pyplot as plt +from viz_constants import VIZ_DIR + + +dataset_dir = VIZ_DIR +# Load the Zarr dataset +z = zarr.open(f'/home/bmv/diffusion_policy_mod_apr24/data/{dataset_dir}/replay_buffer.zarr', mode='r') + +ep_len = z['meta/episode_ends'][:] +ep_len = np.diff(np.insert(ep_len,0,0)) + +plt.figure(figsize=(10, 6)) +plt.plot(ep_len) +plt.title('Episode Lengths') +plt.xlabel('Episode Number') +plt.ylabel('Length') +plt.grid(True) +plt.show() + + diff --git a/visualization/visualize_ft_data.py b/visualization/visualize_ft_data.py new file mode 100644 index 00000000..e94ea221 --- /dev/null +++ b/visualization/visualize_ft_data.py @@ -0,0 +1,64 @@ +import zarr +import numpy as np +import matplotlib.pyplot as plt +from viz_constants import VIZ_DIR + +dataset_dir = VIZ_DIR +# Load the Zarr dataset +z = zarr.open(f'/home/bmv/diffusion_policy_mod_apr24/data/{dataset_dir}/replay_buffer.zarr', mode='r') + +ep_len = z['meta/episode_ends'][:] +ft = z['data/ft_data'][:] +print("ft data shape:", ft.shape) +def plot_per_episode_ft(ax, episode_data): + # Plot ft data for a specific episode on a given axis (subplot) + ax.plot(episode_data) + ax.grid(True) + +def visualize_all_episodes_ft(ft_data, ep_len, num_cols=5): + num_episodes = len(ep_len) + num_rows = (num_episodes + num_cols - 1) // num_cols # Calculate number of rows based on number of episodes and columns + print(f"Total number of episodes: {num_episodes}") + + fig, axs = plt.subplots(num_rows, num_cols, figsize=(15, 5 * num_rows)) + + for episode_index in range(num_episodes): + row = episode_index // num_cols + col = episode_index % num_cols + ax = axs[row, col] if num_rows > 1 else axs[col] + + start_index = 0 if episode_index == 0 else ep_len[episode_index - 1] + end_index = ep_len[episode_index] + episode_data = ft_data[start_index:end_index] + + plot_per_episode_ft(ax, episode_data) + + # Adjust layout and spacing + fig.tight_layout(pad=0.5) + plt.show() + +# Function to plot ft data for a particular episode +def plot_episode_ft(ft_data, ep_len, episode_index): + # Plot ft data for a specific episode + start_index = 0 if episode_index == 0 else ep_len[episode_index - 1] + end_index = ep_len[episode_index] + episode_ft_data = ft_data[start_index:end_index] + with np.printoptions(threshold=np.inf): + print("episode_ft_data:", episode_ft_data) + plt.figure(figsize=(10, 6)) + plt.plot(episode_ft_data) + plt.title(f'Force/Torque Data for Episode {episode_index}') + plt.xlabel('Time Step') + plt.ylabel('Force/Torque') + plt.legend(['FX', 'FY', 'FZ', 'TX', 'TY', 'TZ']) # Assuming order of dimensions + plt.grid(True) + plt.show() + +print(f"Visualizing Datset: {dataset_dir}") +episode_index = int(input("Enter episode index to visualize || -1 for all data: ")) +if episode_index>=0: + plot_episode_ft(ft, ep_len, episode_index=episode_index) +elif episode_index==-1: + visualize_all_episodes_ft(ft, ep_len) +else: + raise NotImplementedError \ No newline at end of file diff --git a/visualization/visualize_robot_calibration.py b/visualization/visualize_robot_calibration.py new file mode 100644 index 00000000..bdc9fdda --- /dev/null +++ b/visualization/visualize_robot_calibration.py @@ -0,0 +1,79 @@ +import zarr +import numpy as np +from visual_kinematics.RobotSerial import * +from visual_kinematics.RobotTrajectory import * +import matplotlib.pyplot as plt +from matplotlib import gridspec +from matplotlib.patches import Circle +from math import pi +import scipy.spatial.transform as st +from viz_constants import VIZ_DIR + +dataset_dir = VIZ_DIR +# Load the Zarr dataset +z = zarr.open(f'/home/bmv/diffusion_policy_mod_apr24/data/{dataset_dir}/replay_buffer.zarr', mode='r') + +ep_len = z['meta/episode_ends'][:] +robot_eef = z['data/robot_eef_pose'][:] +replica_eef = z['data/action'][:] + +def plot_episode_ft(robot_eef, replica_eef, ep_len, episode_index): + + if episode_index == -1: + + # Select ft data for the episode + ep_robot_eef = robot_eef[:] + ep_replica_eef = replica_eef[:] + + else: + # Determine start and end index of the episode + start_index = 0 if episode_index == 0 else ep_len[episode_index - 1] + end_index = ep_len[episode_index] + + # Select ft data for the episode + ep_robot_eef = robot_eef[start_index:end_index] + ep_replica_eef = replica_eef[start_index:end_index] + + # Calculate MSE for each component + mse_error = np.mean((ep_robot_eef[:, 0:3] - ep_replica_eef[:, 0:3])**2, axis=0)*1000 + mae_error = np.mean(np.abs((ep_robot_eef[:, 0:3] - ep_replica_eef[:, 0:3])), axis=0)*1000 + print(f'Mean Squared Error milimeters: {mse_error}') + print(f'Mean Absolute Error milimeters: {mae_error}') + + + mse_error_rad = np.mean(np.abs(ep_robot_eef[:, 3:] - ep_replica_eef[:, 3:]), axis=0) + mse_error_deg= np.rad2deg(mse_error_rad) + print(f'Mean Absolute Error Radians: {mse_error_rad}') + print(f'Mean Absolute Error Degrees: {mse_error_deg}') + + # Extract x, y, z coordinates for plotting + robot_x = ep_robot_eef[:, 0] + robot_y = ep_robot_eef[:, 1] + robot_z = ep_robot_eef[:, 2] + replica_x = ep_replica_eef[:, 0] + replica_y = ep_replica_eef[:, 1] + replica_z = ep_replica_eef[:, 2] + + # Plotting the trajectories + fig = plt.figure(figsize=(10, 8)) + ax = fig.add_subplot(111, projection='3d') + + # Plot robot and replica trajectories + ax.plot(robot_x, robot_y, robot_z, label='Robot EEF', color='b') + ax.plot(replica_x, replica_y, replica_z, label='Replica EEF', color='r') + + # Set labels and legend + ax.set_xlabel('X') + ax.set_ylabel('Y') + ax.set_zlabel('Z') + ax.legend() + + plt.title(f'Trajectories for Episode {episode_index}') + plt.show() + +# Example: Plotting episode 0 +print(f"Visualizing Datset: {dataset_dir}") +episode_index = int(input("Enter episode index to visualize || -1 for all data: ")) +plot_episode_ft(robot_eef, replica_eef, ep_len, episode_index) + + diff --git a/visualization/visualize_robot_trajectory.py b/visualization/visualize_robot_trajectory.py new file mode 100644 index 00000000..dee19dce --- /dev/null +++ b/visualization/visualize_robot_trajectory.py @@ -0,0 +1,61 @@ +import zarr +import numpy as np +from visual_kinematics.RobotSerial import * +from visual_kinematics.RobotTrajectory import * +import matplotlib.pyplot as plt +from matplotlib import gridspec +from matplotlib.patches import Circle +from math import pi +import scipy.spatial.transform as st +from viz_constants import VIZ_DIR + +dataset_dir = VIZ_DIR +# Load the Zarr dataset +z = zarr.open(f'/home/bmv/diffusion_policy_new/data/{dataset_dir}/replay_buffer.zarr', mode='r') + +poses = z['data/robot_eef_pose'][:] +# qpos = z['data/robot_joint'][:] +ep_len = z['meta/episode_ends'][:] +ts = z['data/timestamp'][:] + +# Determine start and end index of the episode +print(f"Visualizing Datset: {dataset_dir}") +episode_index = int(input("Enter episode index to visualize || -1 for all data: ")) + +# visualise all tyrajectories +if episode_index == -1: + # Extract x, y, z coordinates + x = poses[:, 0] + y = poses[:, 1] + z = poses[:, 2] + +# visualise specific trajectory +else: + start_index = 0 if episode_index == 0 else ep_len[episode_index - 1] + end_index = ep_len[episode_index] + + x = poses[:, 0][start_index:end_index] + y = poses[:, 1][start_index:end_index] + z = poses[:, 2][start_index:end_index] + + print("poses in the episode are: ", poses[start_index:end_index]) + # print(z) +# Create 3D scatter plot +fig = plt.figure() +ax = fig.add_subplot(111, projection='3d') +ax.scatter(x, y, z) + +# Set labels and title +ax.set_xlabel('X Label') +ax.set_ylabel('Y Label') +ax.set_zlabel('Z Label') +ax.set_title('Robot Poses') + +ax.set_xlim([-1, 1]) # Set limits for the x-axis +ax.set_ylim([-1, 1]) # Set limits for the y-axis +# ax.set_zlim([-1, 1]) # Set limits for the z-axis + +plt.show() + + + diff --git a/visualization/visualize_robot_tree.py b/visualization/visualize_robot_tree.py new file mode 100644 index 00000000..50d8dcc5 --- /dev/null +++ b/visualization/visualize_robot_tree.py @@ -0,0 +1,89 @@ +import zarr +import numpy as np +from visual_kinematics.RobotSerial import * +from visual_kinematics.RobotTrajectory import * +import matplotlib.pyplot as plt +from matplotlib import gridspec +from matplotlib.patches import Circle +from math import pi +import scipy.spatial.transform as st +from viz_constants import VIZ_DIR + +dataset_dir = VIZ_DIR +# Load the Zarr dataset +z = zarr.open(f'/home/bmv/diffusion_policy_new/data/{dataset_dir}/replay_buffer.zarr', mode='r') + +poses = z['data/action'][:] +qpos = z['data/robot_joint'][:] +ep_len = z['meta/episode_ends'][:] +ts = z['data/timestamp'][:] + +dh_params = np.array( +[ + [0.1625, 0.0, 0.5 * pi, 0.0], + [0.0, -0.425, 0, 0.0], + [0.0, -0.3922, 0, 0.0], + [0.1333, 0.0, 0.5 * pi, 0.0], + [0.0997, 0.0, -0.5 * pi, 0.0], + [0.0996, 0.0, 0.0, 0.0], +]) + +print(f"Visualizing Datset: {dataset_dir}") +# Determine start and end index of the episode +episode_index = int(input("Enter episode index to visualize: ")) +start_index = 0 if episode_index == 0 else ep_len[episode_index - 1] +end_index = ep_len[episode_index] + + +viz_type = 'no' +robot = RobotSerial(dh_params) + +fig = plt.figure(figsize=(64, 64)) # Create a figure +gs = gridspec.GridSpec(2, 3, height_ratios=[1, 1]) + +ax_robot = plt.subplot(gs[0, 0], projection="3d") + + +plt.ion() +print("Episode length is " , (end_index-start_index)," steps") +try: + for i in range(start_index, end_index): + # plt.clf() # Clear the previous figure + ax_robot.clear();a = np.arange(10) + if viz_type == 'action': + print("qpos for the episode is ", qpos[i]) + xyz = qpos[i][:3].reshape((3,1)) + rpy = qpos[i][3:] + rotation_object = st.Rotation.from_rotvec(rpy) + euler_angles = rotation_object.as_euler('zxy') # Adjust 'zyx' as necessary based on your convention + end = Frame.from_euler_3(euler_angles, xyz) # Assuming this function expects Euler angles and position + robot.inverse(end) + robot.ax= ax_robot + robot.draw() + ax_robot.set_title('Robot Kinematics') + else: + robot.forward(qpos[i]) + # print(qpos[i]) + robot.ax= ax_robot + robot.draw() + ax_robot.set_title('Robot Kinematics') + + + # plt.tight_layout() # Ensure proper spacing between subplots + plt.pause(0.0001) + fig.canvas.flush_events() + +except KeyboardInterrupt: + plt.ioff() # Turn off interactive mode + plt.close() + +plt.ioff() # Turn off interactive mode +plt.close() + + + + + + + + diff --git a/visualization/viz_constants.py b/visualization/viz_constants.py new file mode 100644 index 00000000..6262bb91 --- /dev/null +++ b/visualization/viz_constants.py @@ -0,0 +1 @@ +VIZ_DIR = "rice_scoop_teleop_org_repo" \ No newline at end of file From e6680b6b43e795f07ca1f0dcad03755f21c3595d Mon Sep 17 00:00:00 2001 From: mohamedamrali1993 Date: Fri, 4 Oct 2024 16:03:17 -0400 Subject: [PATCH 2/4] force torque , propriception, sensor fusion using multihead self attention --- .../config/task/real_pusht_image.yaml | 25 ++---- diffusion_policy/model/common/back_up.py | 90 +++++++++++++++++++ diffusion_policy/model/common/mha.py | 90 +++++++++++++++++++ .../force_torque/end_effector_encoding.py | 39 ++++++++ .../model/force_torque/ft_transformer.py | 55 ++++++++++++ .../model/force_torque/positional_encoding.py | 35 ++++++++ .../model/vision/multi_image_obs_encoder.py | 65 ++++++++++---- 7 files changed, 364 insertions(+), 35 deletions(-) create mode 100644 diffusion_policy/model/common/back_up.py create mode 100644 diffusion_policy/model/common/mha.py create mode 100644 diffusion_policy/model/force_torque/end_effector_encoding.py create mode 100644 diffusion_policy/model/force_torque/ft_transformer.py create mode 100644 diffusion_policy/model/force_torque/positional_encoding.py diff --git a/diffusion_policy/config/task/real_pusht_image.yaml b/diffusion_policy/config/task/real_pusht_image.yaml index f7254a24..26a0b86f 100644 --- a/diffusion_policy/config/task/real_pusht_image.yaml +++ b/diffusion_policy/config/task/real_pusht_image.yaml @@ -1,28 +1,19 @@ name: real_image -data_name: rice_scoop_teleop_org_repo +data_name: ft_100hz -task_type: all_cameras +task_type: ft_100hz image_shape: [3,480,640] -dataset_path: /home/bmv/diffusion_policy_new/data/rice_scoop_teleop_org_repo +dataset_path: /home/bmv/diffusion_policy_new/data/ft_100hz shape_meta: &shape_meta # acceptable types: rgb, low_dim - obs: - camera_0: - shape: ${task.image_shape} - type: rgb - camera_1: - shape: ${task.image_shape} - type: rgb - camera_2: - shape: ${task.image_shape} - type: rgb - robot_eef_pose: + obs: + replica_eef_pose: shape: [6,] type: low_dim - # ft_data: - # shape: [6,] - # type: low_dim + ft_data: + shape: [10,6,] + type: low_dim action: shape: [6,] diff --git a/diffusion_policy/model/common/back_up.py b/diffusion_policy/model/common/back_up.py new file mode 100644 index 00000000..db70e522 --- /dev/null +++ b/diffusion_policy/model/common/back_up.py @@ -0,0 +1,90 @@ +from torch.nn.modules.activation import MultiheadAttention +from torchvision.models import resnet18 +from torchvision.models.feature_extraction import create_feature_extractor +import torch +from torch import nn + +import cv2 +import numpy as np +import time +import matplotlib.pyplot as plt + +from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin +from diffusion_policy.model.force_torque.end_effector_encoding import EndEffectorEncoder +from diffusion_policy.model.force_torque.ft_transformer import ForceTorqueEncoder + + +class Actor(ModuleAttrMixin): + def __init__(self, encoder_dim =256, num_heads = 8, action_dim =6, use_eef_encoder = True): + super().__init__() + + self.force_torque_encoder = ForceTorqueEncoder(ft_seq_len=10).to(self.device) + if use_eef_encoder: + self.end_effector_encoder = EndEffectorEncoder().to(self.device) + else: + self.end_effector_encoder = nn.Linear(6, encoder_dim) + + + + self.layernorm_embed_shape = encoder_dim + self.encoder_dim = encoder_dim + + self.use_mha = True + + self.modalities = ['force_torque', 'end_effector'] + + self.embed_dim = self.layernorm_embed_shape * len(self.modalities) + + + self.layernorm = nn.LayerNorm(self.layernorm_embed_shape) + self.mha = MultiheadAttention(self.layernorm_embed_shape, num_heads) + self.bottleneck = nn.Linear( + self.embed_dim, self.layernorm_embed_shape + ) # if we dont use mha + + + + self.mlp = torch.nn.Sequential( + torch.nn.Linear(self.layernorm_embed_shape, 1024), + torch.nn.ReLU(), + torch.nn.Linear(1024, 1024), + torch.nn.ReLU(), + torch.nn.Linear(1024, 3**action_dim), + ) + self.aux_mlp = torch.nn.Linear(self.layernorm_embed_shape, 6) + + def forward(self, ft_data,end_effector): + """ + Args: + + ft_data: [batch, dim] + end_effector: [batch, dim] + + """ + + embeds = [] + + ft_data = self.force_torque_encoder(ft_data) + ft_data = ft_data.view(-1, self.layernorm_embed_shape) + embeds.append(ft_data) + + end_effector = self.end_effector_encoder(end_effector) + end_effector = end_effector.view(-1, self.layernorm_embed_shape) + embeds.append(end_effector) + + + + + mlp_inp = torch.stack(embeds, dim=0) # [3, batch, D] + + mha_out, weights = self.mha(mlp_inp, mlp_inp, mlp_inp) # [1, batch, D] + mha_out += mlp_inp + mlp_inp = torch.concat([mha_out[i] for i in range(mha_out.shape[0])], 1) + mlp_inp = self.bottleneck(mlp_inp) + + + + + action_logits = self.mlp(mlp_inp) + xyzrpy = self.aux_mlp(mlp_inp) + return action_logits,xyzrpy , weights diff --git a/diffusion_policy/model/common/mha.py b/diffusion_policy/model/common/mha.py new file mode 100644 index 00000000..db70e522 --- /dev/null +++ b/diffusion_policy/model/common/mha.py @@ -0,0 +1,90 @@ +from torch.nn.modules.activation import MultiheadAttention +from torchvision.models import resnet18 +from torchvision.models.feature_extraction import create_feature_extractor +import torch +from torch import nn + +import cv2 +import numpy as np +import time +import matplotlib.pyplot as plt + +from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin +from diffusion_policy.model.force_torque.end_effector_encoding import EndEffectorEncoder +from diffusion_policy.model.force_torque.ft_transformer import ForceTorqueEncoder + + +class Actor(ModuleAttrMixin): + def __init__(self, encoder_dim =256, num_heads = 8, action_dim =6, use_eef_encoder = True): + super().__init__() + + self.force_torque_encoder = ForceTorqueEncoder(ft_seq_len=10).to(self.device) + if use_eef_encoder: + self.end_effector_encoder = EndEffectorEncoder().to(self.device) + else: + self.end_effector_encoder = nn.Linear(6, encoder_dim) + + + + self.layernorm_embed_shape = encoder_dim + self.encoder_dim = encoder_dim + + self.use_mha = True + + self.modalities = ['force_torque', 'end_effector'] + + self.embed_dim = self.layernorm_embed_shape * len(self.modalities) + + + self.layernorm = nn.LayerNorm(self.layernorm_embed_shape) + self.mha = MultiheadAttention(self.layernorm_embed_shape, num_heads) + self.bottleneck = nn.Linear( + self.embed_dim, self.layernorm_embed_shape + ) # if we dont use mha + + + + self.mlp = torch.nn.Sequential( + torch.nn.Linear(self.layernorm_embed_shape, 1024), + torch.nn.ReLU(), + torch.nn.Linear(1024, 1024), + torch.nn.ReLU(), + torch.nn.Linear(1024, 3**action_dim), + ) + self.aux_mlp = torch.nn.Linear(self.layernorm_embed_shape, 6) + + def forward(self, ft_data,end_effector): + """ + Args: + + ft_data: [batch, dim] + end_effector: [batch, dim] + + """ + + embeds = [] + + ft_data = self.force_torque_encoder(ft_data) + ft_data = ft_data.view(-1, self.layernorm_embed_shape) + embeds.append(ft_data) + + end_effector = self.end_effector_encoder(end_effector) + end_effector = end_effector.view(-1, self.layernorm_embed_shape) + embeds.append(end_effector) + + + + + mlp_inp = torch.stack(embeds, dim=0) # [3, batch, D] + + mha_out, weights = self.mha(mlp_inp, mlp_inp, mlp_inp) # [1, batch, D] + mha_out += mlp_inp + mlp_inp = torch.concat([mha_out[i] for i in range(mha_out.shape[0])], 1) + mlp_inp = self.bottleneck(mlp_inp) + + + + + action_logits = self.mlp(mlp_inp) + xyzrpy = self.aux_mlp(mlp_inp) + return action_logits,xyzrpy , weights diff --git a/diffusion_policy/model/force_torque/end_effector_encoding.py b/diffusion_policy/model/force_torque/end_effector_encoding.py new file mode 100644 index 00000000..b2bc2c51 --- /dev/null +++ b/diffusion_policy/model/force_torque/end_effector_encoding.py @@ -0,0 +1,39 @@ +import torch +import torch.nn as nn + +from diffusion_policy.model.force_torque.ft_transformer import PositionalEncoding + +class EndEffectorEncoder(nn.Module): + + def __init__(self, d_model=256, nhead=8, num_encoder_layers=3): + super(EndEffectorEncoder, self).__init__() + + self.embedding_ee = nn.Linear(6, d_model) + + self.positional_encoding_ee = PositionalEncoding(d_model, max_len=1) + + self.transformer_encoder_ee = nn.TransformerEncoder( + nn.TransformerEncoderLayer(d_model, nhead, batch_first=True), + num_layers=num_encoder_layers + ) + + self.layer_norm = nn.LayerNorm(d_model) + + self.fc = nn.Linear(d_model, d_model) + + def forward(self, ee_data): + + ee_data = ee_data.unsqueeze(1) + ee_data = self.embedding_ee(ee_data) + + ee_data = self.positional_encoding_ee(ee_data) + + ee_data = self.transformer_encoder_ee(ee_data) + ee_data = self.layer_norm(ee_data) + + output = self.fc(ee_data) + + # Flatten to get [batch_size, d_model] shape + output = output.view(output.size(0), -1) + + return output diff --git a/diffusion_policy/model/force_torque/ft_transformer.py b/diffusion_policy/model/force_torque/ft_transformer.py new file mode 100644 index 00000000..35a18289 --- /dev/null +++ b/diffusion_policy/model/force_torque/ft_transformer.py @@ -0,0 +1,55 @@ +import torch +import torch.nn as nn + +class PositionalEncoding(nn.Module): + def __init__(self, model_dim, max_len=5000): + super(PositionalEncoding, self).__init__() + pe = torch.zeros(max_len, model_dim) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, model_dim, 2).float() * (-torch.log(torch.tensor(10000.0)) / model_dim)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) # (1, max_len, model_dim) + self.register_buffer('pe', pe) + + def forward(self, x): + x = x + self.pe[:, :x.size(1), :] + return x + + +class ForceTorqueEncoder(nn.Module): + + def __init__(self, ft_seq_len, d_model=256, nhead=8, num_encoder_layers=3): + super(ForceTorqueEncoder, self).__init__() + + self.ft_seq_len = ft_seq_len + + # input embedding layer + self.embedding_ft = nn.Linear(6, d_model) # batch_size, seq_len, 6 -> batch_size, seq_len, d_model + + self.positional_encoding_ft = PositionalEncoding(d_model, max_len=ft_seq_len) # batch_size, seq_len, d_model -> batch_size, seq_len, d_model + + self.transformer_encoder_ft = nn.TransformerEncoder( + nn.TransformerEncoderLayer(d_model, nhead, batch_first=True ), + num_layers=num_encoder_layers + ) # batch_size, seq_len, d_model -> batch_size, seq_len, d_model + + self.layer_norm = nn.LayerNorm(d_model) + + self.fc = nn.Linear(d_model * ft_seq_len, d_model) + + + def forward(self, ft_data): + + ft_data = self.embedding_ft(ft_data) + ft_data = self.positional_encoding_ft(ft_data) + ft_data = self.transformer_encoder_ft(ft_data) + + ft_data = self.layer_norm(ft_data) + + output = ft_data.view(ft_data.size(0), -1) + output = self.fc(output) + + + + return output \ No newline at end of file diff --git a/diffusion_policy/model/force_torque/positional_encoding.py b/diffusion_policy/model/force_torque/positional_encoding.py new file mode 100644 index 00000000..5c1c71b9 --- /dev/null +++ b/diffusion_policy/model/force_torque/positional_encoding.py @@ -0,0 +1,35 @@ +import math + +import torch +import torch.nn as nn + + +class PositionalEncoding(nn.Module): + + def __init__(self, d_model, max_len=5000): + super(PositionalEncoding, self).__init__() + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + self.register_buffer('pe', pe) + + def forward(self, x): + x = torch.add(x, self.pe[:x.size(-2), :]) + return x + + +# class LearnablePositionalEncoding(nn.Module): + +# def __init__(self, dict_size=128, num_pos_feats=16): +# super().__init__() +# self.embed = nn.Embedding(dict_size, num_pos_feats) + +# def forward(self, x): +# w = x.shape[-2] +# i = torch.arange(w, device=x.device) +# emb = self.embed(i) +# x = torch.add(x, emb) +# return x \ No newline at end of file diff --git a/diffusion_policy/model/vision/multi_image_obs_encoder.py b/diffusion_policy/model/vision/multi_image_obs_encoder.py index de6aa658..93724f87 100644 --- a/diffusion_policy/model/vision/multi_image_obs_encoder.py +++ b/diffusion_policy/model/vision/multi_image_obs_encoder.py @@ -6,8 +6,10 @@ from diffusion_policy.model.vision.crop_randomizer import CropRandomizer from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin from diffusion_policy.common.pytorch_util import dict_apply, replace_submodules - - +from diffusion_policy.model.force_torque.ft_transformer import ForceTorqueEncoder +from diffusion_policy.model.force_torque.end_effector_encoding import EndEffectorEncoder +from diffusion_policy.model.common.mha import Actor + class MultiImageObsEncoder(ModuleAttrMixin): def __init__(self, shape_meta: dict, @@ -28,18 +30,18 @@ def __init__(self, Assumes low_dim input: B,D """ super().__init__() - + rgb_keys = list() low_dim_keys = list() key_model_map = nn.ModuleDict() key_transform_map = nn.ModuleDict() key_shape_map = dict() - + # handle sharing vision backbone if share_rgb_model: assert isinstance(rgb_model, nn.Module) key_model_map['rgb'] = rgb_model - + obs_shape_meta = shape_meta['obs'] for key, attr in obs_shape_meta.items(): shape = tuple(attr['shape']) @@ -64,7 +66,7 @@ def __init__(self, root_module=this_model, predicate=lambda x: isinstance(x, nn.BatchNorm2d), func=lambda x: nn.GroupNorm( - num_groups=x.num_features//16, + num_groups=x.num_features//16, num_channels=x.num_features) ) key_model_map[key] = this_model @@ -81,7 +83,7 @@ def __init__(self, size=(h,w) ) input_shape = (shape[0],h,w) - + # configure randomizer this_randomizer = nn.Identity() if crop_shape is not None: @@ -115,7 +117,7 @@ def __init__(self, raise RuntimeError(f"Unsupported obs type: {type}") rgb_keys = sorted(rgb_keys) low_dim_keys = sorted(low_dim_keys) - + self.shape_meta = shape_meta self.key_model_map = key_model_map self.key_transform_map = key_transform_map @@ -123,7 +125,12 @@ def __init__(self, self.rgb_keys = rgb_keys self.low_dim_keys = low_dim_keys self.key_shape_map = key_shape_map - + + device = self.device + + self.mha = Actor(use_eef_encoder=False).to(device) + + def forward(self, obs_dict): batch_size = None features = list() @@ -162,17 +169,38 @@ def forward(self, obs_dict): assert img.shape[1:] == self.key_shape_map[key] img = self.key_transform_map[key](img) feature = self.key_model_map[key](img) + # import pdb; pdb.set_trace() + # print("the feature shape is: ", feature.shape) features.append(feature) # process lowdim input - for key in self.low_dim_keys: - data = obs_dict[key] - if batch_size is None: - batch_size = data.shape[0] - else: - assert batch_size == data.shape[0] - assert data.shape[1:] == self.key_shape_map[key] - features.append(data) + # for key in self.low_dim_keys: + # print("working 1") + # data = obs_dict[key] + # if batch_size is None: + # batch_size = data.shape[0] + # else: + # assert batch_size == data.shape[0] + # # assert data.shape[1:] == self.key_shape_map[key] + + # if key == "ft_data": + + # data = self.force_torque_encoder(data) + # else: + # # data = data + # data = self.end_effector_encoder(data) + + # features.append(data) + + + if len(self.low_dim_keys) > 0: + ft_data = obs_dict['ft_data'] + end_effector = obs_dict['replica_eef_pose'] + action_logits,xyzrpy, weights = self.mha(ft_data, end_effector) + # print("the action logits are: ", action_logits.shape) + # print("the weights are: ", weights.shape) + + features.append(action_logits) # concatenate all features result = torch.cat(features, dim=-1) @@ -186,10 +214,11 @@ def output_shape(self): for key, attr in obs_shape_meta.items(): shape = tuple(attr['shape']) this_obs = torch.zeros( - (batch_size,) + shape, + (batch_size,) + shape, dtype=self.dtype, device=self.device) example_obs_dict[key] = this_obs example_output = self.forward(example_obs_dict) output_shape = example_output.shape[1:] return output_shape + \ No newline at end of file From 2f76384b9ddf24d7469a0461b2377ea03a0ea788 Mon Sep 17 00:00:00 2001 From: mohamedamrali1993 Date: Mon, 7 Oct 2024 14:18:12 -0400 Subject: [PATCH 3/4] model 1 --- .../config/task/real_pusht_image.yaml | 9 +++-- ...n_diffusion_unet_real_image_workspace.yaml | 6 +-- diffusion_policy/model/common/mha.py | 22 +++++++++-- .../model/vision/multi_image_obs_encoder.py | 39 ++++++++----------- 4 files changed, 44 insertions(+), 32 deletions(-) diff --git a/diffusion_policy/config/task/real_pusht_image.yaml b/diffusion_policy/config/task/real_pusht_image.yaml index 26a0b86f..8c405d0d 100644 --- a/diffusion_policy/config/task/real_pusht_image.yaml +++ b/diffusion_policy/config/task/real_pusht_image.yaml @@ -1,13 +1,16 @@ name: real_image -data_name: ft_100hz +data_name: ft_vision -task_type: ft_100hz +task_type: ft_vision image_shape: [3,480,640] -dataset_path: /home/bmv/diffusion_policy_new/data/ft_100hz +dataset_path: /home/bmv/diffusion_policy_new/data/ft_vision shape_meta: &shape_meta # acceptable types: rgb, low_dim obs: + camera_0: + shape: ${task.image_shape} + type: rgb replica_eef_pose: shape: [6,] type: low_dim diff --git a/diffusion_policy/config/train_diffusion_unet_real_image_workspace.yaml b/diffusion_policy/config/train_diffusion_unet_real_image_workspace.yaml index 99f157a0..e28ff7a4 100644 --- a/diffusion_policy/config/train_diffusion_unet_real_image_workspace.yaml +++ b/diffusion_policy/config/train_diffusion_unet_real_image_workspace.yaml @@ -50,7 +50,7 @@ policy: random_crop: True use_group_norm: True share_rgb_model: False - imagenet_norm: True + imagenet_norm: False horizon: ${horizon} n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'} @@ -104,7 +104,7 @@ training: # optimization lr_scheduler: cosine lr_warmup_steps: 500 - num_epochs: 300 + num_epochs: 2000 gradient_accumulate_every: 1 # EMA destroys performance when used with BatchNorm # replace BatchNorm with GroupNorm. @@ -135,7 +135,7 @@ checkpoint: topk: monitor_key: train_loss mode: min - k: 5 + k: 20 format_str: 'epoch={epoch:04d}-train_loss={train_loss:.3f}.ckpt' save_last_ckpt: True save_last_snapshot: False diff --git a/diffusion_policy/model/common/mha.py b/diffusion_policy/model/common/mha.py index db70e522..6effca1d 100644 --- a/diffusion_policy/model/common/mha.py +++ b/diffusion_policy/model/common/mha.py @@ -18,9 +18,9 @@ class Actor(ModuleAttrMixin): def __init__(self, encoder_dim =256, num_heads = 8, action_dim =6, use_eef_encoder = True): super().__init__() - self.force_torque_encoder = ForceTorqueEncoder(ft_seq_len=10).to(self.device) + self.force_torque_encoder = ForceTorqueEncoder(ft_seq_len=10) if use_eef_encoder: - self.end_effector_encoder = EndEffectorEncoder().to(self.device) + self.end_effector_encoder = EndEffectorEncoder() else: self.end_effector_encoder = nn.Linear(6, encoder_dim) @@ -31,7 +31,7 @@ def __init__(self, encoder_dim =256, num_heads = 8, action_dim =6, use_eef_encod self.use_mha = True - self.modalities = ['force_torque', 'end_effector'] + self.modalities = ['force_torque', 'end_effector',"cf0" ] self.embed_dim = self.layernorm_embed_shape * len(self.modalities) @@ -53,7 +53,7 @@ def __init__(self, encoder_dim =256, num_heads = 8, action_dim =6, use_eef_encod ) self.aux_mlp = torch.nn.Linear(self.layernorm_embed_shape, 6) - def forward(self, ft_data,end_effector): + def forward(self, ft_data,end_effector, cf0= None,cf1=None,cf2=None,cf3=None ): """ Args: @@ -72,6 +72,20 @@ def forward(self, ft_data,end_effector): end_effector = end_effector.view(-1, self.layernorm_embed_shape) embeds.append(end_effector) + if cf0 is not None: + cf0 = cf0.view(-1, self.layernorm_embed_shape) + embeds.append(cf0) + if cf1 is not None: + cf1 = cf1.view(-1, self.layernorm_embed_shape) + embeds.append(cf1) + if cf2 is not None: + cf2 = cf2.view(-1, self.layernorm_embed_shape) + embeds.append(cf2) + if cf3 is not None: + cf3 = cf3.view(-1, self.layernorm_embed_shape) + embeds.append(cf3) + + diff --git a/diffusion_policy/model/vision/multi_image_obs_encoder.py b/diffusion_policy/model/vision/multi_image_obs_encoder.py index 93724f87..ecb0dbd6 100644 --- a/diffusion_policy/model/vision/multi_image_obs_encoder.py +++ b/diffusion_policy/model/vision/multi_image_obs_encoder.py @@ -134,6 +134,7 @@ def __init__(self, def forward(self, obs_dict): batch_size = None features = list() + cf0,cf1,cf2,cf3 = None, None, None, None # process rgb input if self.share_rgb_model: # pass all rgb obs to rgb model @@ -169,34 +170,28 @@ def forward(self, obs_dict): assert img.shape[1:] == self.key_shape_map[key] img = self.key_transform_map[key](img) feature = self.key_model_map[key](img) - # import pdb; pdb.set_trace() - # print("the feature shape is: ", feature.shape) - features.append(feature) - - # process lowdim input - # for key in self.low_dim_keys: - # print("working 1") - # data = obs_dict[key] - # if batch_size is None: - # batch_size = data.shape[0] - # else: - # assert batch_size == data.shape[0] - # # assert data.shape[1:] == self.key_shape_map[key] - - # if key == "ft_data": - - # data = self.force_torque_encoder(data) - # else: - # # data = data - # data = self.end_effector_encoder(data) + feature = nn.Linear(512,256).to(self.device) (feature) + + if key =="camera_0": + cf0 = feature.to(self.device) + if key =="camera_1": + cf1 = feature.to(self.device) + if key =="camera_2": + cf2 = feature.to(self.device) + if key =="camera_3": + cf3 = feature.to(self.device) - # features.append(data) + # features.append(feature) + + if len(self.low_dim_keys) > 0: ft_data = obs_dict['ft_data'] end_effector = obs_dict['replica_eef_pose'] - action_logits,xyzrpy, weights = self.mha(ft_data, end_effector) + + + action_logits,xyzrpy, weights = self.mha(ft_data, end_effector,cf0,cf1,cf2,cf3 ) # print("the action logits are: ", action_logits.shape) # print("the weights are: ", weights.shape) From 38041acaeaf69c9a0d35a45f2cc5481c67891629 Mon Sep 17 00:00:00 2001 From: mohamedamrali1993 Date: Tue, 8 Oct 2024 11:57:58 -0400 Subject: [PATCH 4/4] no propriception encoding --- diffusion_policy/model/common/mha.py | 10 +++++----- .../model/vision/multi_image_obs_encoder.py | 13 +++++++++++-- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/diffusion_policy/model/common/mha.py b/diffusion_policy/model/common/mha.py index 6effca1d..d0f3e306 100644 --- a/diffusion_policy/model/common/mha.py +++ b/diffusion_policy/model/common/mha.py @@ -31,7 +31,7 @@ def __init__(self, encoder_dim =256, num_heads = 8, action_dim =6, use_eef_encod self.use_mha = True - self.modalities = ['force_torque', 'end_effector',"cf0" ] + self.modalities = ['force_torque',"cf0" ] self.embed_dim = self.layernorm_embed_shape * len(self.modalities) @@ -53,7 +53,7 @@ def __init__(self, encoder_dim =256, num_heads = 8, action_dim =6, use_eef_encod ) self.aux_mlp = torch.nn.Linear(self.layernorm_embed_shape, 6) - def forward(self, ft_data,end_effector, cf0= None,cf1=None,cf2=None,cf3=None ): + def forward(self, ft_data, cf0= None,cf1=None,cf2=None,cf3=None ): """ Args: @@ -68,9 +68,9 @@ def forward(self, ft_data,end_effector, cf0= None,cf1=None,cf2=None,cf3=None ): ft_data = ft_data.view(-1, self.layernorm_embed_shape) embeds.append(ft_data) - end_effector = self.end_effector_encoder(end_effector) - end_effector = end_effector.view(-1, self.layernorm_embed_shape) - embeds.append(end_effector) + # end_effector = self.end_effector_encoder(end_effector) + # end_effector = end_effector.view(-1, self.layernorm_embed_shape) + # embeds.append(end_effector) if cf0 is not None: cf0 = cf0.view(-1, self.layernorm_embed_shape) diff --git a/diffusion_policy/model/vision/multi_image_obs_encoder.py b/diffusion_policy/model/vision/multi_image_obs_encoder.py index ecb0dbd6..9a08dbeb 100644 --- a/diffusion_policy/model/vision/multi_image_obs_encoder.py +++ b/diffusion_policy/model/vision/multi_image_obs_encoder.py @@ -188,14 +188,23 @@ def forward(self, obs_dict): if len(self.low_dim_keys) > 0: ft_data = obs_dict['ft_data'] - end_effector = obs_dict['replica_eef_pose'] + # end_effector = obs_dict['replica_eef_pose'] - action_logits,xyzrpy, weights = self.mha(ft_data, end_effector,cf0,cf1,cf2,cf3 ) + action_logits,xyzrpy, weights = self.mha(ft_data,cf0,cf1,cf2,cf3 ) # print("the action logits are: ", action_logits.shape) # print("the weights are: ", weights.shape) features.append(action_logits) + + # if obs_dict.get('replica_eef_pose') is not None: + # data = obs_dict["replica_eef_pose"] + # if batch_size is None: + # batch_size = data.shape[0] + # else: + # assert batch_size == data.shape[0] + # assert data.shape[1:] == self.key_shape_map["replica_eef_pose"] + # features.append(data) # concatenate all features result = torch.cat(features, dim=-1)