Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
aliberts committed Nov 27, 2024
1 parent 31429e8 commit 6366c7f
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 40 deletions.
48 changes: 35 additions & 13 deletions lerobot/common/robot_devices/cameras/reachy2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
Wrapper for Reachy2 camera from sdk
"""

from dataclasses import dataclass
from dataclasses import dataclass, replace

import cv2
import numpy as np
from reachy2_sdk.media.camera import CameraView
from reachy2_sdk.media.camera_manager import CameraManager
Expand All @@ -18,6 +19,14 @@ class ReachyCameraConfig:
rotation: int | None = None
mock: bool = False

def __post_init__(self):
if self.color_mode not in ["rgb", "bgr"]:
raise ValueError(
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
)

self.channels = 3


class ReachyCamera:
def __init__(
Expand All @@ -29,8 +38,18 @@ def __init__(
config: ReachyCameraConfig | None = None,
**kwargs,
):
if config is None:
config = ReachyCameraConfig()

# Overwrite config arguments using kwargs
config = replace(config, **kwargs)

self.host = host
self.port = port
self.width = config.width
self.height = config.height
self.channels = config.channels
self.fps = config.fps
self.image_type = image_type
self.name = name
self.config = config
Expand All @@ -48,21 +67,24 @@ def read(self) -> np.ndarray:
if not self.is_connected:
self.connect()

frame = None

if self.name == "teleop" and hasattr(self.cam_manager, "teleop"):
if self.image_type == "left":
return self.cam_manager.teleop.get_frame(CameraView.LEFT)
# return self.cam_manager.teleop.get_compressed_frame(CameraView.LEFT)
frame = self.cam_manager.teleop.get_frame(CameraView.LEFT)
elif self.image_type == "right":
return self.cam_manager.teleop.get_frame(CameraView.RIGHT)
# return self.cam_manager.teleop.get_compressed_frame(CameraView.RIGHT)
else:
return None
frame = self.cam_manager.teleop.get_frame(CameraView.RIGHT)
elif self.name == "depth" and hasattr(self.cam_manager, "depth"):
if self.image_type == "depth":
return self.cam_manager.depth.get_depth_frame()
frame = self.cam_manager.depth.get_depth_frame()
elif self.image_type == "rgb":
return self.cam_manager.depth.get_frame()
# return self.cam_manager.depth.get_compressed_frame()
else:
return None
return None
frame = self.cam_manager.depth.get_frame()

if frame is None:
return None

if frame is not None and self.config.color_mode == "rgb":
img, timestamp = frame
frame = (cv2.cvtColor(img, cv2.COLOR_BGR2RGB), timestamp)

return frame
2 changes: 1 addition & 1 deletion lerobot/common/robot_devices/control_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def log_dt(shortname, dt_val_s):
log_dt("dt", dt_s)

# TODO(aliberts): move robot-specific logs logic in robot.print_logs()
if not robot.robot_type.startswith(("stretch", "Reachy")):
if not robot.robot_type.lower().startswith(("stretch", "reachy")):
for name in robot.leader_arms:
key = f"read_leader_{name}_pos_dt_s"
if key in robot.logs:
Expand Down
46 changes: 31 additions & 15 deletions lerobot/common/robot_devices/robots/reachy2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
from copy import copy
from dataclasses import dataclass, field, replace

import numpy as np
import torch
from reachy2_sdk import ReachySDK

from lerobot.common.robot_devices.cameras.utils import Camera
from lerobot.common.robot_devices.cameras.reachy2 import ReachyCamera

REACHY_MOTORS = [
"neck_yaw.pos",
Expand Down Expand Up @@ -52,8 +53,9 @@
@dataclass
class ReachyRobotConfig:
robot_type: str | None = "reachy2"
cameras: dict[str, Camera] = field(default_factory=lambda: {})
cameras: dict[str, ReachyCamera] = field(default_factory=lambda: {})
ip_address: str | None = "172.17.135.207"
# ip_address: str | None = "192.168.0.197"
# ip_address: str | None = "localhost"


Expand All @@ -74,10 +76,8 @@ def __init__(self, config: ReachyRobotConfig | None = None, **kwargs):
self.is_connected = False
self.teleop = None
self.logs = {}
self.reachy: ReachySDK = ReachySDK(host=config.ip_address)
self.reachy.turn_on()
self.is_connected = True # at init Reachy2 is in fact connected...
self.mobile_base_available = self.reachy.mobile_base is not None
self.reachy = None
self.mobile_base_available = False

self.state_keys = None
self.action_keys = None
Expand All @@ -96,16 +96,19 @@ def camera_features(self) -> dict:

@property
def motor_features(self) -> dict:
motors = REACHY_MOTORS
# if self.mobile_base_available:
# motors += REACHY_MOBILE_BASE
return {
"action": {
"dtype": "float32",
"shape": (len(REACHY_MOTORS),),
"names": REACHY_MOTORS,
"shape": (len(motors),),
"names": motors,
},
"observation.state": {
"dtype": "float32",
"shape": (len(REACHY_MOTORS),),
"names": REACHY_MOTORS,
"shape": (len(motors),),
"names": motors,
},
}

Expand All @@ -114,14 +117,16 @@ def features(self):
return {**self.motor_features, **self.camera_features}

def connect(self) -> None:
self.reachy = ReachySDK(host=self.config.ip_address)
print("Connecting to Reachy")
self.reachy.is_connected = self.reachy.connect()
self.reachy.connect()
self.is_connected = self.reachy.is_connected
if not self.is_connected:
print(
f"Cannot connect to Reachy at address {self.config.ip_address}. Maybe a connection already exists."
)
raise ConnectionError()
self.reachy.turn_on()
# self.reachy.turn_on()
print(self.cameras)
if self.cameras is not None:
for name in self.cameras:
Expand All @@ -133,6 +138,8 @@ def connect(self) -> None:
print("Could not connect to the cameras, check that all cameras are plugged-in.")
raise ConnectionError()

self.mobile_base_available = self.reachy.mobile_base is not None

def run_calibration(self):
pass

Expand Down Expand Up @@ -169,8 +176,14 @@ def teleop_step(
action["mobile_base_x.vel"] = last_cmd_vel["x"]
action["mobile_base_y.vel"] = last_cmd_vel["y"]
action["mobile_base_theta.vel"] = last_cmd_vel["theta"]
else:
action["mobile_base_x.vel"] = 0
action["mobile_base_y.vel"] = 0
action["mobile_base_theta.vel"] = 0

action = torch.as_tensor(list(action.values()))
dtype = self.motor_features["action"]["dtype"]
action = np.array(list(action.values()), dtype=dtype)
# action = torch.as_tensor(list(action.values()))

obs_dict = self.capture_observation()
action_dict = {}
Expand Down Expand Up @@ -224,7 +237,9 @@ def capture_observation(self) -> dict:
if self.state_keys is None:
self.state_keys = list(state)

state = torch.as_tensor(list(state.values()))
dtype = self.motor_features["observation.state"]["dtype"]
state = np.array(list(state.values()), dtype=dtype)
# state = torch.as_tensor(list(state.values()))

# Capture images from cameras
images = {}
Expand All @@ -233,6 +248,7 @@ def capture_observation(self) -> dict:
images[name] = self.cameras[name].read() # Reachy cameras read() is not blocking?
# print(f'name: {name} img: {images[name]}')
if images[name] is not None:
# images[name] = copy(images[name][0]) # seems like I need to copy?
images[name] = torch.from_numpy(copy(images[name][0])) # seems like I need to copy?
self.logs[f"read_camera_{name}_dt_s"] = images[name][1] # full timestamp, TODO dt

Expand Down Expand Up @@ -295,7 +311,7 @@ def disconnect(self) -> None:
print("Disconnecting")
self.is_connected = False
print("Turn off")
self.reachy.turn_off_smoothly()
# self.reachy.turn_off_smoothly()
# self.reachy.turn_off()
print("\t turn off done")
self.reachy.disconnect()
15 changes: 14 additions & 1 deletion lerobot/configs/robot/reachy2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,39 @@ cameras:
head_left:
_target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera
name: teleop
host: 172.17.135.207
host: 172.17.134.85
# host: 192.168.0.197
# host: localhost
port: 50065
fps: 30
width: 960
height: 720
image_type: left
# head_right:
# _target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera
# name: teleop
# host: 172.17.135.207
# port: 50065
# image_type: right
# fps: 30
# width: 960
# height: 720
# torso_rgb:
# _target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera
# name: depth
# host: 172.17.135.207
# # host: localhost
# port: 50065
# image_type: rgb
# fps: 30
# width: 1280
# height: 720
# torso_depth:
# _target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera
# name: depth
# host: 172.17.135.207
# port: 50065
# image_type: depth
# fps: 30
# width: 1280
# height: 720
21 changes: 11 additions & 10 deletions lerobot/scripts/control_robot.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def teleoperate(
@safe_disconnect
def record(
robot: Robot,
root: str,
root: Path,
repo_id: str,
single_task: str,
pretrained_policy_name_or_path: str | None = None,
Expand All @@ -204,6 +204,7 @@ def record(
video: bool = True,
run_compute_stats: bool = True,
push_to_hub: bool = True,
tags: list[str] | None = None,
num_image_writer_processes: int = 0,
num_image_writer_threads_per_camera: int = 4,
display_cameras: bool = True,
Expand Down Expand Up @@ -331,7 +332,7 @@ def record(
dataset.consolidate(run_compute_stats)

if push_to_hub:
dataset.push_to_hub()
dataset.push_to_hub(tags=tags)

log_say("Exiting", play_sounds)
return dataset
Expand Down Expand Up @@ -427,7 +428,7 @@ def replay(
parser_record.add_argument(
"--root",
type=Path,
default="data",
default=None,
help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').",
)
parser_record.add_argument(
Expand All @@ -436,6 +437,12 @@ def replay(
default="lerobot/test",
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
)
parser_record.add_argument(
"--resume",
type=int,
default=0,
help="Resume recording on an existing dataset.",
)
parser_record.add_argument(
"--warmup-time-s",
type=int,
Expand Down Expand Up @@ -494,12 +501,6 @@ def replay(
"Not enough threads might cause low camera fps."
),
)
parser_record.add_argument(
"--force-override",
type=int,
default=0,
help="By default, data recording is resumed. When set to 1, delete the local directory and start data recording from scratch.",
)
parser_record.add_argument(
"-p",
"--pretrained-policy-name-or-path",
Expand All @@ -523,7 +524,7 @@ def replay(
parser_replay.add_argument(
"--root",
type=Path,
default="data",
default=None,
help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').",
)
parser_replay.add_argument(
Expand Down

0 comments on commit 6366c7f

Please sign in to comment.