diff --git a/data/.lfs/xarm7.tar.gz b/data/.lfs/xarm7.tar.gz index 8e2cfa368a..897f052bb8 100644 --- a/data/.lfs/xarm7.tar.gz +++ b/data/.lfs/xarm7.tar.gz @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:47dd79f13845ae6a35368345b7443a9190c7584d548caddd9c3eae224442c6fc -size 3280557 +oid sha256:c97e2283c0a726afd48e91172f84605765b8af8ace7ac107b810a8d11869bc99 +size 1606344 diff --git a/dimos/hardware/manipulators/sim/adapter.py b/dimos/hardware/manipulators/sim/adapter.py index 3979ce98c5..581dfb43a9 100644 --- a/dimos/hardware/manipulators/sim/adapter.py +++ b/dimos/hardware/manipulators/sim/adapter.py @@ -12,59 +12,232 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""MuJoCo simulation adapter for ControlCoordinator integration. - -Thin wrapper around SimManipInterface that plugs into the adapter registry. -Arm joint methods are inherited from SimManipInterface. +"""Shared-memory adapter for MuJoCo-based manipulator simulation. +this adapter reads from and writes to the same SHM buffers. """ from __future__ import annotations -from pathlib import Path +import math +import time from typing import TYPE_CHECKING, Any -from dimos.simulation.engines.mujoco_engine import MujocoEngine -from dimos.simulation.manipulators.sim_manip_interface import SimManipInterface +from dimos.hardware.manipulators.spec import ( + ControlMode, + JointLimits, + ManipulatorInfo, +) +from dimos.simulation.engines.mujoco_shm import ( + ManipShmReader, + shm_key_from_path, +) +from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: from dimos.hardware.manipulators.registry import AdapterRegistry -class SimMujocoAdapter(SimManipInterface): - """Uses ``address`` as the MJCF XML path (same field real adapters use for IP/port). - If the engine has more joints than ``dof``, the extra joint at index ``dof`` - is treated as the gripper, with ctrl range scaled automatically. +logger = setup_logger() + +_READY_WAIT_TIMEOUT_S = 60.0 +_READY_WAIT_POLL_S = 0.1 +_ATTACH_RETRY_TIMEOUT_S = 30.0 +_ATTACH_RETRY_POLL_S = 0.2 + + +class ShmMujocoAdapter: + """``ManipulatorAdapter`` that proxies to a ``MujocoSimModule`` via SHM. + + Uses ``address`` (the MJCF XML path) as the discovery key. The sim module + must be running and have signalled ready before ``connect()`` returns. """ def __init__( self, dof: int = 7, address: str | None = None, - headless: bool = True, + hardware_id: str | None = None, **_: Any, ) -> None: if address is None: raise ValueError("address (MJCF XML path) is required for sim_mujoco adapter") - engine = MujocoEngine(config_path=Path(address), headless=headless) + self._dof = dof + self._address = address + self._hardware_id = hardware_id + self._shm_key = shm_key_from_path(address) + self._shm: ManipShmReader | None = None + self._connected = False + self._servos_enabled = False + self._control_mode = ControlMode.POSITION + self._error_code = 0 + self._error_message = "" + self._has_gripper = False + self._effort_mode_warned = False + + def connect(self) -> bool: + deadline = time.monotonic() + _ATTACH_RETRY_TIMEOUT_S + while True: + try: + self._shm = ManipShmReader(self._shm_key) + break + except FileNotFoundError: + if time.monotonic() > deadline: + logger.error( + "SHM buffers not found", + address=self._address, + shm_key=self._shm_key, + timeout_s=_ATTACH_RETRY_TIMEOUT_S, + ) + return False + time.sleep(_ATTACH_RETRY_POLL_S) + + # Wait for sim module to signal ready. + deadline = time.monotonic() + _READY_WAIT_TIMEOUT_S + while not self._shm.is_ready(): + if time.monotonic() > deadline: + logger.error("sim module not ready", timeout_s=_READY_WAIT_TIMEOUT_S) + return False + time.sleep(_READY_WAIT_POLL_S) + + num_joints = self._shm.num_joints() + self._has_gripper = num_joints > self._dof + self._connected = True + self._servos_enabled = True + logger.info("ShmMujocoAdapter connected", dof=self._dof, gripper=self._has_gripper) + return True + + def disconnect(self) -> None: + try: + if self._shm is not None: + self._shm.cleanup() + finally: + self._shm = None + self._connected = False + + def is_connected(self) -> bool: + return self._connected and self._shm is not None + + def get_info(self) -> ManipulatorInfo: + return ManipulatorInfo( + vendor="Simulation", + model="Simulation", + dof=self._dof, + firmware_version=None, + serial_number=None, + ) + + def get_dof(self) -> int: + return self._dof + + def get_limits(self) -> JointLimits: + lower = [-math.pi] * self._dof + upper = [math.pi] * self._dof + max_vel_rad = math.radians(180.0) + return JointLimits( + position_lower=lower, + position_upper=upper, + velocity_max=[max_vel_rad] * self._dof, + ) + + def set_control_mode(self, mode: ControlMode) -> bool: + self._control_mode = mode + return True + + def get_control_mode(self) -> ControlMode: + return self._control_mode + + def read_joint_positions(self) -> list[float]: + if self._shm is None: + return [0.0] * self._dof + return self._shm.read_positions(self._dof) + + def read_joint_velocities(self) -> list[float]: + if self._shm is None: + return [0.0] * self._dof + return self._shm.read_velocities(self._dof) + + def read_joint_efforts(self) -> list[float]: + if self._shm is None: + return [0.0] * self._dof + return self._shm.read_efforts(self._dof) + + def read_state(self) -> dict[str, int]: + velocities = self.read_joint_velocities() + is_moving = any(abs(v) > 1e-4 for v in velocities) + mode_int = list(ControlMode).index(self._control_mode) + return {"state": 1 if is_moving else 0, "mode": mode_int} + + def read_error(self) -> tuple[int, str]: + return self._error_code, self._error_message + + def write_joint_positions(self, positions: list[float], velocity: float = 1.0) -> bool: + if not self._servos_enabled or self._shm is None: + return False + self._control_mode = ControlMode.POSITION + self._shm.write_position_command(positions[: self._dof]) + return True + + def write_joint_velocities(self, velocities: list[float]) -> bool: + if not self._servos_enabled or self._shm is None: + return False + self._control_mode = ControlMode.VELOCITY + self._shm.write_velocity_command(velocities[: self._dof]) + return True + + def write_joint_efforts(self, efforts: list[float]) -> bool: + # Effort mode not exposed via SHM yet; caller can fall back to position. + if not self._effort_mode_warned: + logger.warning( + "write_joint_efforts not supported by sim adapter; ignoring and returning False", + dof=self._dof, + ) + self._effort_mode_warned = True + return False + + def write_stop(self) -> bool: + # Hold current position. + if self._shm is None: + return False + positions = self._shm.read_positions(self._dof) + self._shm.write_position_command(positions) + return True + + def write_enable(self, enable: bool) -> bool: + self._servos_enabled = enable + return True + + def read_enabled(self) -> bool: + return self._servos_enabled + + def write_clear_errors(self) -> bool: + self._error_code = 0 + self._error_message = "" + return True + + def read_cartesian_position(self) -> dict[str, float] | None: + return None + + def write_cartesian_position(self, pose: dict[str, float], velocity: float = 1.0) -> bool: + return False + + def read_gripper_position(self) -> float | None: + if not self._has_gripper or self._shm is None: + return None + return self._shm.read_gripper_position() - # Detect gripper from engine joints - gripper_idx = None - gripper_kwargs = {} - joint_names = list(engine.joint_names) - if len(joint_names) > dof: - gripper_idx = dof - ctrl_range = engine.get_actuator_ctrl_range(dof) - joint_range = engine.get_joint_range(dof) - if ctrl_range is None or joint_range is None: - raise ValueError(f"Gripper joint at index {dof} missing ctrl/joint range in MJCF") - gripper_kwargs = {"gripper_ctrl_range": ctrl_range, "gripper_joint_range": joint_range} + def write_gripper_position(self, position: float) -> bool: + if not self._has_gripper or self._shm is None: + return False + self._shm.write_gripper_command(position) + return True - super().__init__(engine=engine, dof=dof, gripper_idx=gripper_idx, **gripper_kwargs) + def read_force_torque(self) -> list[float] | None: + return None def register(registry: AdapterRegistry) -> None: """Register this adapter with the registry.""" - registry.register("sim_mujoco", SimMujocoAdapter) + registry.register("sim_mujoco", ShmMujocoAdapter) -__all__ = ["SimMujocoAdapter"] +__all__ = ["ShmMujocoAdapter"] diff --git a/dimos/hardware/manipulators/sim/test_shm_adapter.py b/dimos/hardware/manipulators/sim/test_shm_adapter.py new file mode 100644 index 0000000000..45dae82964 --- /dev/null +++ b/dimos/hardware/manipulators/sim/test_shm_adapter.py @@ -0,0 +1,198 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for ShmMujocoAdapter (the SHM-backed ManipulatorAdapter).""" + +from __future__ import annotations + +from unittest.mock import MagicMock +import uuid + +import pytest + +import dimos.hardware.manipulators.sim.adapter as adapter_mod +from dimos.hardware.manipulators.sim.adapter import ShmMujocoAdapter, register +from dimos.hardware.manipulators.spec import ControlMode, ManipulatorAdapter +from dimos.simulation.engines.mujoco_shm import ManipShmWriter + +ARM_DOF = 7 + + +@pytest.fixture +def shm_key(): + return f"test_{uuid.uuid4().hex[:10]}" + + +@pytest.fixture +def writer(shm_key, monkeypatch): + """Pretend we're the sim module: create SHM, signal ready. + + We monkey-patch ``shm_key_from_path`` so the adapter under test resolves + to our fixture's key regardless of the address string. + """ + monkeypatch.setattr(adapter_mod, "shm_key_from_path", lambda _: shm_key) + w = ManipShmWriter(shm_key) + w.signal_ready(num_joints=ARM_DOF) + yield w + w.cleanup() + + +@pytest.fixture +def writer_with_gripper(shm_key, monkeypatch): + monkeypatch.setattr(adapter_mod, "shm_key_from_path", lambda _: shm_key) + w = ManipShmWriter(shm_key) + w.signal_ready(num_joints=ARM_DOF + 1) + yield w + w.cleanup() + + +@pytest.fixture +def adapter(writer): + a = ShmMujocoAdapter(dof=ARM_DOF, address="/fake/scene.xml") + assert a.connect() is True + yield a + a.disconnect() + + +@pytest.fixture +def adapter_with_gripper(writer_with_gripper): + a = ShmMujocoAdapter(dof=ARM_DOF, address="/fake/scene.xml") + assert a.connect() is True + yield a + a.disconnect() + + +class TestProtocolConformance: + def test_implements_manipulator_adapter(self): + a = ShmMujocoAdapter(dof=ARM_DOF, address="/fake/scene.xml") + assert isinstance(a, ManipulatorAdapter) + + def test_address_required(self): + with pytest.raises(ValueError, match="address"): + ShmMujocoAdapter(dof=ARM_DOF, address=None) + + def test_register(self): + registry = MagicMock() + register(registry) + registry.register.assert_called_once_with("sim_mujoco", ShmMujocoAdapter) + + +class TestReadState: + def test_read_joint_positions(self, adapter, writer): + writer.write_joint_state( + positions=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], + velocities=[0.0] * 7, + efforts=[0.0] * 7, + ) + assert adapter.read_joint_positions() == [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + + def test_read_joint_velocities(self, adapter, writer): + writer.write_joint_state( + positions=[0.0] * 7, + velocities=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], + efforts=[0.0] * 7, + ) + assert adapter.read_joint_velocities() == [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0] + + def test_read_joint_efforts(self, adapter, writer): + writer.write_joint_state( + positions=[0.0] * 7, + velocities=[0.0] * 7, + efforts=[-1.0, -2.0, -3.0, -4.0, -5.0, -6.0, -7.0], + ) + assert adapter.read_joint_efforts() == [-1.0, -2.0, -3.0, -4.0, -5.0, -6.0, -7.0] + + def test_returns_only_dof_joints(self, adapter_with_gripper, writer_with_gripper): + # Writer publishes 8 joints (7 arm + 1 gripper); adapter should return 7. + writer_with_gripper.write_joint_state( + positions=list(range(8)), + velocities=[0.0] * 8, + efforts=[0.0] * 8, + ) + positions = adapter_with_gripper.read_joint_positions() + assert len(positions) == ARM_DOF + + +class TestWriteCommand: + def test_write_joint_positions(self, adapter, writer): + assert adapter.write_joint_positions([0.1] * 7) is True + cmd = writer.read_position_command(7) + assert cmd is not None + assert cmd.tolist() == pytest.approx([0.1] * 7) + + def test_write_joint_velocities(self, adapter, writer): + assert adapter.write_joint_velocities([0.5] * 7) is True + cmd = writer.read_velocity_command(7) + assert cmd is not None + assert cmd.tolist() == pytest.approx([0.5] * 7) + + def test_write_when_disabled(self, adapter): + adapter.write_enable(False) + assert adapter.write_joint_positions([0.0] * 7) is False + + def test_control_mode_tracked(self, adapter): + adapter.write_joint_positions([0.0] * 7) + assert adapter.get_control_mode() == ControlMode.POSITION + adapter.write_joint_velocities([0.0] * 7) + assert adapter.get_control_mode() == ControlMode.VELOCITY + + +class TestGripper: + def test_gripper_detected(self, adapter_with_gripper): + assert adapter_with_gripper._has_gripper is True + + def test_no_gripper_when_dof_matches(self, adapter): + assert adapter._has_gripper is False + + def test_read_gripper_position(self, adapter_with_gripper, writer_with_gripper): + writer_with_gripper.write_gripper_state(0.33) + assert adapter_with_gripper.read_gripper_position() == pytest.approx(0.33) + + def test_read_gripper_position_no_gripper(self, adapter): + assert adapter.read_gripper_position() is None + + def test_write_gripper_position(self, adapter_with_gripper, writer_with_gripper): + assert adapter_with_gripper.write_gripper_position(0.5) is True + # Gripper command is raw (unscaled) — sim module handles joint->ctrl. + assert writer_with_gripper.read_gripper_command() == pytest.approx(0.5) + + def test_write_gripper_position_no_gripper(self, adapter): + assert adapter.write_gripper_position(0.5) is False + + +class TestConnect: + def test_connect_before_sim_ready_times_out(self, shm_key, monkeypatch): + """If sim module never signals ready, connect() returns False after timeout.""" + monkeypatch.setattr(adapter_mod, "shm_key_from_path", lambda _: shm_key) + # Shrink timeouts so the test runs fast. + monkeypatch.setattr(adapter_mod, "_READY_WAIT_TIMEOUT_S", 0.2) + monkeypatch.setattr(adapter_mod, "_READY_WAIT_POLL_S", 0.02) + + w = ManipShmWriter(shm_key) + try: + # Note: writer exists but signal_ready is NOT called. + a = ShmMujocoAdapter(dof=ARM_DOF, address="/fake/scene.xml") + assert a.connect() is False + finally: + w.cleanup() + + def test_connect_waits_for_shm(self, shm_key, monkeypatch): + """If SHM buffers don't exist yet, connect() retries briefly.""" + monkeypatch.setattr(adapter_mod, "shm_key_from_path", lambda _: shm_key) + monkeypatch.setattr(adapter_mod, "_ATTACH_RETRY_TIMEOUT_S", 0.2) + monkeypatch.setattr(adapter_mod, "_ATTACH_RETRY_POLL_S", 0.02) + + a = ShmMujocoAdapter(dof=ARM_DOF, address="/fake/scene.xml") + # SHM was never created — attach must time out. + assert a.connect() is False diff --git a/dimos/manipulation/blueprints.py b/dimos/manipulation/blueprints.py index d906b27e3b..3c7f23f768 100644 --- a/dimos/manipulation/blueprints.py +++ b/dimos/manipulation/blueprints.py @@ -262,6 +262,61 @@ ) +# Sim perception: MujocoSimModule owns the MujocoEngine and publishes both +# camera streams and joint state via shared memory. +# ShmMujocoAdapter attaches to the same SHM buffers by MJCF path. + +from dimos.robot.catalog.ufactory import XARM7_SIM_PATH +from dimos.simulation.engines.mujoco_sim_module import MujocoSimModule +from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode + +_xarm7_sim_cfg = _catalog_xarm7( + name="arm", + adapter_type="sim_mujoco", + address=str(XARM7_SIM_PATH), + add_gripper=True, + pitch=math.radians(45), + tf_extra_links=["link7"], + home_joints=[0.0, 0.0, 0.0, 0.0, 0.0, -0.7, 0.0], + pre_grasp_offset=0.05, +) + +xarm_perception_sim = autoconnect( + PickAndPlaceModule.blueprint( + robots=[_xarm7_sim_cfg.to_robot_model_config()], + planning_timeout=10.0, + enable_viz=True, + ), + MujocoSimModule.blueprint( + address=str(XARM7_SIM_PATH), + headless=False, + dof=7, + camera_name="wrist_camera", + base_frame_id="link7", + ), + ObjectSceneRegistrationModule.blueprint(target_frame="world"), + ControlCoordinator.blueprint( + tick_rate=100.0, + publish_joint_state=True, + joint_state_frame_id="coordinator", + hardware=[_xarm7_sim_cfg.to_hardware_component()], + tasks=[_xarm7_sim_cfg.to_task_config()], + ), + RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode()), +).transports( + { + ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), + } +) + + +xarm_perception_sim_agent = autoconnect( + xarm_perception_sim, + McpServer.blueprint(), + McpClient.blueprint(system_prompt=_MANIPULATION_AGENT_SYSTEM_PROMPT), +) + + __all__ = [ "dual_xarm6_planner", "xarm6_planner_only", @@ -269,4 +324,6 @@ "xarm7_planner_coordinator_agent", "xarm_perception", "xarm_perception_agent", + "xarm_perception_sim", + "xarm_perception_sim_agent", ] diff --git a/dimos/manipulation/manipulation_module.py b/dimos/manipulation/manipulation_module.py index 35d1f0c43c..6e0b6be35b 100644 --- a/dimos/manipulation/manipulation_module.py +++ b/dimos/manipulation/manipulation_module.py @@ -472,6 +472,14 @@ def _plan_path_only( if start is None: return self._fail("No joint state") + # Trim goal to planner DOF (e.g. strip gripper joint from coordinator state) + planner_dof = len(start.position) + if len(goal.position) > planner_dof: + goal = JointState( + name=list(goal.name[:planner_dof]) if goal.name else [], + position=list(goal.position[:planner_dof]), + ) + result = self._planner.plan_joint_path( world=self._world_monitor.world, robot_id=robot_id, diff --git a/dimos/manipulation/planning/spec/config.py b/dimos/manipulation/planning/spec/config.py index e278a645ab..410b32a220 100644 --- a/dimos/manipulation/planning/spec/config.py +++ b/dimos/manipulation/planning/spec/config.py @@ -16,7 +16,6 @@ from __future__ import annotations -from collections.abc import Sequence from pathlib import Path from pydantic import Field @@ -74,7 +73,7 @@ class RobotModelConfig(ModuleConfig): coordinator_task_name: str | None = None gripper_hardware_id: str | None = None # TF publishing for extra links (e.g., camera mount) - tf_extra_links: Sequence[str] = () + tf_extra_links: list[str] = Field(default_factory=list) # Home/observe joint configuration for go_home skill home_joints: list[float] | None = None # Pre-grasp offset distance in meters (along approach direction) diff --git a/dimos/msgs/sensor_msgs/CameraInfo.py b/dimos/msgs/sensor_msgs/CameraInfo.py index a371475675..a37682f7a5 100644 --- a/dimos/msgs/sensor_msgs/CameraInfo.py +++ b/dimos/msgs/sensor_msgs/CameraInfo.py @@ -90,6 +90,38 @@ def __init__( self.roi_width = 0 self.roi_do_rectify = False + @classmethod + def from_intrinsics( + cls, + fx: float, + fy: float, + cx: float, + cy: float, + width: int, + height: int, + frame_id: str = "", + ) -> CameraInfo: + """Create CameraInfo from pinhole intrinsics (no distortion). + + Args: + fx: Focal length x (pixels) + fy: Focal length y (pixels) + cx: Principal point x (pixels) + cy: Principal point y (pixels) + width: Image width + height: Image height + frame_id: Frame ID + """ + return cls( + height=height, + width=width, + distortion_model="plumb_bob", + D=[0.0, 0.0, 0.0, 0.0, 0.0], + K=[fx, 0.0, cx, 0.0, fy, cy, 0.0, 0.0, 1.0], + P=[fx, 0.0, cx, 0.0, 0.0, fy, cy, 0.0, 0.0, 0.0, 1.0, 0.0], + frame_id=frame_id, + ) + def with_ts(self, ts: float) -> CameraInfo: """Return a copy of this CameraInfo with the given timestamp. diff --git a/dimos/msgs/sensor_msgs/PointCloud2.py b/dimos/msgs/sensor_msgs/PointCloud2.py index e031228088..7a76a06a96 100644 --- a/dimos/msgs/sensor_msgs/PointCloud2.py +++ b/dimos/msgs/sensor_msgs/PointCloud2.py @@ -127,9 +127,10 @@ def __getstate__(self) -> dict[str, object]: # Remove non-picklable objects del state["_pcd_tensor"] state["_pcd_legacy_cache"] = None - # Remove cached_property entries that hold unpicklable Open3D types - state.pop("oriented_bounding_box", None) - state.pop("axis_aligned_bounding_box", None) + # Remove all cached_property entries + for key in list(state): + if isinstance(getattr(type(self), key, None), functools.cached_property): + del state[key] return state def __setstate__(self, state: dict[str, object]) -> None: diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index 0498b77c75..5289ec74de 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -92,6 +92,8 @@ "unity-sim": "dimos.simulation.unity.blueprint:unity_sim", "xarm-perception": "dimos.manipulation.blueprints:xarm_perception", "xarm-perception-agent": "dimos.manipulation.blueprints:xarm_perception_agent", + "xarm-perception-sim": "dimos.manipulation.blueprints:xarm_perception_sim", + "xarm-perception-sim-agent": "dimos.manipulation.blueprints:xarm_perception_sim_agent", "xarm6-planner-only": "dimos.manipulation.blueprints:xarm6_planner_only", "xarm7-planner-coordinator": "dimos.manipulation.blueprints:xarm7_planner_coordinator", "xarm7-planner-coordinator-agent": "dimos.manipulation.blueprints:xarm7_planner_coordinator_agent", @@ -138,6 +140,7 @@ "mock-b1-connection-module": "dimos.robot.unitree.b1.connection.MockB1ConnectionModule", "module-a": "dimos.robot.unitree.demo_error_on_name_conflicts.ModuleA", "module-b": "dimos.robot.unitree.demo_error_on_name_conflicts.ModuleB", + "mujoco-sim-module": "dimos.simulation.engines.mujoco_sim_module.MujocoSimModule", "navigation-module": "dimos.robot.unitree.rosnav.NavigationModule", "navigation-skill-container": "dimos.agents.skills.navigation.NavigationSkillContainer", "object-db-module": "dimos.perception.detection.moduleDB.ObjectDBModule", diff --git a/dimos/simulation/engines/mujoco_engine.py b/dimos/simulation/engines/mujoco_engine.py index 7cdc5122c3..a5aa45d903 100644 --- a/dimos/simulation/engines/mujoco_engine.py +++ b/dimos/simulation/engines/mujoco_engine.py @@ -16,12 +16,17 @@ from __future__ import annotations +from collections.abc import Callable +from dataclasses import dataclass +from pathlib import Path import threading import time from typing import TYPE_CHECKING import mujoco import mujoco.viewer as viewer # type: ignore[import-untyped,import-not-found] +import numpy as np +from numpy.typing import NDArray from dimos.constants import DEFAULT_THREAD_JOIN_TIMEOUT from dimos.simulation.engines.base import SimulationEngine @@ -29,12 +34,41 @@ from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: - from pathlib import Path - from dimos.msgs.sensor_msgs.JointState import JointState logger = setup_logger() +# Step hook signature: called with the engine instance inside the sim thread. +StepHook = Callable[["MujocoEngine"], None] + + +@dataclass +class CameraConfig: + name: str + width: int = 640 + height: int = 480 + fps: float = 15.0 + + +@dataclass +class CameraFrame: + rgb: NDArray[np.uint8] + depth: NDArray[np.float32] + cam_pos: NDArray[np.float64] + cam_mat: NDArray[np.float64] + fovy: float + timestamp: float + + +@dataclass +class _CameraRendererState: + cfg: CameraConfig + cam_id: int + rgb_renderer: mujoco.Renderer + depth_renderer: mujoco.Renderer + interval: float + last_render_time: float = 0.0 + class MujocoEngine(SimulationEngine): """ @@ -45,8 +79,17 @@ class MujocoEngine(SimulationEngine): - applies control commands """ - def __init__(self, config_path: Path, headless: bool) -> None: + def __init__( + self, + config_path: Path, + headless: bool, + cameras: list[CameraConfig] | None = None, + on_before_step: StepHook | None = None, + on_after_step: StepHook | None = None, + ) -> None: super().__init__(config_path=config_path, headless=headless) + self._on_before_step: StepHook | None = on_before_step + self._on_after_step: StepHook | None = on_after_step xml_path = self._resolve_xml_path(config_path) self._model = mujoco.MjModel.from_xml_path(str(xml_path)) @@ -77,6 +120,11 @@ def __init__(self, config_path: Path, headless: bool) -> None: self._joint_position_targets[i] = current_pos self._joint_positions[i] = current_pos + # Camera rendering state (renderers created in sim thread) + self._camera_configs = cameras or [] + self._camera_frames: dict[str, CameraFrame] = {} + self._camera_lock = threading.Lock() + def _resolve_xml_path(self, config_path: Path) -> Path: if config_path is None: raise ValueError("config_path is required for MuJoCo simulation loading") @@ -143,7 +191,7 @@ def _update_joint_state(self) -> None: def connect(self) -> bool: try: - logger.info(f"{self.__class__.__name__}: connect()") + logger.info("connect()", cls=self.__class__.__name__) with self._lock: self._connected = True self._stop_event.clear() @@ -157,12 +205,12 @@ def connect(self) -> bool: self._sim_thread.start() return True except Exception as e: - logger.error(f"{self.__class__.__name__}: connect() failed: {e}") + logger.error("connect() failed", cls=self.__class__.__name__, error=str(e)) return False def disconnect(self) -> bool: try: - logger.info(f"{self.__class__.__name__}: disconnect()") + logger.info("disconnect()", cls=self.__class__.__name__) with self._lock: self._connected = False self._stop_event.set() @@ -171,20 +219,85 @@ def disconnect(self) -> bool: self._sim_thread = None return True except Exception as e: - logger.error(f"{self.__class__.__name__}: disconnect() failed: {e}") + logger.error("disconnect() failed", cls=self.__class__.__name__, error=str(e)) return False + def _init_cameras(self) -> dict[str, _CameraRendererState]: + """Create renderers for all configured cameras""" + cam_renderers: dict[str, _CameraRendererState] = {} + for cfg in self._camera_configs: + cam_id = mujoco.mj_name2id(self._model, mujoco.mjtObj.mjOBJ_CAMERA, cfg.name) + if cam_id < 0: + logger.warning("Camera not found in MJCF, skipping", camera_name=cfg.name) + continue + rgb_renderer = mujoco.Renderer(self._model, height=cfg.height, width=cfg.width) + depth_renderer = mujoco.Renderer(self._model, height=cfg.height, width=cfg.width) + depth_renderer.enable_depth_rendering() + interval = 1.0 / cfg.fps if cfg.fps > 0 else float("inf") + cam_renderers[cfg.name] = _CameraRendererState( + cfg=cfg, + cam_id=cam_id, + rgb_renderer=rgb_renderer, + depth_renderer=depth_renderer, + interval=interval, + ) + return cam_renderers + + def _render_cameras(self, now: float, cam_renderers: dict[str, _CameraRendererState]) -> None: + """Render all due cameras and store frames. Must be called from sim thread.""" + for state in cam_renderers.values(): + if now - state.last_render_time < state.interval: + continue + state.last_render_time = now + + state.rgb_renderer.update_scene(self._data, camera=state.cam_id) + rgb = state.rgb_renderer.render().copy() + + state.depth_renderer.update_scene(self._data, camera=state.cam_id) + depth = state.depth_renderer.render().copy() + + frame = CameraFrame( + rgb=rgb, + depth=depth.astype(np.float32), + cam_pos=self._data.cam_xpos[state.cam_id].copy(), + cam_mat=self._data.cam_xmat[state.cam_id].copy(), + fovy=float(self._model.cam_fovy[state.cam_id]), + timestamp=now, + ) + with self._camera_lock: + self._camera_frames[state.cfg.name] = frame + + @staticmethod + def _close_cam_renderers(cam_renderers: dict[str, _CameraRendererState]) -> None: + for state in cam_renderers.values(): + state.rgb_renderer.close() + state.depth_renderer.close() + def _sim_loop(self) -> None: - logger.info(f"{self.__class__.__name__}: sim loop started") + logger.info("sim loop started", cls=self.__class__.__name__) dt = 1.0 / self._control_frequency + # Camera renderers: created once in the sim thread + cam_renderers = self._init_cameras() + def _step_once(sync_viewer: bool) -> None: loop_start = time.time() + if self._on_before_step is not None: + try: + self._on_before_step(self) + except Exception as exc: + logger.error("on_before_step failed", error=str(exc)) self._apply_control() mujoco.mj_step(self._model, self._data) if sync_viewer: m_viewer.sync() self._update_joint_state() + if self._on_after_step is not None: + try: + self._on_after_step(self) + except Exception as exc: + logger.error("on_after_step failed", error=str(exc)) + self._render_cameras(loop_start, cam_renderers) elapsed = time.time() - loop_start sleep_time = dt - elapsed @@ -201,7 +314,8 @@ def _step_once(sync_viewer: bool) -> None: while m_viewer.is_running() and not self._stop_event.is_set(): _step_once(sync_viewer=True) - logger.info(f"{self.__class__.__name__}: sim loop stopped") + self._close_cam_renderers(cam_renderers) + logger.info("sim loop stopped", cls=self.__class__.__name__) @property def connected(self) -> bool: @@ -328,7 +442,25 @@ def get_joint_range(self, joint_index: int) -> tuple[float, float] | None: ) return None + def read_camera(self, camera_name: str) -> CameraFrame | None: + """Read the latest rendered frame for a camera (thread-safe). + + Returns None if the camera hasn't rendered yet or doesn't exist. + """ + with self._camera_lock: + return self._camera_frames.get(camera_name) + + def get_camera_fovy(self, camera_name: str) -> float | None: + """Get vertical field of view for a named camera, in degrees.""" + cam_id = mujoco.mj_name2id(self._model, mujoco.mjtObj.mjOBJ_CAMERA, camera_name) + if cam_id < 0: + return None + return float(self._model.cam_fovy[cam_id]) + __all__ = [ + "CameraConfig", + "CameraFrame", "MujocoEngine", + "StepHook", ] diff --git a/dimos/simulation/engines/mujoco_shm.py b/dimos/simulation/engines/mujoco_shm.py new file mode 100644 index 0000000000..1679a00e42 --- /dev/null +++ b/dimos/simulation/engines/mujoco_shm.py @@ -0,0 +1,347 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared-memory buffers for sim-manipulator IPC. + +Layout for exchanging joint state and commands between ``MujocoSimModule`` +(which owns the physics engine) and ``ShmMujocoAdapter`` (which plugs into +ControlCoordinator). Modeled after ``dimos.simulation.mujoco.shared_memory`` +(the Go2 SHM pattern). + +Names are deterministic: both sides derive them from the resolved MJCF path, +so no name exchange over RPC is needed. The sim module creates the buffers +and signals ``ready``; the adapter attaches to them by name. +""" + +from __future__ import annotations + +from dataclasses import dataclass +import hashlib +from multiprocessing import resource_tracker +from multiprocessing.shared_memory import SharedMemory +from pathlib import Path +from typing import Any + +import numpy as np +from numpy.typing import NDArray + +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + +# Upper bound on joint count per sim. Arms + gripper are typically <= 10. +MAX_JOINTS = 16 +_FLOAT_BYTES = 8 # float64 +_INT32_BYTES = 4 + +_joint_array_size = MAX_JOINTS * _FLOAT_BYTES # float64 array + +# Buffer sizes (in bytes). +# seq: one int64 counter per buffer type. +# control: int32 [ready, stop, command_mode, num_joints]. +_shm_sizes = { + "positions": _joint_array_size, + "velocities": _joint_array_size, + "efforts": _joint_array_size, + "position_targets": _joint_array_size, + "velocity_targets": _joint_array_size, + "gripper": 2 * _FLOAT_BYTES, # [gripper_position, gripper_target] + "seq": 8 * _FLOAT_BYTES, # 8 int64 counters + "control": 4 * _INT32_BYTES, # [ready, stop, command_mode, num_joints] +} + +# Sequence counter indices. +SEQ_POSITIONS = 0 +SEQ_VELOCITIES = 1 +SEQ_EFFORTS = 2 +SEQ_POSITION_CMD = 3 +SEQ_VELOCITY_CMD = 4 +SEQ_GRIPPER_STATE = 5 +SEQ_GRIPPER_CMD = 6 + +# Control indices. +CTRL_READY = 0 +CTRL_STOP = 1 +CTRL_COMMAND_MODE = 2 +CTRL_NUM_JOINTS = 3 + +# Command modes. +CMD_MODE_POSITION = 0 +CMD_MODE_VELOCITY = 1 + +_NAME_PREFIX = "dimos_mjmanip" + + +def shm_key_from_path(config_path: Path | str) -> str: + """Derive a deterministic short key from an MJCF path. + + Both sim module and adapter compute the same key from the same path, + so SHM buffer names can be agreed upon without an RPC round-trip. + """ + resolved = str(Path(config_path).expanduser().resolve()) + return hashlib.md5(resolved.encode("utf-8")).hexdigest()[:12] + + +def _buffer_name(key: str, buffer: str) -> str: + return f"{_NAME_PREFIX}_{key}_{buffer}" + + +def _unregister(shm: SharedMemory) -> SharedMemory: + """Detach ``shm`` from ``resource_tracker`` to silence spurious warnings. + + Same technique as ``dimos.simulation.mujoco.shared_memory._unregister``. + """ + try: + resource_tracker.unregister(shm._name, "shared_memory") # type: ignore[attr-defined] + except Exception: + pass + return shm + + +@dataclass(frozen=True) +class ManipShmSet: + """Frozen set of named SharedMemory buffers for manipulator IPC.""" + + positions: SharedMemory + velocities: SharedMemory + efforts: SharedMemory + position_targets: SharedMemory + velocity_targets: SharedMemory + gripper: SharedMemory + seq: SharedMemory + control: SharedMemory + + @classmethod + def create(cls, key: str) -> ManipShmSet: + """Create new SHM buffers with deterministic names derived from *key*""" + buffers: dict[str, SharedMemory] = {} + for buffer_name, size in _shm_sizes.items(): + name = _buffer_name(key, buffer_name) + try: + stale = _unregister(SharedMemory(name=name)) + stale.close() + try: + stale.unlink() + logger.info("ManipShmSet: unlinked stale SHM", name=name) + except FileNotFoundError: + pass + except FileNotFoundError: + pass + buffers[buffer_name] = SharedMemory(create=True, size=size, name=name) + return cls(**buffers) + + @classmethod + def attach(cls, key: str) -> ManipShmSet: + """Attach to existing SHM buffers created by the sim side.""" + buffers: dict[str, SharedMemory] = {} + for buffer_name in _shm_sizes: + name = _buffer_name(key, buffer_name) + buffers[buffer_name] = _unregister(SharedMemory(name=name)) + return cls(**buffers) + + def as_list(self) -> list[SharedMemory]: + return [getattr(self, k) for k in _shm_sizes] + + +class ManipShmWriter: + """Sim-side handle: writes joint state, reads command targets. + Owned by ``MujocoSimModule``. Creates the SHM buffers on init and + unlinks them on cleanup. + """ + + shm: ManipShmSet + + def __init__(self, key: str) -> None: + self.shm = ManipShmSet.create(key) + self._last_pos_cmd_seq = 0 + self._last_vel_cmd_seq = 0 + self._last_gripper_cmd_seq = 0 + # Zero everything. + for buf in self.shm.as_list(): + np.ndarray((buf.size,), dtype=np.uint8, buffer=buf.buf)[:] = 0 + + def write_joint_state( + self, + positions: list[float], + velocities: list[float], + efforts: list[float], + ) -> None: + n = min(len(positions), MAX_JOINTS) + pos_arr = self._array(self.shm.positions, MAX_JOINTS, np.float64) + vel_arr = self._array(self.shm.velocities, MAX_JOINTS, np.float64) + eff_arr = self._array(self.shm.efforts, MAX_JOINTS, np.float64) + pos_arr[:n] = positions[:n] + vel_arr[:n] = velocities[:n] + eff_arr[:n] = efforts[:n] + self._increment_seq(SEQ_POSITIONS) + self._increment_seq(SEQ_VELOCITIES) + self._increment_seq(SEQ_EFFORTS) + + def write_gripper_state(self, position: float) -> None: + arr = self._array(self.shm.gripper, 2, np.float64) + arr[0] = position + self._increment_seq(SEQ_GRIPPER_STATE) + + def read_position_command(self, num_joints: int) -> NDArray[np.float64] | None: + """Return a copy of position targets if a new command arrived since last call.""" + seq = self._get_seq(SEQ_POSITION_CMD) + if seq <= self._last_pos_cmd_seq: + return None + self._last_pos_cmd_seq = seq + arr = self._array(self.shm.position_targets, MAX_JOINTS, np.float64) + return arr[:num_joints].copy() + + def read_velocity_command(self, num_joints: int) -> NDArray[np.float64] | None: + seq = self._get_seq(SEQ_VELOCITY_CMD) + if seq <= self._last_vel_cmd_seq: + return None + self._last_vel_cmd_seq = seq + arr = self._array(self.shm.velocity_targets, MAX_JOINTS, np.float64) + return arr[:num_joints].copy() + + def read_gripper_command(self) -> float | None: + seq = self._get_seq(SEQ_GRIPPER_CMD) + if seq <= self._last_gripper_cmd_seq: + return None + self._last_gripper_cmd_seq = seq + arr = self._array(self.shm.gripper, 2, np.float64) + return float(arr[1]) + + def read_command_mode(self) -> int: + return int(self._control()[CTRL_COMMAND_MODE]) + + def signal_ready(self, num_joints: int) -> None: + ctrl = self._control() + ctrl[CTRL_NUM_JOINTS] = num_joints + ctrl[CTRL_READY] = 1 + + def signal_stop(self) -> None: + self._control()[CTRL_STOP] = 1 + + def should_stop(self) -> bool: + return bool(self._control()[CTRL_STOP] == 1) + + def cleanup(self) -> None: + for shm in self.shm.as_list(): + try: + shm.close() + except FileNotFoundError: + pass # already detached + except OSError as exc: + logger.warning("SHM close failed", name=shm.name, error=str(exc)) + try: + shm.unlink() + except FileNotFoundError: + pass # already unlinked (e.g. cleanup called twice) + except OSError as exc: + logger.warning("SHM unlink failed", name=shm.name, error=str(exc)) + + def _array(self, buf: SharedMemory, n: int, dtype: Any) -> NDArray[Any]: + return np.ndarray((n,), dtype=dtype, buffer=buf.buf) + + def _control(self) -> NDArray[np.int32]: + return np.ndarray((4,), dtype=np.int32, buffer=self.shm.control.buf) + + def _increment_seq(self, index: int) -> None: + seq_arr = np.ndarray((8,), dtype=np.int64, buffer=self.shm.seq.buf) + seq_arr[index] += 1 + + def _get_seq(self, index: int) -> int: + seq_arr = np.ndarray((8,), dtype=np.int64, buffer=self.shm.seq.buf) + return int(seq_arr[index]) + + +class ManipShmReader: + """Adapter-side handle: reads joint state, writes command targets. + + Owned by ``ShmMujocoAdapter``. Attaches to existing buffers created by + the sim module; does not unlink them on cleanup. + """ + + shm: ManipShmSet + + def __init__(self, key: str) -> None: + self.shm = ManipShmSet.attach(key) + + def read_positions(self, num_joints: int) -> list[float]: + arr = np.ndarray((MAX_JOINTS,), dtype=np.float64, buffer=self.shm.positions.buf) + return [float(x) for x in arr[:num_joints]] + + def read_velocities(self, num_joints: int) -> list[float]: + arr = np.ndarray((MAX_JOINTS,), dtype=np.float64, buffer=self.shm.velocities.buf) + return [float(x) for x in arr[:num_joints]] + + def read_efforts(self, num_joints: int) -> list[float]: + arr = np.ndarray((MAX_JOINTS,), dtype=np.float64, buffer=self.shm.efforts.buf) + return [float(x) for x in arr[:num_joints]] + + def read_gripper_position(self) -> float: + arr = np.ndarray((2,), dtype=np.float64, buffer=self.shm.gripper.buf) + return float(arr[0]) + + def write_position_command(self, positions: list[float]) -> None: + n = min(len(positions), MAX_JOINTS) + arr = np.ndarray((MAX_JOINTS,), dtype=np.float64, buffer=self.shm.position_targets.buf) + arr[:n] = positions[:n] + self._set_command_mode(CMD_MODE_POSITION) + self._increment_seq(SEQ_POSITION_CMD) + + def write_velocity_command(self, velocities: list[float]) -> None: + n = min(len(velocities), MAX_JOINTS) + arr = np.ndarray((MAX_JOINTS,), dtype=np.float64, buffer=self.shm.velocity_targets.buf) + arr[:n] = velocities[:n] + self._set_command_mode(CMD_MODE_VELOCITY) + self._increment_seq(SEQ_VELOCITY_CMD) + + def write_gripper_command(self, position: float) -> None: + arr = np.ndarray((2,), dtype=np.float64, buffer=self.shm.gripper.buf) + arr[1] = position + self._increment_seq(SEQ_GRIPPER_CMD) + + def is_ready(self) -> bool: + return bool(self._control()[CTRL_READY] == 1) + + def num_joints(self) -> int: + return int(self._control()[CTRL_NUM_JOINTS]) + + def signal_stop(self) -> None: + self._control()[CTRL_STOP] = 1 + + def cleanup(self) -> None: + for shm in self.shm.as_list(): + try: + shm.close() + except FileNotFoundError: + pass # already detached + except OSError as exc: + logger.warning("SHM close failed", name=shm.name, error=str(exc)) + + def _control(self) -> NDArray[np.int32]: + return np.ndarray((4,), dtype=np.int32, buffer=self.shm.control.buf) + + def _set_command_mode(self, mode: int) -> None: + self._control()[CTRL_COMMAND_MODE] = mode + + def _increment_seq(self, index: int) -> None: + seq_arr = np.ndarray((8,), dtype=np.int64, buffer=self.shm.seq.buf) + seq_arr[index] += 1 + + +__all__ = [ + "MAX_JOINTS", + "ManipShmReader", + "ManipShmSet", + "ManipShmWriter", + "shm_key_from_path", +] diff --git a/dimos/simulation/engines/mujoco_sim_module.py b/dimos/simulation/engines/mujoco_sim_module.py new file mode 100644 index 0000000000..9250a2c6ac --- /dev/null +++ b/dimos/simulation/engines/mujoco_sim_module.py @@ -0,0 +1,503 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unified MuJoCo simulation Module. + +Owns a single ``MujocoEngine`` and publishes: +- camera streams (Out ports), replacing ``MujocoCamera`` +- joint state via shared memory, consumed by ``ShmMujocoAdapter`` inside + ``ControlCoordinator`` + +This avoids the prior pattern of sharing engines via a global in-process +registry, which was fragile when ``WorkerManager`` places the adapter and +the camera in different worker processes. +""" + +from __future__ import annotations + +import math +from pathlib import Path +import threading +import time +from typing import Any + +from pydantic import Field +import reactivex as rx +from scipy.spatial.transform import Rotation as R + +from dimos.core.core import rpc +from dimos.core.module import Module, ModuleConfig +from dimos.core.stream import Out +from dimos.hardware.sensors.camera.spec import DepthCameraConfig, DepthCameraHardware +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat +from dimos.msgs.sensor_msgs.JointState import JointState +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.simulation.engines.mujoco_engine import ( + CameraConfig, + CameraFrame, + MujocoEngine, +) +from dimos.simulation.engines.mujoco_shm import ( + ManipShmWriter, + shm_key_from_path, +) +from dimos.spec import perception +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + +_RX180 = R.from_euler("x", 180, degrees=True) + + +def _default_identity_transform() -> Transform: + return Transform( + translation=Vector3(0.0, 0.0, 0.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + ) + + +class MujocoSimModuleConfig(ModuleConfig, DepthCameraConfig): + """Configuration for the unified MuJoCo simulation module.""" + + address: str = "" + headless: bool = False + dof: int = 7 + + # Camera config (matches former MujocoCameraConfig). + camera_name: str = "wrist_camera" + width: int = 640 + height: int = 480 + fps: int = 15 + base_frame_id: str = "link7" + base_transform: Transform | None = Field(default_factory=_default_identity_transform) + align_depth_to_color: bool = True + enable_depth: bool = True + enable_pointcloud: bool = False + pointcloud_fps: float = 5.0 + camera_info_fps: float = 1.0 + + +class MujocoSimModule( + DepthCameraHardware, + Module[MujocoSimModuleConfig], + perception.DepthCamera, +): + """Single Module that owns a MujocoEngine, publishes camera streams, and + exposes joint state/commands to a ``ShmMujocoAdapter`` via shared memory. + + The adapter attaches to the same SHM buffers using the MJCF path as the + discovery key — no RPC, no globals. From ControlCoordinator's perspective + the adapter is an ordinary ``ManipulatorAdapter``; SHM is its transport. + """ + + color_image: Out[Image] + depth_image: Out[Image] + pointcloud: Out[PointCloud2] + camera_info: Out[CameraInfo] + depth_camera_info: Out[CameraInfo] + + default_config = MujocoSimModuleConfig + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._engine: MujocoEngine | None = None + self._shm: ManipShmWriter | None = None + self._gripper_idx: int | None = None + self._gripper_ctrl_range: tuple[float, float] = (0.0, 1.0) + self._gripper_joint_range: tuple[float, float] = (0.0, 1.0) + self._stop_event = threading.Event() + self._publish_thread: threading.Thread | None = None + self._camera_info_base: CameraInfo | None = None + + @property + def _camera_link(self) -> str: + return f"{self.config.camera_name}_link" + + @property + def _color_frame(self) -> str: + return f"{self.config.camera_name}_color_frame" + + @property + def _color_optical_frame(self) -> str: + return f"{self.config.camera_name}_color_optical_frame" + + @property + def _depth_frame(self) -> str: + return f"{self.config.camera_name}_depth_frame" + + @property + def _depth_optical_frame(self) -> str: + return f"{self.config.camera_name}_depth_optical_frame" + + @rpc + def get_color_camera_info(self) -> CameraInfo | None: + if self._camera_info_base is None: + return None + return self._camera_info_base.with_ts(time.time()) + + @rpc + def get_depth_camera_info(self) -> CameraInfo | None: + if self._camera_info_base is None: + return None + return self._camera_info_base.with_ts(time.time()) + + @rpc + def get_depth_scale(self) -> float: + return 1.0 + + @rpc + def start(self) -> None: + if not self.config.address: + raise RuntimeError("MujocoSimModule: config.address (MJCF path) is required") + + # SHM key — adapter derives the same key from the same MJCF path. + shm_key = shm_key_from_path(self.config.address) + self._shm = ManipShmWriter(shm_key) + + # Build engine with SHM hooks installed. + self._engine = MujocoEngine( + config_path=Path(self.config.address), + headless=self.config.headless, + cameras=[ + CameraConfig( + name=self.config.camera_name, + width=self.config.width, + height=self.config.height, + fps=float(self.config.fps), + ) + ], + on_before_step=self._apply_shm_commands, + on_after_step=self._publish_shm_state, + ) + + # Detect gripper (extra joint beyond dof). + dof = self.config.dof + joint_names = list(self._engine.joint_names) + if len(joint_names) > dof: + ctrl_range = self._engine.get_actuator_ctrl_range(dof) + joint_range = self._engine.get_joint_range(dof) + if ctrl_range is None or joint_range is None: + raise ValueError(f"Gripper joint at index {dof} missing ctrl/joint range in MJCF") + self._gripper_idx = dof + self._gripper_ctrl_range = ctrl_range + self._gripper_joint_range = joint_range + logger.info( + "MujocoSimModule: gripper detected", + idx=dof, + ctrl_range=ctrl_range, + joint_range=joint_range, + ) + + # Start physics (sim thread spawned inside engine.connect()). + if not self._engine.connect(): + raise RuntimeError("MujocoSimModule: engine.connect() failed") + + self._shm.signal_ready(num_joints=len(joint_names)) + + # Camera intrinsics. + self._build_camera_info() + + self._stop_event.clear() + self._publish_thread = threading.Thread( + target=self._publish_loop, daemon=True, name="MujocoSimPublish" + ) + self._publish_thread.start() + + # Periodic camera_info publishing. + interval_sec = 1.0 / self.config.camera_info_fps + self._disposables.add( + rx.interval(interval_sec).subscribe( + on_next=lambda _: self._publish_camera_info(), + on_error=lambda e: logger.error("CameraInfo publish error", error=str(e)), + ) + ) + + # Optional pointcloud generation. + if self.config.enable_pointcloud and self.config.enable_depth: + pc_interval = 1.0 / self.config.pointcloud_fps + self._disposables.add( + rx.interval(pc_interval).subscribe( + on_next=lambda _: self._generate_pointcloud(), + on_error=lambda e: logger.error("Pointcloud error", error=str(e)), + ) + ) + + logger.info( + "MujocoSimModule started", + address=self.config.address, + dof=dof, + camera=self.config.camera_name, + shm_key=shm_key, + ) + + @rpc + def stop(self) -> None: + self._stop_event.set() + if self._publish_thread and self._publish_thread.is_alive(): + self._publish_thread.join(timeout=2.0) + self._publish_thread = None + + errors: list[tuple[str, BaseException]] = [] + if self._engine is not None: + try: + self._engine.disconnect() + self._engine = None + except Exception as exc: + logger.error("engine.disconnect() failed", error=str(exc)) + errors.append(("engine.disconnect", exc)) + if self._shm is not None: + try: + self._shm.signal_stop() + self._shm.cleanup() + self._shm = None + except Exception as exc: + logger.error("SHM cleanup failed", error=str(exc)) + errors.append(("shm.cleanup", exc)) + + self._camera_info_base = None + super().stop() + + if errors: + op, err = errors[0] + raise RuntimeError(f"MujocoSimModule.stop() failed during {op}: {err}") from err + + def _apply_shm_commands(self, engine: MujocoEngine) -> None: + """Pre-step hook: pull command targets from SHM into the engine.""" + shm = self._shm + if shm is None: + return + dof = self.config.dof + + pos_cmd = shm.read_position_command(dof) + if pos_cmd is not None: + engine.write_joint_command(JointState(position=pos_cmd.tolist())) + + vel_cmd = shm.read_velocity_command(dof) + if vel_cmd is not None: + engine.write_joint_command(JointState(velocity=vel_cmd.tolist())) + + if self._gripper_idx is not None: + gripper_cmd = shm.read_gripper_command() + if gripper_cmd is not None: + ctrl_value = self._gripper_joint_to_ctrl(gripper_cmd) + engine.set_position_target(self._gripper_idx, ctrl_value) + + def _publish_shm_state(self, engine: MujocoEngine) -> None: + """Post-step hook: publish joint state to SHM.""" + shm = self._shm + if shm is None: + return + shm.write_joint_state( + positions=engine.joint_positions, + velocities=engine.joint_velocities, + efforts=engine.joint_efforts, + ) + if self._gripper_idx is not None: + positions = engine.joint_positions + if self._gripper_idx < len(positions): + shm.write_gripper_state(positions[self._gripper_idx]) + + def _gripper_joint_to_ctrl(self, joint_position: float) -> float: + """Map joint-space gripper position to actuator control value.""" + jlo, jhi = self._gripper_joint_range + clo, chi = self._gripper_ctrl_range + clamped = max(jlo, min(jhi, joint_position)) + if jhi == jlo: + return clo + t = (clamped - jlo) / (jhi - jlo) + return chi - t * (chi - clo) + + def _build_camera_info(self) -> None: + if self._engine is None: + return + fovy_deg = self._engine.get_camera_fovy(self.config.camera_name) + if fovy_deg is None: + logger.error("Camera not found in MJCF", camera_name=self.config.camera_name) + return + h = self.config.height + w = self.config.width + fovy_rad = math.radians(fovy_deg) + fy = h / (2.0 * math.tan(fovy_rad / 2.0)) + fx = fy # square pixels + self._camera_info_base = CameraInfo.from_intrinsics( + fx=fx, + fy=fy, + cx=w / 2.0, + cy=h / 2.0, + width=w, + height=h, + frame_id=self._color_optical_frame, + ) + + def _publish_loop(self) -> None: + """Poll engine for rendered frames and publish at configured FPS.""" + engine = self._engine + if engine is None: + return + + interval = 1.0 / self.config.fps + last_timestamp = 0.0 + published_count = 0 + + # Wait for engine to actually be connected (sim thread may take a tick). + deadline = time.monotonic() + 30.0 + while not self._stop_event.is_set() and not engine.connected: + if time.monotonic() > deadline: + logger.error("MujocoSimModule: timed out waiting for engine to connect") + return + self._stop_event.wait(timeout=0.1) + + if self._stop_event.is_set(): + return + + while not self._stop_event.is_set(): + try: + frame = engine.read_camera(self.config.camera_name) + except RuntimeError as exc: + logger.error( + "MuJoCo render failed; stopping publish loop", + camera_name=self.config.camera_name, + error=str(exc), + exc_info=True, + ) + return + + if frame is None or frame.timestamp <= last_timestamp: + self._stop_event.wait(timeout=interval * 0.5) + continue + last_timestamp = frame.timestamp + ts = time.time() + + color_img = Image( + data=frame.rgb, + format=ImageFormat.RGB, + frame_id=self._color_optical_frame, + ts=ts, + ) + self.color_image.publish(color_img) + + if self.config.enable_depth: + depth_img = Image( + data=frame.depth, + format=ImageFormat.DEPTH, + frame_id=self._color_optical_frame, + ts=ts, + ) + self.depth_image.publish(depth_img) + + self._publish_tf(ts, frame) + + published_count += 1 + if published_count == 1: + logger.info( + "MujocoSimModule first frame published", + rgb_shape=frame.rgb.shape, + depth_shape=frame.depth.shape, + ) + + elapsed = time.time() - ts + sleep_time = interval - elapsed + if sleep_time > 0: + time.sleep(sleep_time) + + def _publish_camera_info(self) -> None: + base = self._camera_info_base + if base is None: + return + ts = time.time() + info = CameraInfo( + height=base.height, + width=base.width, + distortion_model=base.distortion_model, + D=base.D, + K=base.K, + P=base.P, + frame_id=base.frame_id, + ts=ts, + ) + self.camera_info.publish(info) + self.depth_camera_info.publish(info) + + def _publish_tf(self, ts: float, frame: CameraFrame | None) -> None: + if frame is None: + return + mj_rot = R.from_matrix(frame.cam_mat.reshape(3, 3)) + optical_rot = mj_rot * _RX180 + q = optical_rot.as_quat() # xyzw + pos = Vector3( + float(frame.cam_pos[0]), + float(frame.cam_pos[1]), + float(frame.cam_pos[2]), + ) + rot = Quaternion(float(q[0]), float(q[1]), float(q[2]), float(q[3])) + self.tf.publish( + Transform( + translation=pos, + rotation=rot, + frame_id="world", + child_frame_id=self._color_optical_frame, + ts=ts, + ), + Transform( + translation=pos, + rotation=rot, + frame_id="world", + child_frame_id=self._depth_optical_frame, + ts=ts, + ), + Transform( + translation=pos, + rotation=rot, + frame_id="world", + child_frame_id=self._camera_link, + ts=ts, + ), + ) + + def _generate_pointcloud(self) -> None: + if self._engine is None or self._camera_info_base is None: + return + frame = self._engine.read_camera(self.config.camera_name) + if frame is None: + return + try: + color_img = Image( + data=frame.rgb, + format=ImageFormat.RGB, + frame_id=self._color_optical_frame, + ts=frame.timestamp, + ) + depth_img = Image( + data=frame.depth, + format=ImageFormat.DEPTH, + frame_id=self._color_optical_frame, + ts=frame.timestamp, + ) + pcd = PointCloud2.from_rgbd( + color_image=color_img, + depth_image=depth_img, + camera_info=self._camera_info_base, + depth_scale=1.0, + ) + pcd = pcd.voxel_downsample(0.005) + self.pointcloud.publish(pcd) + except Exception as exc: + logger.error("Pointcloud generation error", error=str(exc)) + + +__all__ = ["MujocoSimModule", "MujocoSimModuleConfig"] diff --git a/dimos/simulation/manipulators/sim_manip_interface.py b/dimos/simulation/manipulators/sim_manip_interface.py deleted file mode 100644 index 07e56c5afd..0000000000 --- a/dimos/simulation/manipulators/sim_manip_interface.py +++ /dev/null @@ -1,221 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Simulation-agnostic manipulator interface.""" - -from __future__ import annotations - -import logging -import math -from typing import TYPE_CHECKING - -from dimos.hardware.manipulators.spec import ControlMode, JointLimits, ManipulatorInfo -from dimos.msgs.sensor_msgs.JointState import JointState - -if TYPE_CHECKING: - from dimos.simulation.engines.base import SimulationEngine - - -class SimManipInterface: - """Adapter wrapper around a simulation engine to provide a uniform manipulator API.""" - - def __init__( - self, - engine: SimulationEngine, - dof: int | None = None, - gripper_idx: int | None = None, - gripper_ctrl_range: tuple[float, float] = (0.0, 1.0), - gripper_joint_range: tuple[float, float] = (0.0, 1.0), - ) -> None: - self.logger = logging.getLogger(self.__class__.__name__) - self._engine = engine - self._joint_names = list(engine.joint_names) - self._dof = dof if dof is not None else len(self._joint_names) - self._connected = False - self._servos_enabled = False - self._control_mode = ControlMode.POSITION - self._error_code = 0 - self._error_message = "" - self._gripper_idx = gripper_idx - self._gripper_ctrl_range = gripper_ctrl_range - self._gripper_joint_range = gripper_joint_range - - def connect(self) -> bool: - """Connect to the simulation engine.""" - try: - self.logger.info("Connecting to simulation engine...") - if not self._engine.connect(): - self.logger.error("Failed to connect to simulation engine") - return False - if self._engine.connected: - self._connected = True - self._servos_enabled = True - self.logger.info( - "Successfully connected to simulation", - extra={"dof": self._dof}, - ) - return True - self.logger.error("Failed to connect to simulation engine") - return False - except Exception as exc: - self.logger.error(f"Sim connection failed: {exc}") - return False - - def disconnect(self) -> None: - """Disconnect from simulation.""" - try: - self._engine.disconnect() - except Exception as exc: - self.logger.error(f"Sim disconnection failed: {exc}") - finally: - self._connected = False - - def is_connected(self) -> bool: - return bool(self._connected and self._engine.connected) - - def get_info(self) -> ManipulatorInfo: - vendor = "Simulation" - model = "Simulation" - dof = self._dof - return ManipulatorInfo( - vendor=vendor, - model=model, - dof=dof, - firmware_version=None, - serial_number=None, - ) - - def get_dof(self) -> int: - return self._dof - - def get_joint_names(self) -> list[str]: - return list(self._joint_names) - - def get_limits(self) -> JointLimits: - lower = [-math.pi] * self._dof - upper = [math.pi] * self._dof - max_vel_rad = math.radians(180.0) - return JointLimits( - position_lower=lower, - position_upper=upper, - velocity_max=[max_vel_rad] * self._dof, - ) - - def set_control_mode(self, mode: ControlMode) -> bool: - self._control_mode = mode - return True - - def get_control_mode(self) -> ControlMode: - return self._control_mode - - def read_joint_positions(self) -> list[float]: - positions = self._engine.read_joint_positions() - return positions[: self._dof] - - def read_joint_velocities(self) -> list[float]: - velocities = self._engine.read_joint_velocities() - return velocities[: self._dof] - - def read_joint_efforts(self) -> list[float]: - efforts = self._engine.read_joint_efforts() - return efforts[: self._dof] - - def read_state(self) -> dict[str, int]: - velocities = self.read_joint_velocities() - is_moving = any(abs(v) > 1e-4 for v in velocities) - mode_int = list(ControlMode).index(self._control_mode) - return { - "state": 1 if is_moving else 0, - "mode": mode_int, - } - - def read_error(self) -> tuple[int, str]: - return self._error_code, self._error_message - - def write_joint_positions(self, positions: list[float], velocity: float = 1.0) -> bool: - if not self._servos_enabled: - return False - self._control_mode = ControlMode.POSITION - self._engine.write_joint_command(JointState(position=positions[: self._dof])) - return True - - def write_joint_velocities(self, velocities: list[float]) -> bool: - if not self._servos_enabled: - return False - self._control_mode = ControlMode.VELOCITY - self._engine.write_joint_command(JointState(velocity=velocities[: self._dof])) - return True - - def write_joint_efforts(self, efforts: list[float]) -> bool: - if not self._servos_enabled: - return False - self._control_mode = ControlMode.TORQUE - self._engine.write_joint_command(JointState(effort=efforts[: self._dof])) - return True - - def write_stop(self) -> bool: - self._engine.hold_current_position() - return True - - def write_enable(self, enable: bool) -> bool: - self._servos_enabled = enable - return True - - def read_enabled(self) -> bool: - return self._servos_enabled - - def write_clear_errors(self) -> bool: - self._error_code = 0 - self._error_message = "" - return True - - def read_cartesian_position(self) -> dict[str, float] | None: - return None - - def write_cartesian_position( - self, - pose: dict[str, float], - velocity: float = 1.0, - ) -> bool: - _pose = pose - _velocity = velocity - return False - - def read_gripper_position(self) -> float | None: - if self._gripper_idx is None: - return None - positions = self._engine.read_joint_positions() - return positions[self._gripper_idx] - - def write_gripper_position(self, position: float) -> bool: - if self._gripper_idx is None: - return False - jlo, jhi = self._gripper_joint_range - clo, chi = self._gripper_ctrl_range - position = max(jlo, min(jhi, position)) - if jhi != jlo: - t = (position - jlo) / (jhi - jlo) - ctrl_value = chi - t * (chi - clo) - else: - ctrl_value = clo - self._engine.set_position_target(self._gripper_idx, ctrl_value) - return True - - def read_force_torque(self) -> list[float] | None: - return None - - -__all__ = [ - "SimManipInterface", -] diff --git a/dimos/simulation/manipulators/test_sim_adapter.py b/dimos/simulation/manipulators/test_sim_adapter.py deleted file mode 100644 index 8f253229f0..0000000000 --- a/dimos/simulation/manipulators/test_sim_adapter.py +++ /dev/null @@ -1,184 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for SimMujocoAdapter and gripper integration.""" - -from __future__ import annotations - -from pathlib import Path -from unittest.mock import MagicMock, patch - -import numpy as np -import pytest - -from dimos.hardware.manipulators.sim.adapter import SimMujocoAdapter, register -from dimos.simulation.utils.xml_parser import JointMapping - -ARM_DOF = 7 - - -def _make_joint_mapping(name: str, idx: int) -> JointMapping: - """Create a JointMapping for a simple revolute joint.""" - return JointMapping( - name=name, - joint_id=idx, - actuator_id=idx, - qpos_adr=idx, - dof_adr=idx, - tendon_qpos_adrs=(), - tendon_dof_adrs=(), - ) - - -def _make_gripper_mapping(name: str, idx: int) -> JointMapping: - """Create a JointMapping for a tendon-driven gripper.""" - return JointMapping( - name=name, - joint_id=None, - actuator_id=idx, - qpos_adr=None, - dof_adr=None, - tendon_qpos_adrs=(idx, idx + 1), - tendon_dof_adrs=(idx, idx + 1), - ) - - -def _patch_mujoco_engine(n_joints: int): - """Patch only the MuJoCo C-library and filesystem boundaries. - - Mocks ``_resolve_xml_path``, ``MjModel.from_xml_path``, ``MjData``, and - ``build_joint_mappings`` — the rest of ``MujocoEngine.__init__`` runs as-is. - """ - mappings = [_make_joint_mapping(f"joint{i}", i) for i in range(ARM_DOF)] - if n_joints > ARM_DOF: - mappings.append(_make_gripper_mapping(f"joint{ARM_DOF}", ARM_DOF)) - - fake_model = MagicMock() - fake_model.opt.timestep = 0.002 - fake_model.nu = n_joints - fake_model.nq = n_joints - fake_model.njnt = n_joints - fake_model.actuator_ctrlrange = np.array( - [[-6.28, 6.28]] * ARM_DOF + ([[0.0, 255.0]] if n_joints > ARM_DOF else []) - ) - fake_model.jnt_range = np.array( - [[-6.28, 6.28]] * ARM_DOF + ([[0.0, 0.85]] if n_joints > ARM_DOF else []) - ) - fake_model.jnt_qposadr = np.arange(n_joints) - - fake_data = MagicMock() - fake_data.qpos = np.zeros(n_joints + 4) # extra for tendon qpos addresses - fake_data.actuator_length = np.zeros(n_joints) - - patches = [ - patch( - "dimos.simulation.engines.mujoco_engine.MujocoEngine._resolve_xml_path", - return_value=Path("/fake/scene.xml"), - ), - patch( - "dimos.simulation.engines.mujoco_engine.mujoco.MjModel.from_xml_path", - return_value=fake_model, - ), - patch("dimos.simulation.engines.mujoco_engine.mujoco.MjData", return_value=fake_data), - patch("dimos.simulation.engines.mujoco_engine.build_joint_mappings", return_value=mappings), - ] - return patches - - -class TestSimMujocoAdapter: - """Tests for SimMujocoAdapter with and without gripper.""" - - @pytest.fixture - def adapter_with_gripper(self): - """SimMujocoAdapter with ARM_DOF arm joints + 1 gripper joint.""" - patches = _patch_mujoco_engine(ARM_DOF + 1) - for p in patches: - p.start() - try: - adapter = SimMujocoAdapter(dof=ARM_DOF, address="/fake/scene.xml", headless=True) - finally: - for p in patches: - p.stop() - return adapter - - @pytest.fixture - def adapter_no_gripper(self): - """SimMujocoAdapter with ARM_DOF arm joints, no gripper.""" - patches = _patch_mujoco_engine(ARM_DOF) - for p in patches: - p.start() - try: - adapter = SimMujocoAdapter(dof=ARM_DOF, address="/fake/scene.xml", headless=True) - finally: - for p in patches: - p.stop() - return adapter - - def test_address_required(self): - patches = _patch_mujoco_engine(ARM_DOF) - for p in patches: - p.start() - try: - with pytest.raises(ValueError, match="address"): - SimMujocoAdapter(dof=ARM_DOF, address=None) - finally: - for p in patches: - p.stop() - - def test_gripper_detected(self, adapter_with_gripper): - assert adapter_with_gripper._gripper_idx == ARM_DOF - - def test_no_gripper_when_dof_matches(self, adapter_no_gripper): - assert adapter_no_gripper._gripper_idx is None - - def test_read_gripper_position(self, adapter_with_gripper): - pos = adapter_with_gripper.read_gripper_position() - assert pos is not None - - def test_write_gripper_sets_target(self, adapter_with_gripper): - """Write a gripper position and verify the control target was set.""" - assert adapter_with_gripper.write_gripper_position(0.42) is True - target = adapter_with_gripper._engine._joint_position_targets[ARM_DOF] - assert target != 0.0, "write_gripper_position should update the control target" - - def test_read_gripper_position_no_gripper(self, adapter_no_gripper): - assert adapter_no_gripper.read_gripper_position() is None - - def test_write_gripper_position_no_gripper(self, adapter_no_gripper): - assert adapter_no_gripper.write_gripper_position(0.5) is False - - def test_write_gripper_does_not_clobber_arm(self, adapter_with_gripper): - """Gripper write must not overwrite arm joint targets.""" - engine = adapter_with_gripper._engine - for i in range(ARM_DOF): - engine._joint_position_targets[i] = float(i) + 1.0 - - adapter_with_gripper.write_gripper_position(0.0) - - for i in range(ARM_DOF): - assert engine._joint_position_targets[i] == pytest.approx(float(i) + 1.0) - - def test_read_joint_positions_excludes_gripper(self, adapter_with_gripper): - positions = adapter_with_gripper.read_joint_positions() - assert len(positions) == ARM_DOF - - def test_connect_and_disconnect(self, adapter_with_gripper): - with patch("dimos.simulation.engines.mujoco_engine.mujoco.mj_step"): - assert adapter_with_gripper.connect() is True - adapter_with_gripper.disconnect() - - def test_register(self): - registry = MagicMock() - register(registry) - registry.register.assert_called_once_with("sim_mujoco", SimMujocoAdapter)