From c4fc46910e82dbfb95804c6a975cf32ca588020b Mon Sep 17 00:00:00 2001 From: Pratham Gala <90299416+pmg5408@users.noreply.github.com> Date: Fri, 3 Apr 2026 16:11:16 +0000 Subject: [PATCH 1/8] feat(safety): add RGB collision guardrail scaffold Adds a guardrail module that intercepts Twist commands and evaluates risk against RGB camera frames. Establishes the module lifecycle, runtime state, background decision thread with condition-based synchronization, and stream interfaces. Risk evaluation logic is stubbed for a follow-up commit. --- dimos/control/safety/guardrail_policy.py | 33 ++++ .../control/safety/rgb_collision_guardrail.py | 147 ++++++++++++++++++ 2 files changed, 180 insertions(+) create mode 100644 dimos/control/safety/guardrail_policy.py create mode 100644 dimos/control/safety/rgb_collision_guardrail.py diff --git a/dimos/control/safety/guardrail_policy.py b/dimos/control/safety/guardrail_policy.py new file mode 100644 index 0000000000..b53561aa71 --- /dev/null +++ b/dimos/control/safety/guardrail_policy.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import Protocol + +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.sensor_msgs.Image import Image + + +class GuardrailState(str, Enum): + INIT = "init" + PASS = "pass" + CLAMP = "clamp" + STOP_LATCHED = "stop_latched" + SENSOR_DEGRADED = "sensor_degraded" + + +@dataclass +class GuardrailDecision: + state: GuardrailState + cmd_vel: Twist + reason: str + + +class GuardrailPolicy(Protocol): + def evaluate( + self, + previous_image: Image, + current_image: Image, + incoming_cmd_vel: Twist, + ) -> GuardrailDecision: + """Evaluate the latest frame pair and command.""" diff --git a/dimos/control/safety/rgb_collision_guardrail.py b/dimos/control/safety/rgb_collision_guardrail.py new file mode 100644 index 0000000000..923021d0d8 --- /dev/null +++ b/dimos/control/safety/rgb_collision_guardrail.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +from dataclasses import dataclass +import time +from threading import Condition, Event, Lock, Thread +from typing import Any + +from reactivex.disposable import Disposable + +from dimos.core.core import rpc +from dimos.core.module import Module, ModuleConfig +from dimos.core.stream import In, Out +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.sensor_msgs.Image import Image +from dimos.utils.logging_config import setup_logger + +from dimos.control.safety.guardrail_policy import GuardrailDecision, GuardrailState + +_DEFAULT_FALLBACK_PERIOD_S = 0.1 +_THREAD_JOIN_TIMEOUT_S = 2.0 + +logger = setup_logger() + + +class RGBCollisionGuardrailConfig(ModuleConfig): + # TODO: Step 2 + guarded_output_publish_hz: float = 10.0 + risk_evaluation_hz: float = 10.0 + command_timeout_s: float = 0.25 + image_timeout_s: float = 0.25 + risk_timeout_s: float = 0.25 + fail_closed_on_missing_image: bool = True + publish_zero_on_stop: bool = True + + +@dataclass +class _GuardrailRuntimeState: + # TODO: Step 2 + latest_image: Image | None = None + previous_image: Image | None = None + latest_image_time: float | None = None + previous_image_time: float | None = None + latest_cmd_vel: Twist | None = None + latest_cmd_time: float | None = None + last_decision: GuardrailDecision | None = None + last_risk_time: float | None = None + state: GuardrailState = GuardrailState.INIT + + +class RGBCollisionGuardrail(Module[RGBCollisionGuardrailConfig]): + """RGB-only motion guardrail for direct Twist control.""" + + default_config = RGBCollisionGuardrailConfig + + color_image: In[Image] + incoming_cmd_vel: In[Twist] + safe_cmd_vel: Out[Twist] + + _condition: Condition + _runtime_lock: Lock + _runtime_state: _GuardrailRuntimeState + _stop_event: Event + _thread: Thread | None + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._runtime_lock = Lock() + self._condition = Condition(self._runtime_lock) + self._runtime_state = _GuardrailRuntimeState() + self._stop_event = Event() + self._thread = None + + @rpc + def start(self) -> None: + super().start() + self._stop_event.clear() + self._disposables.add(Disposable(self.color_image.subscribe(self._on_color_image))) + self._disposables.add(Disposable(self.incoming_cmd_vel.subscribe(self._on_incoming_cmd_vel))) + + self._thread = Thread( + target=self._decision_loop, + name=f"{self.__class__.__name__}-thread", + daemon=True, + ) + self._thread.start() + + @rpc + def stop(self) -> None: + self._stop_event.set() + with self._condition: + self._condition.notify_all() + + if self.config.publish_zero_on_stop: + self.safe_cmd_vel.publish(Twist.zero()) + + if self._thread is not None: + self._thread.join(timeout=_THREAD_JOIN_TIMEOUT_S) + self._thread = None + + super().stop() + + def _on_color_image(self, image: Image) -> None: + now = time.monotonic() + with self._condition: + self._runtime_state.previous_image = self._runtime_state.latest_image + self._runtime_state.previous_image_time = self._runtime_state.latest_image_time + self._runtime_state.latest_image = image + self._runtime_state.latest_image_time = now + self._condition.notify() + + def _on_incoming_cmd_vel(self, cmd_vel: Twist) -> None: + now = time.monotonic() + with self._condition: + self._runtime_state.latest_cmd_vel = cmd_vel + self._runtime_state.latest_cmd_time = now + self._condition.notify() + + def _decision_loop(self) -> None: + while not self._stop_event.is_set(): + with self._condition: + timeout_s = self._next_wakeup_timeout_locked() + self._condition.wait(timeout=timeout_s) + + if self._stop_event.is_set(): + return + + # Step 2 will add the timed wakeup and per-command evaluation. + continue + + def _guarded_output_publish_period_s(self) -> float: + if self.config.guarded_output_publish_hz <= 0: + return _DEFAULT_FALLBACK_PERIOD_S + return 1.0 / self.config.guarded_output_publish_hz + + def _risk_evaluation_period_s(self) -> float: + if self.config.risk_evaluation_hz <= 0: + return _DEFAULT_FALLBACK_PERIOD_S + return 1.0 / self.config.risk_evaluation_hz + + def _next_wakeup_timeout_locked(self) -> float: + return min( + self._guarded_output_publish_period_s(), + self._risk_evaluation_period_s(), + ) + + +rgb_collision_guardrail = RGBCollisionGuardrail.blueprint From 9e2c25655296705e4ec79868b9028d57c1a9095b Mon Sep 17 00:00:00 2001 From: Pratham Gala <90299416+pmg5408@users.noreply.github.com> Date: Fri, 3 Apr 2026 17:06:42 +0000 Subject: [PATCH 2/8] feat(control): add guardrail worker loop and policy evaluation flow Adds the single-worker guardrail loop, timed wakeup behavior, and shared runtime state for applying the latest safety assessment to incoming Twist commands. Separates module orchestration from policy logic and fails closed if policy evaluation errors out. --- dimos/control/safety/guardrail_policy.py | 28 +++ .../control/safety/rgb_collision_guardrail.py | 225 ++++++++++++++++-- 2 files changed, 229 insertions(+), 24 deletions(-) diff --git a/dimos/control/safety/guardrail_policy.py b/dimos/control/safety/guardrail_policy.py index b53561aa71..8243b84a9f 100644 --- a/dimos/control/safety/guardrail_policy.py +++ b/dimos/control/safety/guardrail_policy.py @@ -15,12 +15,22 @@ class GuardrailState(str, Enum): STOP_LATCHED = "stop_latched" SENSOR_DEGRADED = "sensor_degraded" +@dataclass(frozen=True) +class GuardrailHealth: + has_previous_frame: bool + image_fresh: bool + cmd_fresh: bool + risk_fresh: bool + low_texture: bool = False + occluded: bool = False @dataclass class GuardrailDecision: state: GuardrailState cmd_vel: Twist reason: str + risk_score: float = 0.0 + publish_immediately: bool = False class GuardrailPolicy(Protocol): @@ -29,5 +39,23 @@ def evaluate( previous_image: Image, current_image: Image, incoming_cmd_vel: Twist, + health: GuardrailHealth, ) -> GuardrailDecision: """Evaluate the latest frame pair and command.""" + + +class PassThroughGuardrailPolicy: + def evaluate( + self, + previous_image: Image, + current_image: Image, + incoming_cmd_vel: Twist, + health: GuardrailHealth, + ) -> GuardrailDecision: + return GuardrailDecision( + state=GuardrailState.PASS, + cmd_vel=incoming_cmd_vel, + reason="pass_through", + risk_score=0.0, + publish_immediately=False, + ) \ No newline at end of file diff --git a/dimos/control/safety/rgb_collision_guardrail.py b/dimos/control/safety/rgb_collision_guardrail.py index 923021d0d8..50e467015a 100644 --- a/dimos/control/safety/rgb_collision_guardrail.py +++ b/dimos/control/safety/rgb_collision_guardrail.py @@ -2,8 +2,9 @@ from dataclasses import dataclass import time -from threading import Condition, Event, Lock, Thread +from threading import Condition, Event, Thread from typing import Any +from pydantic import Field from reactivex.disposable import Disposable @@ -14,28 +15,32 @@ from dimos.msgs.sensor_msgs.Image import Image from dimos.utils.logging_config import setup_logger -from dimos.control.safety.guardrail_policy import GuardrailDecision, GuardrailState +from dimos.control.safety.guardrail_policy import ( + GuardrailDecision, + GuardrailHealth, + GuardrailPolicy, + GuardrailState, + PassThroughGuardrailPolicy, +) + -_DEFAULT_FALLBACK_PERIOD_S = 0.1 _THREAD_JOIN_TIMEOUT_S = 2.0 -logger = setup_logger() +logger = setup_logger() class RGBCollisionGuardrailConfig(ModuleConfig): - # TODO: Step 2 - guarded_output_publish_hz: float = 10.0 - risk_evaluation_hz: float = 10.0 - command_timeout_s: float = 0.25 - image_timeout_s: float = 0.25 - risk_timeout_s: float = 0.25 + guarded_output_publish_hz: float = Field(default=10.0, gt=0.0) + risk_evaluation_hz: float = Field(default=10.0, gt=0.0) + command_timeout_s: float = Field(default=0.25, gt=0.0) + image_timeout_s: float = Field(default=0.25, gt=0.0) + risk_timeout_s: float = Field(default=0.25, gt=0.0) fail_closed_on_missing_image: bool = True publish_zero_on_stop: bool = True @dataclass class _GuardrailRuntimeState: - # TODO: Step 2 latest_image: Image | None = None previous_image: Image | None = None latest_image_time: float | None = None @@ -44,9 +49,20 @@ class _GuardrailRuntimeState: latest_cmd_time: float | None = None last_decision: GuardrailDecision | None = None last_risk_time: float | None = None + last_publish_time: float | None = None + next_risk_time: float | None = None + pending_cmd_update: bool = False state: GuardrailState = GuardrailState.INIT +@dataclass(frozen=True) +class _RiskEvaluationInput: + previous_image: Image + current_image: Image + incoming_cmd_vel: Twist + health: GuardrailHealth + + class RGBCollisionGuardrail(Module[RGBCollisionGuardrailConfig]): """RGB-only motion guardrail for direct Twist control.""" @@ -57,23 +73,28 @@ class RGBCollisionGuardrail(Module[RGBCollisionGuardrailConfig]): safe_cmd_vel: Out[Twist] _condition: Condition - _runtime_lock: Lock _runtime_state: _GuardrailRuntimeState _stop_event: Event _thread: Thread | None + _policy: GuardrailPolicy def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) - self._runtime_lock = Lock() - self._condition = Condition(self._runtime_lock) + self._condition = Condition() self._runtime_state = _GuardrailRuntimeState() self._stop_event = Event() self._thread = None + # TODO: Replace placeholder policy with RGB optical-flow guardrail logic. + self._policy = PassThroughGuardrailPolicy() @rpc def start(self) -> None: super().start() self._stop_event.clear() + + with self._condition: + self._runtime_state.next_risk_time = time.monotonic() + self._disposables.add(Disposable(self.color_image.subscribe(self._on_color_image))) self._disposables.add(Disposable(self.incoming_cmd_vel.subscribe(self._on_incoming_cmd_vel))) @@ -113,10 +134,13 @@ def _on_incoming_cmd_vel(self, cmd_vel: Twist) -> None: with self._condition: self._runtime_state.latest_cmd_vel = cmd_vel self._runtime_state.latest_cmd_time = now + self._runtime_state.pending_cmd_update = True self._condition.notify() def _decision_loop(self) -> None: while not self._stop_event.is_set(): + risk_input: _RiskEvaluationInput | None = None + with self._condition: timeout_s = self._next_wakeup_timeout_locked() self._condition.wait(timeout=timeout_s) @@ -124,24 +148,177 @@ def _decision_loop(self) -> None: if self._stop_event.is_set(): return - # Step 2 will add the timed wakeup and per-command evaluation. - continue + now = time.monotonic() + if self._should_recompute_risk_locked(now): + risk_input = self._take_risk_evaluation_input_locked(now) + + decision: GuardrailDecision | None = None + if risk_input is not None: + try: + decision = self._policy.evaluate( + previous_image=risk_input.previous_image, + current_image=risk_input.current_image, + incoming_cmd_vel=risk_input.incoming_cmd_vel, + health=risk_input.health, + ) + except Exception: + logger.exception("RGB guardrail policy evaluation failed") + decision = GuardrailDecision( + state=GuardrailState.SENSOR_DEGRADED, + cmd_vel=Twist.zero(), + reason="policy_evaluation_failed", + risk_score=1.0, + publish_immediately=True, + ) + + + cmd_vel_to_publish: Twist | None = None + publish_time: float | None = None + with self._condition: + now = time.monotonic() + + if decision is not None: + self._store_decision_locked(decision, now) + + cmd_vel_to_publish = self._consume_publish_cmd_locked(now) + if cmd_vel_to_publish is not None: + publish_time = now + + if cmd_vel_to_publish is not None and publish_time is not None: + self.safe_cmd_vel.publish(cmd_vel_to_publish) + with self._condition: + self._runtime_state.last_publish_time = publish_time + def _guarded_output_publish_period_s(self) -> float: - if self.config.guarded_output_publish_hz <= 0: - return _DEFAULT_FALLBACK_PERIOD_S return 1.0 / self.config.guarded_output_publish_hz def _risk_evaluation_period_s(self) -> float: - if self.config.risk_evaluation_hz <= 0: - return _DEFAULT_FALLBACK_PERIOD_S return 1.0 / self.config.risk_evaluation_hz - def _next_wakeup_timeout_locked(self) -> float: - return min( - self._guarded_output_publish_period_s(), - self._risk_evaluation_period_s(), + def _should_recompute_risk_locked(self, now: float) -> bool: + next_risk_time = self._runtime_state.next_risk_time + if next_risk_time is None: + return True + return now >= next_risk_time + + def _is_cmd_fresh_locked(self, now: float) -> bool: + latest_cmd_time = self._runtime_state.latest_cmd_time + if latest_cmd_time is None: + return False + return (now - latest_cmd_time) <= self.config.command_timeout_s + + def _is_image_fresh_locked(self, now: float) -> bool: + latest_image_time = self._runtime_state.latest_image_time + if latest_image_time is None: + return False + return (now - latest_image_time) <= self.config.image_timeout_s + + def _is_risk_fresh_locked(self, now: float) -> bool: + last_risk_time = self._runtime_state.last_risk_time + if last_risk_time is None: + return False + return (now - last_risk_time) <= self.config.risk_timeout_s + + def _build_health_locked(self, now: float) -> GuardrailHealth: + return GuardrailHealth( + has_previous_frame=self._runtime_state.previous_image is not None, + image_fresh=self._is_image_fresh_locked(now), + cmd_fresh=self._is_cmd_fresh_locked(now), + risk_fresh=self._is_risk_fresh_locked(now), + low_texture=False, + occluded=False, ) + def _resolved_cmd_for_latest_locked(self) -> Twist | None: + latest_cmd_vel = self._runtime_state.latest_cmd_vel + if latest_cmd_vel is None: + return Twist.zero() + + last_decision = self._runtime_state.last_decision + if last_decision is None: + return latest_cmd_vel + + if last_decision.state == GuardrailState.PASS: + return latest_cmd_vel + + return last_decision.cmd_vel + + def _take_risk_evaluation_input_locked(self, now: float) -> _RiskEvaluationInput | None: + previous_image = self._runtime_state.previous_image + current_image = self._runtime_state.latest_image + incoming_cmd_vel = self._runtime_state.latest_cmd_vel + + self._runtime_state.next_risk_time = now + self._risk_evaluation_period_s() + + if previous_image is None or current_image is None or incoming_cmd_vel is None: + return None + + return _RiskEvaluationInput( + previous_image=previous_image, + current_image=current_image, + incoming_cmd_vel=incoming_cmd_vel, + health=self._build_health_locked(now), + ) + + def _store_decision_locked(self, decision: GuardrailDecision, now: float) -> None: + self._runtime_state.last_decision = decision + self._runtime_state.last_risk_time = now + self._runtime_state.state = decision.state + + def _consume_publish_cmd_locked(self, now: float) -> Twist | None: + cmd_vel_to_publish: Twist | None = None + + if self._runtime_state.pending_cmd_update: + cmd_vel_to_publish = self._resolved_cmd_for_latest_locked() + self._runtime_state.pending_cmd_update = False + elif self._should_republish_non_pass_output_locked(now): + last_decision = self._runtime_state.last_decision + if last_decision is not None: + cmd_vel_to_publish = last_decision.cmd_vel + + return cmd_vel_to_publish + + def _should_republish_non_pass_output_locked(self, now: float) -> bool: + last_decision = self._runtime_state.last_decision + if last_decision is None: + return False + + if last_decision.state == GuardrailState.PASS: + return False + + last_publish_time = self._runtime_state.last_publish_time + if last_publish_time is None: + return True + + return (now - last_publish_time) >= self._guarded_output_publish_period_s() + + + def _next_wakeup_timeout_locked(self) -> float: + now = time.monotonic() + timeouts: list[float] = [] + + next_risk_time = self._runtime_state.next_risk_time + if next_risk_time is not None: + timeouts.append(max(next_risk_time - now, 0.0)) + else: + timeouts.append(self._risk_evaluation_period_s()) + + if self._should_republish_non_pass_output_locked(now): + timeouts.append(0.0) + else: + last_decision = self._runtime_state.last_decision + last_publish_time = self._runtime_state.last_publish_time + if ( + last_decision is not None + and last_decision.state != GuardrailState.PASS + and last_publish_time is not None + ): + next_publish_time = last_publish_time + self._guarded_output_publish_period_s() + timeouts.append(max(next_publish_time - now, 0.0)) + + return min(timeouts, default=self._risk_evaluation_period_s()) + + rgb_collision_guardrail = RGBCollisionGuardrail.blueprint From 6667c7dcca91ba3bcd130e85fde1647543b0c349 Mon Sep 17 00:00:00 2001 From: Pratham Gala <90299416+pmg5408@users.noreply.github.com> Date: Fri, 3 Apr 2026 17:42:11 +0000 Subject: [PATCH 3/8] feat(control): add guardrail init and degraded-state handling Adds fail-closed startup, stale input handling, and explicit INIT and SENSOR_DEGRADED decisions to the RGB guardrail loop. The worker now publishes state changes immediately while keeping policy evaluation outside the callback lock for responsiveness. --- dimos/control/safety/guardrail_policy.py | 24 +- .../control/safety/rgb_collision_guardrail.py | 237 ++++++++++++++---- 2 files changed, 211 insertions(+), 50 deletions(-) diff --git a/dimos/control/safety/guardrail_policy.py b/dimos/control/safety/guardrail_policy.py index 8243b84a9f..4af501b32e 100644 --- a/dimos/control/safety/guardrail_policy.py +++ b/dimos/control/safety/guardrail_policy.py @@ -1,3 +1,17 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import annotations from dataclasses import dataclass @@ -15,6 +29,7 @@ class GuardrailState(str, Enum): STOP_LATCHED = "stop_latched" SENSOR_DEGRADED = "sensor_degraded" + @dataclass(frozen=True) class GuardrailHealth: has_previous_frame: bool @@ -24,8 +39,15 @@ class GuardrailHealth: low_texture: bool = False occluded: bool = False + @dataclass class GuardrailDecision: + """Policy result consumed by the guardrail worker. + + publish_immediately requests an immediate worker-side publish on the next + loop iteration. It does not bypass command freshness checks. + """ + state: GuardrailState cmd_vel: Twist reason: str @@ -58,4 +80,4 @@ def evaluate( reason="pass_through", risk_score=0.0, publish_immediately=False, - ) \ No newline at end of file + ) diff --git a/dimos/control/safety/rgb_collision_guardrail.py b/dimos/control/safety/rgb_collision_guardrail.py index 50e467015a..12ef1b98ba 100644 --- a/dimos/control/safety/rgb_collision_guardrail.py +++ b/dimos/control/safety/rgb_collision_guardrail.py @@ -1,34 +1,47 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import annotations from dataclasses import dataclass -import time from threading import Condition, Event, Thread +import time from typing import Any -from pydantic import Field +from pydantic import Field from reactivex.disposable import Disposable -from dimos.core.core import rpc -from dimos.core.module import Module, ModuleConfig -from dimos.core.stream import In, Out -from dimos.msgs.geometry_msgs.Twist import Twist -from dimos.msgs.sensor_msgs.Image import Image -from dimos.utils.logging_config import setup_logger - from dimos.control.safety.guardrail_policy import ( - GuardrailDecision, + GuardrailDecision, GuardrailHealth, GuardrailPolicy, GuardrailState, PassThroughGuardrailPolicy, ) - +from dimos.core.core import rpc +from dimos.core.module import Module, ModuleConfig +from dimos.core.stream import In, Out +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.sensor_msgs.Image import Image +from dimos.utils.logging_config import setup_logger _THREAD_JOIN_TIMEOUT_S = 2.0 logger = setup_logger() + class RGBCollisionGuardrailConfig(ModuleConfig): guarded_output_publish_hz: float = Field(default=10.0, gt=0.0) risk_evaluation_hz: float = Field(default=10.0, gt=0.0) @@ -52,6 +65,7 @@ class _GuardrailRuntimeState: last_publish_time: float | None = None next_risk_time: float | None = None pending_cmd_update: bool = False + pending_decision_publish: bool = False state: GuardrailState = GuardrailState.INIT @@ -96,7 +110,9 @@ def start(self) -> None: self._runtime_state.next_risk_time = time.monotonic() self._disposables.add(Disposable(self.color_image.subscribe(self._on_color_image))) - self._disposables.add(Disposable(self.incoming_cmd_vel.subscribe(self._on_incoming_cmd_vel))) + self._disposables.add( + Disposable(self.incoming_cmd_vel.subscribe(self._on_incoming_cmd_vel)) + ) self._thread = Thread( target=self._decision_loop, @@ -152,33 +168,44 @@ def _decision_loop(self) -> None: if self._should_recompute_risk_locked(now): risk_input = self._take_risk_evaluation_input_locked(now) - decision: GuardrailDecision | None = None - if risk_input is not None: - try: - decision = self._policy.evaluate( - previous_image=risk_input.previous_image, - current_image=risk_input.current_image, - incoming_cmd_vel=risk_input.incoming_cmd_vel, - health=risk_input.health, - ) - except Exception: - logger.exception("RGB guardrail policy evaluation failed") - decision = GuardrailDecision( - state=GuardrailState.SENSOR_DEGRADED, - cmd_vel=Twist.zero(), - reason="policy_evaluation_failed", - risk_score=1.0, - publish_immediately=True, - ) - + # Evaluate a consistent snapshot outside the condition lock so image + # and command callbacks stay cheap. If newer inputs arrive while + # evaluation runs, the next wakeup will process that fresher snapshot. + policy_decision: GuardrailDecision | None = None + if risk_input is not None: + try: + policy_decision = self._policy.evaluate( + previous_image=risk_input.previous_image, + current_image=risk_input.current_image, + incoming_cmd_vel=risk_input.incoming_cmd_vel, + health=risk_input.health, + ) + except Exception: + logger.exception("RGB guardrail policy evaluation failed") + policy_decision = self._build_sensor_degraded_decision( + "policy_evaluation_failed" + ) cmd_vel_to_publish: Twist | None = None publish_time: float | None = None + with self._condition: now = time.monotonic() - if decision is not None: - self._store_decision_locked(decision, now) + if policy_decision is not None: + self._store_decision_locked( + policy_decision, + now, + update_risk_time=True, + ) + else: + fallback_decision = self._select_fallback_decision_locked(now) + if fallback_decision is not None: + self._store_decision_locked( + fallback_decision, + now, + update_risk_time=False, + ) cmd_vel_to_publish = self._consume_publish_cmd_locked(now) if cmd_vel_to_publish is not None: @@ -189,7 +216,6 @@ def _decision_loop(self) -> None: with self._condition: self._runtime_state.last_publish_time = publish_time - def _guarded_output_publish_period_s(self) -> float: return 1.0 / self.config.guarded_output_publish_hz @@ -197,6 +223,7 @@ def _risk_evaluation_period_s(self) -> float: return 1.0 / self.config.risk_evaluation_hz def _should_recompute_risk_locked(self, now: float) -> bool: + """Return True when the next scheduled risk evaluation is due.""" next_risk_time = self._runtime_state.next_risk_time if next_risk_time is None: return True @@ -221,6 +248,8 @@ def _is_risk_fresh_locked(self, now: float) -> bool: return (now - last_risk_time) <= self.config.risk_timeout_s def _build_health_locked(self, now: float) -> GuardrailHealth: + """Build a health snapshot from the current cached inputs.""" + # TODO: Populate these from image-quality checks. return GuardrailHealth( has_previous_frame=self._runtime_state.previous_image is not None, image_fresh=self._is_image_fresh_locked(now), @@ -230,15 +259,18 @@ def _build_health_locked(self, now: float) -> GuardrailHealth: occluded=False, ) - - def _resolved_cmd_for_latest_locked(self) -> Twist | None: + def _resolved_cmd_for_latest_locked(self, now: float) -> Twist: + """Resolve the command to publish for the latest cached upstream input.""" latest_cmd_vel = self._runtime_state.latest_cmd_vel if latest_cmd_vel is None: return Twist.zero() + if not self._is_cmd_fresh_locked(now): + return Twist.zero() + last_decision = self._runtime_state.last_decision if last_decision is None: - return latest_cmd_vel + return Twist.zero() if last_decision.state == GuardrailState.PASS: return latest_cmd_vel @@ -246,6 +278,7 @@ def _resolved_cmd_for_latest_locked(self) -> Twist | None: return last_decision.cmd_vel def _take_risk_evaluation_input_locked(self, now: float) -> _RiskEvaluationInput | None: + """Capture a consistent snapshot for policy evaluation and advance the risk deadline.""" previous_image = self._runtime_state.previous_image current_image = self._runtime_state.latest_image incoming_cmd_vel = self._runtime_state.latest_cmd_vel @@ -262,25 +295,61 @@ def _take_risk_evaluation_input_locked(self, now: float) -> _RiskEvaluationInput health=self._build_health_locked(now), ) - def _store_decision_locked(self, decision: GuardrailDecision, now: float) -> None: + def _store_decision_locked( + self, + decision: GuardrailDecision, + now: float, + *, + update_risk_time: bool, + ) -> None: + had_previous_decision = self._runtime_state.last_decision is not None + previous_state = self._runtime_state.state + self._runtime_state.last_decision = decision - self._runtime_state.last_risk_time = now + if update_risk_time: + self._runtime_state.last_risk_time = now self._runtime_state.state = decision.state + # Request an immediate publish when the policy says so or when the + # high-level state changes. Freshness checks still apply later. + if ( + decision.publish_immediately + or not had_previous_decision + or previous_state != decision.state + ): + self._runtime_state.pending_decision_publish = True + + if previous_state != decision.state: + logger.info( + "RGB guardrail state changed", + previous_state=previous_state.value, + state=decision.state.value, + reason=decision.reason, + ) + def _consume_publish_cmd_locked(self, now: float) -> Twist | None: - cmd_vel_to_publish: Twist | None = None + """Consume and return the next command that should be published, if any.""" + if self._runtime_state.pending_decision_publish: + self._runtime_state.pending_decision_publish = False + return self._resolved_cmd_for_latest_locked(now) if self._runtime_state.pending_cmd_update: - cmd_vel_to_publish = self._resolved_cmd_for_latest_locked() self._runtime_state.pending_cmd_update = False - elif self._should_republish_non_pass_output_locked(now): + return self._resolved_cmd_for_latest_locked(now) + + latest_cmd_time = self._runtime_state.latest_cmd_time + if latest_cmd_time is not None and not self._is_cmd_fresh_locked(now): + return Twist.zero() + + if self._should_republish_non_pass_output_locked(now): last_decision = self._runtime_state.last_decision if last_decision is not None: - cmd_vel_to_publish = last_decision.cmd_vel + return last_decision.cmd_vel - return cmd_vel_to_publish + return None def _should_republish_non_pass_output_locked(self, now: float) -> bool: + """Return True when a non-pass output should be republished on heartbeat.""" last_decision = self._runtime_state.last_decision if last_decision is None: return False @@ -294,16 +363,30 @@ def _should_republish_non_pass_output_locked(self, now: float) -> bool: return (now - last_publish_time) >= self._guarded_output_publish_period_s() - def _next_wakeup_timeout_locked(self) -> float: + """Compute the next worker wakeup timeout from pending work and deadlines.""" now = time.monotonic() - timeouts: list[float] = [] + + if self._runtime_state.pending_cmd_update or self._runtime_state.pending_decision_publish: + return 0.0 + + timeouts: list[float] = [self._risk_evaluation_period_s()] next_risk_time = self._runtime_state.next_risk_time if next_risk_time is not None: timeouts.append(max(next_risk_time - now, 0.0)) - else: - timeouts.append(self._risk_evaluation_period_s()) + + latest_cmd_time = self._runtime_state.latest_cmd_time + if latest_cmd_time is not None and self._is_cmd_fresh_locked(now): + timeouts.append(max((latest_cmd_time + self.config.command_timeout_s) - now, 0.0)) + + latest_image_time = self._runtime_state.latest_image_time + if latest_image_time is not None and self._is_image_fresh_locked(now): + timeouts.append(max((latest_image_time + self.config.image_timeout_s) - now, 0.0)) + + last_risk_time = self._runtime_state.last_risk_time + if last_risk_time is not None and self._is_risk_fresh_locked(now): + timeouts.append(max((last_risk_time + self.config.risk_timeout_s) - now, 0.0)) if self._should_republish_non_pass_output_locked(now): timeouts.append(0.0) @@ -318,7 +401,63 @@ def _next_wakeup_timeout_locked(self) -> float: next_publish_time = last_publish_time + self._guarded_output_publish_period_s() timeouts.append(max(next_publish_time - now, 0.0)) - return min(timeouts, default=self._risk_evaluation_period_s()) + return min(timeouts) + + def _build_zero_decision( + self, + state: GuardrailState, + reason: str, + *, + publish_immediately: bool = False, + risk_score: float = 0.0, + ) -> GuardrailDecision: + return GuardrailDecision( + state=state, + cmd_vel=Twist.zero(), + reason=reason, + risk_score=risk_score, + publish_immediately=publish_immediately, + ) + + def _build_init_decision(self, reason: str) -> GuardrailDecision: + return self._build_zero_decision( + GuardrailState.INIT, + reason, + publish_immediately=True, + ) + + def _build_sensor_degraded_decision(self, reason: str) -> GuardrailDecision: + return self._build_zero_decision( + GuardrailState.SENSOR_DEGRADED, + reason, + publish_immediately=True, + risk_score=1.0, + ) + + def _select_fallback_decision_locked(self, now: float) -> GuardrailDecision | None: + if self._runtime_state.latest_cmd_vel is None: + return self._build_init_decision("no_command_received") + + if self._runtime_state.latest_image is None: + if self.config.fail_closed_on_missing_image: + return self._build_init_decision("waiting_for_first_image") + return None + + if self._runtime_state.previous_image is None: + if self.config.fail_closed_on_missing_image: + return self._build_init_decision("waiting_for_frame_pair") + return None + + if not self._is_image_fresh_locked(now): + return self._build_sensor_degraded_decision("image_stale") + + if self._runtime_state.last_risk_time is None: + return self._build_init_decision("waiting_for_first_risk_evaluation") + + if not self._is_risk_fresh_locked(now): + return self._build_sensor_degraded_decision("risk_state_stale") + + return None rgb_collision_guardrail = RGBCollisionGuardrail.blueprint From ed98e1041938ad1c514424ed001dc0170d325492 Mon Sep 17 00:00:00 2001 From: Pratham Gala <90299416+pmg5408@users.noreply.github.com> Date: Fri, 3 Apr 2026 19:55:55 +0000 Subject: [PATCH 4/8] feat(control): add optical-flow RGB guardrail policy for forward motion Implements a reusable RGB-only guardrail policy using optical-flow magnitude in a forward ROI with conservative clamp and stop behavior. Adds image-quality checks, hysteresis, and policy-side configuration while keeping scheduling and publishing logic in the module shell. --- dimos/control/safety/guardrail_policy.py | 336 +++++++++++++++++- .../control/safety/rgb_collision_guardrail.py | 79 +++- 2 files changed, 397 insertions(+), 18 deletions(-) diff --git a/dimos/control/safety/guardrail_policy.py b/dimos/control/safety/guardrail_policy.py index 4af501b32e..fa6f5aa129 100644 --- a/dimos/control/safety/guardrail_policy.py +++ b/dimos/control/safety/guardrail_policy.py @@ -16,11 +16,17 @@ from dataclasses import dataclass from enum import Enum -from typing import Protocol +from typing import Any, Protocol, cast + +import cv2 +import numpy as np +from numpy.typing import NDArray from dimos.msgs.geometry_msgs.Twist import Twist from dimos.msgs.sensor_msgs.Image import Image +GrayImage = NDArray[np.uint8] + class GuardrailState(str, Enum): INIT = "init" @@ -36,8 +42,6 @@ class GuardrailHealth: image_fresh: bool cmd_fresh: bool risk_fresh: bool - low_texture: bool = False - occluded: bool = False @dataclass @@ -55,6 +59,25 @@ class GuardrailDecision: publish_immediately: bool = False +@dataclass(frozen=True) +class OpticalFlowMagnitudePolicyConfig: + forward_motion_deadband_mps: float + clamp_forward_speed_mps: float + flow_downsample_width_px: int + forward_roi_top_fraction: float + forward_roi_bottom_fraction: float + forward_roi_width_fraction: float + low_texture_variance_threshold: float + occlusion_dark_pixel_threshold: int + occlusion_bright_pixel_threshold: int + occlusion_extreme_fraction_threshold: float + caution_flow_magnitude_threshold: float + stop_flow_magnitude_threshold: float + caution_frame_count: int + stop_frame_count: int + clear_frame_count: int + + class GuardrailPolicy(Protocol): def evaluate( self, @@ -62,11 +85,29 @@ def evaluate( current_image: Image, incoming_cmd_vel: Twist, health: GuardrailHealth, - ) -> GuardrailDecision: - """Evaluate the latest frame pair and command.""" + ) -> GuardrailDecision: ... + + +class OpticalFlowMagnitudeGuardrailPolicy(GuardrailPolicy): + """Forward-motion RGB guardrail using flow magnitude in a central lower ROI.""" + # V1 keeps Farneback internals fixed to reduce tuning surface. Promote + # these to config only after hardware tuning shows they need adjustment. + _FARNEBACK_PYR_SCALE = 0.5 + _FARNEBACK_LEVELS = 3 + _FARNEBACK_WINDOW_SIZE = 15 + _FARNEBACK_ITERATIONS = 3 + _FARNEBACK_POLY_N = 5 + _FARNEBACK_POLY_SIGMA = 1.2 + _FARNEBACK_FLAGS = 0 + + def __init__(self, config: OpticalFlowMagnitudePolicyConfig) -> None: + self._config = config + self._hysteresis_state = GuardrailState.PASS + self._caution_hits = 0 + self._stop_hits = 0 + self._clear_hits = 0 -class PassThroughGuardrailPolicy: def evaluate( self, previous_image: Image, @@ -74,10 +115,285 @@ def evaluate( incoming_cmd_vel: Twist, health: GuardrailHealth, ) -> GuardrailDecision: + if not health.has_previous_frame: + self._reset_hysteresis() + return self._zero_decision( + GuardrailState.INIT, + "missing_previous_frame", + risk_score=0.0, + ) + + if not health.image_fresh: + self._reset_hysteresis() + return self._zero_decision( + GuardrailState.SENSOR_DEGRADED, + "image_not_fresh", + risk_score=1.0, + publish_immediately=True, + ) + + forward_speed = float(incoming_cmd_vel.linear.x) + if forward_speed <= self._config.forward_motion_deadband_mps: + self._reset_hysteresis() + return self._pass_decision(incoming_cmd_vel, "forward_guard_inactive", 0.0) + + previous_gray, current_gray = self._prepare_gray_pair(previous_image, current_image) + previous_roi, current_roi = self._extract_forward_rois(previous_gray, current_gray) + + if previous_roi.size == 0 or current_roi.size == 0: + self._reset_hysteresis() + return self._zero_decision( + GuardrailState.SENSOR_DEGRADED, + "invalid_forward_roi", + risk_score=1.0, + publish_immediately=True, + ) + + if self._is_occluded(current_roi): + self._reset_hysteresis() + return self._zero_decision( + GuardrailState.SENSOR_DEGRADED, + "forward_roi_occluded", + risk_score=1.0, + publish_immediately=True, + ) + + if self._is_low_texture(current_roi): + self._reset_hysteresis() + return self._zero_decision( + GuardrailState.SENSOR_DEGRADED, + "forward_roi_low_texture", + risk_score=1.0, + publish_immediately=True, + ) + + mean_flow_magnitude = self._mean_flow_magnitude(previous_roi, current_roi) + next_state = self._next_state(mean_flow_magnitude) + self._active_state = next_state + + if next_state == GuardrailState.STOP_LATCHED: + return self._stop_forward_decision( + incoming_cmd_vel, + "forward_flow_stop", + mean_flow_magnitude, + ) + + if next_state == GuardrailState.CLAMP: + reason = ( + "forward_flow_clamp" + if mean_flow_magnitude >= self._config.caution_flow_magnitude_threshold + else "forward_flow_recovery" + ) + return self._clamp_forward_decision( + incoming_cmd_vel, + reason, + mean_flow_magnitude, + ) + + return self._pass_decision( + incoming_cmd_vel, + "forward_flow_clear", + mean_flow_magnitude, + ) + + def _prepare_gray_pair( + self, + previous_image: Image, + current_image: Image, + ) -> tuple[GrayImage, GrayImage]: + previous_gray = self._to_resized_gray(previous_image) + current_gray = self._to_resized_gray(current_image) + + shared_height = min(previous_gray.shape[0], current_gray.shape[0]) + shared_width = min(previous_gray.shape[1], current_gray.shape[1]) + + return ( + np.ascontiguousarray(previous_gray[:shared_height, :shared_width]), + np.ascontiguousarray(current_gray[:shared_height, :shared_width]), + ) + + def _to_resized_gray(self, image: Image) -> GrayImage: + gray = cast("GrayImage", image.to_grayscale().data) + if gray.dtype != np.uint8: + gray = cv2.convertScaleAbs(gray) # type: ignore[call-overload] + + height, width = gray.shape[:2] + if width <= 0 or height <= 0: + raise ValueError("Image has invalid dimensions") + + target_width = min(width, self._config.flow_downsample_width_px) + if target_width == width: + return cast("GrayImage", np.ascontiguousarray(gray)) + + scale = target_width / float(width) + target_height = max(round(height * scale), 2) + resized = cv2.resize( # type: ignore[call-overload] + gray, + (target_width, target_height), + interpolation=cv2.INTER_AREA, + ) + return cast("GrayImage", np.ascontiguousarray(resized)) + + def _extract_forward_rois( + self, + previous_gray: GrayImage, + current_gray: GrayImage, + ) -> tuple[GrayImage, GrayImage]: + height, width = current_gray.shape + x0, x1, y0, y1 = self._forward_roi_bounds(width=width, height=height) + + return ( + np.ascontiguousarray(previous_gray[y0:y1, x0:x1]), + np.ascontiguousarray(current_gray[y0:y1, x0:x1]), + ) + + def _forward_roi_bounds(self, *, width: int, height: int) -> tuple[int, int, int, int]: + roi_width = max(round(width * self._config.forward_roi_width_fraction), 2) + x0 = max((width - roi_width) // 2, 0) + x1 = min(x0 + roi_width, width) + + y0 = min(max(round(height * self._config.forward_roi_top_fraction), 0), height - 1) + y1 = min(max(round(height * self._config.forward_roi_bottom_fraction), y0 + 1), height) + + return x0, x1, y0, y1 + + def _is_low_texture(self, roi: GrayImage) -> bool: + return float(np.var(roi)) < self._config.low_texture_variance_threshold + + def _is_occluded(self, roi: GrayImage) -> bool: + dark_fraction = float(np.mean(roi <= self._config.occlusion_dark_pixel_threshold)) + bright_fraction = float(np.mean(roi >= self._config.occlusion_bright_pixel_threshold)) + return ( + max(dark_fraction, bright_fraction) >= self._config.occlusion_extreme_fraction_threshold + ) + + def _mean_flow_magnitude(self, previous_roi: GrayImage, current_roi: GrayImage) -> float: + flow = cv2.calcOpticalFlowFarneback( # type: ignore[call-overload] + previous_roi, + current_roi, + cast("Any", None), + self._FARNEBACK_PYR_SCALE, + self._FARNEBACK_LEVELS, + self._FARNEBACK_WINDOW_SIZE, + self._FARNEBACK_ITERATIONS, + self._FARNEBACK_POLY_N, + self._FARNEBACK_POLY_SIGMA, + self._FARNEBACK_FLAGS, + ) + + magnitude, _ = cv2.cartToPolar(flow[..., 0], flow[..., 1]) + return float(np.mean(magnitude)) + + def _next_state(self, mean_flow_magnitude: float) -> GuardrailState: + if mean_flow_magnitude >= self._config.stop_flow_magnitude_threshold: + self._stop_hits += 1 + # Stop-level flow is also caution-level flow. This lets us clamp + # first when stop evidence has not yet met its own persistence rule. + self._caution_hits += 1 + self._clear_hits = 0 + elif mean_flow_magnitude >= self._config.caution_flow_magnitude_threshold: + self._stop_hits = 0 + self._caution_hits += 1 + self._clear_hits = 0 + else: + self._stop_hits = 0 + self._caution_hits = 0 + self._clear_hits += 1 + + if self._hysteresis_state == GuardrailState.STOP_LATCHED: + if self._stop_hits >= self._config.stop_frame_count: + return GuardrailState.STOP_LATCHED + if self._clear_hits >= self._config.clear_frame_count: + return GuardrailState.PASS + return GuardrailState.CLAMP + + if self._hysteresis_state == GuardrailState.CLAMP: + if self._stop_hits >= self._config.stop_frame_count: + return GuardrailState.STOP_LATCHED + if self._clear_hits >= self._config.clear_frame_count: + return GuardrailState.PASS + return GuardrailState.CLAMP + + if self._stop_hits >= self._config.stop_frame_count: + return GuardrailState.STOP_LATCHED + + if self._caution_hits >= self._config.caution_frame_count: + return GuardrailState.CLAMP + + return GuardrailState.PASS + + def _reset_hysteresis(self) -> None: + """Reset internal clamp/stop evidence without changing module output state.""" + self._hysteresis_state = GuardrailState.PASS + self._caution_hits = 0 + self._stop_hits = 0 + self._clear_hits = 0 + + def _pass_decision( + self, + incoming_cmd_vel: Twist, + reason: str, + risk_score: float, + ) -> GuardrailDecision: + cmd_vel = Twist( + linear=incoming_cmd_vel.linear, + angular=incoming_cmd_vel.angular, + ) return GuardrailDecision( state=GuardrailState.PASS, - cmd_vel=incoming_cmd_vel, - reason="pass_through", - risk_score=0.0, - publish_immediately=False, + cmd_vel=cmd_vel, + reason=reason, + risk_score=risk_score, + ) + + def _clamp_forward_decision( + self, + incoming_cmd_vel: Twist, + reason: str, + risk_score: float, + ) -> GuardrailDecision: + cmd_vel = Twist( + linear=incoming_cmd_vel.linear, + angular=incoming_cmd_vel.angular, + ) + cmd_vel.linear.x = min(float(cmd_vel.linear.x), self._config.clamp_forward_speed_mps) + return GuardrailDecision( + state=GuardrailState.CLAMP, + cmd_vel=cmd_vel, + reason=reason, + risk_score=risk_score, + ) + + def _stop_forward_decision( + self, + incoming_cmd_vel: Twist, + reason: str, + risk_score: float, + ) -> GuardrailDecision: + cmd_vel = Twist( + linear=incoming_cmd_vel.linear, + angular=incoming_cmd_vel.angular, + ) + cmd_vel.linear.x = 0.0 + return GuardrailDecision( + state=GuardrailState.STOP_LATCHED, + cmd_vel=cmd_vel, + reason=reason, + risk_score=risk_score, + ) + + def _zero_decision( + self, + state: GuardrailState, + reason: str, + *, + risk_score: float = 0.0, + publish_immediately: bool = False, + ) -> GuardrailDecision: + return GuardrailDecision( + state=state, + cmd_vel=Twist.zero(), + reason=reason, + risk_score=risk_score, + publish_immediately=publish_immediately, ) diff --git a/dimos/control/safety/rgb_collision_guardrail.py b/dimos/control/safety/rgb_collision_guardrail.py index 12ef1b98ba..7c649ac1a7 100644 --- a/dimos/control/safety/rgb_collision_guardrail.py +++ b/dimos/control/safety/rgb_collision_guardrail.py @@ -17,9 +17,9 @@ from dataclasses import dataclass from threading import Condition, Event, Thread import time -from typing import Any +from typing import Any, Self -from pydantic import Field +from pydantic import Field, model_validator from reactivex.disposable import Disposable from dimos.control.safety.guardrail_policy import ( @@ -27,7 +27,8 @@ GuardrailHealth, GuardrailPolicy, GuardrailState, - PassThroughGuardrailPolicy, + OpticalFlowMagnitudeGuardrailPolicy, + OpticalFlowMagnitudePolicyConfig, ) from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig @@ -43,14 +44,60 @@ class RGBCollisionGuardrailConfig(ModuleConfig): + # Scheduling guarded_output_publish_hz: float = Field(default=10.0, gt=0.0) risk_evaluation_hz: float = Field(default=10.0, gt=0.0) + + # Freshness and fail-closed behavior command_timeout_s: float = Field(default=0.25, gt=0.0) image_timeout_s: float = Field(default=0.25, gt=0.0) risk_timeout_s: float = Field(default=0.25, gt=0.0) fail_closed_on_missing_image: bool = True publish_zero_on_stop: bool = True + # Motion gating + forward_motion_deadband_mps: float = Field(default=0.05, ge=0.0) + clamp_forward_speed_mps: float = Field(default=0.1, ge=0.0) + + # Forward ROI geometry + flow_downsample_width_px: int = Field(default=160, ge=32) + forward_roi_top_fraction: float = Field(default=0.45, ge=0.0, le=1.0) + forward_roi_bottom_fraction: float = Field(default=0.95, ge=0.0, le=1.0) + forward_roi_width_fraction: float = Field(default=0.5, gt=0.0, le=1.0) + + # Image-quality checks + low_texture_variance_threshold: float = Field(default=150.0, ge=0.0) + occlusion_dark_pixel_threshold: int = Field(default=20, ge=0, le=255) + occlusion_bright_pixel_threshold: int = Field(default=235, ge=0, le=255) + occlusion_extreme_fraction_threshold: float = Field(default=0.9, ge=0.0, le=1.0) + + # Flow thresholds and hysteresis + caution_flow_magnitude_threshold: float = Field(default=0.8, ge=0.0) + stop_flow_magnitude_threshold: float = Field(default=1.5, ge=0.0) + caution_frame_count: int = Field(default=2, ge=1) + stop_frame_count: int = Field(default=2, ge=1) + clear_frame_count: int = Field(default=3, ge=1) + + @model_validator(mode="after") + def validate_thresholds(self) -> Self: + if self.forward_roi_top_fraction >= self.forward_roi_bottom_fraction: + raise ValueError( + "forward_roi_top_fraction must be less than forward_roi_bottom_fraction" + ) + + if self.occlusion_dark_pixel_threshold >= self.occlusion_bright_pixel_threshold: + raise ValueError( + "occlusion_dark_pixel_threshold must be less than occlusion_bright_pixel_threshold" + ) + + if self.caution_flow_magnitude_threshold > self.stop_flow_magnitude_threshold: + raise ValueError( + "caution_flow_magnitude_threshold must be less than or equal to " + "stop_flow_magnitude_threshold" + ) + + return self + @dataclass class _GuardrailRuntimeState: @@ -98,8 +145,27 @@ def __init__(self, **kwargs: Any) -> None: self._runtime_state = _GuardrailRuntimeState() self._stop_event = Event() self._thread = None - # TODO: Replace placeholder policy with RGB optical-flow guardrail logic. - self._policy = PassThroughGuardrailPolicy() + self._policy = self._build_policy() + + def _build_policy(self) -> GuardrailPolicy: + policy_config = OpticalFlowMagnitudePolicyConfig( + forward_motion_deadband_mps=self.config.forward_motion_deadband_mps, + clamp_forward_speed_mps=self.config.clamp_forward_speed_mps, + flow_downsample_width_px=self.config.flow_downsample_width_px, + forward_roi_top_fraction=self.config.forward_roi_top_fraction, + forward_roi_bottom_fraction=self.config.forward_roi_bottom_fraction, + forward_roi_width_fraction=self.config.forward_roi_width_fraction, + low_texture_variance_threshold=self.config.low_texture_variance_threshold, + occlusion_dark_pixel_threshold=self.config.occlusion_dark_pixel_threshold, + occlusion_bright_pixel_threshold=self.config.occlusion_bright_pixel_threshold, + occlusion_extreme_fraction_threshold=self.config.occlusion_extreme_fraction_threshold, + caution_flow_magnitude_threshold=self.config.caution_flow_magnitude_threshold, + stop_flow_magnitude_threshold=self.config.stop_flow_magnitude_threshold, + caution_frame_count=self.config.caution_frame_count, + stop_frame_count=self.config.stop_frame_count, + clear_frame_count=self.config.clear_frame_count, + ) + return OpticalFlowMagnitudeGuardrailPolicy(policy_config) @rpc def start(self) -> None: @@ -249,14 +315,11 @@ def _is_risk_fresh_locked(self, now: float) -> bool: def _build_health_locked(self, now: float) -> GuardrailHealth: """Build a health snapshot from the current cached inputs.""" - # TODO: Populate these from image-quality checks. return GuardrailHealth( has_previous_frame=self._runtime_state.previous_image is not None, image_fresh=self._is_image_fresh_locked(now), cmd_fresh=self._is_cmd_fresh_locked(now), risk_fresh=self._is_risk_fresh_locked(now), - low_texture=False, - occluded=False, ) def _resolved_cmd_for_latest_locked(self, now: float) -> Twist: From a047ae5a53081719a0864f00c9d16e60f8fe8648 Mon Sep 17 00:00:00 2001 From: Pratham Gala <90299416+pmg5408@users.noreply.github.com> Date: Sat, 4 Apr 2026 21:28:25 +0000 Subject: [PATCH 5/8] feat(control): harden RGB guardrail state handling and frame validation Tightens the guardrail's fail-closed behavior by validating frame-pair freshness, checking image quality across both flow frames, and preventing stop-threshold detections from passing through unchanged. Also improves worker-side scheduling so risk evaluation only runs on new image data and adds logging around startup, shutdown, and policy failures. --- dimos/control/safety/guardrail_policy.py | 67 ++++++++++++++----- .../control/safety/rgb_collision_guardrail.py | 60 ++++++++++++++--- 2 files changed, 99 insertions(+), 28 deletions(-) diff --git a/dimos/control/safety/guardrail_policy.py b/dimos/control/safety/guardrail_policy.py index fa6f5aa129..14dedf860c 100644 --- a/dimos/control/safety/guardrail_policy.py +++ b/dimos/control/safety/guardrail_policy.py @@ -42,6 +42,7 @@ class GuardrailHealth: image_fresh: bool cmd_fresh: bool risk_fresh: bool + frame_pair_fresh: bool @dataclass @@ -76,6 +77,7 @@ class OpticalFlowMagnitudePolicyConfig: caution_frame_count: int stop_frame_count: int clear_frame_count: int + stop_release_frame_count: int class GuardrailPolicy(Protocol): @@ -107,6 +109,7 @@ def __init__(self, config: OpticalFlowMagnitudePolicyConfig) -> None: self._caution_hits = 0 self._stop_hits = 0 self._clear_hits = 0 + self._below_stop_hits = 0 def evaluate( self, @@ -132,9 +135,17 @@ def evaluate( publish_immediately=True, ) + if not health.frame_pair_fresh: + self._reset_hysteresis() + return self._zero_decision( + GuardrailState.SENSOR_DEGRADED, + "frame_pair_stale", + risk_score=1.0, + publish_immediately=True, + ) + forward_speed = float(incoming_cmd_vel.linear.x) if forward_speed <= self._config.forward_motion_deadband_mps: - self._reset_hysteresis() return self._pass_decision(incoming_cmd_vel, "forward_guard_inactive", 0.0) previous_gray, current_gray = self._prepare_gray_pair(previous_image, current_image) @@ -149,27 +160,19 @@ def evaluate( publish_immediately=True, ) - if self._is_occluded(current_roi): - self._reset_hysteresis() - return self._zero_decision( - GuardrailState.SENSOR_DEGRADED, - "forward_roi_occluded", - risk_score=1.0, - publish_immediately=True, - ) - - if self._is_low_texture(current_roi): + quality_failure_reason = self._quality_failure_reason(previous_roi, current_roi) + if quality_failure_reason is not None: self._reset_hysteresis() return self._zero_decision( GuardrailState.SENSOR_DEGRADED, - "forward_roi_low_texture", + quality_failure_reason, risk_score=1.0, publish_immediately=True, ) mean_flow_magnitude = self._mean_flow_magnitude(previous_roi, current_roi) next_state = self._next_state(mean_flow_magnitude) - self._active_state = next_state + self._hysteresis_state = next_state if next_state == GuardrailState.STOP_LATCHED: return self._stop_forward_decision( @@ -215,7 +218,7 @@ def _prepare_gray_pair( def _to_resized_gray(self, image: Image) -> GrayImage: gray = cast("GrayImage", image.to_grayscale().data) if gray.dtype != np.uint8: - gray = cv2.convertScaleAbs(gray) # type: ignore[call-overload] + gray = cast("GrayImage", cv2.convertScaleAbs(gray)) height, width = gray.shape[:2] if width <= 0 or height <= 0: @@ -247,6 +250,25 @@ def _extract_forward_rois( np.ascontiguousarray(current_gray[y0:y1, x0:x1]), ) + def _quality_failure_reason( + self, + previous_roi: GrayImage, + current_roi: GrayImage, + ) -> str | None: + if self._is_occluded(previous_roi): + return "previous_roi_occluded" + + if self._is_occluded(current_roi): + return "current_roi_occluded" + + if self._is_low_texture(previous_roi): + return "previous_roi_low_texture" + + if self._is_low_texture(current_roi): + return "current_roi_low_texture" + + return None + def _forward_roi_bounds(self, *, width: int, height: int) -> tuple[int, int, int, int]: roi_width = max(round(width * self._config.forward_roi_width_fraction), 2) x0 = max((width - roi_width) // 2, 0) @@ -287,46 +309,57 @@ def _mean_flow_magnitude(self, previous_roi: GrayImage, current_roi: GrayImage) def _next_state(self, mean_flow_magnitude: float) -> GuardrailState: if mean_flow_magnitude >= self._config.stop_flow_magnitude_threshold: self._stop_hits += 1 - # Stop-level flow is also caution-level flow. This lets us clamp - # first when stop evidence has not yet met its own persistence rule. self._caution_hits += 1 + self._below_stop_hits = 0 self._clear_hits = 0 elif mean_flow_magnitude >= self._config.caution_flow_magnitude_threshold: self._stop_hits = 0 self._caution_hits += 1 + self._below_stop_hits += 1 self._clear_hits = 0 else: self._stop_hits = 0 self._caution_hits = 0 + self._below_stop_hits += 1 self._clear_hits += 1 if self._hysteresis_state == GuardrailState.STOP_LATCHED: if self._stop_hits >= self._config.stop_frame_count: return GuardrailState.STOP_LATCHED + + if self._below_stop_hits < self._config.stop_release_frame_count: + return GuardrailState.STOP_LATCHED + if self._clear_hits >= self._config.clear_frame_count: return GuardrailState.PASS + return GuardrailState.CLAMP if self._hysteresis_state == GuardrailState.CLAMP: if self._stop_hits >= self._config.stop_frame_count: return GuardrailState.STOP_LATCHED + if self._clear_hits >= self._config.clear_frame_count: return GuardrailState.PASS + return GuardrailState.CLAMP if self._stop_hits >= self._config.stop_frame_count: return GuardrailState.STOP_LATCHED + if self._stop_hits > 0: + return GuardrailState.CLAMP + if self._caution_hits >= self._config.caution_frame_count: return GuardrailState.CLAMP return GuardrailState.PASS def _reset_hysteresis(self) -> None: - """Reset internal clamp/stop evidence without changing module output state.""" self._hysteresis_state = GuardrailState.PASS self._caution_hits = 0 self._stop_hits = 0 + self._below_stop_hits = 0 self._clear_hits = 0 def _pass_decision( diff --git a/dimos/control/safety/rgb_collision_guardrail.py b/dimos/control/safety/rgb_collision_guardrail.py index 7c649ac1a7..4386a55597 100644 --- a/dimos/control/safety/rgb_collision_guardrail.py +++ b/dimos/control/safety/rgb_collision_guardrail.py @@ -54,6 +54,7 @@ class RGBCollisionGuardrailConfig(ModuleConfig): risk_timeout_s: float = Field(default=0.25, gt=0.0) fail_closed_on_missing_image: bool = True publish_zero_on_stop: bool = True + frame_pair_max_gap_s: float = Field(default=0.2, gt=0.0) # Motion gating forward_motion_deadband_mps: float = Field(default=0.05, ge=0.0) @@ -77,6 +78,7 @@ class RGBCollisionGuardrailConfig(ModuleConfig): caution_frame_count: int = Field(default=2, ge=1) stop_frame_count: int = Field(default=2, ge=1) clear_frame_count: int = Field(default=3, ge=1) + stop_release_frame_count: int = Field(default=2, ge=1) @model_validator(mode="after") def validate_thresholds(self) -> Self: @@ -114,6 +116,8 @@ class _GuardrailRuntimeState: pending_cmd_update: bool = False pending_decision_publish: bool = False state: GuardrailState = GuardrailState.INIT + image_generation: int = 0 + last_evaluated_image_generation: int = -1 @dataclass(frozen=True) @@ -164,6 +168,7 @@ def _build_policy(self) -> GuardrailPolicy: caution_frame_count=self.config.caution_frame_count, stop_frame_count=self.config.stop_frame_count, clear_frame_count=self.config.clear_frame_count, + stop_release_frame_count=self.config.stop_release_frame_count, ) return OpticalFlowMagnitudeGuardrailPolicy(policy_config) @@ -187,6 +192,16 @@ def start(self) -> None: ) self._thread.start() + logger.info( + "RGB guardrail started", + risk_evaluation_hz=self.config.risk_evaluation_hz, + guarded_output_publish_hz=self.config.guarded_output_publish_hz, + command_timeout_s=self.config.command_timeout_s, + image_timeout_s=self.config.image_timeout_s, + risk_timeout_s=self.config.risk_timeout_s, + frame_pair_max_gap_s=self.config.frame_pair_max_gap_s, + ) + @rpc def stop(self) -> None: self._stop_event.set() @@ -198,8 +213,14 @@ def stop(self) -> None: if self._thread is not None: self._thread.join(timeout=_THREAD_JOIN_TIMEOUT_S) + if self._thread.is_alive(): + logger.warning( + "RGB guardrail worker thread did not stop within timeout", + timeout_s=_THREAD_JOIN_TIMEOUT_S, + ) self._thread = None + logger.info("RGB guardrail stopped") super().stop() def _on_color_image(self, image: Image) -> None: @@ -209,6 +230,7 @@ def _on_color_image(self, image: Image) -> None: self._runtime_state.previous_image_time = self._runtime_state.latest_image_time self._runtime_state.latest_image = image self._runtime_state.latest_image_time = now + self._runtime_state.image_generation += 1 self._condition.notify() def _on_incoming_cmd_vel(self, cmd_vel: Twist) -> None: @@ -247,7 +269,10 @@ def _decision_loop(self) -> None: health=risk_input.health, ) except Exception: - logger.exception("RGB guardrail policy evaluation failed") + logger.exception( + "RGB guardrail policy evaluation failed", + state=self._runtime_state.state.value, + ) policy_decision = self._build_sensor_degraded_decision( "policy_evaluation_failed" ) @@ -289,11 +314,17 @@ def _risk_evaluation_period_s(self) -> float: return 1.0 / self.config.risk_evaluation_hz def _should_recompute_risk_locked(self, now: float) -> bool: - """Return True when the next scheduled risk evaluation is due.""" next_risk_time = self._runtime_state.next_risk_time - if next_risk_time is None: - return True - return now >= next_risk_time + if next_risk_time is not None and now < next_risk_time: + return False + + if self._runtime_state.previous_image is None: + return False + + return ( + self._runtime_state.image_generation + != self._runtime_state.last_evaluated_image_generation + ) def _is_cmd_fresh_locked(self, now: float) -> bool: latest_cmd_time = self._runtime_state.latest_cmd_time @@ -307,6 +338,13 @@ def _is_image_fresh_locked(self, now: float) -> bool: return False return (now - latest_image_time) <= self.config.image_timeout_s + def _is_frame_pair_fresh_locked(self) -> bool: + previous_image_time = self._runtime_state.previous_image_time + latest_image_time = self._runtime_state.latest_image_time + if previous_image_time is None or latest_image_time is None: + return False + return (latest_image_time - previous_image_time) <= self.config.frame_pair_max_gap_s + def _is_risk_fresh_locked(self, now: float) -> bool: last_risk_time = self._runtime_state.last_risk_time if last_risk_time is None: @@ -314,16 +352,15 @@ def _is_risk_fresh_locked(self, now: float) -> bool: return (now - last_risk_time) <= self.config.risk_timeout_s def _build_health_locked(self, now: float) -> GuardrailHealth: - """Build a health snapshot from the current cached inputs.""" return GuardrailHealth( has_previous_frame=self._runtime_state.previous_image is not None, image_fresh=self._is_image_fresh_locked(now), cmd_fresh=self._is_cmd_fresh_locked(now), risk_fresh=self._is_risk_fresh_locked(now), + frame_pair_fresh=self._is_frame_pair_fresh_locked(), ) def _resolved_cmd_for_latest_locked(self, now: float) -> Twist: - """Resolve the command to publish for the latest cached upstream input.""" latest_cmd_vel = self._runtime_state.latest_cmd_vel if latest_cmd_vel is None: return Twist.zero() @@ -341,7 +378,6 @@ def _resolved_cmd_for_latest_locked(self, now: float) -> Twist: return last_decision.cmd_vel def _take_risk_evaluation_input_locked(self, now: float) -> _RiskEvaluationInput | None: - """Capture a consistent snapshot for policy evaluation and advance the risk deadline.""" previous_image = self._runtime_state.previous_image current_image = self._runtime_state.latest_image incoming_cmd_vel = self._runtime_state.latest_cmd_vel @@ -351,6 +387,8 @@ def _take_risk_evaluation_input_locked(self, now: float) -> _RiskEvaluationInput if previous_image is None or current_image is None or incoming_cmd_vel is None: return None + self._runtime_state.last_evaluated_image_generation = self._runtime_state.image_generation + return _RiskEvaluationInput( previous_image=previous_image, current_image=current_image, @@ -391,7 +429,6 @@ def _store_decision_locked( ) def _consume_publish_cmd_locked(self, now: float) -> Twist | None: - """Consume and return the next command that should be published, if any.""" if self._runtime_state.pending_decision_publish: self._runtime_state.pending_decision_publish = False return self._resolved_cmd_for_latest_locked(now) @@ -412,7 +449,6 @@ def _consume_publish_cmd_locked(self, now: float) -> Twist | None: return None def _should_republish_non_pass_output_locked(self, now: float) -> bool: - """Return True when a non-pass output should be republished on heartbeat.""" last_decision = self._runtime_state.last_decision if last_decision is None: return False @@ -427,7 +463,6 @@ def _should_republish_non_pass_output_locked(self, now: float) -> bool: return (now - last_publish_time) >= self._guarded_output_publish_period_s() def _next_wakeup_timeout_locked(self) -> float: - """Compute the next worker wakeup timeout from pending work and deadlines.""" now = time.monotonic() if self._runtime_state.pending_cmd_update or self._runtime_state.pending_decision_publish: @@ -514,6 +549,9 @@ def _select_fallback_decision_locked(self, now: float) -> GuardrailDecision | No if not self._is_image_fresh_locked(now): return self._build_sensor_degraded_decision("image_stale") + if not self._is_frame_pair_fresh_locked(): + return self._build_sensor_degraded_decision("frame_pair_stale") + if self._runtime_state.last_risk_time is None: return self._build_init_decision("waiting_for_first_risk_evaluation") From 378a608060e735af648c41cb49734effe37d4325 Mon Sep 17 00:00:00 2001 From: Pratham Gala <90299416+pmg5408@users.noreply.github.com> Date: Sat, 4 Apr 2026 21:30:15 +0000 Subject: [PATCH 6/8] test(control): add RGB guardrail policy and module coverage Adds focused tests for optical-flow policy transitions, fail-closed image quality handling, worker-thread command gating, and end-to-end guarded stream behavior. The suite covers startup, stale inputs, stop/clamp republication, fast-upstream handling, and concurrent stop safety. --- dimos/control/safety/test_guardrail_policy.py | 454 +++++++++++++++ .../safety/test_rgb_collision_guardrail.py | 524 ++++++++++++++++++ ...est_rgb_collision_guardrail_integration.py | 394 +++++++++++++ 3 files changed, 1372 insertions(+) create mode 100644 dimos/control/safety/test_guardrail_policy.py create mode 100644 dimos/control/safety/test_rgb_collision_guardrail.py create mode 100644 dimos/control/safety/test_rgb_collision_guardrail_integration.py diff --git a/dimos/control/safety/test_guardrail_policy.py b/dimos/control/safety/test_guardrail_policy.py new file mode 100644 index 0000000000..25c5eb52b1 --- /dev/null +++ b/dimos/control/safety/test_guardrail_policy.py @@ -0,0 +1,454 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import numpy as np +import pytest + +from dimos.control.safety.guardrail_policy import ( + GuardrailHealth, + GuardrailState, + OpticalFlowMagnitudeGuardrailPolicy, + OpticalFlowMagnitudePolicyConfig, +) +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat + + +def _policy_config( + *, + caution_frame_count: int = 2, + stop_frame_count: int = 2, + clear_frame_count: int = 3, + stop_release_frame_count: int = 2, +) -> OpticalFlowMagnitudePolicyConfig: + return OpticalFlowMagnitudePolicyConfig( + forward_motion_deadband_mps=0.05, + clamp_forward_speed_mps=0.1, + flow_downsample_width_px=160, + forward_roi_top_fraction=0.45, + forward_roi_bottom_fraction=0.95, + forward_roi_width_fraction=0.5, + low_texture_variance_threshold=150.0, + occlusion_dark_pixel_threshold=20, + occlusion_bright_pixel_threshold=235, + occlusion_extreme_fraction_threshold=0.9, + caution_flow_magnitude_threshold=0.8, + stop_flow_magnitude_threshold=1.5, + caution_frame_count=caution_frame_count, + stop_frame_count=stop_frame_count, + clear_frame_count=clear_frame_count, + stop_release_frame_count=stop_release_frame_count, + ) + + +def _forward_cmd( + x: float = 0.4, + *, + linear_y: float = 0.0, + linear_z: float = 0.0, + angular_z: float = 0.2, +) -> Twist: + return Twist( + linear=[x, linear_y, linear_z], + angular=[0.0, 0.0, angular_z], + ) + + +def _fresh_health( + *, + has_previous_frame: bool = True, + image_fresh: bool = True, + frame_pair_fresh: bool = True, +) -> GuardrailHealth: + return GuardrailHealth( + has_previous_frame=has_previous_frame, + image_fresh=image_fresh, + cmd_fresh=True, + risk_fresh=True, + frame_pair_fresh=frame_pair_fresh, + ) + + +def _textured_gray_image(*, width: int = 160, height: int = 120, shift_x: int = 0) -> Image: + yy, xx = np.indices((height, width)) + pattern = ((xx * 7 + yy * 11 + shift_x * 13) % 256).astype(np.uint8) + return Image.from_numpy(pattern, format=ImageFormat.GRAY) + + +def _uniform_gray_image(value: int, *, width: int = 160, height: int = 120) -> Image: + return Image.from_numpy( + np.full((height, width), value, dtype=np.uint8), + format=ImageFormat.GRAY, + ) + + +@pytest.fixture +def image_pair() -> tuple[Image, Image]: + return (_textured_gray_image(), _textured_gray_image(shift_x=3)) + + +@pytest.mark.parametrize( + "cmd", + [ + pytest.param(_forward_cmd(0.03, angular_z=0.35), id="below_forward_deadband"), + pytest.param(_forward_cmd(-0.2, linear_y=0.1, angular_z=0.4), id="reverse_motion"), + pytest.param( + Twist(linear=[0.0, 0.0, 0.0], angular=[0.0, 0.0, 0.6]), + id="pure_yaw", + ), + ], +) +def test_forward_guard_inactive_passthrough(image_pair: tuple[Image, Image], cmd: Twist) -> None: + policy = OpticalFlowMagnitudeGuardrailPolicy(_policy_config()) + + decision = policy.evaluate( + previous_image=image_pair[0], + current_image=image_pair[1], + incoming_cmd_vel=cmd, + health=_fresh_health(), + ) + + assert decision.state == GuardrailState.PASS + assert decision.reason == "forward_guard_inactive" + assert decision.cmd_vel == cmd + + +def test_missing_previous_frame_returns_init_zero(image_pair: tuple[Image, Image]) -> None: + policy = OpticalFlowMagnitudeGuardrailPolicy(_policy_config()) + + decision = policy.evaluate( + previous_image=image_pair[0], + current_image=image_pair[1], + incoming_cmd_vel=_forward_cmd(), + health=_fresh_health(has_previous_frame=False), + ) + + assert decision.state == GuardrailState.INIT + assert decision.reason == "missing_previous_frame" + assert decision.cmd_vel == Twist.zero() + + +def test_stale_image_health_degrades_to_zero(image_pair: tuple[Image, Image]) -> None: + policy = OpticalFlowMagnitudeGuardrailPolicy(_policy_config()) + + decision = policy.evaluate( + previous_image=image_pair[0], + current_image=image_pair[1], + incoming_cmd_vel=_forward_cmd(), + health=_fresh_health(image_fresh=False), + ) + + assert decision.state == GuardrailState.SENSOR_DEGRADED + assert decision.reason == "image_not_fresh" + assert decision.cmd_vel == Twist.zero() + assert decision.publish_immediately is True + + +@pytest.mark.parametrize( + ("bad_frame_position", "bad_frame_kind", "expected_reason"), + [ + pytest.param("previous", "black", "previous_roi_occluded", id="previous_black_occluded"), + pytest.param("current", "black", "current_roi_occluded", id="current_black_occluded"), + pytest.param("previous", "white", "previous_roi_occluded", id="previous_white_occluded"), + pytest.param("current", "white", "current_roi_occluded", id="current_white_occluded"), + pytest.param( + "previous", + "uniform_gray", + "previous_roi_low_texture", + id="previous_low_texture", + ), + pytest.param( + "current", + "uniform_gray", + "current_roi_low_texture", + id="current_low_texture", + ), + ], +) +def test_bad_previous_or_current_roi_fail_closes( + image_pair: tuple[Image, Image], + bad_frame_position: str, + bad_frame_kind: str, + expected_reason: str, +) -> None: + policy = OpticalFlowMagnitudeGuardrailPolicy(_policy_config()) + + if bad_frame_kind == "black": + bad_image = _uniform_gray_image(0) + elif bad_frame_kind == "white": + bad_image = _uniform_gray_image(255) + else: + bad_image = _uniform_gray_image(127) + + previous_image, current_image = image_pair + if bad_frame_position == "previous": + previous_image = bad_image + else: + current_image = bad_image + + decision = policy.evaluate( + previous_image=previous_image, + current_image=current_image, + incoming_cmd_vel=_forward_cmd(), + health=_fresh_health(), + ) + + assert decision.state == GuardrailState.SENSOR_DEGRADED + assert decision.reason == expected_reason + assert decision.cmd_vel == Twist.zero() + assert decision.publish_immediately is True + + +def test_caution_hysteresis_reaches_clamp(image_pair: tuple[Image, Image], mocker) -> None: + policy = OpticalFlowMagnitudeGuardrailPolicy(_policy_config()) + mocker.patch.object( + policy, + "_mean_flow_magnitude", + side_effect=[0.9, 0.9], + ) + + first = policy.evaluate( + previous_image=image_pair[0], + current_image=image_pair[1], + incoming_cmd_vel=_forward_cmd(), + health=_fresh_health(), + ) + second = policy.evaluate( + previous_image=image_pair[0], + current_image=image_pair[1], + incoming_cmd_vel=_forward_cmd(), + health=_fresh_health(), + ) + + assert first.state == GuardrailState.PASS + assert second.state == GuardrailState.CLAMP + assert second.reason == "forward_flow_clamp" + assert second.cmd_vel.linear.x == pytest.approx(0.1) + + +def test_first_stop_strength_frame_clamps_immediately( + image_pair: tuple[Image, Image], mocker +) -> None: + policy = OpticalFlowMagnitudeGuardrailPolicy(_policy_config()) + cmd = _forward_cmd(0.45, angular_z=0.55) + + mocker.patch.object(policy, "_mean_flow_magnitude", return_value=1.8) + + decision = policy.evaluate( + previous_image=image_pair[0], + current_image=image_pair[1], + incoming_cmd_vel=cmd, + health=_fresh_health(), + ) + + assert decision.state == GuardrailState.CLAMP + assert decision.reason == "forward_flow_clamp" + assert decision.cmd_vel.linear.x == pytest.approx(0.1) + assert decision.cmd_vel.angular.z == pytest.approx(cmd.angular.z) + + +def test_repeated_stop_strength_frames_reach_stop_latched( + image_pair: tuple[Image, Image], mocker +) -> None: + policy = OpticalFlowMagnitudeGuardrailPolicy(_policy_config()) + cmd = _forward_cmd(0.45, angular_z=0.55) + + mocker.patch.object( + policy, + "_mean_flow_magnitude", + side_effect=[1.8, 1.8], + ) + + first = policy.evaluate( + previous_image=image_pair[0], + current_image=image_pair[1], + incoming_cmd_vel=cmd, + health=_fresh_health(), + ) + second = policy.evaluate( + previous_image=image_pair[0], + current_image=image_pair[1], + incoming_cmd_vel=cmd, + health=_fresh_health(), + ) + + assert first.state == GuardrailState.CLAMP + assert first.reason == "forward_flow_clamp" + assert second.state == GuardrailState.STOP_LATCHED + assert second.reason == "forward_flow_stop" + assert second.cmd_vel.linear.x == pytest.approx(0.0) + assert second.cmd_vel.angular.z == pytest.approx(cmd.angular.z) + + +def test_stop_latched_does_not_release_on_first_clear_frame( + image_pair: tuple[Image, Image], mocker +) -> None: + policy = OpticalFlowMagnitudeGuardrailPolicy(_policy_config()) + + mocker.patch.object( + policy, + "_mean_flow_magnitude", + side_effect=[1.8, 1.8, 0.0], + ) + + states = [ + policy.evaluate( + previous_image=image_pair[0], + current_image=image_pair[1], + incoming_cmd_vel=_forward_cmd(), + health=_fresh_health(), + ).state + for _ in range(3) + ] + + assert states == [ + GuardrailState.CLAMP, + GuardrailState.STOP_LATCHED, + GuardrailState.STOP_LATCHED, + ] + + +def test_stop_latched_recovery_requires_clear_frames( + image_pair: tuple[Image, Image], mocker +) -> None: + policy = OpticalFlowMagnitudeGuardrailPolicy(_policy_config(clear_frame_count=3)) + + mocker.patch.object( + policy, + "_mean_flow_magnitude", + side_effect=[1.8, 1.8, 0.0, 0.0, 0.0], + ) + + states = [ + policy.evaluate( + previous_image=image_pair[0], + current_image=image_pair[1], + incoming_cmd_vel=_forward_cmd(), + health=_fresh_health(), + ).state + for _ in range(5) + ] + + assert states[0] == GuardrailState.CLAMP + assert states[1] == GuardrailState.STOP_LATCHED + assert states[2] != GuardrailState.PASS + assert states[3] != GuardrailState.PASS + assert states[4] == GuardrailState.PASS + + +def test_recovery_after_clear_frames_returns_to_pass( + image_pair: tuple[Image, Image], mocker +) -> None: + policy = OpticalFlowMagnitudeGuardrailPolicy(_policy_config()) + + mocker.patch.object( + policy, + "_mean_flow_magnitude", + side_effect=[0.9, 0.9, 0.0, 0.0, 0.0], + ) + + states = [ + policy.evaluate( + previous_image=image_pair[0], + current_image=image_pair[1], + incoming_cmd_vel=_forward_cmd(), + health=_fresh_health(), + ).state + for _ in range(5) + ] + + assert states == [ + GuardrailState.PASS, + GuardrailState.CLAMP, + GuardrailState.CLAMP, + GuardrailState.CLAMP, + GuardrailState.PASS, + ] + + +def test_forward_deadband_does_not_reset_hysteresis( + image_pair: tuple[Image, Image], mocker +) -> None: + policy = OpticalFlowMagnitudeGuardrailPolicy(_policy_config(caution_frame_count=2)) + + mocker.patch.object( + policy, + "_mean_flow_magnitude", + side_effect=[0.9, 0.9], + ) + + first = policy.evaluate( + previous_image=image_pair[0], + current_image=image_pair[1], + incoming_cmd_vel=_forward_cmd(0.4), + health=_fresh_health(), + ) + inactive = policy.evaluate( + previous_image=image_pair[0], + current_image=image_pair[1], + incoming_cmd_vel=_forward_cmd(0.03), + health=_fresh_health(), + ) + third = policy.evaluate( + previous_image=image_pair[0], + current_image=image_pair[1], + incoming_cmd_vel=_forward_cmd(0.4), + health=_fresh_health(), + ) + + assert first.state == GuardrailState.PASS + assert inactive.state == GuardrailState.PASS + assert inactive.reason == "forward_guard_inactive" + assert third.state == GuardrailState.CLAMP + assert third.reason == "forward_flow_clamp" + + +def test_clamp_preserves_angular_terms(image_pair: tuple[Image, Image], mocker) -> None: + policy = OpticalFlowMagnitudeGuardrailPolicy(_policy_config(caution_frame_count=1)) + mocker.patch.object(policy, "_mean_flow_magnitude", return_value=0.9) + cmd = _forward_cmd(0.4, linear_y=0.15, linear_z=-0.1, angular_z=0.65) + + decision = policy.evaluate( + previous_image=image_pair[0], + current_image=image_pair[1], + incoming_cmd_vel=cmd, + health=_fresh_health(), + ) + + assert decision.state == GuardrailState.CLAMP + assert decision.cmd_vel.linear.x == pytest.approx(0.1) + assert decision.cmd_vel.linear.y == pytest.approx(cmd.linear.y) + assert decision.cmd_vel.linear.z == pytest.approx(cmd.linear.z) + assert decision.cmd_vel.angular.z == pytest.approx(cmd.angular.z) + + +def test_stop_zeroes_only_linear_x(image_pair: tuple[Image, Image], mocker) -> None: + policy = OpticalFlowMagnitudeGuardrailPolicy(_policy_config(stop_frame_count=1)) + mocker.patch.object(policy, "_mean_flow_magnitude", return_value=1.8) + cmd = _forward_cmd(0.45, linear_y=0.2, linear_z=-0.1, angular_z=0.75) + + decision = policy.evaluate( + previous_image=image_pair[0], + current_image=image_pair[1], + incoming_cmd_vel=cmd, + health=_fresh_health(), + ) + + assert decision.state == GuardrailState.STOP_LATCHED + assert decision.cmd_vel.linear.x == pytest.approx(0.0) + assert decision.cmd_vel.linear.y == pytest.approx(cmd.linear.y) + assert decision.cmd_vel.linear.z == pytest.approx(cmd.linear.z) + assert decision.cmd_vel.angular.z == pytest.approx(cmd.angular.z) diff --git a/dimos/control/safety/test_rgb_collision_guardrail.py b/dimos/control/safety/test_rgb_collision_guardrail.py new file mode 100644 index 0000000000..2a68a83e18 --- /dev/null +++ b/dimos/control/safety/test_rgb_collision_guardrail.py @@ -0,0 +1,524 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Callable +import queue +import threading +import time +from typing import Any, TypeVar + +import numpy as np +import pytest + +from dimos.control.safety.guardrail_policy import ( + GuardrailDecision, + GuardrailHealth, + GuardrailState, +) +from dimos.control.safety.rgb_collision_guardrail import RGBCollisionGuardrail +from dimos.core.stream import Out, Transport +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat + +T = TypeVar("T") + + +class FakeTransport(Transport[T]): + def __init__(self) -> None: + self._subscribers: list[Callable[[T], Any]] = [] + + def start(self) -> None: + pass + + def stop(self) -> None: + pass + + def broadcast(self, selfstream: Out[T] | None, value: T) -> None: + for callback in list(self._subscribers): + callback(value) + + def subscribe( + self, + callback: Callable[[T], Any], + selfstream=None, # type: ignore[no-untyped-def] + ) -> Callable[[], None]: + self._subscribers.append(callback) + + def unsubscribe() -> None: + self._subscribers.remove(callback) + + return unsubscribe + + +class SequencePolicy: + def __init__(self, decisions: list[GuardrailDecision]) -> None: + self._decisions = decisions + self._index = 0 + + def evaluate( + self, + previous_image: Image, + current_image: Image, + incoming_cmd_vel: Twist, + health: GuardrailHealth, + ) -> GuardrailDecision: + if self._index < len(self._decisions): + decision = self._decisions[self._index] + self._index += 1 + return decision + return self._decisions[-1] + + +class RaisingPolicy: + def evaluate( + self, + previous_image: Image, + current_image: Image, + incoming_cmd_vel: Twist, + health: GuardrailHealth, + ) -> GuardrailDecision: + raise RuntimeError("synthetic policy failure") + + +class CountingPassPolicy: + def __init__(self) -> None: + self._lock = threading.Lock() + self._call_count = 0 + + @property + def call_count(self) -> int: + with self._lock: + return self._call_count + + def evaluate( + self, + previous_image: Image, + current_image: Image, + incoming_cmd_vel: Twist, + health: GuardrailHealth, + ) -> GuardrailDecision: + with self._lock: + self._call_count += 1 + + return GuardrailDecision( + state=GuardrailState.PASS, + cmd_vel=Twist( + linear=incoming_cmd_vel.linear, + angular=incoming_cmd_vel.angular, + ), + reason="counting_pass", + ) + + +def _textured_gray_image(*, width: int = 160, height: int = 120, shift_x: int = 0) -> Image: + yy, xx = np.indices((height, width)) + pattern = ((xx * 5 + yy * 9 + shift_x * 17) % 256).astype(np.uint8) + return Image.from_numpy(pattern, format=ImageFormat.GRAY) + + +def _cmd( + x: float = 0.35, + *, + linear_y: float = 0.0, + angular_z: float = 0.25, +) -> Twist: + return Twist( + linear=[x, linear_y, 0.0], + angular=[0.0, 0.0, angular_z], + ) + + +def _decision( + state: GuardrailState, + cmd_vel: Twist, + *, + reason: str = "test", + publish_immediately: bool = False, +) -> GuardrailDecision: + return GuardrailDecision( + state=state, + cmd_vel=cmd_vel, + reason=reason, + publish_immediately=publish_immediately, + ) + + +@pytest.fixture +def module() -> RGBCollisionGuardrail: + guardrail = RGBCollisionGuardrail( + guarded_output_publish_hz=50.0, + risk_evaluation_hz=50.0, + command_timeout_s=0.05, + image_timeout_s=0.05, + risk_timeout_s=0.05, + ) + yield guardrail + guardrail._close_module() + + +def _wait_for_output( + outputs: queue.Queue[Twist], + predicate: Callable[[Twist], bool], + *, + timeout_s: float = 0.5, +) -> Twist: + deadline = time.monotonic() + timeout_s + while time.monotonic() < deadline: + remaining = deadline - time.monotonic() + try: + candidate = outputs.get(timeout=max(remaining, 0.01)) + except queue.Empty: + continue + if predicate(candidate): + return candidate + raise AssertionError("Timed out waiting for matching guardrail output") + + +def _wait_for_decision( + guardrail: RGBCollisionGuardrail, + predicate: Callable[[GuardrailDecision], bool], + *, + timeout_s: float = 0.5, +) -> GuardrailDecision: + deadline = time.monotonic() + timeout_s + while time.monotonic() < deadline: + with guardrail._condition: + decision = guardrail._runtime_state.last_decision + if decision is not None and predicate(decision): + return decision + time.sleep(0.01) + + raise AssertionError("Timed out waiting for matching guardrail decision") + + +def _start_threaded_guardrail( + policy: Any, + **config_overrides: float, +) -> tuple[RGBCollisionGuardrail, FakeTransport[Image], FakeTransport[Twist], queue.Queue[Twist]]: + config: dict[str, float] = { + "guarded_output_publish_hz": 50.0, + "risk_evaluation_hz": 50.0, + "command_timeout_s": 0.3, + "image_timeout_s": 0.3, + "risk_timeout_s": 0.3, + } + config.update(config_overrides) + + guardrail = RGBCollisionGuardrail(**config) + image_transport: FakeTransport[Image] = FakeTransport() + cmd_transport: FakeTransport[Twist] = FakeTransport() + outputs: queue.Queue[Twist] = queue.Queue() + + guardrail.color_image.transport = image_transport + guardrail.incoming_cmd_vel.transport = cmd_transport + guardrail.safe_cmd_vel.subscribe(outputs.put) + guardrail._policy = policy + guardrail.start() + + return guardrail, image_transport, cmd_transport, outputs + + +def test_no_command_returns_init_zero(module: RGBCollisionGuardrail) -> None: + now = time.monotonic() + + with module._condition: + decision = module._select_fallback_decision_locked(now) + + assert decision is not None + assert decision.state == GuardrailState.INIT + assert decision.reason == "no_command_received" + assert decision.cmd_vel == Twist.zero() + + +def test_waiting_for_first_image_returns_init_zero(module: RGBCollisionGuardrail) -> None: + now = time.monotonic() + + with module._condition: + module._runtime_state.latest_cmd_vel = _cmd() + module._runtime_state.latest_cmd_time = now + decision = module._select_fallback_decision_locked(now) + + assert decision is not None + assert decision.state == GuardrailState.INIT + assert decision.reason == "waiting_for_first_image" + assert decision.cmd_vel == Twist.zero() + + +def test_no_frame_pair_returns_init_zero(module: RGBCollisionGuardrail) -> None: + now = time.monotonic() + + with module._condition: + module._runtime_state.latest_cmd_vel = _cmd() + module._runtime_state.latest_cmd_time = now + module._runtime_state.latest_image = _textured_gray_image() + module._runtime_state.latest_image_time = now + decision = module._select_fallback_decision_locked(now) + + assert decision is not None + assert decision.state == GuardrailState.INIT + assert decision.reason == "waiting_for_frame_pair" + assert decision.cmd_vel == Twist.zero() + + +def test_waiting_for_first_risk_evaluation_returns_init_zero( + module: RGBCollisionGuardrail, +) -> None: + now = time.monotonic() + + with module._condition: + module._runtime_state.latest_cmd_vel = _cmd() + module._runtime_state.latest_cmd_time = now + module._runtime_state.previous_image = _textured_gray_image() + module._runtime_state.previous_image_time = now + module._runtime_state.latest_image = _textured_gray_image(shift_x=2) + module._runtime_state.latest_image_time = now + module._runtime_state.last_risk_time = None + decision = module._select_fallback_decision_locked(now) + + assert decision is not None + assert decision.state == GuardrailState.INIT + assert decision.reason == "waiting_for_first_risk_evaluation" + assert decision.cmd_vel == Twist.zero() + + +def test_stale_image_returns_sensor_degraded_zero(module: RGBCollisionGuardrail) -> None: + now = time.monotonic() + stale_time = now - 0.2 + + with module._condition: + module._runtime_state.latest_cmd_vel = _cmd() + module._runtime_state.latest_cmd_time = now + module._runtime_state.previous_image = _textured_gray_image() + module._runtime_state.latest_image = _textured_gray_image(shift_x=2) + module._runtime_state.previous_image_time = stale_time + module._runtime_state.latest_image_time = stale_time + decision = module._select_fallback_decision_locked(now) + + assert decision is not None + assert decision.state == GuardrailState.SENSOR_DEGRADED + assert decision.reason == "image_stale" + assert decision.cmd_vel == Twist.zero() + + +def test_stale_risk_returns_sensor_degraded_zero(module: RGBCollisionGuardrail) -> None: + now = time.monotonic() + stale_risk_time = now - 0.2 + + with module._condition: + module._runtime_state.latest_cmd_vel = _cmd() + module._runtime_state.latest_cmd_time = now + module._runtime_state.previous_image = _textured_gray_image() + module._runtime_state.latest_image = _textured_gray_image(shift_x=2) + module._runtime_state.previous_image_time = now + module._runtime_state.latest_image_time = now + module._runtime_state.last_risk_time = stale_risk_time + decision = module._select_fallback_decision_locked(now) + + assert decision is not None + assert decision.state == GuardrailState.SENSOR_DEGRADED + assert decision.reason == "risk_state_stale" + assert decision.cmd_vel == Twist.zero() + + +def test_stale_command_publishes_zero_output(module: RGBCollisionGuardrail) -> None: + now = time.monotonic() + stale_cmd_time = now - 0.2 + + with module._condition: + module._runtime_state.latest_cmd_vel = _cmd() + module._runtime_state.latest_cmd_time = stale_cmd_time + module._runtime_state.last_decision = _decision(GuardrailState.PASS, _cmd()) + module._runtime_state.pending_cmd_update = True + cmd_to_publish = module._consume_publish_cmd_locked(now) + + assert cmd_to_publish == Twist.zero() + + +def test_policy_exception_fail_closes_to_zero() -> None: + guardrail, image_transport, cmd_transport, outputs = _start_threaded_guardrail( + RaisingPolicy(), + ) + + try: + cmd_transport.publish(_cmd(0.4, angular_z=0.3)) + image_transport.publish(_textured_gray_image()) + image_transport.publish(_textured_gray_image(shift_x=2)) + + observed = _wait_for_output(outputs, lambda twist: twist == Twist.zero()) + assert observed == Twist.zero() + + decision = _wait_for_decision( + guardrail, + lambda d: d.state == GuardrailState.SENSOR_DEGRADED + and d.reason == "policy_evaluation_failed", + timeout_s=0.5, + ) + + assert decision.state == GuardrailState.SENSOR_DEGRADED + assert decision.reason == "policy_evaluation_failed" + + with guardrail._condition: + assert guardrail._runtime_state.state == GuardrailState.SENSOR_DEGRADED + finally: + guardrail.stop() + + +def test_pass_publishes_latest_upstream_command() -> None: + upstream_first = _cmd(0.3, angular_z=0.1) + upstream_second = _cmd(0.45, angular_z=0.35) + misleading_policy_cmd = _cmd(0.02, angular_z=-0.2) + policy = SequencePolicy([_decision(GuardrailState.PASS, misleading_policy_cmd, reason="pass")]) + guardrail, image_transport, cmd_transport, outputs = _start_threaded_guardrail(policy) + + try: + cmd_transport.publish(upstream_first) + image_transport.publish(_textured_gray_image()) + image_transport.publish(_textured_gray_image(shift_x=2)) + + first_output = _wait_for_output(outputs, lambda twist: twist == upstream_first) + assert first_output == upstream_first + + cmd_transport.publish(upstream_second) + second_output = _wait_for_output(outputs, lambda twist: twist == upstream_second) + assert second_output == upstream_second + finally: + guardrail.stop() + + +@pytest.mark.parametrize( + ("state", "guarded_cmd"), + [ + ( + GuardrailState.CLAMP, + Twist(linear=[0.1, 0.0, 0.0], angular=[0.0, 0.0, 0.4]), + ), + ( + GuardrailState.STOP_LATCHED, + Twist(linear=[0.0, 0.0, 0.0], angular=[0.0, 0.0, 0.4]), + ), + ], +) +def test_non_pass_states_publish_guarded_output(state: GuardrailState, guarded_cmd: Twist) -> None: + upstream_cmd = _cmd(0.35, angular_z=0.4) + policy = SequencePolicy([_decision(state, guarded_cmd)]) + guardrail, image_transport, cmd_transport, outputs = _start_threaded_guardrail(policy) + + try: + cmd_transport.publish(upstream_cmd) + image_transport.publish(_textured_gray_image()) + image_transport.publish(_textured_gray_image(shift_x=2)) + + published = _wait_for_output(outputs, lambda twist: twist == guarded_cmd) + assert published == guarded_cmd + finally: + guardrail.stop() + + +def test_non_pass_heartbeat_republishes_guarded_output() -> None: + guarded_cmd = Twist(linear=[0.1, 0.0, 0.0], angular=[0.0, 0.0, 0.5]) + policy = SequencePolicy([_decision(GuardrailState.CLAMP, guarded_cmd)]) + guardrail, image_transport, cmd_transport, outputs = _start_threaded_guardrail(policy) + + try: + cmd_transport.publish(_cmd(0.4, angular_z=0.5)) + image_transport.publish(_textured_gray_image()) + image_transport.publish(_textured_gray_image(shift_x=2)) + + first = _wait_for_output(outputs, lambda twist: twist == guarded_cmd) + second = _wait_for_output(outputs, lambda twist: twist == guarded_cmd) + + assert first == guarded_cmd + assert second == guarded_cmd + finally: + guardrail.stop() + + +def test_non_pass_decision_can_publish_without_new_command() -> None: + upstream_cmd = _cmd(0.4, angular_z=0.3) + stop_cmd = Twist(linear=[0.0, 0.0, 0.0], angular=[0.0, 0.0, 0.3]) + policy = SequencePolicy( + [ + _decision(GuardrailState.PASS, upstream_cmd, reason="initial_pass"), + _decision( + GuardrailState.STOP_LATCHED, + stop_cmd, + reason="forced_stop", + publish_immediately=True, + ), + ] + ) + guardrail, image_transport, cmd_transport, outputs = _start_threaded_guardrail(policy) + + try: + cmd_transport.publish(upstream_cmd) + image_transport.publish(_textured_gray_image()) + image_transport.publish(_textured_gray_image(shift_x=2)) + + first_output = _wait_for_output(outputs, lambda twist: twist == upstream_cmd) + assert first_output == upstream_cmd + + image_transport.publish(_textured_gray_image(shift_x=4)) + + autonomous_stop = _wait_for_output(outputs, lambda twist: twist == stop_cmd) + assert autonomous_stop == stop_cmd + finally: + guardrail.stop() + + +def test_fast_upstream_commands_reuse_last_risk_decision() -> None: + policy = CountingPassPolicy() + guardrail, image_transport, cmd_transport, outputs = _start_threaded_guardrail( + policy, + guarded_output_publish_hz=100.0, + risk_evaluation_hz=2.0, + command_timeout_s=1.0, + image_timeout_s=1.0, + risk_timeout_s=1.0, + ) + + first_cmd = _cmd(0.20, angular_z=0.10) + second_cmd = _cmd(0.32, angular_z=0.20) + third_cmd = _cmd(0.44, angular_z=0.30) + + try: + cmd_transport.publish(first_cmd) + image_transport.publish(_textured_gray_image()) + image_transport.publish(_textured_gray_image(shift_x=2)) + + first_output = _wait_for_output(outputs, lambda twist: twist == first_cmd, timeout_s=0.6) + assert first_output == first_cmd + assert policy.call_count == 1 + + cmd_transport.publish(second_cmd) + second_output = _wait_for_output( + outputs, + lambda twist: twist == second_cmd, + timeout_s=0.2, + ) + assert second_output == second_cmd + + cmd_transport.publish(third_cmd) + third_output = _wait_for_output( + outputs, + lambda twist: twist == third_cmd, + timeout_s=0.2, + ) + assert third_output == third_cmd + + assert policy.call_count == 1 + assert policy.call_count < 3 + finally: + guardrail.stop() diff --git a/dimos/control/safety/test_rgb_collision_guardrail_integration.py b/dimos/control/safety/test_rgb_collision_guardrail_integration.py new file mode 100644 index 0000000000..4c9c0a9b0a --- /dev/null +++ b/dimos/control/safety/test_rgb_collision_guardrail_integration.py @@ -0,0 +1,394 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Callable +import queue +import threading +import time +from typing import Any, TypeVar + +import numpy as np +import pytest + +from dimos.control.safety.guardrail_policy import ( + GuardrailDecision, + GuardrailHealth, + GuardrailState, +) +from dimos.control.safety.rgb_collision_guardrail import RGBCollisionGuardrail +from dimos.core.stream import Out, Transport +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat + +T = TypeVar("T") + + +class FakeTransport(Transport[T]): + def __init__(self) -> None: + self._subscribers: list[Callable[[T], Any]] = [] + + def start(self) -> None: + pass + + def stop(self) -> None: + pass + + def broadcast(self, selfstream: Out[T] | None, value: T) -> None: + for callback in list(self._subscribers): + callback(value) + + def subscribe( + self, + callback: Callable[[T], Any], + selfstream=None, + ) -> Callable[[], None]: + self._subscribers.append(callback) + + def unsubscribe() -> None: + self._subscribers.remove(callback) + + return unsubscribe + + +class SequencePolicy: + def __init__(self, decisions: list[GuardrailDecision]) -> None: + self._decisions = decisions + self._index = 0 + + def evaluate( + self, + previous_image: Image, + current_image: Image, + incoming_cmd_vel: Twist, + health: GuardrailHealth, + ) -> GuardrailDecision: + if self._index < len(self._decisions): + decision = self._decisions[self._index] + self._index += 1 + return decision + return self._decisions[-1] + + +def _textured_gray_image(*, width: int = 160, height: int = 120, shift_x: int = 0) -> Image: + yy, xx = np.indices((height, width)) + pattern = ((xx * 5 + yy * 9 + shift_x * 17) % 256).astype(np.uint8) + return Image.from_numpy(pattern, format=ImageFormat.GRAY) + + +def _black_gray_image(*, width: int = 160, height: int = 120) -> Image: + return Image.from_numpy( + np.zeros((height, width), dtype=np.uint8), + format=ImageFormat.GRAY, + ) + + +def _cmd( + x: float = 0.35, + *, + linear_y: float = 0.0, + angular_z: float = 0.25, +) -> Twist: + return Twist( + linear=[x, linear_y, 0.0], + angular=[0.0, 0.0, angular_z], + ) + + +def _decision( + state: GuardrailState, + cmd_vel: Twist, + *, + reason: str = "test", + publish_immediately: bool = False, +) -> GuardrailDecision: + return GuardrailDecision( + state=state, + cmd_vel=cmd_vel, + reason=reason, + publish_immediately=publish_immediately, + ) + + +def _start_guardrail( + **config_overrides: float, +) -> tuple[ + RGBCollisionGuardrail, + FakeTransport[Image], + FakeTransport[Twist], + queue.Queue[tuple[float, Twist]], +]: + config: dict[str, float] = { + "guarded_output_publish_hz": 20.0, + "risk_evaluation_hz": 20.0, + "command_timeout_s": 0.3, + "image_timeout_s": 0.3, + "risk_timeout_s": 0.3, + } + config.update(config_overrides) + + guardrail = RGBCollisionGuardrail(**config) + image_transport: FakeTransport[Image] = FakeTransport() + cmd_transport: FakeTransport[Twist] = FakeTransport() + outputs: queue.Queue[tuple[float, Twist]] = queue.Queue() + + guardrail.color_image.transport = image_transport + guardrail.incoming_cmd_vel.transport = cmd_transport + guardrail.safe_cmd_vel.subscribe(lambda msg: outputs.put((time.monotonic(), msg))) + guardrail.start() + + return guardrail, image_transport, cmd_transport, outputs + + +@pytest.fixture +def started_guardrail() -> tuple[ + RGBCollisionGuardrail, + FakeTransport[Image], + FakeTransport[Twist], + queue.Queue[tuple[float, Twist]], +]: + guardrail, image_transport, cmd_transport, outputs = _start_guardrail() + + yield guardrail, image_transport, cmd_transport, outputs + + guardrail.stop() + + +def _wait_for_output( + outputs: queue.Queue[tuple[float, Twist]], + predicate: Callable[[Twist], bool], + *, + timeout_s: float = 1.0, +) -> tuple[float, Twist]: + deadline = time.monotonic() + timeout_s + while time.monotonic() < deadline: + remaining = deadline - time.monotonic() + try: + ts, msg = outputs.get(timeout=max(remaining, 0.01)) + except queue.Empty: + continue + if predicate(msg): + return ts, msg + raise AssertionError("Timed out waiting for matching guardrail output") + + +@pytest.mark.slow +def test_stream_wiring_end_to_end_passes_upstream_twist( + started_guardrail: tuple[ + RGBCollisionGuardrail, + FakeTransport[Image], + FakeTransport[Twist], + queue.Queue[tuple[float, Twist]], + ], +) -> None: + guardrail, image_transport, cmd_transport, outputs = started_guardrail + + upstream = _cmd(0.03, angular_z=0.15) + + cmd_transport.publish(upstream) + image_transport.publish(_textured_gray_image()) + image_transport.publish(_textured_gray_image(shift_x=2)) + + _, observed = _wait_for_output(outputs, lambda msg: msg == upstream) + assert observed == upstream + + +@pytest.mark.slow +def test_non_pass_output_is_republished_while_upstream_is_quiet() -> None: + guardrail, image_transport, cmd_transport, outputs = _start_guardrail( + guarded_output_publish_hz=20.0, + risk_evaluation_hz=20.0, + command_timeout_s=0.5, + image_timeout_s=0.5, + risk_timeout_s=0.5, + ) + guarded = Twist(linear=[0.1, 0.0, 0.0], angular=[0.0, 0.0, 0.2]) + + try: + guardrail._policy = SequencePolicy( + [_decision(GuardrailState.CLAMP, guarded, reason="forced_clamp")] + ) + + cmd_transport.publish(_cmd(0.4, angular_z=0.2)) + image_transport.publish(_textured_gray_image()) + image_transport.publish(_textured_gray_image(shift_x=2)) + + _wait_for_output(outputs, lambda msg: msg == guarded, timeout_s=1.0) + _wait_for_output(outputs, lambda msg: msg == guarded, timeout_s=1.0) + + t1, msg1 = _wait_for_output(outputs, lambda msg: msg == guarded, timeout_s=1.0) + t2, msg2 = _wait_for_output(outputs, lambda msg: msg == guarded, timeout_s=1.0) + + assert msg1 == guarded + assert msg2 == guarded + + interval = t2 - t1 + expected_period = 1.0 / guardrail.config.guarded_output_publish_hz + + assert expected_period * 0.5 <= interval <= expected_period * 2.0 + finally: + guardrail.stop() + + +@pytest.mark.slow +def test_forced_stop_never_leaks_positive_linear_x_under_concurrent_updates() -> None: + guardrail, image_transport, cmd_transport, outputs = _start_guardrail( + guarded_output_publish_hz=30.0, + risk_evaluation_hz=30.0, + command_timeout_s=0.5, + image_timeout_s=0.5, + risk_timeout_s=0.5, + ) + stop_cmd = Twist(linear=[0.0, 0.0, 0.0], angular=[0.0, 0.0, 0.2]) + + stop_event = threading.Event() + errors: list[Exception] = [] + + def publish_images() -> None: + try: + shift = 0 + image_transport.publish(_textured_gray_image(shift_x=shift)) + while not stop_event.is_set(): + shift += 1 + image_transport.publish(_textured_gray_image(shift_x=shift)) + time.sleep(0.01) + except Exception as exc: + errors.append(exc) + + def publish_commands() -> None: + try: + speeds = [0.4, 0.25, 0.5, 0.15, 0.35] + idx = 0 + while not stop_event.is_set(): + cmd_transport.publish(_cmd(speeds[idx % len(speeds)], angular_z=0.2)) + idx += 1 + time.sleep(0.01) + except Exception as exc: + errors.append(exc) + + try: + guardrail._policy = SequencePolicy( + [ + _decision( + GuardrailState.STOP_LATCHED, + stop_cmd, + reason="forced_stop", + publish_immediately=True, + ) + ] + ) + + image_thread = threading.Thread(target=publish_images, daemon=True) + cmd_thread = threading.Thread(target=publish_commands, daemon=True) + + image_thread.start() + cmd_thread.start() + + observed_outputs: list[Twist] = [] + while len(observed_outputs) < 8: + _, observed = _wait_for_output( + outputs, lambda msg: isinstance(msg, Twist), timeout_s=1.0 + ) + observed_outputs.append(observed) + + assert observed_outputs + assert all(twist.linear.x == pytest.approx(0.0) for twist in observed_outputs) + assert all(twist == stop_cmd or twist == Twist.zero() for twist in observed_outputs) + finally: + stop_event.set() + if "image_thread" in locals(): + image_thread.join(timeout=1.0) + if "cmd_thread" in locals(): + cmd_thread.join(timeout=1.0) + guardrail.stop() + + assert not errors + assert not image_thread.is_alive() + assert not cmd_thread.is_alive() + + +@pytest.mark.slow +def test_black_frame_end_to_end_fail_closes_to_zero( + started_guardrail: tuple[ + RGBCollisionGuardrail, + FakeTransport[Image], + FakeTransport[Twist], + queue.Queue[tuple[float, Twist]], + ], +) -> None: + guardrail, image_transport, cmd_transport, outputs = started_guardrail + + cmd_transport.publish(_cmd(0.3, angular_z=0.1)) + image_transport.publish(_textured_gray_image()) + image_transport.publish(_black_gray_image()) + + _, observed = _wait_for_output(outputs, lambda msg: msg == Twist.zero(), timeout_s=1.0) + assert observed == Twist.zero() + + +@pytest.mark.slow +def test_stale_image_end_to_end_fail_closes_without_new_command() -> None: + guardrail, image_transport, cmd_transport, outputs = _start_guardrail( + guarded_output_publish_hz=25.0, + risk_evaluation_hz=25.0, + command_timeout_s=0.5, + image_timeout_s=0.08, + risk_timeout_s=0.5, + ) + # Keep the initial command below deadband so this test is about + # autonomous stale-image fail-close, not optical-flow threshold tuning. + upstream = _cmd(0.03, angular_z=0.1) + + try: + cmd_transport.publish(upstream) + image_transport.publish(_textured_gray_image()) + image_transport.publish(_textured_gray_image(shift_x=2)) + + _, pass_output = _wait_for_output(outputs, lambda msg: msg == upstream, timeout_s=1.0) + assert pass_output == upstream + + _, degraded_output = _wait_for_output( + outputs, + lambda msg: msg == Twist.zero(), + timeout_s=0.4, + ) + assert degraded_output == Twist.zero() + finally: + guardrail.stop() + + +@pytest.mark.slow +def test_stale_command_end_to_end_fail_closes_to_zero( + started_guardrail: tuple[ + RGBCollisionGuardrail, + FakeTransport[Image], + FakeTransport[Twist], + queue.Queue[tuple[float, Twist]], + ], +) -> None: + guardrail, image_transport, cmd_transport, outputs = started_guardrail + + # Keep the initial command below deadband so this test is about + # stale-command fail-close, not optical-flow clamp behavior. + upstream = _cmd(0.03, angular_z=0.1) + + cmd_transport.publish(upstream) + image_transport.publish(_textured_gray_image()) + image_transport.publish(_textured_gray_image(shift_x=2)) + + _wait_for_output(outputs, lambda msg: msg == upstream, timeout_s=1.0) + _, observed = _wait_for_output(outputs, lambda msg: msg == Twist.zero(), timeout_s=1.0) + + assert observed == Twist.zero() From 86b6901e8cd7c05b619a6cc19d2d9e62996ae5a4 Mon Sep 17 00:00:00 2001 From: Pratham Gala <90299416+pmg5408@users.noreply.github.com> Date: Sat, 4 Apr 2026 23:24:24 +0000 Subject: [PATCH 7/8] test(control): share guardrail test helpers and tighten exception logging Extracts duplicated guardrail test helpers into a shared conftest to keep the safety test suite consistent as it evolves. Also captures guardrail state under lock before exception logging so failure logs reflect a stable runtime snapshot. --- .../control/safety/rgb_collision_guardrail.py | 5 +- dimos/control/safety/test_guardrail_policy.py | 9 +- .../safety/test_rgb_collision_guardrail.py | 94 ++------------- ...est_rgb_collision_guardrail_integration.py | 113 ++++-------------- 4 files changed, 39 insertions(+), 182 deletions(-) diff --git a/dimos/control/safety/rgb_collision_guardrail.py b/dimos/control/safety/rgb_collision_guardrail.py index 4386a55597..605279e378 100644 --- a/dimos/control/safety/rgb_collision_guardrail.py +++ b/dimos/control/safety/rgb_collision_guardrail.py @@ -269,9 +269,12 @@ def _decision_loop(self) -> None: health=risk_input.health, ) except Exception: + with self._condition: + _logged_state = self._runtime_state.state.value + logger.exception( "RGB guardrail policy evaluation failed", - state=self._runtime_state.state.value, + state=_logged_state, ) policy_decision = self._build_sensor_degraded_decision( "policy_evaluation_failed" diff --git a/dimos/control/safety/test_guardrail_policy.py b/dimos/control/safety/test_guardrail_policy.py index 25c5eb52b1..58003dcf75 100644 --- a/dimos/control/safety/test_guardrail_policy.py +++ b/dimos/control/safety/test_guardrail_policy.py @@ -23,6 +23,9 @@ OpticalFlowMagnitudeGuardrailPolicy, OpticalFlowMagnitudePolicyConfig, ) +from dimos.control.safety.test_utils import ( + _textured_gray_image, +) from dimos.msgs.geometry_msgs.Twist import Twist from dimos.msgs.sensor_msgs.Image import Image, ImageFormat @@ -82,12 +85,6 @@ def _fresh_health( ) -def _textured_gray_image(*, width: int = 160, height: int = 120, shift_x: int = 0) -> Image: - yy, xx = np.indices((height, width)) - pattern = ((xx * 7 + yy * 11 + shift_x * 13) % 256).astype(np.uint8) - return Image.from_numpy(pattern, format=ImageFormat.GRAY) - - def _uniform_gray_image(value: int, *, width: int = 160, height: int = 120) -> Image: return Image.from_numpy( np.full((height, width), value, dtype=np.uint8), diff --git a/dimos/control/safety/test_rgb_collision_guardrail.py b/dimos/control/safety/test_rgb_collision_guardrail.py index 2a68a83e18..f6c7212828 100644 --- a/dimos/control/safety/test_rgb_collision_guardrail.py +++ b/dimos/control/safety/test_rgb_collision_guardrail.py @@ -14,13 +14,12 @@ from __future__ import annotations -from collections.abc import Callable +from collections.abc import Callable, Iterator import queue import threading import time from typing import Any, TypeVar -import numpy as np import pytest from dimos.control.safety.guardrail_policy import ( @@ -29,59 +28,19 @@ GuardrailState, ) from dimos.control.safety.rgb_collision_guardrail import RGBCollisionGuardrail -from dimos.core.stream import Out, Transport +from dimos.control.safety.test_utils import ( + FakeTransport, + SequencePolicy, + _cmd, + _decision, + _textured_gray_image, +) from dimos.msgs.geometry_msgs.Twist import Twist -from dimos.msgs.sensor_msgs.Image import Image, ImageFormat +from dimos.msgs.sensor_msgs.Image import Image T = TypeVar("T") -class FakeTransport(Transport[T]): - def __init__(self) -> None: - self._subscribers: list[Callable[[T], Any]] = [] - - def start(self) -> None: - pass - - def stop(self) -> None: - pass - - def broadcast(self, selfstream: Out[T] | None, value: T) -> None: - for callback in list(self._subscribers): - callback(value) - - def subscribe( - self, - callback: Callable[[T], Any], - selfstream=None, # type: ignore[no-untyped-def] - ) -> Callable[[], None]: - self._subscribers.append(callback) - - def unsubscribe() -> None: - self._subscribers.remove(callback) - - return unsubscribe - - -class SequencePolicy: - def __init__(self, decisions: list[GuardrailDecision]) -> None: - self._decisions = decisions - self._index = 0 - - def evaluate( - self, - previous_image: Image, - current_image: Image, - incoming_cmd_vel: Twist, - health: GuardrailHealth, - ) -> GuardrailDecision: - if self._index < len(self._decisions): - decision = self._decisions[self._index] - self._index += 1 - return decision - return self._decisions[-1] - - class RaisingPolicy: def evaluate( self, @@ -123,41 +82,8 @@ def evaluate( ) -def _textured_gray_image(*, width: int = 160, height: int = 120, shift_x: int = 0) -> Image: - yy, xx = np.indices((height, width)) - pattern = ((xx * 5 + yy * 9 + shift_x * 17) % 256).astype(np.uint8) - return Image.from_numpy(pattern, format=ImageFormat.GRAY) - - -def _cmd( - x: float = 0.35, - *, - linear_y: float = 0.0, - angular_z: float = 0.25, -) -> Twist: - return Twist( - linear=[x, linear_y, 0.0], - angular=[0.0, 0.0, angular_z], - ) - - -def _decision( - state: GuardrailState, - cmd_vel: Twist, - *, - reason: str = "test", - publish_immediately: bool = False, -) -> GuardrailDecision: - return GuardrailDecision( - state=state, - cmd_vel=cmd_vel, - reason=reason, - publish_immediately=publish_immediately, - ) - - @pytest.fixture -def module() -> RGBCollisionGuardrail: +def module() -> Iterator[RGBCollisionGuardrail]: guardrail = RGBCollisionGuardrail( guarded_output_publish_hz=50.0, risk_evaluation_hz=50.0, diff --git a/dimos/control/safety/test_rgb_collision_guardrail_integration.py b/dimos/control/safety/test_rgb_collision_guardrail_integration.py index 4c9c0a9b0a..46e969cb84 100644 --- a/dimos/control/safety/test_rgb_collision_guardrail_integration.py +++ b/dimos/control/safety/test_rgb_collision_guardrail_integration.py @@ -14,80 +14,32 @@ from __future__ import annotations -from collections.abc import Callable +from collections.abc import Callable, Iterator import queue import threading import time -from typing import Any, TypeVar +from typing import TypeVar import numpy as np import pytest from dimos.control.safety.guardrail_policy import ( - GuardrailDecision, - GuardrailHealth, GuardrailState, ) from dimos.control.safety.rgb_collision_guardrail import RGBCollisionGuardrail -from dimos.core.stream import Out, Transport +from dimos.control.safety.test_utils import ( + FakeTransport, + SequencePolicy, + _cmd, + _decision, + _textured_gray_image, +) from dimos.msgs.geometry_msgs.Twist import Twist from dimos.msgs.sensor_msgs.Image import Image, ImageFormat T = TypeVar("T") -class FakeTransport(Transport[T]): - def __init__(self) -> None: - self._subscribers: list[Callable[[T], Any]] = [] - - def start(self) -> None: - pass - - def stop(self) -> None: - pass - - def broadcast(self, selfstream: Out[T] | None, value: T) -> None: - for callback in list(self._subscribers): - callback(value) - - def subscribe( - self, - callback: Callable[[T], Any], - selfstream=None, - ) -> Callable[[], None]: - self._subscribers.append(callback) - - def unsubscribe() -> None: - self._subscribers.remove(callback) - - return unsubscribe - - -class SequencePolicy: - def __init__(self, decisions: list[GuardrailDecision]) -> None: - self._decisions = decisions - self._index = 0 - - def evaluate( - self, - previous_image: Image, - current_image: Image, - incoming_cmd_vel: Twist, - health: GuardrailHealth, - ) -> GuardrailDecision: - if self._index < len(self._decisions): - decision = self._decisions[self._index] - self._index += 1 - return decision - return self._decisions[-1] - - -def _textured_gray_image(*, width: int = 160, height: int = 120, shift_x: int = 0) -> Image: - yy, xx = np.indices((height, width)) - pattern = ((xx * 5 + yy * 9 + shift_x * 17) % 256).astype(np.uint8) - return Image.from_numpy(pattern, format=ImageFormat.GRAY) - - def _black_gray_image(*, width: int = 160, height: int = 120) -> Image: return Image.from_numpy( np.zeros((height, width), dtype=np.uint8), @@ -95,33 +47,6 @@ def _black_gray_image(*, width: int = 160, height: int = 120) -> Image: ) -def _cmd( - x: float = 0.35, - *, - linear_y: float = 0.0, - angular_z: float = 0.25, -) -> Twist: - return Twist( - linear=[x, linear_y, 0.0], - angular=[0.0, 0.0, angular_z], - ) - - -def _decision( - state: GuardrailState, - cmd_vel: Twist, - *, - reason: str = "test", - publish_immediately: bool = False, -) -> GuardrailDecision: - return GuardrailDecision( - state=state, - cmd_vel=cmd_vel, - reason=reason, - publish_immediately=publish_immediately, - ) - - def _start_guardrail( **config_overrides: float, ) -> tuple[ @@ -153,11 +78,13 @@ def _start_guardrail( @pytest.fixture -def started_guardrail() -> tuple[ - RGBCollisionGuardrail, - FakeTransport[Image], - FakeTransport[Twist], - queue.Queue[tuple[float, Twist]], +def started_guardrail() -> Iterator[ + tuple[ + RGBCollisionGuardrail, + FakeTransport[Image], + FakeTransport[Twist], + queue.Queue[tuple[float, Twist]], + ] ]: guardrail, image_transport, cmd_transport, outputs = _start_guardrail() @@ -255,6 +182,8 @@ def test_forced_stop_never_leaks_positive_linear_x_under_concurrent_updates() -> stop_event = threading.Event() errors: list[Exception] = [] + image_thread: threading.Thread | None = None + cmd_thread: threading.Thread | None = None def publish_images() -> None: try: @@ -308,13 +237,15 @@ def publish_commands() -> None: assert all(twist == stop_cmd or twist == Twist.zero() for twist in observed_outputs) finally: stop_event.set() - if "image_thread" in locals(): + if image_thread is not None: image_thread.join(timeout=1.0) - if "cmd_thread" in locals(): + if cmd_thread is not None: cmd_thread.join(timeout=1.0) guardrail.stop() assert not errors + assert image_thread is not None + assert cmd_thread is not None assert not image_thread.is_alive() assert not cmd_thread.is_alive() From 0a07d6b547f26300bb27b1c76ba1bcbd500afc4b Mon Sep 17 00:00:00 2001 From: Pratham Gala <90299416+pmg5408@users.noreply.github.com> Date: Sat, 4 Apr 2026 23:29:19 +0000 Subject: [PATCH 8/8] test(control): extract shared RGB guardrail test helpers Moves duplicated guardrail test helpers into a shared test utility module so the safety test suite stays consistent as coverage grows. This keeps the module and integration tests aligned without relying on duplicated local helper definitions. --- dimos/control/safety/test_utils.py | 110 +++++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) create mode 100644 dimos/control/safety/test_utils.py diff --git a/dimos/control/safety/test_utils.py b/dimos/control/safety/test_utils.py new file mode 100644 index 0000000000..7e3326dc2f --- /dev/null +++ b/dimos/control/safety/test_utils.py @@ -0,0 +1,110 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Callable +from typing import Any, TypeVar + +import numpy as np + +from dimos.control.safety.guardrail_policy import ( + GuardrailDecision, + GuardrailHealth, + GuardrailState, +) +from dimos.core.stream import Out, Transport +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat + +T = TypeVar("T") + + +class FakeTransport(Transport[T]): + def __init__(self) -> None: + self._subscribers: list[Callable[[T], Any]] = [] + + def start(self) -> None: + pass + + def stop(self) -> None: + pass + + def broadcast(self, selfstream: Out[T] | None, value: T) -> None: + for callback in list(self._subscribers): + callback(value) + + def subscribe( + self, + callback: Callable[[T], Any], + selfstream=None, # type: ignore[no-untyped-def] + ) -> Callable[[], None]: + self._subscribers.append(callback) + + def unsubscribe() -> None: + self._subscribers.remove(callback) + + return unsubscribe + + +class SequencePolicy: + def __init__(self, decisions: list[GuardrailDecision]) -> None: + self._decisions = decisions + self._index = 0 + + def evaluate( + self, + previous_image: Image, + current_image: Image, + incoming_cmd_vel: Twist, + health: GuardrailHealth, + ) -> GuardrailDecision: + if self._index < len(self._decisions): + decision = self._decisions[self._index] + self._index += 1 + return decision + return self._decisions[-1] + + +def _textured_gray_image(*, width: int = 160, height: int = 120, shift_x: int = 0) -> Image: + yy, xx = np.indices((height, width)) + pattern = ((xx * 5 + yy * 9 + shift_x * 17) % 256).astype(np.uint8) + return Image.from_numpy(pattern, format=ImageFormat.GRAY) + + +def _cmd( + x: float = 0.35, + *, + linear_y: float = 0.0, + angular_z: float = 0.25, +) -> Twist: + return Twist( + linear=[x, linear_y, 0.0], + angular=[0.0, 0.0, angular_z], + ) + + +def _decision( + state: GuardrailState, + cmd_vel: Twist, + *, + reason: str = "test", + publish_immediately: bool = False, +) -> GuardrailDecision: + return GuardrailDecision( + state=state, + cmd_vel=cmd_vel, + reason=reason, + publish_immediately=publish_immediately, + )