diff --git a/data/.lfs/go2_bigoffice.db.tar.gz b/data/.lfs/go2_bigoffice.db.tar.gz index cad393bfcc..315610b5cb 100644 --- a/data/.lfs/go2_bigoffice.db.tar.gz +++ b/data/.lfs/go2_bigoffice.db.tar.gz @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:2d48cb0b8250bb2878d1008093d45ea377406de00ad42f0f96d7b382e1a9617b -size 191193336 +oid sha256:142f7a7d64d3b77c97acd0d15d53e9ea28c4f558776a6bb3919a4da32c2f4d37 +size 192241937 diff --git a/dimos/agents/agent_test_runner.py b/dimos/agents/agent_test_runner.py index 80758b30eb..b8b6994bec 100644 --- a/dimos/agents/agent_test_runner.py +++ b/dimos/agents/agent_test_runner.py @@ -49,8 +49,8 @@ def __init__(self, **kwargs: Any) -> None: @rpc def start(self) -> None: super().start() - self._disposables.add(Disposable(self.agent.subscribe(self._on_agent_message))) - self._disposables.add(Disposable(self.agent_idle.subscribe(self._on_agent_idle))) + self.register_disposable(Disposable(self.agent.subscribe(self._on_agent_message))) + self.register_disposable(Disposable(self.agent_idle.subscribe(self._on_agent_idle))) # Signal that subscription is ready self._subscription_ready.set() diff --git a/dimos/agents/mcp/mcp_client.py b/dimos/agents/mcp/mcp_client.py index 0c820cdf36..3ab8e62e59 100644 --- a/dimos/agents/mcp/mcp_client.py +++ b/dimos/agents/mcp/mcp_client.py @@ -169,7 +169,7 @@ def start(self) -> None: def _on_human_input(string: str) -> None: self._message_queue.put(HumanMessage(content=string)) - self._disposables.add(Disposable(self.human_input.subscribe(_on_human_input))) + self.register_disposable(Disposable(self.human_input.subscribe(_on_human_input))) @rpc def on_system_modules(self, _modules: list[RPCClient]) -> None: diff --git a/dimos/agents/skills/demo_robot.py b/dimos/agents/skills/demo_robot.py index 2917ec2d76..9e7ac8433b 100644 --- a/dimos/agents/skills/demo_robot.py +++ b/dimos/agents/skills/demo_robot.py @@ -25,7 +25,7 @@ class DemoRobot(Module): def start(self) -> None: super().start() - self._disposables.add(interval(1.0).subscribe(lambda _: self._publish_gps_location())) + self.register_disposable(interval(1.0).subscribe(lambda _: self._publish_gps_location())) def stop(self) -> None: super().stop() diff --git a/dimos/agents/skills/google_maps_skill_container.py b/dimos/agents/skills/google_maps_skill_container.py index ee48e51653..259f3ced6c 100644 --- a/dimos/agents/skills/google_maps_skill_container.py +++ b/dimos/agents/skills/google_maps_skill_container.py @@ -15,6 +15,8 @@ import json from typing import Any +from reactivex.disposable import Disposable + from dimos.agents.annotation import skill from dimos.core.core import rpc from dimos.core.module import Module @@ -49,7 +51,7 @@ def __init__(self, **kwargs: Any) -> None: @rpc def start(self) -> None: super().start() - self._disposables.add(self.gps_location.subscribe(self._on_gps_location)) # type: ignore[arg-type] + self.register_disposable(Disposable(self.gps_location.subscribe(self._on_gps_location))) @rpc def stop(self) -> None: diff --git a/dimos/agents/skills/gps_nav_skill.py b/dimos/agents/skills/gps_nav_skill.py index 52f4e726c5..136db9987d 100644 --- a/dimos/agents/skills/gps_nav_skill.py +++ b/dimos/agents/skills/gps_nav_skill.py @@ -14,6 +14,8 @@ import json +from reactivex.disposable import Disposable + from dimos.agents.annotation import skill from dimos.core.core import rpc from dimos.core.module import Module @@ -38,7 +40,7 @@ class GpsNavSkillContainer(Module): @rpc def start(self) -> None: super().start() - self._disposables.add(self.gps_location.subscribe(self._on_gps_location)) # type: ignore[arg-type] + self.register_disposable(Disposable(self.gps_location.subscribe(self._on_gps_location))) @rpc def stop(self) -> None: diff --git a/dimos/agents/skills/navigation.py b/dimos/agents/skills/navigation.py index d028f9847c..d88bec452e 100644 --- a/dimos/agents/skills/navigation.py +++ b/dimos/agents/skills/navigation.py @@ -62,8 +62,8 @@ def __init__(self, **kwargs: Any) -> None: @rpc def start(self) -> None: super().start() - self._disposables.add(Disposable(self.color_image.subscribe(self._on_color_image))) - self._disposables.add(Disposable(self.odom.subscribe(self._on_odom))) + self.register_disposable(Disposable(self.color_image.subscribe(self._on_color_image))) + self.register_disposable(Disposable(self.odom.subscribe(self._on_odom))) self._skill_started = True @rpc diff --git a/dimos/agents/skills/osm.py b/dimos/agents/skills/osm.py index a89e86044f..2172ed5dc0 100644 --- a/dimos/agents/skills/osm.py +++ b/dimos/agents/skills/osm.py @@ -13,6 +13,8 @@ # limitations under the License. +from reactivex.disposable import Disposable + from dimos.agents.annotation import skill from dimos.core.module import Module from dimos.core.stream import In @@ -39,7 +41,7 @@ def __init__(self) -> None: def start(self) -> None: super().start() if hasattr(self.gps_location, "subscribe"): - self._disposables.add(self.gps_location.subscribe(self._on_gps_location)) # type: ignore[arg-type] + self.register_disposable(Disposable(self.gps_location.subscribe(self._on_gps_location))) else: logger.warning( "OsmSkill: gps_location stream does not support direct subscribe (RemoteIn)" diff --git a/dimos/agents/skills/person_follow.py b/dimos/agents/skills/person_follow.py index ea4f5d8cda..2cb02576b0 100644 --- a/dimos/agents/skills/person_follow.py +++ b/dimos/agents/skills/person_follow.py @@ -94,9 +94,9 @@ def __init__(self, **kwargs: Any) -> None: @rpc def start(self) -> None: super().start() - self._disposables.add(Disposable(self.color_image.subscribe(self._on_color_image))) + self.register_disposable(Disposable(self.color_image.subscribe(self._on_color_image))) if self.config.use_3d_navigation: - self._disposables.add(Disposable(self.global_map.subscribe(self._on_pointcloud))) + self.register_disposable(Disposable(self.global_map.subscribe(self._on_pointcloud))) @rpc def stop(self) -> None: diff --git a/dimos/agents/vlm_agent.py b/dimos/agents/vlm_agent.py index 114302b397..7e05cd7379 100644 --- a/dimos/agents/vlm_agent.py +++ b/dimos/agents/vlm_agent.py @@ -16,6 +16,7 @@ from langchain.chat_models import init_chat_model from langchain_core.messages import AIMessage, HumanMessage, SystemMessage +from reactivex.disposable import Disposable from dimos.agents.system_prompt import SYSTEM_PROMPT from dimos.core.core import rpc @@ -60,8 +61,8 @@ def __init__(self, **kwargs: Any) -> None: @rpc def start(self) -> None: super().start() - self._disposables.add(self.color_image.subscribe(self._on_image)) # type: ignore[arg-type] - self._disposables.add(self.query_stream.subscribe(self._on_query)) # type: ignore[arg-type] + self.register_disposable(Disposable(self.color_image.subscribe(self._on_image))) + self.register_disposable(Disposable(self.query_stream.subscribe(self._on_query))) @rpc def stop(self) -> None: diff --git a/dimos/agents/vlm_stream_tester.py b/dimos/agents/vlm_stream_tester.py index f10bb59708..382fe874cd 100644 --- a/dimos/agents/vlm_stream_tester.py +++ b/dimos/agents/vlm_stream_tester.py @@ -16,6 +16,7 @@ import time from langchain_core.messages import AIMessage, HumanMessage +from reactivex.disposable import Disposable from dimos.agents.vlm_agent_spec import VLMAgentSpec from dimos.constants import DEFAULT_THREAD_JOIN_TIMEOUT @@ -62,8 +63,8 @@ def __init__( # type: ignore[no-untyped-def] @rpc def start(self) -> None: super().start() - self._disposables.add(self.color_image.subscribe(self._on_image)) # type: ignore[arg-type] - self._disposables.add(self.answer_stream.subscribe(self._on_answer)) # type: ignore[arg-type] + self.register_disposable(Disposable(self.color_image.subscribe(self._on_image))) + self.register_disposable(Disposable(self.answer_stream.subscribe(self._on_answer))) self._worker = threading.Thread(target=self._run_queries, daemon=True) self._worker.start() diff --git a/dimos/agents/web_human_input.py b/dimos/agents/web_human_input.py index 09d55c8d64..0a4fe7c3f3 100644 --- a/dimos/agents/web_human_input.py +++ b/dimos/agents/web_human_input.py @@ -65,11 +65,11 @@ def start(self) -> None: # Subscribe to both text input sources # 1. Direct text from web interface unsub = self._web_interface.query_stream.subscribe(self._human_transport.publish) - self._disposables.add(unsub) + self.register_disposable(unsub) # 2. Transcribed text from STT unsub = stt_node.emit_text().subscribe(self._human_transport.publish) - self._disposables.add(unsub) + self.register_disposable(unsub) self._thread = Thread(target=self._web_interface.run, daemon=True) self._thread.start() diff --git a/dimos/core/module.py b/dimos/core/module.py index 078917314a..f340d2325e 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -37,7 +37,8 @@ from dimos.core.global_config import GlobalConfig, global_config from dimos.core.introspection.module.info import extract_module_info from dimos.core.introspection.module.render import render_module_io -from dimos.core.resource import Resource +from dimos.core.resource import CompositeResource +from dimos.core.rpc_client import RpcCall from dimos.core.stream import In, Out, RemoteOut, Transport from dimos.protocol.rpc.pubsubrpc import LCMRPC from dimos.protocol.rpc.spec import DEFAULT_RPC_TIMEOUT, DEFAULT_RPC_TIMEOUTS, RPCSpec @@ -97,7 +98,7 @@ class _BlueprintPartial(Protocol): def __call__(self, **kwargs: Any) -> "Blueprint": ... -class ModuleBase(Configurable[ModuleConfigT], Resource): +class ModuleBase(Configurable[ModuleConfigT], CompositeResource): # This won't type check against the TypeVar, but we need it as the default. default_config: type[ModuleConfigT] = ModuleConfig # type: ignore[assignment] @@ -109,6 +110,7 @@ class ModuleBase(Configurable[ModuleConfigT], Resource): _tf: TFSpec[Any] | None = None _loop: asyncio.AbstractEventLoop | None = None _loop_thread: threading.Thread | None + _bound_rpc_calls: dict[str, RpcCall] = {} _disposables: CompositeDisposable _module_closed: bool = False _module_closed_lock: threading.Lock @@ -118,7 +120,6 @@ def __init__(self, config_args: dict[str, Any]): super().__init__(**config_args) self._module_closed_lock = threading.Lock() self._loop, self._loop_thread = get_loop() - self._disposables = CompositeDisposable() try: self.rpc = self.config.rpc_transport( # type: ignore[call-arg] rpc_timeouts=self.config.rpc_timeouts, @@ -151,6 +152,7 @@ def start(self) -> None: @rpc def stop(self) -> None: + super().stop() self._close_module() def _close_module(self) -> None: @@ -177,14 +179,12 @@ def _close_module(self) -> None: if hasattr(self, "_tf") and self._tf is not None: self._tf.stop() self._tf = None - if hasattr(self, "_disposables"): - self._disposables.dispose() - # Break the In/Out -> owner -> self reference cycle so the instance - # can be freed by refcount instead of waiting for GC. - for attr in list(vars(self).values()): - if isinstance(attr, (In, Out)): - attr.owner = None + # Stop transports and break the In/Out -> owner -> self reference + # cycle so the instance can be freed by refcount instead of waiting for GC. + for attr in [*self.inputs.values(), *self.outputs.values()]: + attr.stop() + attr.owner = None def _close_rpc(self) -> None: if self.rpc: @@ -207,7 +207,6 @@ def __setstate__(self, state) -> None: # type: ignore[no-untyped-def] """Restore object from pickled state.""" self.__dict__.update(state) # Reinitialize runtime attributes - self._disposables = CompositeDisposable() self._module_closed_lock = threading.Lock() self._loop = None self._loop_thread = None diff --git a/dimos/core/resource.py b/dimos/core/resource.py index a4c008b806..a924ed8be3 100644 --- a/dimos/core/resource.py +++ b/dimos/core/resource.py @@ -16,7 +16,7 @@ from abc import abstractmethod import sys -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypeVar if sys.version_info >= (3, 11): from typing import Self @@ -29,6 +29,8 @@ from reactivex.abc import DisposableBase from reactivex.disposable import CompositeDisposable +D = TypeVar("D", bound=DisposableBase) + class Resource(DisposableBase): @abstractmethod @@ -75,18 +77,17 @@ def __exit__( class CompositeResource(Resource): """Resource that owns child disposables, disposed on stop().""" - _disposables: CompositeDisposable - - def __init__(self) -> None: - self._disposables = CompositeDisposable() + _disposables: CompositeDisposable | None = None - def register_disposables(self, *disposables: DisposableBase) -> None: - """Register child disposables to be disposed when this resource stops.""" - for d in disposables: - self._disposables.add(d) + def register_disposable(self, disposable: D) -> D: + """Register a child disposable to be disposed when this resource stops.""" + if self._disposables is None: + self._disposables = CompositeDisposable() + self._disposables.add(disposable) + return disposable - def start(self) -> None: - pass + def start(self) -> None: ... def stop(self) -> None: - self._disposables.dispose() + if self._disposables is not None: + self._disposables.dispose() diff --git a/dimos/core/stream.py b/dimos/core/stream.py index 7791968a29..41462ddbaa 100644 --- a/dimos/core/stream.py +++ b/dimos/core/stream.py @@ -135,6 +135,10 @@ def __str__(self) -> str: + ("" if not self._transport else " via " + str(self._transport)) ) + def stop(self) -> None: + if self._transport is not None: + self._transport.stop() + class Out(Stream[T], ObservableMixin[T]): _transport: Transport # type: ignore[type-arg] diff --git a/dimos/core/test_core.py b/dimos/core/test_core.py index f8cc019af0..bf62e526ee 100644 --- a/dimos/core/test_core.py +++ b/dimos/core/test_core.py @@ -47,7 +47,7 @@ def _odom(msg) -> None: self.mov.publish(msg.position) unsub = self.odometry.subscribe(_odom) - self._disposables.add(Disposable(unsub)) + self.register_disposable(Disposable(unsub)) def _lidar(msg) -> None: self.lidar_msg_count += 1 @@ -57,7 +57,7 @@ def _lidar(msg) -> None: print("RCV: unknown time", msg) unsub = self.lidar.subscribe(_lidar) - self._disposables.add(Disposable(unsub)) + self.register_disposable(Disposable(unsub)) def test_classmethods() -> None: diff --git a/dimos/hardware/sensors/camera/module.py b/dimos/hardware/sensors/camera/module.py index ab61babf8d..e2c7451009 100644 --- a/dimos/hardware/sensors/camera/module.py +++ b/dimos/hardware/sensors/camera/module.py @@ -77,11 +77,11 @@ def on_image(image: Image) -> None: self.color_image.publish(image) self._latest_image = image - self._disposables.add( + self.register_disposable( stream.subscribe(on_image), ) - self._disposables.add( + self.register_disposable( rx.interval(1.0).subscribe(lambda _: self.publish_metadata()), ) diff --git a/dimos/hardware/sensors/camera/realsense/camera.py b/dimos/hardware/sensors/camera/realsense/camera.py index 1271caf79f..fc0676e44f 100644 --- a/dimos/hardware/sensors/camera/realsense/camera.py +++ b/dimos/hardware/sensors/camera/realsense/camera.py @@ -163,7 +163,7 @@ def start(self) -> None: if self.config.enable_pointcloud and self.config.enable_depth: interval_sec = 1.0 / self.config.pointcloud_fps - self._disposables.add( + self.register_disposable( backpressure(rx.interval(interval_sec)).subscribe( on_next=lambda _: self._generate_pointcloud(), on_error=lambda e: print(f"Pointcloud error: {e}"), @@ -171,7 +171,7 @@ def start(self) -> None: ) interval_sec = 1.0 / self.config.camera_info_fps - self._disposables.add( + self.register_disposable( rx.interval(interval_sec).subscribe( on_next=lambda _: self._publish_camera_info(), on_error=lambda e: print(f"CameraInfo error: {e}"), diff --git a/dimos/hardware/sensors/camera/zed/camera.py b/dimos/hardware/sensors/camera/zed/camera.py index d48f50e25e..d49d583b66 100644 --- a/dimos/hardware/sensors/camera/zed/camera.py +++ b/dimos/hardware/sensors/camera/zed/camera.py @@ -181,7 +181,7 @@ def start(self) -> None: self._enable_tracking() interval_sec = 1.0 / self.config.camera_info_fps - self._disposables.add( + self.register_disposable( rx.interval(interval_sec).subscribe( on_next=lambda _: self._publish_camera_info(), on_error=lambda e: print(f"CameraInfo error: {e}"), @@ -194,7 +194,7 @@ def start(self) -> None: if self.config.enable_pointcloud and self.config.enable_depth: interval_sec = 1.0 / self.config.pointcloud_fps - self._disposables.add( + self.register_disposable( backpressure(rx.interval(interval_sec)).subscribe( on_next=lambda _: self._generate_pointcloud(), on_error=lambda e: print(f"Pointcloud error: {e}"), diff --git a/dimos/hardware/sensors/fake_zed_module.py b/dimos/hardware/sensors/fake_zed_module.py index 16e85aa93c..21c1d27599 100644 --- a/dimos/hardware/sensors/fake_zed_module.py +++ b/dimos/hardware/sensors/fake_zed_module.py @@ -224,7 +224,7 @@ def start(self) -> None: unsub = self._get_color_stream().subscribe( lambda msg: self.color_image.publish(msg) if self._running else None ) - self._disposables.add(unsub) + self.register_disposable(unsub) logger.info("Started color image replay stream") except Exception as e: logger.warning(f"Color image stream not available: {e}") @@ -234,7 +234,7 @@ def start(self) -> None: unsub = self._get_depth_stream().subscribe( lambda msg: self.depth_image.publish(msg) if self._running else None ) - self._disposables.add(unsub) + self.register_disposable(unsub) logger.info("Started depth image replay stream") except Exception as e: logger.warning(f"Depth image stream not available: {e}") @@ -244,7 +244,7 @@ def start(self) -> None: unsub = self._get_pose_stream().subscribe( lambda msg: self._publish_pose(msg) if self._running else None ) - self._disposables.add(unsub) + self.register_disposable(unsub) logger.info("Started pose replay stream") except Exception as e: logger.warning(f"Pose stream not available: {e}") @@ -254,7 +254,7 @@ def start(self) -> None: unsub = self._get_camera_info_stream().subscribe( lambda msg: self.camera_info.publish(msg) if self._running else None ) - self._disposables.add(unsub) + self.register_disposable(unsub) logger.info("Started camera info replay stream") except Exception as e: logger.warning(f"Camera info stream not available: {e}") diff --git a/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py b/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py index d2212922cc..86076c5a39 100644 --- a/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py +++ b/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py @@ -30,7 +30,7 @@ mid360_fastlio_voxels = autoconnect( FastLio2.blueprint(), - VoxelGridMapper.blueprint(publish_interval=1.0, voxel_size=voxel_size, carve_columns=False), + VoxelGridMapper.blueprint(voxel_size=voxel_size, carve_columns=False), RerunBridgeModule.blueprint( visual_override={ "world/global_map": lambda grid: grid.to_rerun(voxel_size=voxel_size, mode="boxes"), diff --git a/dimos/mapping/costmapper.py b/dimos/mapping/costmapper.py index 87ed64d404..0ec376a88f 100644 --- a/dimos/mapping/costmapper.py +++ b/dimos/mapping/costmapper.py @@ -60,7 +60,7 @@ def _calculate_and_time( elapsed_ms = (time.perf_counter() - start) * 1000 return grid, elapsed_ms, rx_monotonic - self._disposables.add( + self.register_disposable( self.global_map.observable() # type: ignore[no-untyped-call] .pipe(ops.map(_calculate_and_time)) .subscribe(lambda result: _publish_costmap(result[0], result[1], result[2])) diff --git a/dimos/mapping/pointclouds/test_occupancy_speed.py b/dimos/mapping/pointclouds/test_occupancy_speed.py index ac4085e971..115ee73ae0 100644 --- a/dimos/mapping/pointclouds/test_occupancy_speed.py +++ b/dimos/mapping/pointclouds/test_occupancy_speed.py @@ -18,7 +18,7 @@ import pytest from dimos.mapping.pointclouds.occupancy import OCCUPANCY_ALGOS -from dimos.mapping.voxels import VoxelGridMapper +from dimos.mapping.voxels import VoxelGrid from dimos.utils.cli.plot import bar from dimos.utils.data import get_data, get_data_dir from dimos.utils.testing.replay import TimedSensorReplay @@ -26,18 +26,18 @@ @pytest.mark.tool def test_build_map(): - mapper = VoxelGridMapper(publish_interval=-1) + grid = VoxelGrid() for _ts, frame in TimedSensorReplay("unitree_go2_bigoffice/lidar").iterate(): - mapper.add_frame(frame) + grid.add_frame(frame) pickle_file = get_data_dir() / "unitree_go2_bigoffice_map.pickle" - global_pcd = mapper.get_global_pointcloud2() + global_pcd = grid.get_global_pointcloud2() with open(pickle_file, "wb") as f: pickle.dump(global_pcd, f) - mapper.stop() + grid.dispose() def test_costmap_calc(): diff --git a/dimos/mapping/test_voxels.py b/dimos/mapping/test_voxels.py index bb5f4ed764..fc95b4652b 100644 --- a/dimos/mapping/test_voxels.py +++ b/dimos/mapping/test_voxels.py @@ -19,7 +19,7 @@ import pytest from dimos.core.transport import LCMTransport -from dimos.mapping.voxels import VoxelGridMapper +from dimos.mapping.voxels import VoxelGrid from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.utils.data import get_data from dimos.utils.testing.moment import OutputMoment @@ -28,10 +28,10 @@ @pytest.fixture -def mapper() -> Generator[VoxelGridMapper, None, None]: - mapper = VoxelGridMapper() - yield mapper - mapper.stop() +def grid() -> Generator[VoxelGrid, None, None]: + g = VoxelGrid() + yield g + g.dispose() class Go2MapperMoment(Go2Moment): @@ -78,21 +78,19 @@ def two_perspectives_loop(moment: MomentFactory) -> None: @pytest.mark.tool -def test_carving( - mapper: VoxelGridMapper, moment1: Go2MapperMoment, moment2: Go2MapperMoment -) -> None: +def test_carving(grid: VoxelGrid, moment1: Go2MapperMoment, moment2: Go2MapperMoment) -> None: lidar_frame1 = moment1.lidar.value assert lidar_frame1 is not None lidar_frame2 = moment2.lidar.value assert lidar_frame2 is not None - # Carving mapper (default, carve_columns=True) - mapper.add_frame(lidar_frame1) - mapper.add_frame(lidar_frame2) - count_carving = mapper.size() + # Carving grid (default, carve_columns=True) + grid.add_frame(lidar_frame1) + grid.add_frame(lidar_frame2) + count_carving = grid.size() - voxel_size = mapper.config.voxel_size + voxel_size = grid._voxel_size pts1 = np.asarray(lidar_frame1.pointcloud.points) pts2 = np.asarray(lidar_frame2.pointcloud.points) combined_vox = np.floor(np.vstack([pts1, pts2]) / voxel_size).astype(np.int64) @@ -109,7 +107,7 @@ def test_carving( ) -def test_injest_a_few(mapper: VoxelGridMapper) -> None: +def test_ingest_a_few(grid: VoxelGrid) -> None: data_dir = get_data("unitree_go2_office_walk2") lidar_store = TimedSensorReplay(f"{data_dir}/lidar") @@ -117,9 +115,9 @@ def test_injest_a_few(mapper: VoxelGridMapper) -> None: frame = lidar_store.find_closest_seek(i) assert frame is not None print("add", frame) - mapper.add_frame(frame) + grid.add_frame(frame) - assert len(mapper.get_global_pointcloud2()) == 30136 + assert len(grid.get_global_pointcloud2()) == 30136 @pytest.mark.parametrize( @@ -134,10 +132,10 @@ def test_roundtrip(moment1: Go2MapperMoment, voxel_size: float, expected_points: lidar_frame = moment1.lidar.value assert lidar_frame is not None - mapper = VoxelGridMapper(voxel_size=voxel_size) - mapper.add_frame(lidar_frame) + grid = VoxelGrid(voxel_size=voxel_size) + grid.add_frame(lidar_frame) - global1 = mapper.get_global_pointcloud2() + global1 = grid.get_global_pointcloud2() assert len(global1) == expected_points # loseless roundtrip @@ -146,15 +144,15 @@ def test_roundtrip(moment1: Go2MapperMoment, voxel_size: float, expected_points: # TODO: we want __eq__ on PointCloud2 - should actually compare # all points in both frames - mapper.add_frame(global1) + grid.add_frame(global1) # no new information, no global map change - assert len(mapper.get_global_pointcloud2()) == len(global1) + assert len(grid.get_global_pointcloud2()) == len(global1) moment1.publish() - mapper.stop() + grid.dispose() -def test_roundtrip_range_preserved(mapper: VoxelGridMapper) -> None: +def test_roundtrip_range_preserved(grid: VoxelGrid) -> None: """Test that input coordinate ranges are preserved in output.""" data_dir = get_data("unitree_go2_office_walk2") lidar_store = TimedSensorReplay(f"{data_dir}/lidar") @@ -163,12 +161,12 @@ def test_roundtrip_range_preserved(mapper: VoxelGridMapper) -> None: assert frame is not None input_pts = np.asarray(frame.pointcloud.points) - mapper.add_frame(frame) + grid.add_frame(frame) - out_pcd = mapper.get_global_pointcloud().to_legacy() + out_pcd = grid.get_global_pointcloud().to_legacy() out_pts = np.asarray(out_pcd.points) - voxel_size = mapper.config.voxel_size + voxel_size = grid._voxel_size tolerance = voxel_size # Allow one voxel of difference at boundaries # TODO: we want __eq__ on PointCloud2 - should actually compare diff --git a/dimos/mapping/voxels.py b/dimos/mapping/voxels.py index 94c63b099f..9f976e14c0 100644 --- a/dimos/mapping/voxels.py +++ b/dimos/mapping/voxels.py @@ -12,61 +12,65 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import time -from typing import Any +from typing import TYPE_CHECKING, Any -import numpy as np import open3d as o3d # type: ignore[import-untyped] import open3d.core as o3c # type: ignore[import-untyped] -from reactivex import interval, operators as ops -from reactivex.disposable import Disposable -from reactivex.subject import Subject -from dimos.core.core import rpc -from dimos.core.module import Module, ModuleConfig +from dimos.core.module import ModuleConfig from dimos.core.stream import In, Out +from dimos.memory2.module import StreamModule +from dimos.memory2.stream import Stream +from dimos.memory2.transform import Transformer from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.utils.decorators.decorators import simple_mcache from dimos.utils.logging_config import setup_logger -from dimos.utils.reactive import backpressure -logger = setup_logger() +if TYPE_CHECKING: + from collections.abc import Iterator + from dimos.memory2.type.observation import Observation -class Config(ModuleConfig): - frame_id: str = "world" - # -1 never publishes, 0 publishes on every frame, >0 publishes at interval in seconds - publish_interval: float = 0 - voxel_size: float = 0.05 - block_count: int = 2_000_000 - device: str = "CUDA:0" - carve_columns: bool = True +logger = setup_logger() -class VoxelGridMapper(Module[Config]): - default_config = Config +class VoxelGrid: + """Pure voxel grid accumulator using Open3D VoxelBlockGrid. - lidar: In[PointCloud2] - global_map: Out[PointCloud2] + No Module/framework dependency. Can be used standalone or wrapped + by VoxelGridMapper (Module) or VoxelMapTransformer (memory2 Transformer). + """ - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) + def __init__( + self, + voxel_size: float = 0.05, + block_count: int = 2_000_000, + device: str = "CUDA:0", + carve_columns: bool = True, + frame_id: str = "world", + ) -> None: + self._voxel_size = voxel_size + self._carve_columns = carve_columns + self._frame_id = frame_id dev = ( - o3c.Device(self.config.device) - if (self.config.device.startswith("CUDA") and o3c.cuda.is_available()) + o3c.Device(device) + if (device.startswith("CUDA") and o3c.cuda.is_available()) else o3c.Device("CPU:0") ) - logger.info(f"VoxelGridMapper using device: {dev}") + logger.info(f"VoxelGrid using device: {dev}") - self.vbg = o3d.t.geometry.VoxelBlockGrid( + self.vbg: o3d.t.geometry.VoxelBlockGrid | None = o3d.t.geometry.VoxelBlockGrid( attr_names=("dummy",), attr_dtypes=(o3c.uint8,), attr_channels=(o3c.SizeVector([1]),), - voxel_size=self.config.voxel_size, + voxel_size=voxel_size, block_resolution=1, - block_count=self.config.block_count, + block_count=block_count, device=dev, ) @@ -74,71 +78,27 @@ def __init__(self, **kwargs: Any) -> None: self._voxel_hashmap = self.vbg.hashmap() self._key_dtype = self._voxel_hashmap.key_tensor().dtype self._latest_frame_ts: float = 0.0 + self._disposed = False - @rpc - def start(self) -> None: - super().start() - - # Subject to trigger publishing, with backpressure to drop if busy - self._publish_trigger: Subject[None] = Subject() - self._disposables.add( - backpressure(self._publish_trigger) - .pipe(ops.map(lambda _: self.publish_global_map())) - .subscribe() - ) - - lidar_unsub = self.lidar.subscribe(self._on_frame) - self._disposables.add(Disposable(lidar_unsub)) - - # If publish_interval > 0, publish on timer; otherwise publish on each frame - if self.config.publish_interval > 0: - self._disposables.add( - interval(self.config.publish_interval).subscribe( - lambda _: self._publish_trigger.on_next(None) - ) - ) - - @rpc - def stop(self) -> None: - super().stop() - # Free tensor-tracked objects eagerly so Open3D does not report them as leaks. - self.get_global_pointcloud.invalidate_cache(self) - self.get_global_pointcloud2.invalidate_cache(self) - self.vbg = None - self._voxel_hashmap = None - - def _on_frame(self, frame: PointCloud2) -> None: - self.add_frame(frame) - if self.config.publish_interval == 0: - self._publish_trigger.on_next(None) - - def publish_global_map(self) -> None: - pc = self.get_global_pointcloud2() - self.global_map.publish(pc) - - def size(self) -> int: - return self._voxel_hashmap.size() # type: ignore[no-any-return] - - def __len__(self) -> int: - return self.size() + def _check_disposed(self) -> None: + if self._disposed: + raise RuntimeError("VoxelGrid has been disposed and cannot be used") - # @timed() # TODO: fix thread leak in timed decorator def add_frame(self, frame: PointCloud2) -> None: - # Track latest frame timestamp for proper latency measurement - if hasattr(frame, "ts") and frame.ts: + self._check_disposed() + if frame.ts is not None: self._latest_frame_ts = frame.ts - # we are potentially moving into CUDA here pcd = ensure_tensor_pcd(frame.pointcloud, self._dev) if pcd.is_empty(): return pts = pcd.point["positions"].to(self._dev, o3c.float32) - vox = (pts / self.config.voxel_size).floor().to(self._key_dtype) + vox = (pts / self._voxel_size).floor().to(self._key_dtype) keys_Nx3 = vox.contiguous() - if self.config.carve_columns: + if self._carve_columns: self._carve_and_insert(keys_Nx3) else: self._voxel_hashmap.activate(keys_Nx3) @@ -152,10 +112,8 @@ def _carve_and_insert(self, new_keys: o3c.Tensor) -> None: self._voxel_hashmap.activate(new_keys) return - # Extract (X, Y) from incoming keys xy_keys = new_keys[:, :2].contiguous() - # Build temp hashmap for O(1) (X,Y) membership lookup xy_hashmap = o3c.HashMap( init_capacity=xy_keys.shape[0], key_dtype=self._key_dtype, @@ -167,7 +125,6 @@ def _carve_and_insert(self, new_keys: o3c.Tensor) -> None: dummy_vals = o3c.Tensor.zeros((xy_keys.shape[0], 1), o3c.uint8, self._dev) xy_hashmap.insert(xy_keys, dummy_vals) - # Get existing keys from main hashmap active_indices = self._voxel_hashmap.active_buf_indices() if active_indices.shape[0] == 0: self._voxel_hashmap.activate(new_keys) @@ -176,38 +133,128 @@ def _carve_and_insert(self, new_keys: o3c.Tensor) -> None: existing_keys = self._voxel_hashmap.key_tensor()[active_indices] existing_xy = existing_keys[:, :2].contiguous() - # Find which existing keys have (X,Y) in the incoming set _, found_mask = xy_hashmap.find(existing_xy) - # Erase those columns to_erase = existing_keys[found_mask] if to_erase.shape[0] > 0: self._voxel_hashmap.erase(to_erase) - # Insert new keys self._voxel_hashmap.activate(new_keys) - # returns PointCloud2 message (ready to send off down the pipeline) @simple_mcache def get_global_pointcloud2(self) -> PointCloud2: + self._check_disposed() return PointCloud2( - # we are potentially moving out of CUDA here ensure_legacy_pcd(self.get_global_pointcloud()), - frame_id=self.frame_id, + frame_id=self._frame_id, ts=self._latest_frame_ts if self._latest_frame_ts else time.time(), ) @simple_mcache - # @timed() def get_global_pointcloud(self) -> o3d.t.geometry.PointCloud: + self._check_disposed() + assert self.vbg is not None voxel_coords, _ = self.vbg.voxel_coordinates_and_flattened_indices() # Move to CPU immediately to avoid holding a large duplicate on GPU. cpu = o3c.Device("CPU:0") - pts = voxel_coords.to(cpu) + (self.config.voxel_size * 0.5) + pts = voxel_coords.to(cpu) + (self._voxel_size * 0.5) out = o3d.t.geometry.PointCloud(device=cpu) out.point["positions"] = pts return out + def size(self) -> int: + self._check_disposed() + return self._voxel_hashmap.size() # type: ignore[no-any-return] + + def __len__(self) -> int: + return self.size() + + def dispose(self) -> None: + """Free GPU resources. The object is unusable after this call.""" + if self._disposed: + return + self._disposed = True + self.get_global_pointcloud.invalidate_cache(self) # type: ignore[attr-defined] + self.get_global_pointcloud2.invalidate_cache(self) # type: ignore[attr-defined] + self.vbg = None + self._voxel_hashmap = None + + +class VoxelMapTransformer(Transformer[PointCloud2, PointCloud2]): + """Accumulate PointCloud2 observations into a global voxel map. + + Assumes input clouds are already in world frame. + All keyword arguments except ``emit_every`` are forwarded to + :class:`VoxelGrid`. + + Args: + emit_every: Yield the current accumulated map every *n* frames. + ``1`` (default) = yield after every frame (live-compatible). + ``0`` = yield only when upstream exhausts (batch mode). + **grid_kwargs: Forwarded to ``VoxelGrid()``. + """ + + def __init__(self, *, emit_every: int = 1, **grid_kwargs: Any) -> None: + self.emit_every = emit_every + self._grid_kwargs = grid_kwargs + + def _make_obs( + self, grid: VoxelGrid, last_obs: Observation[PointCloud2], count: int + ) -> Observation[PointCloud2]: + # pose=None: the global map is in world frame, per-observation pose is meaningless + return last_obs.derive( + data=grid.get_global_pointcloud2(), + pose=None, + tags={**last_obs.tags, "frame_count": count}, + ) + + def __call__( + self, upstream: Iterator[Observation[PointCloud2]] + ) -> Iterator[Observation[PointCloud2]]: + grid = VoxelGrid(**self._grid_kwargs) + try: + last_obs: Observation[PointCloud2] | None = None + count = 0 + + for obs in upstream: + grid.add_frame(obs.data) + last_obs = obs + count += 1 + + if self.emit_every > 0 and count % self.emit_every == 0: + yield self._make_obs(grid, last_obs, count) + + # Yield on exhaustion: always in batch mode, or if there are un-emitted frames + if last_obs is not None and (self.emit_every == 0 or count % self.emit_every != 0): + yield self._make_obs(grid, last_obs, count) + finally: + grid.dispose() + + +class VoxelGridMapperConfig(ModuleConfig): + """Configuration for VoxelGridMapper.""" + + voxel_size: float = 0.05 + block_count: int = 2_000_000 + device: str = "CUDA:0" + carve_columns: bool = True + frame_id: str = "world" + + +class VoxelGridMapper(StreamModule[VoxelGridMapperConfig]): + """Accumulate lidar point clouds into a global voxel map.""" + + default_config = VoxelGridMapperConfig + + def pipeline(self, stream: Stream[PointCloud2]) -> Stream[PointCloud2]: + cfg = self.config.model_dump( + include=set(VoxelGridMapperConfig.model_fields) - set(ModuleConfig.model_fields) + ) + return stream.transform(VoxelMapTransformer(**cfg)) + + lidar: In[PointCloud2] + global_map: Out[PointCloud2] + def ensure_tensor_pcd( pcd_any: o3d.t.geometry.PointCloud | o3d.geometry.PointCloud, @@ -222,14 +269,7 @@ def ensure_tensor_pcd( "Input must be a legacy PointCloud or a tensor PointCloud" ) - # Legacy CPU point cloud -> tensor - if isinstance(pcd_any, o3d.geometry.PointCloud): - return o3d.t.geometry.PointCloud.from_legacy(pcd_any, o3c.float32, device) - - pts = np.asarray(pcd_any.points, dtype=np.float32) - pcd_t = o3d.t.geometry.PointCloud(device=device) - pcd_t.point["positions"] = o3c.Tensor(pts, o3c.float32, device) - return pcd_t + return o3d.t.geometry.PointCloud.from_legacy(pcd_any, o3c.float32, device) def ensure_legacy_pcd( diff --git a/dimos/memory/embedding.py b/dimos/memory/embedding.py index df047292a0..9dece58bb7 100644 --- a/dimos/memory/embedding.py +++ b/dimos/memory/embedding.py @@ -56,7 +56,7 @@ class EmbeddingMemory(Module[Config]): def get_costmap(self) -> OccupancyGrid: if self._costmap_getter is None: self._costmap_getter = getter_hot(self.global_costmap.pure_observable()) - self._disposables.add(self._costmap_getter) + self.register_disposable(self._costmap_getter) return self._costmap_getter() @rpc diff --git a/dimos/memory2/backend.py b/dimos/memory2/backend.py index c861993de9..d330b10fd5 100644 --- a/dimos/memory2/backend.py +++ b/dimos/memory2/backend.py @@ -19,6 +19,7 @@ from dataclasses import replace from typing import TYPE_CHECKING, Any, Generic, TypeVar +from dimos.core.resource import CompositeResource from dimos.memory2.codecs.base import Codec, codec_id from dimos.memory2.notifier.subject import SubjectNotifier from dimos.memory2.type.observation import _UNLOADED @@ -39,12 +40,9 @@ T = TypeVar("T") -class Backend(Generic[T]): +class Backend(CompositeResource, Generic[T]): """Orchestrates metadata, blob, vector, and live stores for one stream. - - This is a concrete class — NOT a protocol. All shared orchestration logic (encode → insert → store blob → index vector → notify) lives here, - eliminating duplication between ListObservationStore and SqliteObservationStore. """ def __init__( @@ -57,13 +55,21 @@ def __init__( notifier: Notifier[T] | None = None, eager_blobs: bool = False, ) -> None: - self.metadata_store = metadata_store + super().__init__() + self.metadata_store = self.register_disposable(metadata_store) self.codec = codec - self.blob_store = blob_store - self.vector_store = vector_store - self.notifier: Notifier[T] = notifier or SubjectNotifier() + self.blob_store = self.register_disposable(blob_store) if blob_store else None + self.vector_store = self.register_disposable(vector_store) if vector_store else None + self.notifier: Notifier[T] = self.register_disposable(notifier or SubjectNotifier()) self.eager_blobs = eager_blobs + def start(self) -> None: + self.metadata_store.start() + if self.blob_store is not None: + self.blob_store.start() + if self.vector_store is not None: + self.vector_store.start() + @property def name(self) -> str: return self.metadata_store.name @@ -237,8 +243,3 @@ def serialize(self) -> dict[str, Any]: "vector_store": self.vector_store.serialize() if self.vector_store else None, "notifier": self.notifier.serialize(), } - - def stop(self) -> None: - """Stop the metadata store (closes per-stream connections if any).""" - if hasattr(self.metadata_store, "stop"): - self.metadata_store.stop() diff --git a/dimos/memory2/blobstore/sqlite.py b/dimos/memory2/blobstore/sqlite.py index 1cb5f1aa38..8092a34d1d 100644 --- a/dimos/memory2/blobstore/sqlite.py +++ b/dimos/memory2/blobstore/sqlite.py @@ -78,7 +78,7 @@ def start(self) -> None: if self._conn is None: assert self._path is not None disposable, self._conn = open_disposable_sqlite_connection(self._path) - self.register_disposables(disposable) + self.register_disposable(disposable) def put(self, stream_name: str, key: int, data: bytes) -> None: self._ensure_table(stream_name) diff --git a/dimos/memory2/module.py b/dimos/memory2/module.py new file mode 100644 index 0000000000..881b1d929a --- /dev/null +++ b/dimos/memory2/module.py @@ -0,0 +1,110 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import inspect +from typing import Any + +from dimos.core.core import rpc +from dimos.core.module import Module, ModuleConfigT +from dimos.memory2.store.null import NullStore +from dimos.memory2.stream import Stream + + +class StreamModule(Module[ModuleConfigT]): + """Module base class that wires a memory2 stream pipeline. + + **Static pipeline** + + class VoxelGridMapper(StreamModule): + pipeline = Stream().transform(VoxelMapTransformer()) + lidar: In[PointCloud2] + global_map: Out[PointCloud2] + + **Config-driven pipeline** + + class VoxelGridMapper(StreamModule[VoxelGridMapperConfig]): + def pipeline(self, stream: Stream) -> Stream: + return stream.transform(VoxelMap(**self.config.model_dump())) + + lidar: In[PointCloud2] + global_map: Out[PointCloud2] + + On start, the single ``In`` port feeds a MemoryStore, and the pipeline + is applied to the live stream, publishing results to the single ``Out`` port. + + The MemoryStore acts as a bridge between the push-based Module In port + and the pull-based memory2 stream pipeline — it also enables replay and + persistence if the store is swapped for a persistent backend later. + """ + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + + @rpc + def start(self) -> None: + super().start() + + if len(self.inputs) != 1 or len(self.outputs) != 1: + raise TypeError( + f"{self.__class__.__name__} must have exactly one In and one Out port, " + f"found {len(self.inputs)} In and {len(self.outputs)} Out" + ) + + ((in_name, inp_port),) = self.inputs.items() + ((_, out_port),) = self.outputs.items() + + store = self.register_disposable(NullStore()) + store.start() + + stream: Stream[Any] = store.stream(in_name, inp_port.type) + + # we push input into the stream + inp_port.subscribe(lambda msg: stream.append(msg)) + + live = stream.live() + # and we push stream output to the output port + self._apply_pipeline(live).subscribe( + lambda obs: out_port.publish(obs.data), + ) + + def _apply_pipeline(self, stream: Stream[Any]) -> Stream[Any]: + """Apply the pipeline to a live stream. + + Handles both static (class attr) and dynamic (method) pipelines. + """ + pipeline = getattr(self.__class__, "pipeline", None) + if pipeline is None: + raise TypeError( + f"{self.__class__.__name__} must define a 'pipeline' attribute or method" + ) + + # Method pipeline: self.pipeline(stream) -> stream + if inspect.isfunction(pipeline): + result = pipeline(self, stream) + if not isinstance(result, Stream): + raise TypeError( + f"{self.__class__.__name__}.pipeline() must return a Stream, got {type(result).__name__}" + ) + return result + + # Static class attr: Stream (unbound chain) or Transformer + if isinstance(pipeline, Stream): + return stream.chain(pipeline) + return stream.transform(pipeline) + + @rpc + def stop(self) -> None: + super().stop() diff --git a/dimos/memory2/notifier/base.py b/dimos/memory2/notifier/base.py index 022d26d4e0..bb25a1cbf6 100644 --- a/dimos/memory2/notifier/base.py +++ b/dimos/memory2/notifier/base.py @@ -17,6 +17,7 @@ from abc import abstractmethod from typing import TYPE_CHECKING, Any, Generic, TypeVar +from dimos.core.resource import Resource from dimos.memory2.registry import qual from dimos.protocol.service.spec import BaseConfig, Configurable @@ -33,7 +34,7 @@ class NotifierConfig(BaseConfig): pass -class Notifier(Configurable[NotifierConfig], Generic[T]): +class Notifier(Configurable[NotifierConfig], Resource, Generic[T]): """Push-notification for live observation delivery. Decouples the notification mechanism from storage. The built-in @@ -47,6 +48,12 @@ class Notifier(Configurable[NotifierConfig], Generic[T]): def __init__(self, **kwargs: Any) -> None: Configurable.__init__(self, **kwargs) + def start(self) -> None: + pass + + def stop(self) -> None: + pass + @abstractmethod def subscribe(self, buf: BackpressureBuffer[Observation[T]]) -> DisposableBase: """Register *buf* to receive new observations. Returns a diff --git a/dimos/memory2/notifier/subject.py b/dimos/memory2/notifier/subject.py index d1b8d7f888..4b43d28c0a 100644 --- a/dimos/memory2/notifier/subject.py +++ b/dimos/memory2/notifier/subject.py @@ -68,3 +68,11 @@ def notify(self, obs: Observation[T]) -> None: subs = list(self._subscribers) for buf in subs: buf.put(obs) + + def stop(self) -> None: + """Close all subscribed buffers, unblocking any live iterators.""" + with self._lock: + subs = list(self._subscribers) + self._subscribers.clear() + for buf in subs: + buf.close() diff --git a/dimos/memory2/observationstore/memory.py b/dimos/memory2/observationstore/memory.py index 529cd06394..faeb0fbec1 100644 --- a/dimos/memory2/observationstore/memory.py +++ b/dimos/memory2/observationstore/memory.py @@ -14,6 +14,7 @@ from __future__ import annotations +from collections import deque import threading from typing import TYPE_CHECKING, Any, TypeVar @@ -30,10 +31,17 @@ class ListObservationStoreConfig(ObservationStoreConfig): name: str = "" + max_size: int | None = None class ListObservationStore(ObservationStore[T]): - """In-memory metadata store for experimentation. Thread-safe.""" + """In-memory metadata store for experimentation. Thread-safe. + + ``max_size`` controls how many observations are retained: + - ``None`` (default) — keep all (unbounded). + - ``N`` — rolling window of the most recent N observations. + - ``0`` — discard immediately (live-only, no history). + """ default_config = ListObservationStoreConfig config: ListObservationStoreConfig @@ -41,7 +49,8 @@ class ListObservationStore(ObservationStore[T]): def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) self._name = self.config.name - self._observations: list[Observation[T]] = [] + max_size = self.config.max_size + self._observations: deque[Observation[T]] = deque(maxlen=max_size) self._next_id = 0 self._lock = threading.Lock() diff --git a/dimos/memory2/observationstore/sqlite.py b/dimos/memory2/observationstore/sqlite.py index 5d680c540a..960bb2ce55 100644 --- a/dimos/memory2/observationstore/sqlite.py +++ b/dimos/memory2/observationstore/sqlite.py @@ -273,7 +273,7 @@ def start(self) -> None: if self._conn is None: assert self._path is not None disposable, self._conn = open_disposable_sqlite_connection(self._path) - self.register_disposables(disposable) + self.register_disposable(disposable) self._ensure_tables() def _ensure_tables(self) -> None: diff --git a/dimos/memory2/store/README.md b/dimos/memory2/store/README.md index ff18640c0b..4766c24998 100644 --- a/dimos/memory2/store/README.md +++ b/dimos/memory2/store/README.md @@ -1,25 +1,50 @@ -# store — Store implementations +# store — Store and ObservationStore implementations -Metadata index backends for memory. Each index implements the `ObservationStore` protocol to provide observation metadata storage with query support. The concrete `Backend` class handles orchestration (blob, vector, live) on top of any index. +Store is the top-level user-facing entry point. You create one, ask it for named streams, and use those streams. Internally, each stream gets a Backend that orchestrates the lower-level pieces: -## Existing implementations +``` +Store + └── stream("lidar") → Backend + ├── ObservationStore (metadata: id, timestamp, tags, frame_id) + ├── BlobStore (raw bytes: encoded payloads) + ├── VectorStore (embeddings: similarity search) + └── Notifier (live push: new observation events) +``` + +- **ObservationStore** stores observation *metadata* and handles queries (filters, ordering, limit/offset, text search). Doesn't touch raw data or vectors. +- **BlobStore** stores/retrieves encoded payloads by `(stream_name, row_id)`. Just a key-value byte store. +- **VectorStore** stores/retrieves embedding vectors, handles similarity search. +- **Notifier** pushes new observations to live subscribers (for `.live()` tails). + +The **Backend** is the glue — on `append()` it encodes the payload, inserts metadata into ObservationStore, stores the blob in BlobStore, indexes the vector in VectorStore, and notifies live subscribers. On iterate, it queries ObservationStore for metadata, attaches lazy blob loaders, and handles vector search routing. + +**Store** sits above all that — it manages the mapping of stream names to Backends, handles config inheritance (store-level defaults vs per-stream overrides), and provides the `store.stream("name")` / `store.streams.name` API. `MemoryStore` vs `SqliteStore` vs `NullStore` differ in which component implementations they wire up by default and how they persist the registry of known streams. + +## Store implementations + +| Store | File | Description | +|----------------|-------------|------------------------------------------------------| +| `MemoryStore` | `memory.py` | In-memory store for experimentation | +| `SqliteStore` | `sqlite.py` | SQLite-backed persistent store (WAL, registry, vec0) | +| `NullStore` | `null.py` | Live-only O(1) memory, no history/replay | + +## ObservationStore implementations -| ObservationStore | File | Status | Storage | -|-----------------|-------------|----------|-------------------------------------| -| `ListObservationStore` | `memory.py` | Complete | In-memory lists, brute-force search | -| `SqliteObservationStore` | `sqlite.py` | Complete | SQLite (WAL, R*Tree, vec0) | +| ObservationStore | File | Storage | +|--------------------------|----------------------------|-------------------------------------| +| `ListObservationStore` | `observationstore/memory.py` | In-memory deque, brute-force search. `max_size` controls retention (None=all, N=rolling window, 0=discard) | +| `SqliteObservationStore` | `observationstore/sqlite.py` | SQLite (WAL, R*Tree, vec0) | -## Writing a new index +## Writing a new ObservationStore -### 1. Implement the ObservationStore protocol +### 1. Subclass ObservationStore ```python from dimos.memory2.observationstore.base import ObservationStore -from dimos.memory2.type.filter import StreamQuery -from dimos.memory2.type.observation import Observation -class MyObservationStore(Generic[T]): - def __init__(self, name: str) -> None: +class MyObservationStore(ObservationStore[T]): + def __init__(self, name: str, **kwargs: Any) -> None: + super().__init__(**kwargs) self._name = name @property @@ -35,8 +60,8 @@ class MyObservationStore(Generic[T]): def query(self, q: StreamQuery) -> Iterator[Observation[T]]: """Yield observations matching the query.""" - # The index handles metadata query fields: - # q.filters — list of Filter objects (each has .matches(obs)) + # The query carries metadata fields: + # q.filters — tuple of Filter objects (each has .matches(obs)) # q.order_field — sort field name (e.g. "ts") # q.order_desc — sort direction # q.limit_val — max results @@ -53,7 +78,7 @@ class MyObservationStore(Generic[T]): ... ``` -`ObservationStore` is a `@runtime_checkable` Protocol — no base class needed, just implement the methods. +`ObservationStore` is an abstract base class (extends `CompositeResource` and `Configurable`). ### 2. Create a Store subclass @@ -66,10 +91,11 @@ class MyStore(Store): def _create_backend( self, name: str, payload_type: type | None = None, **config: Any ) -> Backend: - index = MyObservationStore(name) - codec = codec_for(payload_type) + obs = MyObservationStore(name) + obs.start() + codec = self._resolve_codec(payload_type, config.get("codec")) return Backend( - index=index, + metadata_store=obs, codec=codec, blob_store=config.get("blob_store"), vector_store=config.get("vector_store"), @@ -84,29 +110,32 @@ class MyStore(Store): self._streams.pop(name, None) ``` -The Store creates a `Backend` composite for each stream. The `Backend` handles all orchestration (encode → insert → store blob → index vector → notify) so your index only needs to handle metadata. +The Store creates a `Backend` composite for each stream. The `Backend` handles all orchestration (encode -> insert -> store blob -> index vector -> notify) so your ObservationStore only needs to handle metadata. -### 3. Add to the grid test +### 3. Add to the test grid -In `test_impl.py`, add your store to the fixture so all standard tests run against it: +In `conftest.py`, add your store fixture and include it in the parametrized `session` fixture so all standard tests run against it: ```python -@pytest.fixture(params=["memory", "sqlite", "myindex"]) -def store(request, tmp_path): - if request.param == "myindex": - return MyStore(...) - ... +@pytest.fixture +def my_store() -> Iterator[MyStore]: + with MyStore() as store: + yield store + +@pytest.fixture(params=["memory_store", "sqlite_store", "my_store"]) +def session(request): + return request.getfixturevalue(request.param) ``` Use `pytest.mark.xfail` for features not yet implemented — the grid test covers: append, fetch, iterate, count, first/last, exists, all filters, ordering, limit/offset, embeddings, text search. ### Query contract -The index must handle the `StreamQuery` metadata fields. Vector search and blob loading are handled by the `Backend` composite — the index never needs to deal with them. +The ObservationStore must handle the `StreamQuery` metadata fields. Vector search and blob loading are handled by the `Backend` composite — the ObservationStore never needs to deal with them. -`StreamQuery.apply(iterator)` provides a complete Python-side execution path — filters, text search, vector search, ordering, offset/limit — all as in-memory operations. ObservationStorees can use it in three ways: +`StreamQuery.apply(iterator)` provides a complete Python-side execution path — filters, text search, vector search, ordering, offset/limit — all as in-memory operations. ObservationStores can use it in three ways: -**Full delegation** — simplest, good enough for in-memory indexes: +**Full delegation** — simplest, good enough for in-memory stores: ```python def query(self, q: StreamQuery) -> Iterator[Observation[T]]: return q.apply(iter(self._data)) @@ -127,4 +156,4 @@ def query(self, q: StreamQuery) -> Iterator[Observation[T]]: **Full push-down** — translate everything to native queries (SQL WHERE, FTS5 MATCH) without calling `apply()` at all. -For filters, each `Filter` object has a `.matches(obs) -> bool` method that indexes can use directly if they don't have a native equivalent. +For filters, each `Filter` object has a `.matches(obs) -> bool` method that ObservationStores can use directly if they don't have a native equivalent. diff --git a/dimos/memory2/store/base.py b/dimos/memory2/store/base.py index cf571f23b0..ffb4ace8cd 100644 --- a/dimos/memory2/store/base.py +++ b/dimos/memory2/store/base.py @@ -120,17 +120,14 @@ def _create_backend( obs = config.pop("observation_store", self.config.observation_store) if obs is None or isinstance(obs, type): obs = (obs or ListObservationStore)(name=name) - obs.start() bs = config.pop("blob_store", self.config.blob_store) if isinstance(bs, type): bs = bs() - bs.start() vs = config.pop("vector_store", self.config.vector_store) if isinstance(vs, type): vs = vs() - vs.start() notifier = config.pop("notifier", self.config.notifier) if notifier is None or isinstance(notifier, type): @@ -154,6 +151,7 @@ def stream(self, name: str, payload_type: type[T] | None = None, **overrides: An if name not in self._streams: resolved = {**self.config.model_dump(exclude_none=True), **overrides} backend = self._create_backend(name, payload_type, **resolved) + backend.start() self._streams[name] = Stream(source=backend) return cast("Stream[T]", self._streams[name]) @@ -163,4 +161,11 @@ def list_streams(self) -> list[str]: def delete_stream(self, name: str) -> None: """Delete a stream by name (from cache and underlying storage).""" - self._streams.pop(name, None) + stream = self._streams.pop(name, None) + if stream is not None: + stream.stop() + + def stop(self) -> None: + for stream in self._streams.values(): + stream.stop() + super().stop() diff --git a/dimos/memory2/store/memory.py b/dimos/memory2/store/memory.py index 6aecde29dd..5b4523aac6 100644 --- a/dimos/memory2/store/memory.py +++ b/dimos/memory2/store/memory.py @@ -12,10 +12,35 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.memory2.store.base import Store +from typing import Any + +from dimos.memory2.backend import Backend +from dimos.memory2.observationstore.memory import ListObservationStore +from dimos.memory2.store.base import Store, StoreConfig + + +class MemoryStoreConfig(StoreConfig): + max_size: int | None = None class MemoryStore(Store): - """In-memory store for experimentation.""" + """In-memory store for experimentation. + + ``max_size`` controls how many observations each stream retains: + - ``None`` (default) — keep all (unbounded). + - ``N`` — rolling window of the most recent N observations. + - ``0`` — discard immediately (live-only, no history). + """ + + default_config = MemoryStoreConfig + config: MemoryStoreConfig - pass + def _create_backend( + self, name: str, payload_type: type[Any] | None = None, **config: Any + ) -> Backend[Any]: + if "observation_store" not in config and self.config.observation_store is None: + obs: ListObservationStore[Any] = ListObservationStore( + name=name, max_size=self.config.max_size + ) + config["observation_store"] = obs + return super()._create_backend(name, payload_type, **config) diff --git a/dimos/memory2/store/null.py b/dimos/memory2/store/null.py new file mode 100644 index 0000000000..71f02c4aee --- /dev/null +++ b/dimos/memory2/store/null.py @@ -0,0 +1,29 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +from dimos.memory2.store.memory import MemoryStore + + +class NullStore(MemoryStore): + """Live-only store — O(1) memory, no history/replay. + + Shorthand for ``MemoryStore(max_size=0)``. + Observations get IDs (for live dedup) but are immediately discarded. + """ + + def __init__(self, **kwargs: Any) -> None: + kwargs.setdefault("max_size", 0) + super().__init__(**kwargs) diff --git a/dimos/memory2/store/sqlite.py b/dimos/memory2/store/sqlite.py index b655e0a8bc..1071e9977f 100644 --- a/dimos/memory2/store/sqlite.py +++ b/dimos/memory2/store/sqlite.py @@ -14,8 +14,11 @@ from __future__ import annotations +import os import sqlite3 -from typing import Any +from typing import Annotated, Any + +from pydantic import BeforeValidator from dimos.memory2.backend import Backend from dimos.memory2.blobstore.base import BlobStore @@ -33,7 +36,9 @@ class SqliteStoreConfig(StoreConfig): """Config for SQLite-backed store.""" - path: str = "memory.db" + path: Annotated[ + str, BeforeValidator(lambda v: os.fspath(v) if isinstance(v, os.PathLike) else v) + ] = "memory.db" page_size: int = 256 @@ -51,7 +56,7 @@ def __init__(self, **kwargs: Any) -> None: def _open_connection(self) -> sqlite3.Connection: """Open a new WAL-mode connection with sqlite-vec loaded.""" disposable, connection = open_disposable_sqlite_connection(self.config.path) - self.register_disposables(disposable) + self.register_disposable(disposable) return connection def _assemble_backend(self, name: str, stored: dict[str, Any]) -> Backend[Any]: @@ -75,7 +80,6 @@ def _assemble_backend(self, name: str, stored: dict[str, Any]) -> Backend[Any]: bs = deserialize_component(bs_data) else: bs = SqliteBlobStore(conn=backend_conn) - bs.start() vs_data = stored.get("vector_store") if vs_data is not None: @@ -86,7 +90,6 @@ def _assemble_backend(self, name: str, stored: dict[str, Any]) -> Backend[Any]: vs = deserialize_component(vs_data) else: vs = SqliteVectorStore(conn=backend_conn) - vs.start() notifier_data = stored.get("notifier") if notifier_data is not None: @@ -105,8 +108,6 @@ def _assemble_backend(self, name: str, stored: dict[str, Any]) -> Backend[Any]: blob_store_conn_match=blob_store_conn_match and eager_blobs, page_size=page_size, ) - metadata_store.start() - backend: Backend[Any] = Backend( metadata_store=metadata_store, codec=codec, @@ -161,13 +162,9 @@ def _create_backend( # Inject conn-shared instances unless user provided overrides if not isinstance(config.get("blob_store"), BlobStore): - bs = SqliteBlobStore(conn=backend_conn) - bs.start() - config["blob_store"] = bs + config["blob_store"] = SqliteBlobStore(conn=backend_conn) if not isinstance(config.get("vector_store"), VectorStore): - vs = SqliteVectorStore(conn=backend_conn) - vs.start() - config["vector_store"] = vs + config["vector_store"] = SqliteVectorStore(conn=backend_conn) # Resolve codec early — needed for SqliteObservationStore codec = self._resolve_codec(payload_type, config.get("codec")) @@ -184,7 +181,6 @@ def _create_backend( blob_store_conn_match=blob_conn_match and eager_blobs, page_size=config.pop("page_size", self.config.page_size), ) - obs_store.start() config["observation_store"] = obs_store backend = super()._create_backend(name, payload_type, **config) diff --git a/dimos/memory2/store/test_null.py b/dimos/memory2/store/test_null.py new file mode 100644 index 0000000000..3461ff3d9d --- /dev/null +++ b/dimos/memory2/store/test_null.py @@ -0,0 +1,56 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for NullStore and max_size=0 discard behavior.""" + +from __future__ import annotations + +from dimos.memory2.store.null import NullStore + + +def test_max_size_zero_monotonic_ids() -> None: + """NullStore assigns monotonically increasing IDs despite discarding data.""" + store = NullStore() + with store: + stream = store.stream("test", str) + obs0 = stream.append("hello") + obs1 = stream.append("world") + obs2 = stream.append("!") + + assert obs0.id == 0 + assert obs1.id == 1 + assert obs2.id == 2 + + +def test_max_size_zero_empty_query() -> None: + """NullStore queries always return empty.""" + store = NullStore() + with store: + stream = store.stream("test", str) + stream.append("data") + assert stream.count() == 0 + assert stream.fetch() == [] + + +def test_null_store_discards_history() -> None: + """NullStore discards history but still supports live streaming.""" + store = NullStore() + with store: + stream = store.stream("test", int) + stream.append(1) + stream.append(2) + stream.append(3) + + assert stream.count() == 0 + assert stream.fetch() == [] diff --git a/dimos/memory2/stream.py b/dimos/memory2/stream.py index 545d387c32..75bf6ab6a0 100644 --- a/dimos/memory2/stream.py +++ b/dimos/memory2/stream.py @@ -15,9 +15,9 @@ from __future__ import annotations import time -from typing import TYPE_CHECKING, Any, Generic, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast -from dimos.core.resource import Resource +from dimos.core.resource import CompositeResource from dimos.memory2.buffer import BackpressureBuffer, KeepLast from dimos.memory2.transform import FnIterTransformer, FnTransformer, Transformer from dimos.memory2.type.filter import ( @@ -32,6 +32,7 @@ TimeRangeFilter, ) from dimos.memory2.type.observation import EmbeddedObservation, Observation +from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: from collections.abc import Callable, Iterator @@ -44,53 +45,61 @@ T = TypeVar("T") R = TypeVar("R") +logger = setup_logger() -class Stream(Resource, Generic[T]): +class Stream(CompositeResource, Generic[T]): """Lazy, pull-based stream over observations. Every filter/transform method returns a new Stream — no computation happens until iteration. Backends handle query application for stored data; transform sources apply filters as Python predicates. - Implements Resource so live streams can be cleanly stopped via - ``stop()`` or used as a context manager. + Implements CompositeResource so subscriptions created via ``.subscribe()`` + and ``.publish()`` are tracked and disposed on ``stop()``. + + An *unbound* stream (``Stream()``) records a chain of transforms + without a real source. Use ``.chain()`` to apply it to a bound stream:: + + pipeline = Stream().transform(VoxelMapTransformer()).map(postprocess) + store.stream("lidar", PointCloud2).live().chain(pipeline) """ def __init__( self, - source: Backend[T] | Stream[Any], + source: Backend[T] | Stream[Any] | None = None, *, - xf: Transformer[Any, T] | None = None, + transform: Transformer[Any, T] | None = None, query: StreamQuery = StreamQuery(), ) -> None: + super().__init__() self._source = source - self._xf = xf + if source is not None: + self.register_disposable(source) + self._transform = transform self._query = query - def start(self) -> None: - pass - def stop(self) -> None: - """Close the live buffer (if any), unblocking iteration.""" buf = self._query.live_buffer if buf is not None: buf.close() - if isinstance(self._source, Stream): - self._source.stop() + super().stop() def __str__(self) -> str: # Walk the source chain to collect (xf, query) pairs chain: list[tuple[Any, StreamQuery]] = [] current: Any = self while isinstance(current, Stream): - chain.append((current._xf, current._query)) + chain.append((current._transform, current._query)) current = current._source chain.reverse() # innermost first - # current is the Backend - name = getattr(current, "name", "?") - result = f'Stream("{name}")' + # current is the Backend (or None for unbound) + if current is None: + result = "Stream(unbound)" + else: + name = getattr(current, "name", "?") + result = f'Stream("{name}")' for xf, query in chain: if xf is not None: @@ -110,9 +119,10 @@ def is_live(self) -> bool: return False def __iter__(self) -> Iterator[Observation[T]]: - return self._build_iter() - - def _build_iter(self) -> Iterator[Observation[T]]: + if self._source is None: + raise TypeError( + "Cannot iterate an unbound stream. Use .chain() to apply it to a real stream first." + ) if isinstance(self._source, Stream): return self._iter_transform() # Backend handles all query application (including live if requested) @@ -120,8 +130,8 @@ def _build_iter(self) -> Iterator[Observation[T]]: def _iter_transform(self) -> Iterator[Observation[T]]: """Iterate a transform source, applying query filters in Python.""" - assert isinstance(self._source, Stream) and self._xf is not None - it: Iterator[Observation[T]] = self._xf(iter(self._source)) + assert isinstance(self._source, Stream) and self._transform is not None + it: Iterator[Observation[T]] = self._transform(iter(self._source)) return self._query.apply(it, live=self.is_live()) def _replace_query(self, **overrides: Any) -> Stream[T]: @@ -137,7 +147,7 @@ def _replace_query(self, **overrides: Any) -> Stream[T]: search_k=overrides.get("search_k", q.search_k), search_text=overrides.get("search_text", q.search_text), ) - return Stream(self._source, xf=self._xf, query=new_q) + return Stream(self._source, transform=self._transform, query=new_q) def _with_filter(self, f: Filter) -> Stream[T]: return self._replace_query(filters=(*self._query.filters, f)) @@ -210,7 +220,7 @@ def detect(upstream): """ if not isinstance(xf, Transformer): xf = FnIterTransformer(xf) - return Stream(source=self, xf=xf, query=StreamQuery()) + return Stream(source=self, transform=xf, query=StreamQuery()) def live(self, buffer: BackpressureBuffer[Observation[Any]] | None = None) -> Stream[T]: """Return a stream whose iteration never ends — backfill then live tail. @@ -221,9 +231,9 @@ def live(self, buffer: BackpressureBuffer[Observation[Any]] | None = None) -> St Default buffer: KeepLast(). The backend handles subscription, dedup, and backpressure — how it does so is its business. """ - if isinstance(self._source, Stream): + if isinstance(self._source, Stream) or self._source is None: raise TypeError( - "Cannot call .live() on a transform stream. " + "Cannot call .live() on a transform/unbound stream. " "Call .live() on the source stream, then .transform()." ) buf = buffer if buffer is not None else KeepLast() @@ -234,8 +244,10 @@ def save(self, target: Stream[T]) -> Stream[T]: Returns the target stream for continued querying. """ - if isinstance(target._source, Stream): - raise TypeError("Cannot save to a transform stream. Target must be backend-backed.") + if isinstance(target._source, Stream) or target._source is None: + raise TypeError( + "Cannot save to a transform/unbound stream. Target must be backend-backed." + ) backend = target._source for obs in self: backend.append(obs) @@ -264,7 +276,7 @@ def last(self) -> Observation[T]: def count(self) -> int: """Count matching observations.""" - if not isinstance(self._source, Stream): + if self._source is not None and not isinstance(self._source, Stream): return self._source.count(self._query) if self.is_live(): raise TypeError(".count() on a live transform stream would block forever.") @@ -328,13 +340,79 @@ def subscribe( on_error: Callable[[Exception], None] | None = None, on_completed: Callable[[], None] | None = None, ) -> DisposableBase: - """Subscribe to this stream as an RxPY Observable.""" - return self.observable().subscribe( # type: ignore[call-overload] - on_next=on_next, - on_error=on_error, - on_completed=on_completed, + """Subscribe to this stream as an RxPY Observable. + + The subscription is tracked and disposed when this stream is stopped. + """ + return self.register_disposable( + self.observable().subscribe( # type: ignore[call-overload] + on_next=on_next, + on_error=on_error, + on_completed=on_completed, + ) + ) + + def publish(self, out: Any) -> DisposableBase: + """Publish each observation's data to a Module ``Out`` port. + + Iteration runs on the dimos thread pool (via :meth:`subscribe`). + Returns a ``DisposableBase`` suitable for ``register_disposable()``. + + Example:: + + lidar.live().transform(VoxelMapTransformer()).publish(self.global_map) + """ + + def _on_error(e: Exception) -> None: + logger.error("Stream.publish() pipeline error: %s", e, exc_info=True) + + return self.subscribe( + on_next=lambda obs: out.publish(obs.data), + on_error=_on_error, ) + def chain(self, other: Stream[R]) -> Stream[R]: + """Append operations from an unbound stream to this stream. + + Extracts the transform/filter chain from *other* (which must be + unbound) and replays it on top of ``self``:: + + pipeline = Stream().transform(VoxelMapTransformer()).map(postprocess) + store.stream("lidar").live().chain(pipeline) + """ + ops: list[tuple[Transformer[Any, Any] | None, StreamQuery]] = [] + current: Stream[Any] | None | Any = other + found_root = False + while isinstance(current, Stream): + ops.append((current._transform, current._query)) + if current._source is None: + found_root = True + break + current = current._source + if not found_root: + raise TypeError("Can only chain an unbound stream (created with Stream())") + + # Validate no unsupported query fields in the unbound chain + for _, query in ops: + if query.search_vec is not None or query.search_text is not None: + raise TypeError("search() / search_text() cannot be used on unbound streams") + if query.live_buffer is not None: + raise TypeError("live() cannot be used on unbound streams") + + result: Stream[Any] = self + for xf, query in reversed(ops): + if xf is not None: + result = result.transform(xf) + for f in query.filters: + result = result._with_filter(f) + if query.limit_val is not None: + result = result.limit(query.limit_val) + if query.offset_val is not None and query.offset_val != 0: + result = result.offset(query.offset_val) + if query.order_field is not None: + result = result.order_by(query.order_field, desc=query.order_desc) + return cast("Stream[R]", result) + def append( self, payload: T, @@ -345,8 +423,10 @@ def append( embedding: Embedding | None = None, ) -> Observation[T]: """Append to the backing store. Only works if source is a Backend.""" - if isinstance(self._source, Stream): - raise TypeError("Cannot append to a transform stream. Append to the source stream.") + if isinstance(self._source, Stream) or self._source is None: + raise TypeError( + "Cannot append to a transform/unbound stream. Append to the source stream." + ) _ts = ts if ts is not None else time.time() _tags = tags or {} if embedding is not None: diff --git a/dimos/memory2/test_e2e.py b/dimos/memory2/test_e2e.py index efea5a59a2..31d5ee1720 100644 --- a/dimos/memory2/test_e2e.py +++ b/dimos/memory2/test_e2e.py @@ -126,6 +126,31 @@ def test_import_lidar( assert lidar.count() == count print(f"Imported {count} lidar frames") + def test_embed_and_save(self, session: SqliteStore, clip: CLIPModel) -> None: + """Embed video frames at 1Hz and persist to an embedded stream.""" + video = session.stream("color_image", Image) + + # Clear any prior run so the test is idempotent + if "color_image_embedded" in session.list_streams(): + session.delete_stream("color_image_embedded") + + embedded = session.stream("color_image_embedded", Image) + + # Downsample to 1Hz, then embed + pipeline = ( + video.transform(QualityWindow(lambda img: img.sharpness, window=1.0)) + .transform(EmbedImages(clip)) + .save(embedded) + ) + + count = 0 + for obs in pipeline: + count += 1 + print(f" [{count}] ts={obs.ts:.2f} pose={obs.pose}") + + assert count > 0 + print(f"Embedded {count} frames (1Hz from {video.count()} total)") + def test_query_imported_data(self, session: SqliteStore) -> None: video = session.stream("color_image", Image) lidar = session.stream("lidar", PointCloud2) @@ -256,38 +281,12 @@ def test_cross_stream_time_alignment(self, session: SqliteStore) -> None: overlap_start = max(v_first, l_first) overlap_end = min(v_last, l_last) assert overlap_start < overlap_end, "Video and lidar should overlap in time" - assert overlap_start < overlap_end, "Video and lidar should overlap in time" @pytest.mark.tool class TestEmbedImages: """CLIP-embed imported video frames and search by text.""" - def test_embed_and_save(self, session: SqliteStore, clip: CLIPModel) -> None: - """Embed video frames at 1Hz and persist to an embedded stream.""" - video = session.stream("color_image", Image) - - # Clear any prior run so the test is idempotent - if "color_image_embedded" in session.list_streams(): - session.delete_stream("color_image_embedded") - - embedded = session.stream("color_image_embedded", Image) - - # Downsample to 1Hz, then embed - pipeline = ( - video.transform(QualityWindow(lambda img: img.sharpness, window=1.0)) - .transform(EmbedImages(clip)) - .save(embedded) - ) - - count = 0 - for obs in pipeline: - count += 1 - print(f" [{count}] ts={obs.ts:.2f} pose={obs.pose}") - - assert count > 0 - print(f"Embedded {count} frames (1Hz from {video.count()} total)") - def test_search_by_text(self, session: SqliteStore, clip: CLIPModel) -> None: """Search embedded frames with a text query.""" embedded = session.stream("color_image_embedded", Image) diff --git a/dimos/memory2/test_module.py b/dimos/memory2/test_module.py new file mode 100644 index 0000000000..a944539063 --- /dev/null +++ b/dimos/memory2/test_module.py @@ -0,0 +1,131 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Grid tests for StreamModule — same e2e logic across all pipeline styles.""" + +from __future__ import annotations + +from collections.abc import Iterator +import threading + +import pytest +from reactivex.scheduler import ThreadPoolScheduler + +from dimos.core.module import ModuleConfig +from dimos.core.stream import In, Out +from dimos.core.transport import pLCMTransport +from dimos.memory2.module import StreamModule +from dimos.memory2.stream import Stream +from dimos.memory2.transform import Transformer +from dimos.memory2.type.observation import Observation + +# -- Shared transformer --------------------------------------------------- + + +class Double(Transformer[int, int]): + def __init__(self, factor: int = 2) -> None: + self.factor = factor + + def __call__(self, upstream: Iterator[Observation[int]]) -> Iterator[Observation[int]]: + for obs in upstream: + yield obs.derive(data=obs.data * self.factor) + + +# -- Pipeline styles ------------------------------------------------------- + + +class StaticStreamModule(StreamModule): + """Pipeline as a static Stream chain on the class.""" + + pipeline = Stream().transform(Double()) + numbers: In[int] + doubled: Out[int] + + +class StaticTransformerModule(StreamModule): + """Pipeline as a bare Transformer on the class.""" + + pipeline = Double() + numbers: In[int] + doubled: Out[int] + + +class MethodPipelineConfig(ModuleConfig): + factor: int = 2 + + +class MethodPipelineModule(StreamModule[MethodPipelineConfig]): + """Pipeline as a method with access to self.config.""" + + default_config = MethodPipelineConfig + + def pipeline(self, stream: Stream) -> Stream: + return stream.transform(Double(factor=self.config.factor)) + + numbers: In[int] + doubled: Out[int] + + +# -- Grid ------------------------------------------------------------------ + +module_cases = [ + pytest.param(StaticStreamModule, id="static-stream"), + pytest.param(StaticTransformerModule, id="static-transformer"), + pytest.param(MethodPipelineModule, id="method-pipeline"), +] + + +@pytest.mark.parametrize("module_cls", module_cases) +def test_blueprint_ports(module_cls: type[StreamModule]) -> None: + """All pipeline styles produce a blueprint with the correct In/Out ports.""" + bp = module_cls.blueprint() + + assert len(bp.blueprints) == 1 + atom = bp.blueprints[0] + stream_names = {s.name for s in atom.streams} + assert "numbers" in stream_names + assert "doubled" in stream_names + + +def _reset_thread_pool() -> None: + """Shut down and replace the global RxPY thread pool so conftest thread-leak check passes.""" + import dimos.utils.threadpool as tp + + tp.scheduler.executor.shutdown(wait=True) + tp.scheduler = ThreadPoolScheduler(max_workers=tp.get_max_workers()) + + +@pytest.mark.tool +@pytest.mark.parametrize("module_cls", module_cases) +def test_e2e_runtime_wiring(module_cls: type[StreamModule]) -> None: + """Push data into In port, assert doubled data arrives on Out port.""" + module = module_cls() + module.numbers.transport = pLCMTransport("/test/numbers") + module.doubled.transport = pLCMTransport("/test/doubled") + + received: list[int] = [] + done = threading.Event() + + unsub = module.doubled.subscribe(lambda msg: (received.append(msg), done.set())) + + module.start() + try: + module.numbers.transport.publish(42) + assert done.wait(timeout=5.0), f"Timed out, received={received}" + assert received == [84] + finally: + unsub() + module.stop() + _reset_thread_pool() + _reset_thread_pool() diff --git a/dimos/memory2/test_save.py b/dimos/memory2/test_save.py index 13ee73d46a..8ebb12082b 100644 --- a/dimos/memory2/test_save.py +++ b/dimos/memory2/test_save.py @@ -101,7 +101,7 @@ def test_save_rejects_transform_target(self) -> None: base = make_stream(2) transform_stream = base.transform(FnTransformer(lambda obs: obs.derive(obs.data))) - with pytest.raises(TypeError, match="Cannot save to a transform stream"): + with pytest.raises(TypeError, match="Cannot save to a transform"): source.save(transform_stream) def test_save_target_queryable(self) -> None: diff --git a/dimos/memory2/test_store.py b/dimos/memory2/test_store.py index dfba6d6d2b..aa525c8758 100644 --- a/dimos/memory2/test_store.py +++ b/dimos/memory2/test_store.py @@ -24,6 +24,7 @@ import pytest +from dimos.memory2.backend import Backend from dimos.memory2.blobstore.base import BlobStore from dimos.memory2.vectorstore.base import VectorStore @@ -525,3 +526,94 @@ def test_accessor_dynamic(self, session: Store) -> None: assert "late" not in dir(session.streams) session.stream("late", str) assert "late" in dir(session.streams) + + +class TestStoreLifecycle: + """Cleanup chain: Store → Stream → Backend → components.""" + + def test_stop_stream_keeps_other_streams(self, session: Store) -> None: + """Stopping one stream doesn't affect another.""" + s1 = session.stream("a", int) + s2 = session.stream("b", int) + s1.append(1) + s2.append(2) + + s1.stop() + + # s2 still works + s2.append(3) + assert [obs.data for obs in s2] == [2, 3] + + def test_store_stop_stops_backends(self, session: Store) -> None: + """Store.stop() disposes backends transitively via streams.""" + s1 = session.stream("x", int) + s2 = session.stream("y", int) + s1.append(10) + s2.append(20) + + backend1 = s1._source + backend2 = s2._source + assert isinstance(backend1, Backend) + assert isinstance(backend2, Backend) + + session.stop() + + # Both backends' disposables are disposed + assert backend1._disposables is not None + assert backend1._disposables.is_disposed + assert backend2._disposables is not None + assert backend2._disposables.is_disposed + + def test_stream_stop_stops_backend(self, session: Store) -> None: + """stream.stop() disposes its backend (Stream owns Backend).""" + s = session.stream("owned", int) + s.append(42) + + backend = s._source + assert isinstance(backend, Backend) + + s.stop() + + assert backend._disposables is not None + assert backend._disposables.is_disposed + + def test_stream_stop_stops_backend_components(self, session: Store) -> None: + """stream.stop() cascades through backend to its components.""" + s = session.stream("cascade", int) + backend = s._source + assert isinstance(backend, Backend) + + s.stop() + + # Backend registers notifier as disposable, so it gets disposed + assert backend._disposables is not None + assert backend._disposables.is_disposed + # Notifier's own disposables may be None (no children registered), + # but the backend's disposal cascade is what matters. + + def test_delete_stream_stops_backend(self, session: Store) -> None: + """delete_stream() stops the stream+backend and removes from cache.""" + s = session.stream("ephemeral", int) + s.append(1) + + backend = s._source + assert isinstance(backend, Backend) + assert "ephemeral" in session.list_streams() + + session.delete_stream("ephemeral") + + assert backend._disposables is not None + assert backend._disposables.is_disposed + assert "ephemeral" not in session.list_streams() + + def test_backend_stop_stops_components(self, session: Store) -> None: + """Backend.stop() propagates to metadata_store, blob_store, vector_store.""" + s = session.stream("z", int) + backend = s._source + assert isinstance(backend, Backend) + + session.stop() + + # Backend always registers its components, so _disposables is always set + assert backend._disposables is not None + assert backend._disposables.is_disposed diff --git a/dimos/memory2/test_stream.py b/dimos/memory2/test_stream.py index 03c3caec76..e53cd15d9f 100644 --- a/dimos/memory2/test_stream.py +++ b/dimos/memory2/test_stream.py @@ -26,14 +26,14 @@ import pytest from dimos.memory2.buffer import KeepLast, Unbounded +from dimos.memory2.store.memory import MemoryStore +from dimos.memory2.stream import Stream from dimos.memory2.transform import FnTransformer, QualityWindow, Transformer from dimos.memory2.type.observation import Observation if TYPE_CHECKING: from collections.abc import Callable - from dimos.memory2.stream import Stream - @pytest.fixture def make_stream(session) -> Callable[..., Stream[int]]: @@ -50,11 +50,6 @@ def f(n: int = 5, start_ts: float = 0.0): return f -# ═══════════════════════════════════════════════════════════════════ -# 1. Basic iteration -# ═══════════════════════════════════════════════════════════════════ - - class TestBasicIteration: """Streams are lazy iterables — nothing runs until you iterate.""" @@ -85,11 +80,6 @@ def test_stream_is_reiterable(self, make_stream): assert first == second == [0, 10, 20] -# ═══════════════════════════════════════════════════════════════════ -# 2. Temporal filters -# ═══════════════════════════════════════════════════════════════════ - - class TestTemporalFilters: """Temporal filters constrain observations by timestamp.""" @@ -119,11 +109,6 @@ def test_chained_temporal_filters(self, make_stream): assert [o.ts for o in result] == [3.0, 4.0, 5.0, 6.0] -# ═══════════════════════════════════════════════════════════════════ -# 3. Spatial filter -# ═══════════════════════════════════════════════════════════════════ - - class TestSpatialFilter: """.near(pose, radius) filters by Euclidean distance.""" @@ -145,11 +130,6 @@ def test_near_excludes_no_pose(self, memory_session): assert [o.data for o in result] == ["has_pose"] -# ═══════════════════════════════════════════════════════════════════ -# 4. Tags filter -# ═══════════════════════════════════════════════════════════════════ - - class TestTagsFilter: """.filter_tags() matches on observation metadata.""" @@ -171,11 +151,6 @@ def test_filter_multiple_tags(self, memory_session): assert [o.data for o in result] == ["a"] -# ═══════════════════════════════════════════════════════════════════ -# 5. Ordering, limit, offset -# ═══════════════════════════════════════════════════════════════════ - - class TestOrderLimitOffset: def test_limit(self, make_stream): result = make_stream(10).limit(3).fetch() @@ -220,11 +195,6 @@ def test_drain(self, make_stream): assert make_stream(0).drain() == 0 -# ═══════════════════════════════════════════════════════════════════ -# 6. Functional API: .filter(), .map() -# ═══════════════════════════════════════════════════════════════════ - - class TestFunctionalAPI: """Functional combinators receive the full Observation.""" @@ -249,11 +219,6 @@ def test_map_preserves_ts(self, make_stream): assert [o.data for o in result] == ["0", "10", "20"] -# ═══════════════════════════════════════════════════════════════════ -# 7. Transform chaining -# ═══════════════════════════════════════════════════════════════════ - - class TestTransformChaining: """Transforms chain lazily — each obs flows through the full pipeline.""" @@ -352,9 +317,109 @@ def __call__(self, upstream): assert len(calls) == 3 -# ═══════════════════════════════════════════════════════════════════ -# 8. Store -# ═══════════════════════════════════════════════════════════════════ +class TestUnboundStream: + """Unbound streams: pipelines built without a source, applied later via .chain().""" + + def test_creation(self) -> None: + """Stream() with no args creates an unbound stream.""" + s = Stream() + assert s._transform is None + + def test_multi_transform_chain(self) -> None: + """Unbound pipeline with multiple transforms produces correct results when bound.""" + + class Double(Transformer[int, int]): + def __call__(self, upstream): + for obs in upstream: + yield obs.derive(data=obs.data * 2) + + pipeline = Stream().transform(Double()).map(lambda obs: obs.derive(data=obs.data + 1)) + + store = MemoryStore() + with store: + stream = store.stream("test", int) + stream.append(5) + stream.append(10) + + result = stream.chain(pipeline).fetch() + assert [obs.data for obs in result] == [11, 21] + + def test_iteration_raises(self) -> None: + """Iterating an unbound stream raises TypeError.""" + s = Stream().transform(FnTransformer(lambda obs: obs)) + with pytest.raises(TypeError, match="unbound"): + list(s) + + def test_chain_applies_transforms(self) -> None: + """chain() replays unbound transforms on a real stream.""" + store = MemoryStore() + with store: + stream = store.stream("test", int) + stream.append(10) + stream.append(20) + stream.append(30) + + class Double(Transformer[int, int]): + def __call__(self, upstream): + for obs in upstream: + yield obs.derive(data=obs.data * 2) + + pipeline = Stream().transform(Double()) + result = stream.chain(pipeline).fetch() + assert [obs.data for obs in result] == [20, 40, 60] + + def test_chain_multiple_transforms(self) -> None: + """chain() preserves order of multiple transforms.""" + store = MemoryStore() + with store: + stream = store.stream("test", int) + stream.append(5) + + class Double(Transformer[int, int]): + def __call__(self, upstream): + for obs in upstream: + yield obs.derive(data=obs.data * 2) + + class AddTen(Transformer[int, int]): + def __call__(self, upstream): + for obs in upstream: + yield obs.derive(data=obs.data + 10) + + pipeline = Stream().transform(Double()).transform(AddTen()) + result = stream.chain(pipeline).fetch() + assert result[0].data == 20 # (5 * 2) + 10 + + def test_chain_preserves_filters(self) -> None: + """chain() replays filters from the unbound stream.""" + store = MemoryStore() + with store: + stream = store.stream("test", int) + stream.append(10, ts=1.0) + stream.append(20, ts=2.0) + stream.append(30, ts=3.0) + + pipeline = Stream().after(1.5) + result = stream.chain(pipeline).fetch() + assert [obs.data for obs in result] == [20, 30] + + def test_chain_rejects_bound_stream(self) -> None: + """chain() raises if passed a bound (non-unbound) stream.""" + store = MemoryStore() + with store: + s1 = store.stream("a", int) + s2 = store.stream("b", int) + with pytest.raises(TypeError, match="unbound"): + s1.chain(s2) + + def test_live_rejects_unbound(self) -> None: + """live() raises on an unbound stream.""" + with pytest.raises(TypeError, match="unbound"): + Stream().live() + + def test_str(self) -> None: + """Unbound streams display as Stream(unbound).""" + s = Stream() + assert "unbound" in str(s) class TestStore: @@ -385,11 +450,6 @@ def test_delete_stream(self, memory_store): assert "temp" not in memory_store.list_streams() -# ═══════════════════════════════════════════════════════════════════ -# 9. Lazy data loading -# ═══════════════════════════════════════════════════════════════════ - - class TestLazyData: """Observation.data supports lazy loading with cleanup.""" @@ -430,11 +490,6 @@ def test_derive_preserves_metadata(self): assert derived.data == "transformed" -# ═══════════════════════════════════════════════════════════════════ -# 10. Live mode -# ═══════════════════════════════════════════════════════════════════ - - class TestLiveMode: """Live streams yield backfill then block for new observations.""" diff --git a/dimos/memory2/test_visualizer.py b/dimos/memory2/test_visualizer.py index 033d60205f..0830c946fd 100644 --- a/dimos/memory2/test_visualizer.py +++ b/dimos/memory2/test_visualizer.py @@ -16,6 +16,7 @@ from __future__ import annotations +import pickle from typing import TYPE_CHECKING import pytest @@ -24,8 +25,12 @@ from dimos.memory2.transform import Batch, QualityWindow from dimos.models.embedding.clip import CLIPModel from dimos.models.vl.florence import Florence2Model +from dimos.models.vl.moondream import MoondreamVlModel +from dimos.msgs.geometry_msgs.Transform import Transform from dimos.msgs.sensor_msgs.Image import Image -from dimos.utils.data import get_data_dir +from dimos.perception.detection.type.detection3d.pointcloud import Detection3DPC +from dimos.robot.unitree.go2.connection import GO2Connection +from dimos.utils.data import get_data, get_data_dir if TYPE_CHECKING: from collections.abc import Iterator @@ -105,13 +110,11 @@ def test_search_near_pose(self, store: SqliteStore) -> None: # VIS GOAL: draw 2d detections somehow, or project into 3d, draw 3d bounding boxes def test_detect_objects(self, store: SqliteStore, clip: CLIPModel) -> None: """CLIP pre-filter + VLM detection on top candidates.""" - from dimos.models.vl.moondream import MoondreamVlModel - vlm = MoondreamVlModel() embedded = store.streams.color_image_embedded lidar = store.streams.lidar - for obs in embedded.search(clip.embed_text("bottle"), k=10).map( + for obs in embedded.search(clip.embed_text("bottle"), k=1).map( lambda obs: obs.derive(data=vlm.query_detections(obs.data, "bottle")) ): print(f"ts={obs.ts:.2f} sim={obs.similarity:.3f} pose={obs.pose}") @@ -135,13 +138,65 @@ def test_search_reconstruct_full_path(self, store: SqliteStore) -> None: def test_agent_visual_description_passive(self, store: SqliteStore) -> None: florence = Florence2Model() with florence: - pipeline = store.streams.color_image.transform( - QualityWindow(lambda img: img.sharpness, window=5.0) - # we are batch processing images here, - # so we can use the more efficient batch captioning API - # (instead of using .map() and calling caption() for each image, - ).transform(Batch(lambda imgs: florence.caption_batch(*imgs))) + pipeline = ( + store.streams.color_image.limit(200) + .transform( + QualityWindow(lambda img: img.sharpness, window=5.0) + # we are batch processing images here, + # so we can use the more efficient batch captioning API + # (instead of using .map() and calling caption() for each image, + ) + .transform(Batch(lambda imgs: florence.caption_batch(*imgs))) + ) # this can be stored, further embedded etc for obs in pipeline: print(obs.ts, obs.data) + + def test_build_global_map(self, store: SqliteStore) -> None: + global_map = pickle.loads(get_data("unitree_go2_bigoffice_map.pickle").read_bytes()) + print(f"Global map: {len(global_map)}") + + # we semantically search, then detect with a detection model + # + # VIS GOAL: draw 2d detections somehow, or project into 3d, draw 3d bounding boxes + def test_detect_objects_smart(self, store: SqliteStore, clip: CLIPModel) -> None: + """CLIP pre-filter + VLM detection on top candidates.""" + vlm = MoondreamVlModel() + embedded = store.streams.color_image_embedded + lidar = store.streams.lidar + + # find a location in the world with highest semantic similarity to a bottle + bottle_pos = embedded.search(clip.embed_text("bottle"), k=1).first().pose_stamped + + for obs in ( + store.streams.color_image + # find all frames within 60 seconds of the semantic hotspot + .at(bottle_pos.ts, tolerance=60.0) + # filter the frames within 1m radius near the semantic hotspot + .near(bottle_pos, radius=1.0) + # select highest quality frames from these results (based on sharpness) + .transform(QualityWindow(lambda img: img.sharpness, window=1.0)) + # run detection on these frames to find bottles + .map(lambda obs: obs.derive(data=vlm.query_detections(obs.data, "bottle"))) + ): + print(f"ts={obs.ts:.2f} pose={obs.pose_stamped}") + + # find the lidar frame captured closest in time to an image + lidar_frame = lidar.at(obs.ts).first().data + + for det in obs.data.detections: + print(det) + # project each bottle into 3D using lidar frame + # known camera intrinsics + extrinsics + det3d = Detection3DPC.from_2d( + det, + lidar_frame, + camera_info=GO2Connection.camera_info_static, + world_to_optical_transform=Transform( + ts=obs.ts, + translation=obs.pose_stamped.position, + rotation=obs.pose_stamped.orientation, + ).inverse(), + ) + print(det3d) diff --git a/dimos/memory2/test_voxel_map.py b/dimos/memory2/test_voxel_map.py new file mode 100644 index 0000000000..0fd254be60 --- /dev/null +++ b/dimos/memory2/test_voxel_map.py @@ -0,0 +1,135 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import time +from typing import TYPE_CHECKING + +import numpy as np +import pytest + +from dimos.mapping.voxels import VoxelMapTransformer +from dimos.memory2.store.sqlite import SqliteStore +from dimos.memory2.type.observation import Observation +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.utils.data import get_data + +if TYPE_CHECKING: + from collections.abc import Iterator + + +@pytest.fixture(scope="module") +def store() -> Iterator[SqliteStore]: + db = SqliteStore(path=get_data("go2_bigoffice.db")) + with db: + yield db + + +def _make_obs(obs_id: int, points: np.ndarray, ts: float = 0.0) -> Observation[PointCloud2]: + return Observation(id=obs_id, ts=ts, _data=PointCloud2.from_numpy(points)) + + +def _unit_cube_points(n: int = 100) -> np.ndarray: + rng = np.random.default_rng(42) + return rng.random((n, 3)).astype(np.float32) + + +def test_accumulate_two_frames() -> None: + """Two non-overlapping frames produce a larger global map.""" + pts = _unit_cube_points(50) + obs1 = _make_obs(0, pts, ts=1.0) + obs2 = _make_obs(1, pts + 10.0, ts=2.0) # offset by 10m, no overlap + + xf = VoxelMapTransformer(voxel_size=0.5, carve_columns=False) + results = list(xf(iter([obs1, obs2]))) + + assert len(results) == 2 # emit_every=1 default + global_map = results[-1].data # last result has the full accumulated map + + single_results = list(VoxelMapTransformer(voxel_size=0.5)(iter([obs1]))) + assert len(global_map) > len(single_results[0].data) + + +def test_empty_stream() -> None: + xf = VoxelMapTransformer(voxel_size=0.5) + assert list(xf(iter([]))) == [] + + +def test_frame_count_tag() -> None: + pts = _unit_cube_points(30) + obs = [_make_obs(i, pts, ts=float(i)) for i in range(5)] + + xf = VoxelMapTransformer(voxel_size=0.5, device="CPU:0") + results = list(xf(iter(obs))) + + assert len(results) == 5 # emit_every=1 (default), one result per frame + assert results[-1].tags["frame_count"] == 5 + + +def test_emit_every_batch_mode() -> None: + """emit_every=0 yields only on exhaustion (batch mode).""" + pts = _unit_cube_points(30) + obs = [_make_obs(i, pts, ts=float(i)) for i in range(5)] + + xf = VoxelMapTransformer(voxel_size=0.5, device="CPU:0", emit_every=0) + results = list(xf(iter(obs))) + + assert len(results) == 1 + assert results[0].tags["frame_count"] == 5 + + +def test_emit_every_n() -> None: + """emit_every=3 yields after every 3rd frame, plus remainder on exhaustion.""" + pts = _unit_cube_points(30) + obs = [_make_obs(i, pts, ts=float(i)) for i in range(7)] + + xf = VoxelMapTransformer(voxel_size=0.5, device="CPU:0", emit_every=3) + results = list(xf(iter(obs))) + + # 7 frames / emit_every=3 → yields at frame 3, 6, then remainder (7) on exhaustion + assert len(results) == 3 + assert results[0].tags["frame_count"] == 3 + assert results[1].tags["frame_count"] == 6 + assert results[2].tags["frame_count"] == 7 + + +# -- Integration tests against real replay data -- + + +@pytest.mark.tool +def test_build_global_map(store: SqliteStore) -> None: + t_total = time.perf_counter() + + lidar = store.stream("lidar", PointCloud2) + n_frames = lidar.count() + + t0 = time.perf_counter() + result = lidar.transform(VoxelMapTransformer(voxel_size=0.05)).last() + t_transform = time.perf_counter() - t0 + + t_total = time.perf_counter() - t_total + + global_map = result.data + frame_count = result.tags["frame_count"] + + assert frame_count == n_frames + assert len(global_map) > 0 + + print( + lidar.summary(), + f"\n{frame_count} frames -> {len(global_map)} voxels" + f"\n transform: {t_transform:.2f}s ({t_transform / frame_count * 1000:.1f}ms/frame)" + f"\n total wall: {t_total:.2f}s", + ) diff --git a/dimos/memory2/transform.py b/dimos/memory2/transform.py index 20d6bf0baf..5754ac36e3 100644 --- a/dimos/memory2/transform.py +++ b/dimos/memory2/transform.py @@ -105,6 +105,19 @@ def __call__(self, upstream: Iterator[Observation[T]]) -> Iterator[Observation[R yield o.derive(data=r) +def stride(n: int) -> FnIterTransformer[T, T]: + """Yield every *n*-th observation, skipping the rest.""" + if n < 1: + raise ValueError(f"stride(n) requires n >= 1, got {n}") + + def _stride(upstream: Iterator[Observation[T]]) -> Iterator[Observation[T]]: + for i, obs in enumerate(upstream): + if i % n == 0: + yield obs + + return FnIterTransformer(_stride) + + class QualityWindow(Transformer[T, T]): """Keeps the highest-quality item per time window. diff --git a/dimos/memory2/type/observation.py b/dimos/memory2/type/observation.py index 0a6dd16ea5..03a8819867 100644 --- a/dimos/memory2/type/observation.py +++ b/dimos/memory2/type/observation.py @@ -22,6 +22,7 @@ from collections.abc import Callable from dimos.models.embedding.base import Embedding + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped T = TypeVar("T") @@ -50,6 +51,15 @@ class Observation(Generic[T]): _loader: Callable[[], T] | None = field(default=None, repr=False) _data_lock: threading.Lock = field(default_factory=threading.Lock, repr=False) + @property + def pose_stamped(self) -> PoseStamped: + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped + + if self.pose is None: + raise LookupError("No pose set on this observation") + x, y, z, qx, qy, qz, qw = self.pose + return PoseStamped(ts=self.ts, position=(x, y, z), orientation=(qx, qy, qz, qw)) + @property def data(self) -> T: val = self._data diff --git a/dimos/memory2/utils/sqlite.py b/dimos/memory2/utils/sqlite.py index e242a6e1f5..02a48f22b7 100644 --- a/dimos/memory2/utils/sqlite.py +++ b/dimos/memory2/utils/sqlite.py @@ -14,12 +14,13 @@ from __future__ import annotations +from pathlib import Path import sqlite3 from reactivex.disposable import Disposable -def open_sqlite_connection(path: str) -> sqlite3.Connection: +def open_sqlite_connection(path: str | Path) -> sqlite3.Connection: """Open a WAL-mode SQLite connection with sqlite-vec loaded.""" import sqlite_vec @@ -33,7 +34,7 @@ def open_sqlite_connection(path: str) -> sqlite3.Connection: def open_disposable_sqlite_connection( - path: str, + path: str | Path, ) -> tuple[Disposable, sqlite3.Connection]: """Open a WAL-mode SQLite connection and return (disposable, connection). diff --git a/dimos/memory2/vectorstore/sqlite.py b/dimos/memory2/vectorstore/sqlite.py index cd6573cc0c..31ebba45d6 100644 --- a/dimos/memory2/vectorstore/sqlite.py +++ b/dimos/memory2/vectorstore/sqlite.py @@ -76,7 +76,7 @@ def start(self) -> None: if self._conn is None: assert self._path is not None disposable, self._conn = open_disposable_sqlite_connection(self._path) - self.register_disposables(disposable) + self.register_disposable(disposable) def put(self, stream_name: str, key: int, embedding: Embedding) -> None: vec = embedding.to_numpy().tolist() diff --git a/dimos/navigation/bbox_navigation.py b/dimos/navigation/bbox_navigation.py index c96ba9efad..2be8015721 100644 --- a/dimos/navigation/bbox_navigation.py +++ b/dimos/navigation/bbox_navigation.py @@ -48,10 +48,10 @@ def start(self) -> None: unsub = self.camera_info.subscribe( lambda msg: setattr(self, "camera_intrinsics", [msg.K[0], msg.K[4], msg.K[2], msg.K[5]]) ) - self._disposables.add(Disposable(unsub)) + self.register_disposable(Disposable(unsub)) unsub = self.detection2d.subscribe(self._on_detection) - self._disposables.add(Disposable(unsub)) + self.register_disposable(Disposable(unsub)) @rpc def stop(self) -> None: diff --git a/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py b/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py index 4f5ade1a6f..a12c55f99c 100644 --- a/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py +++ b/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py @@ -154,22 +154,22 @@ def start(self) -> None: super().start() unsub = self.global_costmap.subscribe(self._on_costmap) - self._disposables.add(Disposable(unsub)) + self.register_disposable(Disposable(unsub)) unsub = self.odom.subscribe(self._on_odometry) - self._disposables.add(Disposable(unsub)) + self.register_disposable(Disposable(unsub)) if self.goal_reached.transport is not None: unsub = self.goal_reached.subscribe(self._on_goal_reached) - self._disposables.add(Disposable(unsub)) + self.register_disposable(Disposable(unsub)) if self.explore_cmd.transport is not None: unsub = self.explore_cmd.subscribe(self._on_explore_cmd) - self._disposables.add(Disposable(unsub)) + self.register_disposable(Disposable(unsub)) if self.stop_explore_cmd.transport is not None: unsub = self.stop_explore_cmd.subscribe(self._on_stop_explore_cmd) - self._disposables.add(Disposable(unsub)) + self.register_disposable(Disposable(unsub)) @rpc def stop(self) -> None: diff --git a/dimos/navigation/patrolling/module.py b/dimos/navigation/patrolling/module.py index 49cfea8342..72a2b821fc 100644 --- a/dimos/navigation/patrolling/module.py +++ b/dimos/navigation/patrolling/module.py @@ -64,11 +64,11 @@ def __init__(self, g: GlobalConfig = global_config) -> None: def start(self) -> None: super().start() - self._disposables.add(Disposable(self.odom.subscribe(self._on_odom))) - self._disposables.add( + self.register_disposable(Disposable(self.odom.subscribe(self._on_odom))) + self.register_disposable( Disposable(self.global_costmap.subscribe(self._router.handle_occupancy_grid)) ) - self._disposables.add(Disposable(self.goal_reached.subscribe(self._on_goal_reached))) + self.register_disposable(Disposable(self.goal_reached.subscribe(self._on_goal_reached))) @rpc def stop(self) -> None: diff --git a/dimos/navigation/replanning_a_star/module.py b/dimos/navigation/replanning_a_star/module.py index 26c540a254..2375af20ce 100644 --- a/dimos/navigation/replanning_a_star/module.py +++ b/dimos/navigation/replanning_a_star/module.py @@ -53,16 +53,18 @@ def __init__(self, **kwargs: Any) -> None: def start(self) -> None: super().start() - self._disposables.add(Disposable(self.odom.subscribe(self._planner.handle_odom))) - self._disposables.add( + self.register_disposable(Disposable(self.odom.subscribe(self._planner.handle_odom))) + self.register_disposable( Disposable(self.global_costmap.subscribe(self._planner.handle_global_costmap)) ) - self._disposables.add( + self.register_disposable( Disposable(self.goal_request.subscribe(self._planner.handle_goal_request)) ) - self._disposables.add(Disposable(self.target.subscribe(self._planner.handle_goal_request))) + self.register_disposable( + Disposable(self.target.subscribe(self._planner.handle_goal_request)) + ) - self._disposables.add( + self.register_disposable( Disposable( self.clicked_point.subscribe( lambda pt: self._planner.handle_goal_request(pt.to_pose_stamped()) @@ -70,14 +72,14 @@ def start(self) -> None: ) ) - self._disposables.add(self._planner.path.subscribe(self.path.publish)) + self.register_disposable(self._planner.path.subscribe(self.path.publish)) - self._disposables.add(self._planner.cmd_vel.subscribe(self.cmd_vel.publish)) + self.register_disposable(self._planner.cmd_vel.subscribe(self.cmd_vel.publish)) - self._disposables.add(self._planner.goal_reached.subscribe(self.goal_reached.publish)) + self.register_disposable(self._planner.goal_reached.subscribe(self.goal_reached.publish)) if "DEBUG_NAVIGATION" in os.environ: - self._disposables.add( + self.register_disposable( self._planner.navigation_costmap.subscribe(self.navigation_costmap.publish) ) diff --git a/dimos/navigation/rosnav.py b/dimos/navigation/rosnav.py index ec8cc495c3..9b3bdd71f0 100644 --- a/dimos/navigation/rosnav.py +++ b/dimos/navigation/rosnav.py @@ -132,7 +132,7 @@ def __init__(self, **kwargs: Any) -> None: def start(self) -> None: self._running = True - self._disposables.add( + self.register_disposable( self._local_pointcloud_subject.pipe( ops.sample(1.0 / self.config.local_pointcloud_freq), ).subscribe( @@ -141,7 +141,7 @@ def start(self) -> None: ) ) - self._disposables.add( + self.register_disposable( self._global_map_subject.pipe( ops.sample(1.0 / self.config.global_map_freq), ).subscribe( diff --git a/dimos/perception/detection/type/detection2d/bbox.py b/dimos/perception/detection/type/detection2d/bbox.py index 9ce3f11b96..a38d6f3bce 100644 --- a/dimos/perception/detection/type/detection2d/bbox.py +++ b/dimos/perception/detection/type/detection2d/bbox.py @@ -98,33 +98,6 @@ def to_repr_dict(self) -> dict[str, Any]: "bbox": f"[{x1:.0f},{y1:.0f},{x2:.0f},{y2:.0f}]", } - def center_to_3d( - self, - pixel: tuple[int, int], - camera_info: CameraInfo, # type: ignore[name-defined] - assumed_depth: float = 1.0, - ) -> PoseStamped: # type: ignore[name-defined] - """Unproject 2D pixel coordinates to 3D position in camera optical frame. - - Args: - camera_info: Camera calibration information - assumed_depth: Assumed depth in meters (default 1.0m from camera) - - Returns: - Vector3 position in camera optical frame coordinates - """ - # Extract camera intrinsics - fx, fy = camera_info.K[0], camera_info.K[4] - cx, cy = camera_info.K[2], camera_info.K[5] - - # Unproject pixel to normalized camera coordinates - x_norm = (pixel[0] - cx) / fx - y_norm = (pixel[1] - cy) / fy - - # Create 3D point at assumed depth in camera optical frame - # Camera optical frame: X right, Y down, Z forward - return Vector3(x_norm * assumed_depth, y_norm * assumed_depth, assumed_depth) # type: ignore[name-defined] - # return focused image, only on the bbox def cropped_image(self, padding: int = 20) -> Image: """Return a cropped version of the image focused on the bounding box. diff --git a/dimos/perception/experimental/temporal_memory/temporal_memory.py b/dimos/perception/experimental/temporal_memory/temporal_memory.py index da9fe62370..3342ef9a5e 100644 --- a/dimos/perception/experimental/temporal_memory/temporal_memory.py +++ b/dimos/perception/experimental/temporal_memory/temporal_memory.py @@ -297,11 +297,11 @@ def _on_frame(img: Image) -> None: f"buffered={len(self._accumulator._buffer)}" ) - self._disposables.add( + self.register_disposable( frame_subject.pipe(sharpness_barrier(self.config.fps)).subscribe(_on_frame) ) unsub_image = self.color_image.subscribe(frame_subject.on_next) - self._disposables.add(Disposable(unsub_image)) + self.register_disposable(Disposable(unsub_image)) # Odometry tracking for entity world positioning (optional — # module works without it, entities just won't have world positions) @@ -313,14 +313,14 @@ def _on_odom(msg: PoseStamped) -> None: if self.odom.transport is not None: unsub_odom = self.odom.subscribe(_on_odom) - self._disposables.add(Disposable(unsub_odom)) + self.register_disposable(Disposable(unsub_odom)) else: logger.warning( "[temporal-memory] odom stream not connected — entity positions will be (0,0,0)" ) # Periodic window analysis - self._disposables.add( + self.register_disposable( interval(self.config.stride_s).subscribe(lambda _: self._analyze_window()) ) logger.info("TemporalMemory started") diff --git a/dimos/perception/experimental/temporal_memory/test_temporal_memory_module.py b/dimos/perception/experimental/temporal_memory/test_temporal_memory_module.py index 0b1dd15ca5..17c56fd397 100644 --- a/dimos/perception/experimental/temporal_memory/test_temporal_memory_module.py +++ b/dimos/perception/experimental/temporal_memory/test_temporal_memory_module.py @@ -537,7 +537,7 @@ def emit_frames(observer, scheduler): # type: ignore[no-untyped-def] time.sleep(0.5) observer.on_completed() - self._disposables.add( + self.register_disposable( reactivex.create(emit_frames) .pipe( ops.observe_on(reactivex.scheduler.NewThreadScheduler()), diff --git a/dimos/perception/object_tracker.py b/dimos/perception/object_tracker.py index a20ecf24ac..b4570547a5 100644 --- a/dimos/perception/object_tracker.py +++ b/dimos/perception/object_tracker.py @@ -148,7 +148,7 @@ def on_aligned_frames(frames_tuple) -> None: # type: ignore[no-untyped-def] match_tolerance=0.5, # 500ms tolerance ) unsub = aligned_frames.subscribe(on_aligned_frames) - self._disposables.add(unsub) + self.register_disposable(unsub) # Subscribe to camera info stream separately (doesn't need alignment) def on_camera_info(camera_info_msg: CameraInfo) -> None: @@ -163,7 +163,7 @@ def on_camera_info(camera_info_msg: CameraInfo) -> None: ] unsub = self.camera_info.subscribe(on_camera_info) # type: ignore[assignment] - self._disposables.add(Disposable(unsub)) # type: ignore[arg-type] + self.register_disposable(Disposable(unsub)) # type: ignore[arg-type] @rpc def stop(self) -> None: diff --git a/dimos/perception/object_tracker_2d.py b/dimos/perception/object_tracker_2d.py index e5059fdd22..abf0674632 100644 --- a/dimos/perception/object_tracker_2d.py +++ b/dimos/perception/object_tracker_2d.py @@ -97,7 +97,7 @@ def on_frame(frame_msg: Image) -> None: self._frame_arrival_time = arrival_time unsub = self.color_image.subscribe(on_frame) - self._disposables.add(Disposable(unsub)) + self.register_disposable(Disposable(unsub)) logger.info("ObjectTracker2D module started") @rpc diff --git a/dimos/perception/object_tracker_3d.py b/dimos/perception/object_tracker_3d.py index 317a58dba0..f6945920fb 100644 --- a/dimos/perception/object_tracker_3d.py +++ b/dimos/perception/object_tracker_3d.py @@ -99,7 +99,7 @@ def on_aligned_frames(frames_tuple) -> None: # type: ignore[no-untyped-def] match_tolerance=0.5, # 500ms tolerance ) unsub = aligned_frames.subscribe(on_aligned_frames) - self._disposables.add(unsub) + self.register_disposable(unsub) # Subscribe to camera info def on_camera_info(camera_info_msg: CameraInfo) -> None: diff --git a/dimos/perception/spatial_perception.py b/dimos/perception/spatial_perception.py index 66502d8f80..9e27030d4a 100644 --- a/dimos/perception/spatial_perception.py +++ b/dimos/perception/spatial_perception.py @@ -196,10 +196,10 @@ def set_video(image_msg: Image) -> None: else: logger.warning("Received image message without data attribute") - self._disposables.add(Disposable(self.color_image.subscribe(set_video))) + self.register_disposable(Disposable(self.color_image.subscribe(set_video))) # Start periodic processing using interval - self._disposables.add( + self.register_disposable( interval(self._process_interval).subscribe(lambda _: self._process_frame()) ) diff --git a/dimos/robot/drone/connection_module.py b/dimos/robot/drone/connection_module.py index c1caaf609d..b2704fd95b 100644 --- a/dimos/robot/drone/connection_module.py +++ b/dimos/robot/drone/connection_module.py @@ -22,7 +22,7 @@ from typing import Any from dimos_lcm.std_msgs import String -from reactivex.disposable import CompositeDisposable, Disposable +from reactivex.disposable import Disposable from dimos.agents.annotation import skill from dimos.constants import DEFAULT_THREAD_JOIN_TIMEOUT @@ -43,13 +43,6 @@ logger = setup_logger() -def _add_disposable(composite: CompositeDisposable, item: Disposable | Any) -> None: - if isinstance(item, Disposable): - composite.add(item) - elif callable(item): - composite.add(Disposable(item)) - - class Config(ModuleConfig): connection_string: str = "udp:0.0.0.0:14550" video_port: int = 5600 @@ -127,8 +120,7 @@ def start(self) -> None: if self.video_stream.start(): logger.info("Video stream started") # Subscribe to video, store latest frame and publish it - _add_disposable( - self._disposables, + self.register_disposable( self.video_stream.get_stream().subscribe(self._store_and_publish_frame), ) # # TEMPORARY - DELETE AFTER RECORDING @@ -140,29 +132,25 @@ def start(self) -> None: logger.warning("Video stream failed to start") # Subscribe to drone streams - _add_disposable( - self._disposables, self.connection.odom_stream().subscribe(self._publish_tf) - ) - _add_disposable( - self._disposables, self.connection.status_stream().subscribe(self._publish_status) - ) - _add_disposable( - self._disposables, self.connection.telemetry_stream().subscribe(self._publish_telemetry) + self.register_disposable(self.connection.odom_stream().subscribe(self._publish_tf)) + self.register_disposable(self.connection.status_stream().subscribe(self._publish_status)) + self.register_disposable( + self.connection.telemetry_stream().subscribe(self._publish_telemetry) ) # Subscribe to movement commands - _add_disposable(self._disposables, self.movecmd.subscribe(self.move)) + self.register_disposable(Disposable(self.movecmd.subscribe(self.move))) # Subscribe to Twist movement commands if self.movecmd_twist.transport: - _add_disposable(self._disposables, self.movecmd_twist.subscribe(self._on_move_twist)) + self.register_disposable(Disposable(self.movecmd_twist.subscribe(self._on_move_twist))) if self.gps_goal.transport: - _add_disposable(self._disposables, self.gps_goal.subscribe(self._on_gps_goal)) + self.register_disposable(Disposable(self.gps_goal.subscribe(self._on_gps_goal))) if self.tracking_status.transport: - _add_disposable( - self._disposables, self.tracking_status.subscribe(self._on_tracking_status) + self.register_disposable( + Disposable(self.tracking_status.subscribe(self._on_tracking_status)) ) # Start telemetry update thread diff --git a/dimos/robot/drone/test_drone.py b/dimos/robot/drone/test_drone.py index 0b30c22c35..2b9517614a 100644 --- a/dimos/robot/drone/test_drone.py +++ b/dimos/robot/drone/test_drone.py @@ -240,13 +240,13 @@ def test_connection_module_replay_mode(self) -> None: mock_conn_instance = MagicMock() mock_conn_instance.connected = True mock_conn_instance.odom_stream.return_value.subscribe = MagicMock( - return_value=lambda: None + return_value=MagicMock() ) mock_conn_instance.status_stream.return_value.subscribe = MagicMock( - return_value=lambda: None + return_value=MagicMock() ) mock_conn_instance.telemetry_stream.return_value.subscribe = MagicMock( - return_value=lambda: None + return_value=MagicMock() ) mock_conn_instance.disconnect = MagicMock() mock_fake_conn.return_value = mock_conn_instance @@ -255,7 +255,7 @@ def test_connection_module_replay_mode(self) -> None: mock_video_instance = MagicMock() mock_video_instance.start.return_value = True mock_video_instance.get_stream.return_value.subscribe = MagicMock( - return_value=lambda: None + return_value=MagicMock() ) mock_video_instance.stop = MagicMock() mock_fake_video.return_value = mock_video_instance @@ -264,7 +264,7 @@ def test_connection_module_replay_mode(self) -> None: module = DroneConnectionModule(connection_string="replay") module.video = MagicMock() module.movecmd = MagicMock() - module.movecmd.subscribe = MagicMock(return_value=lambda: None) + module.movecmd.subscribe = MagicMock(return_value=MagicMock()) module.tf = MagicMock() try: diff --git a/dimos/robot/test_all_blueprints_generation.py b/dimos/robot/test_all_blueprints_generation.py index 48d482f3b6..d40ad2aed5 100644 --- a/dimos/robot/test_all_blueprints_generation.py +++ b/dimos/robot/test_all_blueprints_generation.py @@ -33,7 +33,7 @@ "dimos/core/test_blueprints.py", } BLUEPRINT_METHODS = {"transports", "global_config", "remappings", "requirements", "configurators"} -_EXCLUDED_MODULE_NAMES = {"Module", "ModuleBase"} +_EXCLUDED_MODULE_NAMES = {"Module", "ModuleBase", "StreamModule"} def test_all_blueprints_is_current() -> None: diff --git a/dimos/robot/unitree/b1/connection.py b/dimos/robot/unitree/b1/connection.py index 11af31b296..26fe3db933 100644 --- a/dimos/robot/unitree/b1/connection.py +++ b/dimos/robot/unitree/b1/connection.py @@ -121,24 +121,24 @@ def start(self) -> None: # Subscribe to input streams if self.cmd_vel: unsub = self.cmd_vel.subscribe(self.handle_twist_stamped) - self._disposables.add(Disposable(unsub)) + self.register_disposable(Disposable(unsub)) if self.mode_cmd: unsub = self.mode_cmd.subscribe(self.handle_mode) - self._disposables.add(Disposable(unsub)) + self.register_disposable(Disposable(unsub)) if self.odom_in: unsub = self.odom_in.subscribe(self._publish_odom_pose) - self._disposables.add(Disposable(unsub)) + self.register_disposable(Disposable(unsub)) # Subscribe to ROS In ports if self.ros_cmd_vel: unsub = self.ros_cmd_vel.subscribe(self.handle_twist_stamped) - self._disposables.add(Disposable(unsub)) + self.register_disposable(Disposable(unsub)) if self.ros_odom_in: unsub = self.ros_odom_in.subscribe(self._publish_odom_pose) - self._disposables.add(Disposable(unsub)) + self.register_disposable(Disposable(unsub)) if self.ros_tf: unsub = self.ros_tf.subscribe(self._on_ros_tf) - self._disposables.add(Disposable(unsub)) + self.register_disposable(Disposable(unsub)) # Start threads self.running = True diff --git a/dimos/robot/unitree/g1/connection.py b/dimos/robot/unitree/g1/connection.py index 6f43783d99..56eb0ff1b6 100644 --- a/dimos/robot/unitree/g1/connection.py +++ b/dimos/robot/unitree/g1/connection.py @@ -92,7 +92,7 @@ def start(self) -> None: assert self.connection is not None self.connection.start() - self._disposables.add(Disposable(self.cmd_vel.subscribe(self.move))) + self.register_disposable(Disposable(self.cmd_vel.subscribe(self.move))) @rpc def stop(self) -> None: diff --git a/dimos/robot/unitree/g1/sim.py b/dimos/robot/unitree/g1/sim.py index d83ba19368..6c09cfd5d3 100644 --- a/dimos/robot/unitree/g1/sim.py +++ b/dimos/robot/unitree/g1/sim.py @@ -70,10 +70,10 @@ def start(self) -> None: assert self.connection is not None self.connection.start() - self._disposables.add(Disposable(self.cmd_vel.subscribe(self.move))) - self._disposables.add(self.connection.odom_stream().subscribe(self._publish_sim_odom)) - self._disposables.add(self.connection.lidar_stream().subscribe(self.lidar.publish)) - self._disposables.add(self.connection.video_stream().subscribe(self.color_image.publish)) + self.register_disposable(Disposable(self.cmd_vel.subscribe(self.move))) + self.register_disposable(self.connection.odom_stream().subscribe(self._publish_sim_odom)) + self.register_disposable(self.connection.lidar_stream().subscribe(self.lidar.publish)) + self.register_disposable(self.connection.video_stream().subscribe(self.color_image.publish)) self._camera_info_thread = Thread( target=self._publish_camera_info_loop, diff --git a/dimos/robot/unitree/go2/connection.py b/dimos/robot/unitree/go2/connection.py index 9fbe3811f0..826e3a0b1f 100644 --- a/dimos/robot/unitree/go2/connection.py +++ b/dimos/robot/unitree/go2/connection.py @@ -238,10 +238,10 @@ def onimage(image: Image) -> None: self.color_image.publish(image) self._latest_video_frame = image - self._disposables.add(self.connection.lidar_stream().subscribe(self.lidar.publish)) - self._disposables.add(self.connection.odom_stream().subscribe(self._publish_tf)) - self._disposables.add(self.connection.video_stream().subscribe(onimage)) - self._disposables.add(Disposable(self.cmd_vel.subscribe(self.move))) + self.register_disposable(self.connection.lidar_stream().subscribe(self.lidar.publish)) + self.register_disposable(self.connection.odom_stream().subscribe(self._publish_tf)) + self.register_disposable(self.connection.video_stream().subscribe(onimage)) + self.register_disposable(Disposable(self.cmd_vel.subscribe(self.move))) self._camera_info_thread = Thread( target=self.publish_camera_info, diff --git a/dimos/robot/unitree/type/map.py b/dimos/robot/unitree/type/map.py index dcd5ddaa14..e1f420972a 100644 --- a/dimos/robot/unitree/type/map.py +++ b/dimos/robot/unitree/type/map.py @@ -67,11 +67,11 @@ def __init__(self, **kwargs: Any) -> None: def start(self) -> None: super().start() - self._disposables.add(Disposable(self.lidar.subscribe(self.add_frame))) + self.register_disposable(Disposable(self.lidar.subscribe(self.add_frame))) if self.global_publish_interval is not None: unsub = interval(self.global_publish_interval).subscribe(self._publish) - self._disposables.add(unsub) + self.register_disposable(unsub) @rpc def stop(self) -> None: diff --git a/dimos/simulation/unity/module.py b/dimos/simulation/unity/module.py index f67678c10b..67c05a3279 100644 --- a/dimos/simulation/unity/module.py +++ b/dimos/simulation/unity/module.py @@ -295,8 +295,8 @@ def __init__(self, **kwargs: Any) -> None: @rpc def start(self) -> None: super().start() - self._disposables.add(Disposable(self.cmd_vel.subscribe(self._on_cmd_vel))) - self._disposables.add(Disposable(self.terrain_map.subscribe(self._on_terrain))) + self.register_disposable(Disposable(self.cmd_vel.subscribe(self._on_cmd_vel))) + self.register_disposable(Disposable(self.terrain_map.subscribe(self._on_terrain))) self._running.set() self._sim_thread = threading.Thread(target=self._sim_loop, daemon=True) self._sim_thread.start() diff --git a/dimos/simulation/unity/test_unity_sim.py b/dimos/simulation/unity/test_unity_sim.py index 1e7dfea8b2..e0e740e081 100644 --- a/dimos/simulation/unity/test_unity_sim.py +++ b/dimos/simulation/unity/test_unity_sim.py @@ -65,6 +65,9 @@ def subscribe(self, cb, *_a): self._subscribers.append(cb) return lambda: self._subscribers.remove(cb) + def stop(self): + pass + def _wire(module) -> dict[str, _MockTransport]: ts = {} diff --git a/dimos/utils/demo_image_encoding.py b/dimos/utils/demo_image_encoding.py index 56048d3a40..39c31d1504 100644 --- a/dimos/utils/demo_image_encoding.py +++ b/dimos/utils/demo_image_encoding.py @@ -76,7 +76,7 @@ class ReceiverModule(Module): def start(self) -> None: super().start() - self._disposables.add(Disposable(self.image.subscribe(self._on_image))) + self.register_disposable(Disposable(self.image.subscribe(self._on_image))) self._open_file = open("/tmp/receiver-times", "w") def stop(self) -> None: diff --git a/dimos/utils/testing/collector.py b/dimos/utils/testing/collector.py index bcc3150e73..faf9464843 100644 --- a/dimos/utils/testing/collector.py +++ b/dimos/utils/testing/collector.py @@ -30,7 +30,7 @@ class CallbackCollector: assert len(collector.results) == 3 """ - def __init__(self, n: int, timeout: float = 2.0) -> None: + def __init__(self, n: int, timeout: float = 5.0) -> None: self.results: list[tuple[Any, Any]] = [] self._done = threading.Event() self._n = n diff --git a/dimos/visualization/rerun/bridge.py b/dimos/visualization/rerun/bridge.py index 2ef674133e..1f49295f68 100644 --- a/dimos/visualization/rerun/bridge.py +++ b/dimos/visualization/rerun/bridge.py @@ -341,12 +341,12 @@ def start(self) -> None: if hasattr(pubsub, "start"): pubsub.start() # type: ignore[union-attr] unsub = pubsub.subscribe_all(self._on_message) - self._disposables.add(Disposable(unsub)) + self.register_disposable(Disposable(unsub)) # Add pubsub stop as disposable for pubsub in self.config.pubsubs: if hasattr(pubsub, "stop"): - self._disposables.add(Disposable(pubsub.stop)) # type: ignore[union-attr] + self.register_disposable(Disposable(pubsub.stop)) # type: ignore[union-attr] self._log_static() diff --git a/dimos/web/websocket_vis/websocket_vis_module.py b/dimos/web/websocket_vis/websocket_vis_module.py index 39163bfb95..61cc7f9f75 100644 --- a/dimos/web/websocket_vis/websocket_vis_module.py +++ b/dimos/web/websocket_vis/websocket_vis_module.py @@ -174,25 +174,25 @@ def start(self) -> None: try: unsub = self.odom.subscribe(self._on_robot_pose) - self._disposables.add(Disposable(unsub)) + self.register_disposable(Disposable(unsub)) except Exception: ... try: unsub = self.gps_location.subscribe(self._on_gps_location) - self._disposables.add(Disposable(unsub)) + self.register_disposable(Disposable(unsub)) except Exception: ... try: unsub = self.path.subscribe(self._on_path) - self._disposables.add(Disposable(unsub)) + self.register_disposable(Disposable(unsub)) except Exception: ... try: unsub = self.global_costmap.subscribe(self._on_global_costmap) - self._disposables.add(Disposable(unsub)) + self.register_disposable(Disposable(unsub)) except Exception: ... diff --git a/docs/agents/index.md b/docs/agents/index.md index ec9d66e886..4170a0e898 100644 --- a/docs/agents/index.md +++ b/docs/agents/index.md @@ -1,19 +1,8 @@ # For Agents -These docs are mostly for coding agents - -```sh -tree . -P '*.md' --prune -``` - - -``` -. -├── docs +├── testing.md (docs about writing tests) +├── docs (these are docs about writing docs) │   ├── codeblocks.md │   ├── doclinks.md │   └── index.md └── index.md - -2 directories, 4 files -``` diff --git a/docs/agents/style.md b/docs/agents/style.md new file mode 100644 index 0000000000..37354cc681 --- /dev/null +++ b/docs/agents/style.md @@ -0,0 +1,49 @@ +# Code Style Guidelines + +Rules for writing code in dimos. These address recurring issues found in code review. + +## No comment banners + +Don't use decorative section dividers or box comments. + +```python +# BAD +# ═══════════════════════════════════════════════════════════════════ +# 1. Basic iteration +# ═══════════════════════════════════════════════════════════════════ + +# BAD +# ------------------------------------------------------------------- +# Section name +# ------------------------------------------------------------------- + +# GOOD — just use a plain comment if a section heading is needed +# Basic iteration +``` + +If a file has enough sections to warrant banners, it should probably be split into separate files instead. For example, instead of one large `test_something.py` with banner-separated sections, create a `something/` directory: + +``` +# BAD +test_something.py (500 lines with banner-separated sections) + +# GOOD +something/ + test_iteration.py + test_lifecycle.py + test_queries.py +``` + +## No `__init__.py` re-exports + +Never add imports to `__init__.py` files. Re-exporting from `__init__.py` makes imports too wide and slow — importing one symbol pulls in the entire package tree. + +```python +# BAD — dimos/memory2/__init__.py +from dimos.memory2.store import Store, SqliteStore +from dimos.memory2.stream import Stream + +# GOOD — import directly from the module +from dimos.memory2.store.base import Store +from dimos.memory2.stream import Stream +``` diff --git a/docs/agents/testing.md b/docs/agents/testing.md new file mode 100644 index 0000000000..45614c81d2 --- /dev/null +++ b/docs/agents/testing.md @@ -0,0 +1,149 @@ +# Testing Guidelines + +Rules for writing tests in dimos. These address recurring issues found in code review. + +For grid testing (spec/impl tests across multiple backends), see [Grid Testing Strategy](/docs/development/grid_testing.md). + +## Imports at the top + +All imports must be at module level, not inside test functions. + +```python +# BAD +def test_something() -> None: + import threading + from dimos.core.transport import pLCMTransport + ... + +# GOOD +import threading +from dimos.core.transport import pLCMTransport + +def test_something() -> None: + ... +``` + +## Always clean up resources + +Use context managers or try/finally. If a test creates a resource, it must be cleaned up even if assertions fail. + +```python +# BAD - store.stop() never called +def test_something() -> None: + store = ListObservationStore(name="test", max_size=0) + store.start() + assert store.count(StreamQuery()) == 0 + +# BAD - module.stop() skipped if assertion fails +def test_wiring() -> None: + module = MyModule() + module.start() + assert received == [84] + module.stop() + +# GOOD - context manager (ideal) +def test_something() -> None: + store = ListObservationStore(name="test", max_size=0) + with store: + assert store.count(StreamQuery()) == 0 + +# GOOD - try/finally +def test_wiring() -> None: + module = MyModule() + module.start() + try: + assert received == [84] + finally: + module.stop() +``` + +When a resource is shared across multiple tests, use a pytest fixture with `yield` instead of repeating context managers in each test: + +```python +# GOOD - fixture handles lifecycle for all tests that use it +@pytest.fixture(scope="module") +def store() -> Iterator[SqliteStore]: + db = SqliteStore(path=str(DB_PATH)) + with db: + yield db + +def test_query(store: SqliteStore) -> None: + assert store.stream("video", Image).count() > 0 + +def test_search(store: SqliteStore) -> None: + results = store.stream("video", Image).limit(5).fetch() + assert len(results) == 5 +``` + +## No conditional logic in assertions + +Tests must be deterministic. If you don't know the state, the test is wrong. + +```python +# BAD - assertion may never execute +if hasattr(obj, "_disposables") and obj._disposables is not None: + assert obj._disposables.is_disposed + +# BAD - masks whether disposables were created +assert obj._disposables is None or obj._disposables.is_disposed + +# GOOD - explicit about what we expect +assert obj._disposables is not None +assert obj._disposables.is_disposed +``` + +## Print statements + +- **Unit tests**: no prints. Use assertions. +- **`@pytest.mark.tool` tests** (integration/exploration): prints are fine for progress and inspection output. + +## Avoid unnecessary sleeps + +Don't use `time.sleep()` to wait for async operations. Use `threading.Event` to synchronize emitter/receiver patterns. + +```python +# BAD - arbitrary sleep, fragile +module.start() +time.sleep(0.5) +module.numbers.transport.publish(42) +time.sleep(1.0) +assert len(received) == 1 + +# GOOD - use threading.Event with a timeout +done = threading.Event() +unsub = module.doubled.subscribe(lambda msg: (received.append(msg), done.set())) +module.start() +module.numbers.transport.publish(42) +assert done.wait(timeout=5.0), f"Timed out, received={received}" +assert received == [84] +``` + +## Private fields + +Configuration fields on non-Pydantic classes should be private (underscore-prefixed) unless they are part of the public API. + +```python +# BAD +self.voxel_size = voxel_size +self.carve_columns = carve_columns + +# GOOD +self._voxel_size = voxel_size +self._carve_columns = carve_columns +``` + +## Type ignores + +Avoid `# type: ignore` by using proper types: + +```python +# BAD +self.vbg = None # type: ignore[assignment] + +# GOOD - type as Optional +self.vbg: VoxelBlockGrid | None = VoxelBlockGrid(...) +# then later: +self.vbg = None # no ignore needed +``` + +Type ignores are acceptable when caused by untyped third-party libraries (e.g. `open3d`) or decorator-generated attributes (e.g. `@simple_mcache` adding `invalidate_cache`). diff --git a/docs/capabilities/memory/.gitattributes b/docs/capabilities/memory/.gitattributes new file mode 100644 index 0000000000..03438845de --- /dev/null +++ b/docs/capabilities/memory/.gitattributes @@ -0,0 +1,7 @@ +assets/color_image.svg filter=lfs diff=lfs merge=lfs -text +assets/embedding.svg filter=lfs diff=lfs merge=lfs -text +assets/embedding_focused.svg filter=lfs diff=lfs merge=lfs -text +assets/imageposes.svg filter=lfs diff=lfs merge=lfs -text +assets/speed.svg filter=lfs diff=lfs merge=lfs -text +assets/brightness.svg filter=lfs diff=lfs merge=lfs -text +assets/grid.png filter=lfs diff=lfs merge=lfs -text diff --git a/docs/capabilities/memory/demo.md b/docs/capabilities/memory/demo.md new file mode 100644 index 0000000000..1a9f89982f --- /dev/null +++ b/docs/capabilities/memory/demo.md @@ -0,0 +1,68 @@ + +```python skip +from dimos.msgs.sensor_msgs.Image import Image +from dimos.utils.data import get_data +from dimos.memory2.store.sqlite import SqliteStore + + +store = SqliteStore(path=get_data("go2_bigoffice.db")) + +print(store.streams.color_image) + +``` + + +``` +Stream("color_image") +``` + +```python + # Downsample to 2Hz, then embed + pipeline = ( + video.filter(lambda obs: obs.data.brightness > 0.1) + .transform(QualityWindow(lambda img: img.sharpness, window=0.5)) + .transform(EmbedImages(clip)) + .save(embedded) + ) + + + +``` + + +``` +File "/tmp/tmpmfb4vuia.py", line 2 + pipeline = ( +IndentationError: unexpected indent + +Exit code: 1 +``` + +```python +import pickle +from dimos.mapping.pointclouds.occupancy import general_occupancy, simple_occupancy, height_cost_occupancy +from dimos.mapping.occupancy.inflation import simple_inflate +from dimos.memory2.store.sqlite import SqliteStore +from dimos.memory2.vis.drawing.drawing import Drawing2D +from dimos.utils.data import get_data +from dimos.memory2.vis.type import Point +from dimos.models.embedding.clip import CLIPModel + +clip = CLIPModel() + +#global_map = pickle.loads(get_data("unitree_go2_bigoffice_map.pickle").read_bytes()) +#drawing = Drawing2D() +#costmap = simple_inflate(general_occupancy(global_map), 0.05) +#drawing.add(costmap) + +store = SqliteStore(path=get_data("go2_bigoffice.db")) + +store.streams.color_image \ +.filter(lambda obs: obs.data.brightness > 0.1) \ +.map(drawing.add) + +drawing.to_svg("assets/test.svg") +``` + + +![output](assets/test.svg) diff --git a/docs/capabilities/memory/search.ipynb b/docs/capabilities/memory/search.ipynb new file mode 100644 index 0000000000..4c41553bc4 --- /dev/null +++ b/docs/capabilities/memory/search.ipynb @@ -0,0 +1,29 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from dimos.memory2.store.sqlite import SqliteStore\n", + "from dimos.models.embedding.clip import CLIPModel\n", + "from dimos.utils.data import get_data\n", + "\n", + "store = SqliteStore(path=get_data(\"go2_bigoffice.db\"))\n", + "clip = CLIPModel()\n", + "print(\"Ready\")\n", + "print(\"Streams:\", store.streams)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/capabilities/memory/test.db b/docs/capabilities/memory/test.db new file mode 100644 index 0000000000..2d416423c7 Binary files /dev/null and b/docs/capabilities/memory/test.db differ diff --git a/examples/simplerobot/simplerobot.py b/examples/simplerobot/simplerobot.py index 517684d7cd..902736f06a 100644 --- a/examples/simplerobot/simplerobot.py +++ b/examples/simplerobot/simplerobot.py @@ -68,11 +68,11 @@ class SimpleRobot(Module[SimpleRobotConfig]): @rpc def start(self) -> None: - self._disposables.add(self.cmd_vel.observable().subscribe(self._on_twist)) - self._disposables.add( + self.register_disposable(self.cmd_vel.observable().subscribe(self._on_twist)) + self.register_disposable( rx.interval(1.0 / self.config.update_rate).subscribe(lambda _: self._update()) ) - self._disposables.add( + self.register_disposable( rx.interval(1.0).subscribe(lambda _: print(f"\033[34m{self._pose}\033[0m")) ) diff --git a/zenoh.md b/zenoh.md new file mode 100644 index 0000000000..4eb8b283ac --- /dev/null +++ b/zenoh.md @@ -0,0 +1,478 @@ +# Zenoh Investigation + +Research for [#1726](https://github.com/dimensionalOS/dimos/issues/1726). + +## What is Zenoh + +Zenoh is a communication protocol and implementation by ZettaScale (the people behind the original DDS/CycloneDDS). Designed for the cloud-to-edge-to-microcontroller continuum. Written in Rust, with bindings for Python, C, C++, and others. + +Key difference from LCM/DDS: it unifies pub/sub, request/response, and distributed storage in one protocol, and works across unreliable/heterogeneous networks natively — not bolted on. + +## Communication Modes + +Three session modes: + +- **Peer:** Direct connectivity via multicast scouting (`224.0.0.224:7446`). Similar to LCM's current model but with more features. Peers auto-discover each other and routers via multicast/gossip. +- **Router:** Zenoh routers (`zenohd`) mediate between subnets. Routers **do not auto-discover each other** — every router-to-router link must be explicitly configured via `connect.endpoints`/`listen.endpoints`. Once connected, they exchange full link-state topology and compute spanning trees automatically. +- **Client:** Applications connect to a router as clients. Simplest for constrained devices. Clients auto-discover routers via multicast scouting. + +These can be mixed — robots peer locally, connect via router to cloud. + +### Router discovery and topology + +Routers join the multicast group (`224.0.0.224:7446`) and **respond** to scouting so clients/peers can find them, but they **never initiate connections from scouting**. This is by design — you control the router topology explicitly. + +Auto-connect *can* be enabled (`scouting.multicast.autoconnect: { router: ["router"] }`) but is off by default. Once router links are established, routers maintain two link-state networks: +- **Router graph** — full link-state of all routers (always on) +- **Linkstate-peer graph** — peers operating in linkstate mode (optional) + +Link-state carries **topology only** (who's connected to whom, with what weights) — NOT subscription state. Subscription declarations propagate separately via the spanning tree. + +## Pub/Sub + +- Key expressions are hierarchical paths (like MQTT topics): `/robot1/sensors/lidar` +- Wildcard matching: `*` (single level), `**` (multi-level) +- Discovery: multicast scouting locally, gossip protocol when multicast unavailable +- Late joiners can receive latest cached values via publication caches (zenoh-ext `AdvancedPublisher`) or storage nodes + +### How subscription interest propagates through routers + +The network has two subsystems with different behavior: + +**Router backbone (fully informed):** Routers propagate subscription declarations along the spanning tree. When a client behind Router B subscribes to `robot1/pose`, Router B sends a `DeclareSubscriber` to its peer routers. Router A receives it and starts forwarding matching publications toward Router B. If nobody behind Router B subscribes, **no data flows** — this is interest-based routing. + +**Client/peer edge (simple):** Since Zenoh 1.0.0, subscription declarations are **not** propagated to clients/peers. All publications from a client/peer go to their nearest router unconditionally — the router decides what to forward. This is a scalability optimization: edge devices don't get flooded with subscription state from the entire network. + +**Consequence:** Writer-side filtering (don't publish if nobody's listening) only works at the edge if the publisher explicitly declares a `Publisher` object, which triggers an Interest exchange with the router. A bare `session.put()` always goes to the router. + +**Wildcard gotcha:** A wildcard storage (`key_expr: "**"`) declares itself as a subscriber for everything. This causes the hosting router to pull ALL publications from every connected router. Be careful with broad storage key expressions on routers that bridge constrained links. + +**Interest timeout:** `routing.interests.timeout` (default 10s) — when a node connects and sends Interests to discover existing subscribers, it waits this long for the router to reply. If it times out, discovery may be incomplete (potential message loss at startup). + +## Request/Response (Queryables) + +- `Session.get(key_expr)` sends a query, returns zero or more replies +- Any node can declare a `queryable` that handles queries for a key expression +- A queryable can send multiple replies (streaming) — maps well to LLM token streaming, chunked image transfer +- Query payload is arbitrary — you can encode whatever request format you want +- Queries can be cancelled (since 1.7.x) — relevant for interactive/long-running queries + +## Shared Memory (Zero-Copy Local IPC) + +Zenoh has built-in SHM transport for zero-copy communication between processes on the same machine: + +- Automatic: messages above a threshold (default 3 KB) are placed in SHM instead of copied +- Default pool size: 16 MB, configurable +- POSIX shared memory protocol (`PosixShmProvider`) +- Available in Rust, C, C++, and Python (since 1.6.x) +- Automatic garbage collection and defragmentation +- Typed SHM buffers with dynamic resize + +Relevant for us: local sensor pipelines between co-located processes can bypass network serialization entirely, similar to our existing `SHMTransport` but with Zenoh's routing/discovery on top. + +## Liveliness (Presence Detection) + +Built-in presence/health system tied to session lifetime: + +- `session.liveliness().declare_token("robot/robot1")` — announce presence +- `session.liveliness().declare_subscriber("robot/**")` — get `Put` on appearance, `Delete` on disappearance +- `session.liveliness().get("robot/**")` — query currently live tokens +- Crashes/disconnects automatically trigger `Delete` — no heartbeat polling needed + +Directly useful for robot fleet management, module health monitoring. + +## Timestamps (Hybrid Logical Clocks) + +Every value in Zenoh gets a HLC timestamp automatically (when passing through a router): + +- Timestamp = NTP64 time + HLC UUID +- Theoretical resolution: ~3.5 nanoseconds +- Guarantees: unique timestamps, happened-before ordering preserved, no consensus required +- Used by the storage alignment protocol for eventual consistency + +## Reliability on Unreliable Networks + +- Minimum 5-byte wire overhead (3-byte data message + 2-byte frame header; multiple messages share the frame overhead) +- Works at OSI Layer 2 — supports non-IP networks: BLE, serial, LoRa, CANbus +- Two channel types: best-effort and reliable (ordered delivery) +- Three reliability strategies: + 1. **Hop-by-hop** (default) — single reliability state per app, reliable under stable topology + 2. **End-to-end** — dedicated reliability channel per producer-consumer pair, no loss even during topology changes, higher resource cost + 3. **First-router-to-last-router** — reliability between edge routers, offloading pressure from endpoints to infrastructure +- Designed for WiFi/GSM from the start, unlike DDS which was designed for reliable LAN + +## Backpressure & Congestion Control + +Three independently configurable concerns: + +1. **Resending (receiver controls):** Declares whether missing messages should be resent. Hop-by-hop or end-to-end. +2. **Buffering (each node controls):** Each node decides how much memory to dedicate. Constrained devices use minimal, routers buffer more. +3. **Dropping (sender controls):** Congestion control policy — drop sample or block publication. Propagates through routing path. + +## QoS + +- **Priority levels:** `RealTime`, `InteractiveHigh`, `InteractiveLow`, `DataHigh`, `Data` (default), `DataLow`, `Background` +- **Reliability:** best-effort vs reliable per-channel +- **Congestion control:** drop vs block +- **Express mode:** Bypass batching for immediate transmission (latency vs throughput tradeoff) +- **Locality control:** `SessionLocal` / `Remote` — control whether messages stay local-only, go remote-only, or both +- **Downsampling:** Configurable per-key-expression rate limiting — useful for high-frequency sensor data on slow links +- All configurable per key-expression via config file + +## Performance + +Official benchmarks (2023, Zenoh team + NTU, peer mode, single machine): + +| Metric | 8-byte payload | 8 KB payload | +|-----------------|-------------------|---------------------| +| Messages/sec | ~4M msg/s | — | +| Throughput | — | ~67 Gbps | +| Latency | ~10 us (single machine), ~16 us (multi-machine, 64B) | | + +Comparisons (vary by payload size and conditions): + +| vs | Messages/sec (8B) | Throughput (8KB) | Latency | +|----------|-------------------|------------------|----------------| +| MQTT | ~130x faster | ~27x higher | significantly lower | +| DDS | ~2x faster | ~4x higher | comparable (DDS can be lower single-machine) | + +Source: [Zenoh benchmarks blog](https://zenoh.io/blog/2023-03-21-zenoh-vs-mqtt-kafka-dds/), [arXiv paper](https://arxiv.org/abs/2303.09419) + +## Scope Control (Local vs Internet) + +Controlled via deployment topology and access control: + +- **Local:** Peer mode, multicast scouting, stays on LAN +- **Site-wide:** Routers connect subnets +- **Internet:** Routers with TCP/TLS/QUIC endpoints +- **Access control:** ACLs per key-expression restrict what data flows where and to whom + +No magic — you design the router topology to match your network architecture. + +## Traffic Control at Routers + +There is no single "routing filter" config. Instead, five layered mechanisms combine to control what crosses which link. All configured per-router in JSON5. + +### 1. Interest-Based Routing (automatic) + +Routers only forward data if a downstream subscriber exists for that key expression. No subscriber on the far side of the LoRa link = no data flows. This is automatic but **not sufficient alone** — a single wildcard subscriber (`**`) on the far side would pull everything. + +### 2. ACL (deny key expressions per interface/link) + +The primary hard filter. Block specific key expressions from egressing on specific interfaces or link protocols: + +```json5 +{ + access_control: { + enabled: true, + default_permission: "allow", + rules: [ + { + id: "block_highbw_on_lora", + messages: ["put", "delete"], + flows: ["egress"], + permission: "deny", + key_exprs: ["robot1/sensors/lidar/**", "robot1/sensors/camera/**"] + } + ], + subjects: [ + // match by interface name, link protocol, TLS cert CN, username, or Zenoh ID + { id: "lora_link", interfaces: ["ttyLoRa0"], link_protocols: ["serial"] } + ], + policies: [ + { id: "lora_restrict", rules: ["block_highbw_on_lora"], subjects: ["lora_link"] } + ] + } +} +``` + +Limitation: not runtime-updatable, requires restart. + +### 3. Multilink (route by priority to different physical links) + +A single router can listen on multiple transports simultaneously. Each endpoint declares which priority range and reliability mode it carries. Zenoh automatically routes messages to the matching link: + +```json5 +{ + listen: { + endpoints: [ + "tcp/0.0.0.0:7447?prio=1-5;rel=1", // priorities 1-5 over WiFi (reliable) + "serial//dev/ttyLoRa0#baudrate=9600?prio=6-7;rel=0" // priorities 6-7 over LoRa (best-effort) + ] + } +} +``` + +Publishers choose their priority: +```python +session.declare_publisher("robot1/pose", priority=Priority.INTERACTIVE_HIGH) # → WiFi +session.declare_publisher("robot1/battery", priority=Priority.BACKGROUND) # → LoRa +``` + +Priority values: 1=RealTime through 7=Background. Also supports DSCP marking for IP-level QoS (`tcp/...#dscp=0x08`). + +### 4. Downsampling (rate-limit per key expression per interface) + +Rate-limit what crosses a link, even if ACL allows it: + +```json5 +{ + downsampling: [ + { + interfaces: ["ttyLoRa0"], + link_protocols: ["serial"], + flows: ["egress"], + messages: ["put", "delete"], + rules: [ + { key_expr: "robot1/pose", freq: 1.0 }, // 100 Hz → 1 Hz on LoRa + { key_expr: "robot1/battery", freq: 0.1 }, // once per 10s + { key_expr: "fleet/status/**", freq: 0.5 } + ] + } + ] +} +``` + +### 5. Low-Pass Filter (message size limit per interface) + +Drop messages exceeding a size threshold — prevents accidentally large payloads from saturating a constrained link: + +```json5 +{ + low_pass_filter: [ + { + interfaces: ["ttyLoRa0"], + link_protocols: ["serial"], + flows: ["egress"], + messages: ["put", "delete"], + key_exprs: ["**"], + size_limit: 256 // bytes + } + ] +} +``` + +### 6. Locality API (application-level) + +Publishers/subscribers can declare whether they're local-only or remote-only: + +```python +# Never leaves this process — internal debug data +session.declare_publisher("robot1/internal/debug", locality=Locality.SESSION_LOCAL) +``` + +### Combining them (example: robot over LoRa) + +For a robot with lidar/video locally and poses/commands over LoRa, layer: +1. **ACL** blocks lidar/camera/pointcloud from egressing on the serial interface +2. **Multilink** routes high-priority control on WiFi, low-priority telemetry on LoRa +3. **Downsampling** reduces pose from 100 Hz to 1 Hz on the LoRa egress +4. **Low-pass filter** caps message size at 256 bytes on LoRa as a safety net +5. **Locality** keeps internal debug topics process-local +6. **Interest-based routing** automatically avoids forwarding anything nobody subscribed to + +## Video / Teleop + +- **gst-plugin-zenoh:** Third-party GStreamer plugin for video over Zenoh (not official eclipse-zenoh) +- Express mode for low-latency frames +- Plugin supports optional compression (zstd, lz4, gzip) as compile-time features +- NAT traversal via router relay (router with public IP mediates), not ICE/STUN hole-punching +- For production teleop: deploy a cloud-side Zenoh router as relay endpoint + +Note: Zenoh itself has transport-level compression (enabled via `transport.unicast.compression.enabled` in config), but it's all-or-nothing for a session — you don't pick the algorithm. The GStreamer plugin's compression is application-level and per-stream. + +## Router (`zenohd`) + +The Zenoh router is a standalone daemon: + +- Installed as binary (Debian packages, Docker images, `cargo install zenohd`) +- Default REST plugin on **port 8000** — maps HTTP verbs to Zenoh operations (GET=get, PUT=put, DELETE=remove, GET+SSE=subscribe) +- Plugin system: loads dynamic libraries (`libzenoh_plugin_.so`) +- Configured via JSON5 config file or CLI arguments +- **Admin space:** Every node exposes internal state under the `@` key prefix, queryable via standard Zenoh or REST API + +## Built-in Storage (Distributed KV) + +Storage manager plugin watches key expressions and persists values. Five backends: + +- **Memory:** In-process RAM, lost on restart. Caching. +- **Filesystem:** Each key maps to a file in a configured directory. +- **RocksDB:** Embedded LSM-tree database. Persistent, fast, handles large volumes. +- **InfluxDB:** Time-series database backend. Good for sensor telemetry history. +- **S3:** Amazon S3 / MinIO compatible. Blob storage for large artifacts (maps, point clouds). + +Any node can run the storage plugin. Other nodes query transparently — `session.get("/map/global")` routes to whichever node stores it. Multiple nodes can store the same keys — alignment protocol (using HLC timestamps) syncs them (eventual consistency, not linearizable). + +Example config: +```json5 +{ + "plugins": { + "storage_manager": { + "storages": { + "robot_map": { + "key_expr": "/map/**", + "volume": "rocksdb", + "db": "/var/zenoh/map_db" + }, + "sensor_cache": { + "key_expr": "/sensors/**", + "volume": "memory" + } + } + } + } +} +``` + +### Storage limitations + +This is a **flat key-value store**, not a database. You can: +- Look up by exact key +- Scan by key prefix/range (lexicographic) +- Get latest value for late joiners + +You cannot: +- Query by value content or field values +- Do spatial queries ("find within 5m of X") +- Do temporal range queries on arbitrary fields +- Secondary indexes, joins, aggregation + +### RocksDB specifics + +Embedded library (no server process). LSM-tree — optimized for fast writes. Keys sorted lexicographically, only primary key index. Bloom filters for fast point lookups. Column families for namespaces. TTL for auto-expiry. Compression per level (snappy, zstd, lz4). + +For real spatial/temporal queries, you'd implement a queryable node that runs actual DB logic (PostGIS, SQLite R-tree, etc.) and returns results via Zenoh's reply mechanism. Zenoh transports the query/response, your code does the indexing. + +## Access Control + +ACL system with three components: + +- **Rules:** Define permission (allow/deny), message types (put, delete, declare_subscriber, query, reply, liveliness_token, etc.), flows (ingress/egress), and key expressions +- **Subjects:** Match remote nodes by network interface, TLS cert common name, or username +- **Policies:** Link rules to subjects +- `default_permission`: global allow/deny fallback +- **Limitation:** Cannot be updated at runtime — requires restart + +## What Zenoh doesn't have + +- **No consensus/leader election** — no Raft, no distributed locking. Build on top or use a separate tool. +- **No rich queries** — KV store only, not a database +- **No ICE/STUN NAT traversal** — relay through routers only +- **No runtime ACL updates** — access control requires restart + +## Routing in Router Meshes + +When routers form a mesh (multiple paths exist between source and destination), routing works as follows: + +### Spanning tree computation + +Every router computes **N trees, one rooted at each router** in the network, using Bellman-Ford shortest paths on the weighted topology graph. Each tree stores, for every destination, which direct neighbor is the next hop. + +- **Default link weight:** 100 (configurable via `routing.router.linkstate.transport_weights`) +- **Tiebreaking:** Deterministic hash jitter (up to 1%) added to weights so all routers agree on the same tree when weights are equal. This prevents loops. +- **If both sides of a link set weights, the greater value is used.** If only one side sets it, that value is used. + +### Data forwarding + +Messages carry the source router's NodeId. Each forwarding router looks up `trees[source].directions[subscriber]` — "in the tree rooted at the original publisher's router, what's my next hop toward this subscriber?" + +This is **multicast along the source-rooted tree**: one copy per distinct outgoing face that has subscribers behind it. Intermediate routers branch as needed. Multiple subscribers behind the same next-hop collapse into one forwarded copy. + +**No load balancing.** Strictly single-path per (source, destination) pair. No ECMP. + +### Router election (deduplication) + +When peers operate in linkstate mode and connect to multiple routers, they'd receive duplicate messages. `elect_router()` deterministically picks one router per key expression (via hash of key_expr + router ZID) to be responsible for forwarding to/from that peer. Different key expressions can elect different routers, spreading load. + +### Failover + +When a router-to-router link fails: +1. Transport layer detects disconnect +2. Router removes the edge from its graph, broadcasts updated link-state to remaining neighbors +3. **100ms debounce delay** (`TREES_COMPUTATION_DELAY_MS`), then trees are recomputed +4. All cached routes are invalidated; next message triggers fresh route computation + +**During the ~100ms recomputation window, messages routed through the dead link are silently lost.** There is no buffering or automatic reroute during recomputation. The hop-by-hop reliability layer ensures delivery over working links but cannot reroute around a failed one. End-to-end reliability (dedicated per-producer-consumer channels) can recover from this via retransmission after the new route is established. + +### Loop prevention + +Guaranteed loop-free by three mechanisms: +1. Source-rooted trees are DAGs by construction (shortest-path tree from Bellman-Ford) +2. Deterministic edge weight jitter ensures all routers compute identical trees +3. Ingress/egress filters prevent forwarding back to the arrival face + +## Protocol Bridges (Plugins) + +- **zenoh-plugin-mqtt:** MQTT broker that routes to/from Zenoh — existing MQTT devices participate transparently +- **zenoh-plugin-ros2dds:** Bridges DDS-based ROS 2 topics into Zenoh without changing RMW — useful for mixed fleets +- **zenoh-plugin-webserver:** HTTP/WebSocket access to Zenoh key space + +## Advanced Pub/Sub (zenoh-ext) + +Extensions beyond basic pub/sub: + +- **Publication cache:** Publisher retains N recent samples, replies to late-joiner queries automatically +- **Advanced subscriber:** Detects sample miss (gaps in sequence numbers), can recover from publication caches +- **Matching listeners:** Get notified when a subscriber/publisher for your key expression appears/disappears + +## Python Support + +- PyPI: `eclipse-zenoh` +- Latest version: 1.8.0 (March 2026) +- **Development status: Beta** — not marked as production-ready on PyPI +- Rust-backed bindings (not pure Python) — good performance +- Minimum CPython 3.9 (binary wheels use `cp39-abi3` tag) +- Binary wheels for Linux (x86_64, ARM, ARMv6/v7, i686), macOS, Windows +- SHM API available since 1.6.x + +## vs LCM (our current default) + +| | LCM | Zenoh | +|------------------|--------------------|-----------------------------------------------------| +| Transport | UDP multicast only | UDP, TCP, TLS, QUIC, BLE, serial, WebSocket | +| Scope | Local LAN only | Local to internet-scale | +| Reliability | Best-effort only | Configurable (best-effort, reliable, 3 strategies) | +| Discovery | Multicast | Multicast + gossip fallback | +| Request/response | No (pub/sub only) | Built-in queryables | +| Storage | No | Built-in KV with 5 backends | +| QoS | None | 7 priority levels, congestion control, express mode | +| Backpressure | None | Per-concern (resend/buffer/drop) | +| Local IPC | UDP (no zero-copy) | Shared memory (zero-copy above 3 KB threshold) | +| Presence | None | Liveliness tokens (automatic crash detection) | +| Timestamps | None | HLC (hybrid logical clocks) | +| Cross-language | Via LCM codegen | Native bindings (Rust, C, C++, Python) | + +## vs DDS + +| | DDS | Zenoh | +|---------------|-----------------------------|----------------------------| +| Wireless | Poor (multicast flooding) | Designed for it | +| Network scale | LAN (multicast constraints) | Internet-scale via routers | +| Discovery | Complex, unreliable on WiFi | Robust scouting + gossip | +| Modes | Peer-to-peer only | Peer, routed, client | +| Throughput | Good | ~2-4x DDS depending on payload | +| Latency | Good (can beat Zenoh locally) | Comparable | +| Complexity | Heavy config | Simpler API | + +## Integration path for dimos + +Stub exists at `dimos/core/transport.py:323` — `class ZenohTransport(PubSubTransport[T]): ...` + +The pubsub abstraction layer lives in `dimos/protocol/pubsub/`: +- `spec.py` — `PubSub[TopicT, MsgT]` base, `AllPubSub`, `DiscoveryPubSub` +- `encoders.py` — `PickleEncoderMixin`, `LCMEncoderMixin`, `JpegEncoderMixin` +- `impl/` — existing implementations: `lcmpubsub.py`, `shmpubsub.py`, `ddspubsub.py`, `rospubsub.py`, `redispubsub.py`, `memory.py` + +Implementation steps: + +1. Create `dimos/protocol/pubsub/impl/zenohpubsub.py` — implement `PubSub[TopicT, MsgT]` backed by Zenoh sessions +2. Reuse existing encoder mixins (e.g. `PickleEncoderMixin`) for payload serialization +3. Complete `ZenohTransport(PubSubTransport[T])` in `transport.py` +4. Implement RPC via `PubSubRPCMixin` + Zenoh queryables (fits naturally — queryables already support streaming replies) +5. Test: split system — local control loop on robot (peer mode), GPU module offloaded to server (routed via Zenoh router) + +## Robotics adoption + +- **ROS 2:** Official `rmw_zenoh` middleware — supported on Rolling, Kilted, Jazzy, Humble. Zenoh is **Tier 1 middleware** in Kilted Kaiju (May 2025) but **Fast DDS remains the default** +- **PX4:** Zenoh-Pico client runs on the flight controller, communicates via UART/TCP/UDP to a `zenohd` router on companion computer, bridged to ROS 2 topics via `rmw_zenoh`. Alternative to the uXRCE-DDS bridge, not a replacement. Note: ROS 2 Jazzy added message type hashes that break PX4 compat unless `CONFIG_ZENOH_KEY_TYPE_HASH=n` +- **ZettaScale** (creators) are the CycloneDDS team — same people, evolved design