diff --git a/dimos/control/safety/guardrail_policy.py b/dimos/control/safety/guardrail_policy.py new file mode 100644 index 0000000000..14dedf860c --- /dev/null +++ b/dimos/control/safety/guardrail_policy.py @@ -0,0 +1,432 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import Any, Protocol, cast + +import cv2 +import numpy as np +from numpy.typing import NDArray + +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.sensor_msgs.Image import Image + +GrayImage = NDArray[np.uint8] + + +class GuardrailState(str, Enum): + INIT = "init" + PASS = "pass" + CLAMP = "clamp" + STOP_LATCHED = "stop_latched" + SENSOR_DEGRADED = "sensor_degraded" + + +@dataclass(frozen=True) +class GuardrailHealth: + has_previous_frame: bool + image_fresh: bool + cmd_fresh: bool + risk_fresh: bool + frame_pair_fresh: bool + + +@dataclass +class GuardrailDecision: + """Policy result consumed by the guardrail worker. + + publish_immediately requests an immediate worker-side publish on the next + loop iteration. It does not bypass command freshness checks. + """ + + state: GuardrailState + cmd_vel: Twist + reason: str + risk_score: float = 0.0 + publish_immediately: bool = False + + +@dataclass(frozen=True) +class OpticalFlowMagnitudePolicyConfig: + forward_motion_deadband_mps: float + clamp_forward_speed_mps: float + flow_downsample_width_px: int + forward_roi_top_fraction: float + forward_roi_bottom_fraction: float + forward_roi_width_fraction: float + low_texture_variance_threshold: float + occlusion_dark_pixel_threshold: int + occlusion_bright_pixel_threshold: int + occlusion_extreme_fraction_threshold: float + caution_flow_magnitude_threshold: float + stop_flow_magnitude_threshold: float + caution_frame_count: int + stop_frame_count: int + clear_frame_count: int + stop_release_frame_count: int + + +class GuardrailPolicy(Protocol): + def evaluate( + self, + previous_image: Image, + current_image: Image, + incoming_cmd_vel: Twist, + health: GuardrailHealth, + ) -> GuardrailDecision: ... + + +class OpticalFlowMagnitudeGuardrailPolicy(GuardrailPolicy): + """Forward-motion RGB guardrail using flow magnitude in a central lower ROI.""" + + # V1 keeps Farneback internals fixed to reduce tuning surface. Promote + # these to config only after hardware tuning shows they need adjustment. + _FARNEBACK_PYR_SCALE = 0.5 + _FARNEBACK_LEVELS = 3 + _FARNEBACK_WINDOW_SIZE = 15 + _FARNEBACK_ITERATIONS = 3 + _FARNEBACK_POLY_N = 5 + _FARNEBACK_POLY_SIGMA = 1.2 + _FARNEBACK_FLAGS = 0 + + def __init__(self, config: OpticalFlowMagnitudePolicyConfig) -> None: + self._config = config + self._hysteresis_state = GuardrailState.PASS + self._caution_hits = 0 + self._stop_hits = 0 + self._clear_hits = 0 + self._below_stop_hits = 0 + + def evaluate( + self, + previous_image: Image, + current_image: Image, + incoming_cmd_vel: Twist, + health: GuardrailHealth, + ) -> GuardrailDecision: + if not health.has_previous_frame: + self._reset_hysteresis() + return self._zero_decision( + GuardrailState.INIT, + "missing_previous_frame", + risk_score=0.0, + ) + + if not health.image_fresh: + self._reset_hysteresis() + return self._zero_decision( + GuardrailState.SENSOR_DEGRADED, + "image_not_fresh", + risk_score=1.0, + publish_immediately=True, + ) + + if not health.frame_pair_fresh: + self._reset_hysteresis() + return self._zero_decision( + GuardrailState.SENSOR_DEGRADED, + "frame_pair_stale", + risk_score=1.0, + publish_immediately=True, + ) + + forward_speed = float(incoming_cmd_vel.linear.x) + if forward_speed <= self._config.forward_motion_deadband_mps: + return self._pass_decision(incoming_cmd_vel, "forward_guard_inactive", 0.0) + + previous_gray, current_gray = self._prepare_gray_pair(previous_image, current_image) + previous_roi, current_roi = self._extract_forward_rois(previous_gray, current_gray) + + if previous_roi.size == 0 or current_roi.size == 0: + self._reset_hysteresis() + return self._zero_decision( + GuardrailState.SENSOR_DEGRADED, + "invalid_forward_roi", + risk_score=1.0, + publish_immediately=True, + ) + + quality_failure_reason = self._quality_failure_reason(previous_roi, current_roi) + if quality_failure_reason is not None: + self._reset_hysteresis() + return self._zero_decision( + GuardrailState.SENSOR_DEGRADED, + quality_failure_reason, + risk_score=1.0, + publish_immediately=True, + ) + + mean_flow_magnitude = self._mean_flow_magnitude(previous_roi, current_roi) + next_state = self._next_state(mean_flow_magnitude) + self._hysteresis_state = next_state + + if next_state == GuardrailState.STOP_LATCHED: + return self._stop_forward_decision( + incoming_cmd_vel, + "forward_flow_stop", + mean_flow_magnitude, + ) + + if next_state == GuardrailState.CLAMP: + reason = ( + "forward_flow_clamp" + if mean_flow_magnitude >= self._config.caution_flow_magnitude_threshold + else "forward_flow_recovery" + ) + return self._clamp_forward_decision( + incoming_cmd_vel, + reason, + mean_flow_magnitude, + ) + + return self._pass_decision( + incoming_cmd_vel, + "forward_flow_clear", + mean_flow_magnitude, + ) + + def _prepare_gray_pair( + self, + previous_image: Image, + current_image: Image, + ) -> tuple[GrayImage, GrayImage]: + previous_gray = self._to_resized_gray(previous_image) + current_gray = self._to_resized_gray(current_image) + + shared_height = min(previous_gray.shape[0], current_gray.shape[0]) + shared_width = min(previous_gray.shape[1], current_gray.shape[1]) + + return ( + np.ascontiguousarray(previous_gray[:shared_height, :shared_width]), + np.ascontiguousarray(current_gray[:shared_height, :shared_width]), + ) + + def _to_resized_gray(self, image: Image) -> GrayImage: + gray = cast("GrayImage", image.to_grayscale().data) + if gray.dtype != np.uint8: + gray = cast("GrayImage", cv2.convertScaleAbs(gray)) + + height, width = gray.shape[:2] + if width <= 0 or height <= 0: + raise ValueError("Image has invalid dimensions") + + target_width = min(width, self._config.flow_downsample_width_px) + if target_width == width: + return cast("GrayImage", np.ascontiguousarray(gray)) + + scale = target_width / float(width) + target_height = max(round(height * scale), 2) + resized = cv2.resize( # type: ignore[call-overload] + gray, + (target_width, target_height), + interpolation=cv2.INTER_AREA, + ) + return cast("GrayImage", np.ascontiguousarray(resized)) + + def _extract_forward_rois( + self, + previous_gray: GrayImage, + current_gray: GrayImage, + ) -> tuple[GrayImage, GrayImage]: + height, width = current_gray.shape + x0, x1, y0, y1 = self._forward_roi_bounds(width=width, height=height) + + return ( + np.ascontiguousarray(previous_gray[y0:y1, x0:x1]), + np.ascontiguousarray(current_gray[y0:y1, x0:x1]), + ) + + def _quality_failure_reason( + self, + previous_roi: GrayImage, + current_roi: GrayImage, + ) -> str | None: + if self._is_occluded(previous_roi): + return "previous_roi_occluded" + + if self._is_occluded(current_roi): + return "current_roi_occluded" + + if self._is_low_texture(previous_roi): + return "previous_roi_low_texture" + + if self._is_low_texture(current_roi): + return "current_roi_low_texture" + + return None + + def _forward_roi_bounds(self, *, width: int, height: int) -> tuple[int, int, int, int]: + roi_width = max(round(width * self._config.forward_roi_width_fraction), 2) + x0 = max((width - roi_width) // 2, 0) + x1 = min(x0 + roi_width, width) + + y0 = min(max(round(height * self._config.forward_roi_top_fraction), 0), height - 1) + y1 = min(max(round(height * self._config.forward_roi_bottom_fraction), y0 + 1), height) + + return x0, x1, y0, y1 + + def _is_low_texture(self, roi: GrayImage) -> bool: + return float(np.var(roi)) < self._config.low_texture_variance_threshold + + def _is_occluded(self, roi: GrayImage) -> bool: + dark_fraction = float(np.mean(roi <= self._config.occlusion_dark_pixel_threshold)) + bright_fraction = float(np.mean(roi >= self._config.occlusion_bright_pixel_threshold)) + return ( + max(dark_fraction, bright_fraction) >= self._config.occlusion_extreme_fraction_threshold + ) + + def _mean_flow_magnitude(self, previous_roi: GrayImage, current_roi: GrayImage) -> float: + flow = cv2.calcOpticalFlowFarneback( # type: ignore[call-overload] + previous_roi, + current_roi, + cast("Any", None), + self._FARNEBACK_PYR_SCALE, + self._FARNEBACK_LEVELS, + self._FARNEBACK_WINDOW_SIZE, + self._FARNEBACK_ITERATIONS, + self._FARNEBACK_POLY_N, + self._FARNEBACK_POLY_SIGMA, + self._FARNEBACK_FLAGS, + ) + + magnitude, _ = cv2.cartToPolar(flow[..., 0], flow[..., 1]) + return float(np.mean(magnitude)) + + def _next_state(self, mean_flow_magnitude: float) -> GuardrailState: + if mean_flow_magnitude >= self._config.stop_flow_magnitude_threshold: + self._stop_hits += 1 + self._caution_hits += 1 + self._below_stop_hits = 0 + self._clear_hits = 0 + elif mean_flow_magnitude >= self._config.caution_flow_magnitude_threshold: + self._stop_hits = 0 + self._caution_hits += 1 + self._below_stop_hits += 1 + self._clear_hits = 0 + else: + self._stop_hits = 0 + self._caution_hits = 0 + self._below_stop_hits += 1 + self._clear_hits += 1 + + if self._hysteresis_state == GuardrailState.STOP_LATCHED: + if self._stop_hits >= self._config.stop_frame_count: + return GuardrailState.STOP_LATCHED + + if self._below_stop_hits < self._config.stop_release_frame_count: + return GuardrailState.STOP_LATCHED + + if self._clear_hits >= self._config.clear_frame_count: + return GuardrailState.PASS + + return GuardrailState.CLAMP + + if self._hysteresis_state == GuardrailState.CLAMP: + if self._stop_hits >= self._config.stop_frame_count: + return GuardrailState.STOP_LATCHED + + if self._clear_hits >= self._config.clear_frame_count: + return GuardrailState.PASS + + return GuardrailState.CLAMP + + if self._stop_hits >= self._config.stop_frame_count: + return GuardrailState.STOP_LATCHED + + if self._stop_hits > 0: + return GuardrailState.CLAMP + + if self._caution_hits >= self._config.caution_frame_count: + return GuardrailState.CLAMP + + return GuardrailState.PASS + + def _reset_hysteresis(self) -> None: + self._hysteresis_state = GuardrailState.PASS + self._caution_hits = 0 + self._stop_hits = 0 + self._below_stop_hits = 0 + self._clear_hits = 0 + + def _pass_decision( + self, + incoming_cmd_vel: Twist, + reason: str, + risk_score: float, + ) -> GuardrailDecision: + cmd_vel = Twist( + linear=incoming_cmd_vel.linear, + angular=incoming_cmd_vel.angular, + ) + return GuardrailDecision( + state=GuardrailState.PASS, + cmd_vel=cmd_vel, + reason=reason, + risk_score=risk_score, + ) + + def _clamp_forward_decision( + self, + incoming_cmd_vel: Twist, + reason: str, + risk_score: float, + ) -> GuardrailDecision: + cmd_vel = Twist( + linear=incoming_cmd_vel.linear, + angular=incoming_cmd_vel.angular, + ) + cmd_vel.linear.x = min(float(cmd_vel.linear.x), self._config.clamp_forward_speed_mps) + return GuardrailDecision( + state=GuardrailState.CLAMP, + cmd_vel=cmd_vel, + reason=reason, + risk_score=risk_score, + ) + + def _stop_forward_decision( + self, + incoming_cmd_vel: Twist, + reason: str, + risk_score: float, + ) -> GuardrailDecision: + cmd_vel = Twist( + linear=incoming_cmd_vel.linear, + angular=incoming_cmd_vel.angular, + ) + cmd_vel.linear.x = 0.0 + return GuardrailDecision( + state=GuardrailState.STOP_LATCHED, + cmd_vel=cmd_vel, + reason=reason, + risk_score=risk_score, + ) + + def _zero_decision( + self, + state: GuardrailState, + reason: str, + *, + risk_score: float = 0.0, + publish_immediately: bool = False, + ) -> GuardrailDecision: + return GuardrailDecision( + state=state, + cmd_vel=Twist.zero(), + reason=reason, + risk_score=risk_score, + publish_immediately=publish_immediately, + ) diff --git a/dimos/control/safety/rgb_collision_guardrail.py b/dimos/control/safety/rgb_collision_guardrail.py new file mode 100644 index 0000000000..605279e378 --- /dev/null +++ b/dimos/control/safety/rgb_collision_guardrail.py @@ -0,0 +1,567 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from threading import Condition, Event, Thread +import time +from typing import Any, Self + +from pydantic import Field, model_validator +from reactivex.disposable import Disposable + +from dimos.control.safety.guardrail_policy import ( + GuardrailDecision, + GuardrailHealth, + GuardrailPolicy, + GuardrailState, + OpticalFlowMagnitudeGuardrailPolicy, + OpticalFlowMagnitudePolicyConfig, +) +from dimos.core.core import rpc +from dimos.core.module import Module, ModuleConfig +from dimos.core.stream import In, Out +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.sensor_msgs.Image import Image +from dimos.utils.logging_config import setup_logger + +_THREAD_JOIN_TIMEOUT_S = 2.0 + + +logger = setup_logger() + + +class RGBCollisionGuardrailConfig(ModuleConfig): + # Scheduling + guarded_output_publish_hz: float = Field(default=10.0, gt=0.0) + risk_evaluation_hz: float = Field(default=10.0, gt=0.0) + + # Freshness and fail-closed behavior + command_timeout_s: float = Field(default=0.25, gt=0.0) + image_timeout_s: float = Field(default=0.25, gt=0.0) + risk_timeout_s: float = Field(default=0.25, gt=0.0) + fail_closed_on_missing_image: bool = True + publish_zero_on_stop: bool = True + frame_pair_max_gap_s: float = Field(default=0.2, gt=0.0) + + # Motion gating + forward_motion_deadband_mps: float = Field(default=0.05, ge=0.0) + clamp_forward_speed_mps: float = Field(default=0.1, ge=0.0) + + # Forward ROI geometry + flow_downsample_width_px: int = Field(default=160, ge=32) + forward_roi_top_fraction: float = Field(default=0.45, ge=0.0, le=1.0) + forward_roi_bottom_fraction: float = Field(default=0.95, ge=0.0, le=1.0) + forward_roi_width_fraction: float = Field(default=0.5, gt=0.0, le=1.0) + + # Image-quality checks + low_texture_variance_threshold: float = Field(default=150.0, ge=0.0) + occlusion_dark_pixel_threshold: int = Field(default=20, ge=0, le=255) + occlusion_bright_pixel_threshold: int = Field(default=235, ge=0, le=255) + occlusion_extreme_fraction_threshold: float = Field(default=0.9, ge=0.0, le=1.0) + + # Flow thresholds and hysteresis + caution_flow_magnitude_threshold: float = Field(default=0.8, ge=0.0) + stop_flow_magnitude_threshold: float = Field(default=1.5, ge=0.0) + caution_frame_count: int = Field(default=2, ge=1) + stop_frame_count: int = Field(default=2, ge=1) + clear_frame_count: int = Field(default=3, ge=1) + stop_release_frame_count: int = Field(default=2, ge=1) + + @model_validator(mode="after") + def validate_thresholds(self) -> Self: + if self.forward_roi_top_fraction >= self.forward_roi_bottom_fraction: + raise ValueError( + "forward_roi_top_fraction must be less than forward_roi_bottom_fraction" + ) + + if self.occlusion_dark_pixel_threshold >= self.occlusion_bright_pixel_threshold: + raise ValueError( + "occlusion_dark_pixel_threshold must be less than occlusion_bright_pixel_threshold" + ) + + if self.caution_flow_magnitude_threshold > self.stop_flow_magnitude_threshold: + raise ValueError( + "caution_flow_magnitude_threshold must be less than or equal to " + "stop_flow_magnitude_threshold" + ) + + return self + + +@dataclass +class _GuardrailRuntimeState: + latest_image: Image | None = None + previous_image: Image | None = None + latest_image_time: float | None = None + previous_image_time: float | None = None + latest_cmd_vel: Twist | None = None + latest_cmd_time: float | None = None + last_decision: GuardrailDecision | None = None + last_risk_time: float | None = None + last_publish_time: float | None = None + next_risk_time: float | None = None + pending_cmd_update: bool = False + pending_decision_publish: bool = False + state: GuardrailState = GuardrailState.INIT + image_generation: int = 0 + last_evaluated_image_generation: int = -1 + + +@dataclass(frozen=True) +class _RiskEvaluationInput: + previous_image: Image + current_image: Image + incoming_cmd_vel: Twist + health: GuardrailHealth + + +class RGBCollisionGuardrail(Module[RGBCollisionGuardrailConfig]): + """RGB-only motion guardrail for direct Twist control.""" + + default_config = RGBCollisionGuardrailConfig + + color_image: In[Image] + incoming_cmd_vel: In[Twist] + safe_cmd_vel: Out[Twist] + + _condition: Condition + _runtime_state: _GuardrailRuntimeState + _stop_event: Event + _thread: Thread | None + _policy: GuardrailPolicy + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._condition = Condition() + self._runtime_state = _GuardrailRuntimeState() + self._stop_event = Event() + self._thread = None + self._policy = self._build_policy() + + def _build_policy(self) -> GuardrailPolicy: + policy_config = OpticalFlowMagnitudePolicyConfig( + forward_motion_deadband_mps=self.config.forward_motion_deadband_mps, + clamp_forward_speed_mps=self.config.clamp_forward_speed_mps, + flow_downsample_width_px=self.config.flow_downsample_width_px, + forward_roi_top_fraction=self.config.forward_roi_top_fraction, + forward_roi_bottom_fraction=self.config.forward_roi_bottom_fraction, + forward_roi_width_fraction=self.config.forward_roi_width_fraction, + low_texture_variance_threshold=self.config.low_texture_variance_threshold, + occlusion_dark_pixel_threshold=self.config.occlusion_dark_pixel_threshold, + occlusion_bright_pixel_threshold=self.config.occlusion_bright_pixel_threshold, + occlusion_extreme_fraction_threshold=self.config.occlusion_extreme_fraction_threshold, + caution_flow_magnitude_threshold=self.config.caution_flow_magnitude_threshold, + stop_flow_magnitude_threshold=self.config.stop_flow_magnitude_threshold, + caution_frame_count=self.config.caution_frame_count, + stop_frame_count=self.config.stop_frame_count, + clear_frame_count=self.config.clear_frame_count, + stop_release_frame_count=self.config.stop_release_frame_count, + ) + return OpticalFlowMagnitudeGuardrailPolicy(policy_config) + + @rpc + def start(self) -> None: + super().start() + self._stop_event.clear() + + with self._condition: + self._runtime_state.next_risk_time = time.monotonic() + + self._disposables.add(Disposable(self.color_image.subscribe(self._on_color_image))) + self._disposables.add( + Disposable(self.incoming_cmd_vel.subscribe(self._on_incoming_cmd_vel)) + ) + + self._thread = Thread( + target=self._decision_loop, + name=f"{self.__class__.__name__}-thread", + daemon=True, + ) + self._thread.start() + + logger.info( + "RGB guardrail started", + risk_evaluation_hz=self.config.risk_evaluation_hz, + guarded_output_publish_hz=self.config.guarded_output_publish_hz, + command_timeout_s=self.config.command_timeout_s, + image_timeout_s=self.config.image_timeout_s, + risk_timeout_s=self.config.risk_timeout_s, + frame_pair_max_gap_s=self.config.frame_pair_max_gap_s, + ) + + @rpc + def stop(self) -> None: + self._stop_event.set() + with self._condition: + self._condition.notify_all() + + if self.config.publish_zero_on_stop: + self.safe_cmd_vel.publish(Twist.zero()) + + if self._thread is not None: + self._thread.join(timeout=_THREAD_JOIN_TIMEOUT_S) + if self._thread.is_alive(): + logger.warning( + "RGB guardrail worker thread did not stop within timeout", + timeout_s=_THREAD_JOIN_TIMEOUT_S, + ) + self._thread = None + + logger.info("RGB guardrail stopped") + super().stop() + + def _on_color_image(self, image: Image) -> None: + now = time.monotonic() + with self._condition: + self._runtime_state.previous_image = self._runtime_state.latest_image + self._runtime_state.previous_image_time = self._runtime_state.latest_image_time + self._runtime_state.latest_image = image + self._runtime_state.latest_image_time = now + self._runtime_state.image_generation += 1 + self._condition.notify() + + def _on_incoming_cmd_vel(self, cmd_vel: Twist) -> None: + now = time.monotonic() + with self._condition: + self._runtime_state.latest_cmd_vel = cmd_vel + self._runtime_state.latest_cmd_time = now + self._runtime_state.pending_cmd_update = True + self._condition.notify() + + def _decision_loop(self) -> None: + while not self._stop_event.is_set(): + risk_input: _RiskEvaluationInput | None = None + + with self._condition: + timeout_s = self._next_wakeup_timeout_locked() + self._condition.wait(timeout=timeout_s) + + if self._stop_event.is_set(): + return + + now = time.monotonic() + if self._should_recompute_risk_locked(now): + risk_input = self._take_risk_evaluation_input_locked(now) + + # Evaluate a consistent snapshot outside the condition lock so image + # and command callbacks stay cheap. If newer inputs arrive while + # evaluation runs, the next wakeup will process that fresher snapshot. + policy_decision: GuardrailDecision | None = None + if risk_input is not None: + try: + policy_decision = self._policy.evaluate( + previous_image=risk_input.previous_image, + current_image=risk_input.current_image, + incoming_cmd_vel=risk_input.incoming_cmd_vel, + health=risk_input.health, + ) + except Exception: + with self._condition: + _logged_state = self._runtime_state.state.value + + logger.exception( + "RGB guardrail policy evaluation failed", + state=_logged_state, + ) + policy_decision = self._build_sensor_degraded_decision( + "policy_evaluation_failed" + ) + + cmd_vel_to_publish: Twist | None = None + publish_time: float | None = None + + with self._condition: + now = time.monotonic() + + if policy_decision is not None: + self._store_decision_locked( + policy_decision, + now, + update_risk_time=True, + ) + else: + fallback_decision = self._select_fallback_decision_locked(now) + if fallback_decision is not None: + self._store_decision_locked( + fallback_decision, + now, + update_risk_time=False, + ) + + cmd_vel_to_publish = self._consume_publish_cmd_locked(now) + if cmd_vel_to_publish is not None: + publish_time = now + + if cmd_vel_to_publish is not None and publish_time is not None: + self.safe_cmd_vel.publish(cmd_vel_to_publish) + with self._condition: + self._runtime_state.last_publish_time = publish_time + + def _guarded_output_publish_period_s(self) -> float: + return 1.0 / self.config.guarded_output_publish_hz + + def _risk_evaluation_period_s(self) -> float: + return 1.0 / self.config.risk_evaluation_hz + + def _should_recompute_risk_locked(self, now: float) -> bool: + next_risk_time = self._runtime_state.next_risk_time + if next_risk_time is not None and now < next_risk_time: + return False + + if self._runtime_state.previous_image is None: + return False + + return ( + self._runtime_state.image_generation + != self._runtime_state.last_evaluated_image_generation + ) + + def _is_cmd_fresh_locked(self, now: float) -> bool: + latest_cmd_time = self._runtime_state.latest_cmd_time + if latest_cmd_time is None: + return False + return (now - latest_cmd_time) <= self.config.command_timeout_s + + def _is_image_fresh_locked(self, now: float) -> bool: + latest_image_time = self._runtime_state.latest_image_time + if latest_image_time is None: + return False + return (now - latest_image_time) <= self.config.image_timeout_s + + def _is_frame_pair_fresh_locked(self) -> bool: + previous_image_time = self._runtime_state.previous_image_time + latest_image_time = self._runtime_state.latest_image_time + if previous_image_time is None or latest_image_time is None: + return False + return (latest_image_time - previous_image_time) <= self.config.frame_pair_max_gap_s + + def _is_risk_fresh_locked(self, now: float) -> bool: + last_risk_time = self._runtime_state.last_risk_time + if last_risk_time is None: + return False + return (now - last_risk_time) <= self.config.risk_timeout_s + + def _build_health_locked(self, now: float) -> GuardrailHealth: + return GuardrailHealth( + has_previous_frame=self._runtime_state.previous_image is not None, + image_fresh=self._is_image_fresh_locked(now), + cmd_fresh=self._is_cmd_fresh_locked(now), + risk_fresh=self._is_risk_fresh_locked(now), + frame_pair_fresh=self._is_frame_pair_fresh_locked(), + ) + + def _resolved_cmd_for_latest_locked(self, now: float) -> Twist: + latest_cmd_vel = self._runtime_state.latest_cmd_vel + if latest_cmd_vel is None: + return Twist.zero() + + if not self._is_cmd_fresh_locked(now): + return Twist.zero() + + last_decision = self._runtime_state.last_decision + if last_decision is None: + return Twist.zero() + + if last_decision.state == GuardrailState.PASS: + return latest_cmd_vel + + return last_decision.cmd_vel + + def _take_risk_evaluation_input_locked(self, now: float) -> _RiskEvaluationInput | None: + previous_image = self._runtime_state.previous_image + current_image = self._runtime_state.latest_image + incoming_cmd_vel = self._runtime_state.latest_cmd_vel + + self._runtime_state.next_risk_time = now + self._risk_evaluation_period_s() + + if previous_image is None or current_image is None or incoming_cmd_vel is None: + return None + + self._runtime_state.last_evaluated_image_generation = self._runtime_state.image_generation + + return _RiskEvaluationInput( + previous_image=previous_image, + current_image=current_image, + incoming_cmd_vel=incoming_cmd_vel, + health=self._build_health_locked(now), + ) + + def _store_decision_locked( + self, + decision: GuardrailDecision, + now: float, + *, + update_risk_time: bool, + ) -> None: + had_previous_decision = self._runtime_state.last_decision is not None + previous_state = self._runtime_state.state + + self._runtime_state.last_decision = decision + if update_risk_time: + self._runtime_state.last_risk_time = now + self._runtime_state.state = decision.state + + # Request an immediate publish when the policy says so or when the + # high-level state changes. Freshness checks still apply later. + if ( + decision.publish_immediately + or not had_previous_decision + or previous_state != decision.state + ): + self._runtime_state.pending_decision_publish = True + + if previous_state != decision.state: + logger.info( + "RGB guardrail state changed", + previous_state=previous_state.value, + state=decision.state.value, + reason=decision.reason, + ) + + def _consume_publish_cmd_locked(self, now: float) -> Twist | None: + if self._runtime_state.pending_decision_publish: + self._runtime_state.pending_decision_publish = False + return self._resolved_cmd_for_latest_locked(now) + + if self._runtime_state.pending_cmd_update: + self._runtime_state.pending_cmd_update = False + return self._resolved_cmd_for_latest_locked(now) + + latest_cmd_time = self._runtime_state.latest_cmd_time + if latest_cmd_time is not None and not self._is_cmd_fresh_locked(now): + return Twist.zero() + + if self._should_republish_non_pass_output_locked(now): + last_decision = self._runtime_state.last_decision + if last_decision is not None: + return last_decision.cmd_vel + + return None + + def _should_republish_non_pass_output_locked(self, now: float) -> bool: + last_decision = self._runtime_state.last_decision + if last_decision is None: + return False + + if last_decision.state == GuardrailState.PASS: + return False + + last_publish_time = self._runtime_state.last_publish_time + if last_publish_time is None: + return True + + return (now - last_publish_time) >= self._guarded_output_publish_period_s() + + def _next_wakeup_timeout_locked(self) -> float: + now = time.monotonic() + + if self._runtime_state.pending_cmd_update or self._runtime_state.pending_decision_publish: + return 0.0 + + timeouts: list[float] = [self._risk_evaluation_period_s()] + + next_risk_time = self._runtime_state.next_risk_time + if next_risk_time is not None: + timeouts.append(max(next_risk_time - now, 0.0)) + + latest_cmd_time = self._runtime_state.latest_cmd_time + if latest_cmd_time is not None and self._is_cmd_fresh_locked(now): + timeouts.append(max((latest_cmd_time + self.config.command_timeout_s) - now, 0.0)) + + latest_image_time = self._runtime_state.latest_image_time + if latest_image_time is not None and self._is_image_fresh_locked(now): + timeouts.append(max((latest_image_time + self.config.image_timeout_s) - now, 0.0)) + + last_risk_time = self._runtime_state.last_risk_time + if last_risk_time is not None and self._is_risk_fresh_locked(now): + timeouts.append(max((last_risk_time + self.config.risk_timeout_s) - now, 0.0)) + + if self._should_republish_non_pass_output_locked(now): + timeouts.append(0.0) + else: + last_decision = self._runtime_state.last_decision + last_publish_time = self._runtime_state.last_publish_time + if ( + last_decision is not None + and last_decision.state != GuardrailState.PASS + and last_publish_time is not None + ): + next_publish_time = last_publish_time + self._guarded_output_publish_period_s() + timeouts.append(max(next_publish_time - now, 0.0)) + + return min(timeouts) + + def _build_zero_decision( + self, + state: GuardrailState, + reason: str, + *, + publish_immediately: bool = False, + risk_score: float = 0.0, + ) -> GuardrailDecision: + return GuardrailDecision( + state=state, + cmd_vel=Twist.zero(), + reason=reason, + risk_score=risk_score, + publish_immediately=publish_immediately, + ) + + def _build_init_decision(self, reason: str) -> GuardrailDecision: + return self._build_zero_decision( + GuardrailState.INIT, + reason, + publish_immediately=True, + ) + + def _build_sensor_degraded_decision(self, reason: str) -> GuardrailDecision: + return self._build_zero_decision( + GuardrailState.SENSOR_DEGRADED, + reason, + publish_immediately=True, + risk_score=1.0, + ) + + def _select_fallback_decision_locked(self, now: float) -> GuardrailDecision | None: + if self._runtime_state.latest_cmd_vel is None: + return self._build_init_decision("no_command_received") + + if self._runtime_state.latest_image is None: + if self.config.fail_closed_on_missing_image: + return self._build_init_decision("waiting_for_first_image") + return None + + if self._runtime_state.previous_image is None: + if self.config.fail_closed_on_missing_image: + return self._build_init_decision("waiting_for_frame_pair") + return None + + if not self._is_image_fresh_locked(now): + return self._build_sensor_degraded_decision("image_stale") + + if not self._is_frame_pair_fresh_locked(): + return self._build_sensor_degraded_decision("frame_pair_stale") + + if self._runtime_state.last_risk_time is None: + return self._build_init_decision("waiting_for_first_risk_evaluation") + + if not self._is_risk_fresh_locked(now): + return self._build_sensor_degraded_decision("risk_state_stale") + + return None + + +rgb_collision_guardrail = RGBCollisionGuardrail.blueprint diff --git a/dimos/control/safety/test_guardrail_policy.py b/dimos/control/safety/test_guardrail_policy.py new file mode 100644 index 0000000000..58003dcf75 --- /dev/null +++ b/dimos/control/safety/test_guardrail_policy.py @@ -0,0 +1,451 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import numpy as np +import pytest + +from dimos.control.safety.guardrail_policy import ( + GuardrailHealth, + GuardrailState, + OpticalFlowMagnitudeGuardrailPolicy, + OpticalFlowMagnitudePolicyConfig, +) +from dimos.control.safety.test_utils import ( + _textured_gray_image, +) +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat + + +def _policy_config( + *, + caution_frame_count: int = 2, + stop_frame_count: int = 2, + clear_frame_count: int = 3, + stop_release_frame_count: int = 2, +) -> OpticalFlowMagnitudePolicyConfig: + return OpticalFlowMagnitudePolicyConfig( + forward_motion_deadband_mps=0.05, + clamp_forward_speed_mps=0.1, + flow_downsample_width_px=160, + forward_roi_top_fraction=0.45, + forward_roi_bottom_fraction=0.95, + forward_roi_width_fraction=0.5, + low_texture_variance_threshold=150.0, + occlusion_dark_pixel_threshold=20, + occlusion_bright_pixel_threshold=235, + occlusion_extreme_fraction_threshold=0.9, + caution_flow_magnitude_threshold=0.8, + stop_flow_magnitude_threshold=1.5, + caution_frame_count=caution_frame_count, + stop_frame_count=stop_frame_count, + clear_frame_count=clear_frame_count, + stop_release_frame_count=stop_release_frame_count, + ) + + +def _forward_cmd( + x: float = 0.4, + *, + linear_y: float = 0.0, + linear_z: float = 0.0, + angular_z: float = 0.2, +) -> Twist: + return Twist( + linear=[x, linear_y, linear_z], + angular=[0.0, 0.0, angular_z], + ) + + +def _fresh_health( + *, + has_previous_frame: bool = True, + image_fresh: bool = True, + frame_pair_fresh: bool = True, +) -> GuardrailHealth: + return GuardrailHealth( + has_previous_frame=has_previous_frame, + image_fresh=image_fresh, + cmd_fresh=True, + risk_fresh=True, + frame_pair_fresh=frame_pair_fresh, + ) + + +def _uniform_gray_image(value: int, *, width: int = 160, height: int = 120) -> Image: + return Image.from_numpy( + np.full((height, width), value, dtype=np.uint8), + format=ImageFormat.GRAY, + ) + + +@pytest.fixture +def image_pair() -> tuple[Image, Image]: + return (_textured_gray_image(), _textured_gray_image(shift_x=3)) + + +@pytest.mark.parametrize( + "cmd", + [ + pytest.param(_forward_cmd(0.03, angular_z=0.35), id="below_forward_deadband"), + pytest.param(_forward_cmd(-0.2, linear_y=0.1, angular_z=0.4), id="reverse_motion"), + pytest.param( + Twist(linear=[0.0, 0.0, 0.0], angular=[0.0, 0.0, 0.6]), + id="pure_yaw", + ), + ], +) +def test_forward_guard_inactive_passthrough(image_pair: tuple[Image, Image], cmd: Twist) -> None: + policy = OpticalFlowMagnitudeGuardrailPolicy(_policy_config()) + + decision = policy.evaluate( + previous_image=image_pair[0], + current_image=image_pair[1], + incoming_cmd_vel=cmd, + health=_fresh_health(), + ) + + assert decision.state == GuardrailState.PASS + assert decision.reason == "forward_guard_inactive" + assert decision.cmd_vel == cmd + + +def test_missing_previous_frame_returns_init_zero(image_pair: tuple[Image, Image]) -> None: + policy = OpticalFlowMagnitudeGuardrailPolicy(_policy_config()) + + decision = policy.evaluate( + previous_image=image_pair[0], + current_image=image_pair[1], + incoming_cmd_vel=_forward_cmd(), + health=_fresh_health(has_previous_frame=False), + ) + + assert decision.state == GuardrailState.INIT + assert decision.reason == "missing_previous_frame" + assert decision.cmd_vel == Twist.zero() + + +def test_stale_image_health_degrades_to_zero(image_pair: tuple[Image, Image]) -> None: + policy = OpticalFlowMagnitudeGuardrailPolicy(_policy_config()) + + decision = policy.evaluate( + previous_image=image_pair[0], + current_image=image_pair[1], + incoming_cmd_vel=_forward_cmd(), + health=_fresh_health(image_fresh=False), + ) + + assert decision.state == GuardrailState.SENSOR_DEGRADED + assert decision.reason == "image_not_fresh" + assert decision.cmd_vel == Twist.zero() + assert decision.publish_immediately is True + + +@pytest.mark.parametrize( + ("bad_frame_position", "bad_frame_kind", "expected_reason"), + [ + pytest.param("previous", "black", "previous_roi_occluded", id="previous_black_occluded"), + pytest.param("current", "black", "current_roi_occluded", id="current_black_occluded"), + pytest.param("previous", "white", "previous_roi_occluded", id="previous_white_occluded"), + pytest.param("current", "white", "current_roi_occluded", id="current_white_occluded"), + pytest.param( + "previous", + "uniform_gray", + "previous_roi_low_texture", + id="previous_low_texture", + ), + pytest.param( + "current", + "uniform_gray", + "current_roi_low_texture", + id="current_low_texture", + ), + ], +) +def test_bad_previous_or_current_roi_fail_closes( + image_pair: tuple[Image, Image], + bad_frame_position: str, + bad_frame_kind: str, + expected_reason: str, +) -> None: + policy = OpticalFlowMagnitudeGuardrailPolicy(_policy_config()) + + if bad_frame_kind == "black": + bad_image = _uniform_gray_image(0) + elif bad_frame_kind == "white": + bad_image = _uniform_gray_image(255) + else: + bad_image = _uniform_gray_image(127) + + previous_image, current_image = image_pair + if bad_frame_position == "previous": + previous_image = bad_image + else: + current_image = bad_image + + decision = policy.evaluate( + previous_image=previous_image, + current_image=current_image, + incoming_cmd_vel=_forward_cmd(), + health=_fresh_health(), + ) + + assert decision.state == GuardrailState.SENSOR_DEGRADED + assert decision.reason == expected_reason + assert decision.cmd_vel == Twist.zero() + assert decision.publish_immediately is True + + +def test_caution_hysteresis_reaches_clamp(image_pair: tuple[Image, Image], mocker) -> None: + policy = OpticalFlowMagnitudeGuardrailPolicy(_policy_config()) + mocker.patch.object( + policy, + "_mean_flow_magnitude", + side_effect=[0.9, 0.9], + ) + + first = policy.evaluate( + previous_image=image_pair[0], + current_image=image_pair[1], + incoming_cmd_vel=_forward_cmd(), + health=_fresh_health(), + ) + second = policy.evaluate( + previous_image=image_pair[0], + current_image=image_pair[1], + incoming_cmd_vel=_forward_cmd(), + health=_fresh_health(), + ) + + assert first.state == GuardrailState.PASS + assert second.state == GuardrailState.CLAMP + assert second.reason == "forward_flow_clamp" + assert second.cmd_vel.linear.x == pytest.approx(0.1) + + +def test_first_stop_strength_frame_clamps_immediately( + image_pair: tuple[Image, Image], mocker +) -> None: + policy = OpticalFlowMagnitudeGuardrailPolicy(_policy_config()) + cmd = _forward_cmd(0.45, angular_z=0.55) + + mocker.patch.object(policy, "_mean_flow_magnitude", return_value=1.8) + + decision = policy.evaluate( + previous_image=image_pair[0], + current_image=image_pair[1], + incoming_cmd_vel=cmd, + health=_fresh_health(), + ) + + assert decision.state == GuardrailState.CLAMP + assert decision.reason == "forward_flow_clamp" + assert decision.cmd_vel.linear.x == pytest.approx(0.1) + assert decision.cmd_vel.angular.z == pytest.approx(cmd.angular.z) + + +def test_repeated_stop_strength_frames_reach_stop_latched( + image_pair: tuple[Image, Image], mocker +) -> None: + policy = OpticalFlowMagnitudeGuardrailPolicy(_policy_config()) + cmd = _forward_cmd(0.45, angular_z=0.55) + + mocker.patch.object( + policy, + "_mean_flow_magnitude", + side_effect=[1.8, 1.8], + ) + + first = policy.evaluate( + previous_image=image_pair[0], + current_image=image_pair[1], + incoming_cmd_vel=cmd, + health=_fresh_health(), + ) + second = policy.evaluate( + previous_image=image_pair[0], + current_image=image_pair[1], + incoming_cmd_vel=cmd, + health=_fresh_health(), + ) + + assert first.state == GuardrailState.CLAMP + assert first.reason == "forward_flow_clamp" + assert second.state == GuardrailState.STOP_LATCHED + assert second.reason == "forward_flow_stop" + assert second.cmd_vel.linear.x == pytest.approx(0.0) + assert second.cmd_vel.angular.z == pytest.approx(cmd.angular.z) + + +def test_stop_latched_does_not_release_on_first_clear_frame( + image_pair: tuple[Image, Image], mocker +) -> None: + policy = OpticalFlowMagnitudeGuardrailPolicy(_policy_config()) + + mocker.patch.object( + policy, + "_mean_flow_magnitude", + side_effect=[1.8, 1.8, 0.0], + ) + + states = [ + policy.evaluate( + previous_image=image_pair[0], + current_image=image_pair[1], + incoming_cmd_vel=_forward_cmd(), + health=_fresh_health(), + ).state + for _ in range(3) + ] + + assert states == [ + GuardrailState.CLAMP, + GuardrailState.STOP_LATCHED, + GuardrailState.STOP_LATCHED, + ] + + +def test_stop_latched_recovery_requires_clear_frames( + image_pair: tuple[Image, Image], mocker +) -> None: + policy = OpticalFlowMagnitudeGuardrailPolicy(_policy_config(clear_frame_count=3)) + + mocker.patch.object( + policy, + "_mean_flow_magnitude", + side_effect=[1.8, 1.8, 0.0, 0.0, 0.0], + ) + + states = [ + policy.evaluate( + previous_image=image_pair[0], + current_image=image_pair[1], + incoming_cmd_vel=_forward_cmd(), + health=_fresh_health(), + ).state + for _ in range(5) + ] + + assert states[0] == GuardrailState.CLAMP + assert states[1] == GuardrailState.STOP_LATCHED + assert states[2] != GuardrailState.PASS + assert states[3] != GuardrailState.PASS + assert states[4] == GuardrailState.PASS + + +def test_recovery_after_clear_frames_returns_to_pass( + image_pair: tuple[Image, Image], mocker +) -> None: + policy = OpticalFlowMagnitudeGuardrailPolicy(_policy_config()) + + mocker.patch.object( + policy, + "_mean_flow_magnitude", + side_effect=[0.9, 0.9, 0.0, 0.0, 0.0], + ) + + states = [ + policy.evaluate( + previous_image=image_pair[0], + current_image=image_pair[1], + incoming_cmd_vel=_forward_cmd(), + health=_fresh_health(), + ).state + for _ in range(5) + ] + + assert states == [ + GuardrailState.PASS, + GuardrailState.CLAMP, + GuardrailState.CLAMP, + GuardrailState.CLAMP, + GuardrailState.PASS, + ] + + +def test_forward_deadband_does_not_reset_hysteresis( + image_pair: tuple[Image, Image], mocker +) -> None: + policy = OpticalFlowMagnitudeGuardrailPolicy(_policy_config(caution_frame_count=2)) + + mocker.patch.object( + policy, + "_mean_flow_magnitude", + side_effect=[0.9, 0.9], + ) + + first = policy.evaluate( + previous_image=image_pair[0], + current_image=image_pair[1], + incoming_cmd_vel=_forward_cmd(0.4), + health=_fresh_health(), + ) + inactive = policy.evaluate( + previous_image=image_pair[0], + current_image=image_pair[1], + incoming_cmd_vel=_forward_cmd(0.03), + health=_fresh_health(), + ) + third = policy.evaluate( + previous_image=image_pair[0], + current_image=image_pair[1], + incoming_cmd_vel=_forward_cmd(0.4), + health=_fresh_health(), + ) + + assert first.state == GuardrailState.PASS + assert inactive.state == GuardrailState.PASS + assert inactive.reason == "forward_guard_inactive" + assert third.state == GuardrailState.CLAMP + assert third.reason == "forward_flow_clamp" + + +def test_clamp_preserves_angular_terms(image_pair: tuple[Image, Image], mocker) -> None: + policy = OpticalFlowMagnitudeGuardrailPolicy(_policy_config(caution_frame_count=1)) + mocker.patch.object(policy, "_mean_flow_magnitude", return_value=0.9) + cmd = _forward_cmd(0.4, linear_y=0.15, linear_z=-0.1, angular_z=0.65) + + decision = policy.evaluate( + previous_image=image_pair[0], + current_image=image_pair[1], + incoming_cmd_vel=cmd, + health=_fresh_health(), + ) + + assert decision.state == GuardrailState.CLAMP + assert decision.cmd_vel.linear.x == pytest.approx(0.1) + assert decision.cmd_vel.linear.y == pytest.approx(cmd.linear.y) + assert decision.cmd_vel.linear.z == pytest.approx(cmd.linear.z) + assert decision.cmd_vel.angular.z == pytest.approx(cmd.angular.z) + + +def test_stop_zeroes_only_linear_x(image_pair: tuple[Image, Image], mocker) -> None: + policy = OpticalFlowMagnitudeGuardrailPolicy(_policy_config(stop_frame_count=1)) + mocker.patch.object(policy, "_mean_flow_magnitude", return_value=1.8) + cmd = _forward_cmd(0.45, linear_y=0.2, linear_z=-0.1, angular_z=0.75) + + decision = policy.evaluate( + previous_image=image_pair[0], + current_image=image_pair[1], + incoming_cmd_vel=cmd, + health=_fresh_health(), + ) + + assert decision.state == GuardrailState.STOP_LATCHED + assert decision.cmd_vel.linear.x == pytest.approx(0.0) + assert decision.cmd_vel.linear.y == pytest.approx(cmd.linear.y) + assert decision.cmd_vel.linear.z == pytest.approx(cmd.linear.z) + assert decision.cmd_vel.angular.z == pytest.approx(cmd.angular.z) diff --git a/dimos/control/safety/test_rgb_collision_guardrail.py b/dimos/control/safety/test_rgb_collision_guardrail.py new file mode 100644 index 0000000000..f6c7212828 --- /dev/null +++ b/dimos/control/safety/test_rgb_collision_guardrail.py @@ -0,0 +1,450 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Callable, Iterator +import queue +import threading +import time +from typing import Any, TypeVar + +import pytest + +from dimos.control.safety.guardrail_policy import ( + GuardrailDecision, + GuardrailHealth, + GuardrailState, +) +from dimos.control.safety.rgb_collision_guardrail import RGBCollisionGuardrail +from dimos.control.safety.test_utils import ( + FakeTransport, + SequencePolicy, + _cmd, + _decision, + _textured_gray_image, +) +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.sensor_msgs.Image import Image + +T = TypeVar("T") + + +class RaisingPolicy: + def evaluate( + self, + previous_image: Image, + current_image: Image, + incoming_cmd_vel: Twist, + health: GuardrailHealth, + ) -> GuardrailDecision: + raise RuntimeError("synthetic policy failure") + + +class CountingPassPolicy: + def __init__(self) -> None: + self._lock = threading.Lock() + self._call_count = 0 + + @property + def call_count(self) -> int: + with self._lock: + return self._call_count + + def evaluate( + self, + previous_image: Image, + current_image: Image, + incoming_cmd_vel: Twist, + health: GuardrailHealth, + ) -> GuardrailDecision: + with self._lock: + self._call_count += 1 + + return GuardrailDecision( + state=GuardrailState.PASS, + cmd_vel=Twist( + linear=incoming_cmd_vel.linear, + angular=incoming_cmd_vel.angular, + ), + reason="counting_pass", + ) + + +@pytest.fixture +def module() -> Iterator[RGBCollisionGuardrail]: + guardrail = RGBCollisionGuardrail( + guarded_output_publish_hz=50.0, + risk_evaluation_hz=50.0, + command_timeout_s=0.05, + image_timeout_s=0.05, + risk_timeout_s=0.05, + ) + yield guardrail + guardrail._close_module() + + +def _wait_for_output( + outputs: queue.Queue[Twist], + predicate: Callable[[Twist], bool], + *, + timeout_s: float = 0.5, +) -> Twist: + deadline = time.monotonic() + timeout_s + while time.monotonic() < deadline: + remaining = deadline - time.monotonic() + try: + candidate = outputs.get(timeout=max(remaining, 0.01)) + except queue.Empty: + continue + if predicate(candidate): + return candidate + raise AssertionError("Timed out waiting for matching guardrail output") + + +def _wait_for_decision( + guardrail: RGBCollisionGuardrail, + predicate: Callable[[GuardrailDecision], bool], + *, + timeout_s: float = 0.5, +) -> GuardrailDecision: + deadline = time.monotonic() + timeout_s + while time.monotonic() < deadline: + with guardrail._condition: + decision = guardrail._runtime_state.last_decision + if decision is not None and predicate(decision): + return decision + time.sleep(0.01) + + raise AssertionError("Timed out waiting for matching guardrail decision") + + +def _start_threaded_guardrail( + policy: Any, + **config_overrides: float, +) -> tuple[RGBCollisionGuardrail, FakeTransport[Image], FakeTransport[Twist], queue.Queue[Twist]]: + config: dict[str, float] = { + "guarded_output_publish_hz": 50.0, + "risk_evaluation_hz": 50.0, + "command_timeout_s": 0.3, + "image_timeout_s": 0.3, + "risk_timeout_s": 0.3, + } + config.update(config_overrides) + + guardrail = RGBCollisionGuardrail(**config) + image_transport: FakeTransport[Image] = FakeTransport() + cmd_transport: FakeTransport[Twist] = FakeTransport() + outputs: queue.Queue[Twist] = queue.Queue() + + guardrail.color_image.transport = image_transport + guardrail.incoming_cmd_vel.transport = cmd_transport + guardrail.safe_cmd_vel.subscribe(outputs.put) + guardrail._policy = policy + guardrail.start() + + return guardrail, image_transport, cmd_transport, outputs + + +def test_no_command_returns_init_zero(module: RGBCollisionGuardrail) -> None: + now = time.monotonic() + + with module._condition: + decision = module._select_fallback_decision_locked(now) + + assert decision is not None + assert decision.state == GuardrailState.INIT + assert decision.reason == "no_command_received" + assert decision.cmd_vel == Twist.zero() + + +def test_waiting_for_first_image_returns_init_zero(module: RGBCollisionGuardrail) -> None: + now = time.monotonic() + + with module._condition: + module._runtime_state.latest_cmd_vel = _cmd() + module._runtime_state.latest_cmd_time = now + decision = module._select_fallback_decision_locked(now) + + assert decision is not None + assert decision.state == GuardrailState.INIT + assert decision.reason == "waiting_for_first_image" + assert decision.cmd_vel == Twist.zero() + + +def test_no_frame_pair_returns_init_zero(module: RGBCollisionGuardrail) -> None: + now = time.monotonic() + + with module._condition: + module._runtime_state.latest_cmd_vel = _cmd() + module._runtime_state.latest_cmd_time = now + module._runtime_state.latest_image = _textured_gray_image() + module._runtime_state.latest_image_time = now + decision = module._select_fallback_decision_locked(now) + + assert decision is not None + assert decision.state == GuardrailState.INIT + assert decision.reason == "waiting_for_frame_pair" + assert decision.cmd_vel == Twist.zero() + + +def test_waiting_for_first_risk_evaluation_returns_init_zero( + module: RGBCollisionGuardrail, +) -> None: + now = time.monotonic() + + with module._condition: + module._runtime_state.latest_cmd_vel = _cmd() + module._runtime_state.latest_cmd_time = now + module._runtime_state.previous_image = _textured_gray_image() + module._runtime_state.previous_image_time = now + module._runtime_state.latest_image = _textured_gray_image(shift_x=2) + module._runtime_state.latest_image_time = now + module._runtime_state.last_risk_time = None + decision = module._select_fallback_decision_locked(now) + + assert decision is not None + assert decision.state == GuardrailState.INIT + assert decision.reason == "waiting_for_first_risk_evaluation" + assert decision.cmd_vel == Twist.zero() + + +def test_stale_image_returns_sensor_degraded_zero(module: RGBCollisionGuardrail) -> None: + now = time.monotonic() + stale_time = now - 0.2 + + with module._condition: + module._runtime_state.latest_cmd_vel = _cmd() + module._runtime_state.latest_cmd_time = now + module._runtime_state.previous_image = _textured_gray_image() + module._runtime_state.latest_image = _textured_gray_image(shift_x=2) + module._runtime_state.previous_image_time = stale_time + module._runtime_state.latest_image_time = stale_time + decision = module._select_fallback_decision_locked(now) + + assert decision is not None + assert decision.state == GuardrailState.SENSOR_DEGRADED + assert decision.reason == "image_stale" + assert decision.cmd_vel == Twist.zero() + + +def test_stale_risk_returns_sensor_degraded_zero(module: RGBCollisionGuardrail) -> None: + now = time.monotonic() + stale_risk_time = now - 0.2 + + with module._condition: + module._runtime_state.latest_cmd_vel = _cmd() + module._runtime_state.latest_cmd_time = now + module._runtime_state.previous_image = _textured_gray_image() + module._runtime_state.latest_image = _textured_gray_image(shift_x=2) + module._runtime_state.previous_image_time = now + module._runtime_state.latest_image_time = now + module._runtime_state.last_risk_time = stale_risk_time + decision = module._select_fallback_decision_locked(now) + + assert decision is not None + assert decision.state == GuardrailState.SENSOR_DEGRADED + assert decision.reason == "risk_state_stale" + assert decision.cmd_vel == Twist.zero() + + +def test_stale_command_publishes_zero_output(module: RGBCollisionGuardrail) -> None: + now = time.monotonic() + stale_cmd_time = now - 0.2 + + with module._condition: + module._runtime_state.latest_cmd_vel = _cmd() + module._runtime_state.latest_cmd_time = stale_cmd_time + module._runtime_state.last_decision = _decision(GuardrailState.PASS, _cmd()) + module._runtime_state.pending_cmd_update = True + cmd_to_publish = module._consume_publish_cmd_locked(now) + + assert cmd_to_publish == Twist.zero() + + +def test_policy_exception_fail_closes_to_zero() -> None: + guardrail, image_transport, cmd_transport, outputs = _start_threaded_guardrail( + RaisingPolicy(), + ) + + try: + cmd_transport.publish(_cmd(0.4, angular_z=0.3)) + image_transport.publish(_textured_gray_image()) + image_transport.publish(_textured_gray_image(shift_x=2)) + + observed = _wait_for_output(outputs, lambda twist: twist == Twist.zero()) + assert observed == Twist.zero() + + decision = _wait_for_decision( + guardrail, + lambda d: d.state == GuardrailState.SENSOR_DEGRADED + and d.reason == "policy_evaluation_failed", + timeout_s=0.5, + ) + + assert decision.state == GuardrailState.SENSOR_DEGRADED + assert decision.reason == "policy_evaluation_failed" + + with guardrail._condition: + assert guardrail._runtime_state.state == GuardrailState.SENSOR_DEGRADED + finally: + guardrail.stop() + + +def test_pass_publishes_latest_upstream_command() -> None: + upstream_first = _cmd(0.3, angular_z=0.1) + upstream_second = _cmd(0.45, angular_z=0.35) + misleading_policy_cmd = _cmd(0.02, angular_z=-0.2) + policy = SequencePolicy([_decision(GuardrailState.PASS, misleading_policy_cmd, reason="pass")]) + guardrail, image_transport, cmd_transport, outputs = _start_threaded_guardrail(policy) + + try: + cmd_transport.publish(upstream_first) + image_transport.publish(_textured_gray_image()) + image_transport.publish(_textured_gray_image(shift_x=2)) + + first_output = _wait_for_output(outputs, lambda twist: twist == upstream_first) + assert first_output == upstream_first + + cmd_transport.publish(upstream_second) + second_output = _wait_for_output(outputs, lambda twist: twist == upstream_second) + assert second_output == upstream_second + finally: + guardrail.stop() + + +@pytest.mark.parametrize( + ("state", "guarded_cmd"), + [ + ( + GuardrailState.CLAMP, + Twist(linear=[0.1, 0.0, 0.0], angular=[0.0, 0.0, 0.4]), + ), + ( + GuardrailState.STOP_LATCHED, + Twist(linear=[0.0, 0.0, 0.0], angular=[0.0, 0.0, 0.4]), + ), + ], +) +def test_non_pass_states_publish_guarded_output(state: GuardrailState, guarded_cmd: Twist) -> None: + upstream_cmd = _cmd(0.35, angular_z=0.4) + policy = SequencePolicy([_decision(state, guarded_cmd)]) + guardrail, image_transport, cmd_transport, outputs = _start_threaded_guardrail(policy) + + try: + cmd_transport.publish(upstream_cmd) + image_transport.publish(_textured_gray_image()) + image_transport.publish(_textured_gray_image(shift_x=2)) + + published = _wait_for_output(outputs, lambda twist: twist == guarded_cmd) + assert published == guarded_cmd + finally: + guardrail.stop() + + +def test_non_pass_heartbeat_republishes_guarded_output() -> None: + guarded_cmd = Twist(linear=[0.1, 0.0, 0.0], angular=[0.0, 0.0, 0.5]) + policy = SequencePolicy([_decision(GuardrailState.CLAMP, guarded_cmd)]) + guardrail, image_transport, cmd_transport, outputs = _start_threaded_guardrail(policy) + + try: + cmd_transport.publish(_cmd(0.4, angular_z=0.5)) + image_transport.publish(_textured_gray_image()) + image_transport.publish(_textured_gray_image(shift_x=2)) + + first = _wait_for_output(outputs, lambda twist: twist == guarded_cmd) + second = _wait_for_output(outputs, lambda twist: twist == guarded_cmd) + + assert first == guarded_cmd + assert second == guarded_cmd + finally: + guardrail.stop() + + +def test_non_pass_decision_can_publish_without_new_command() -> None: + upstream_cmd = _cmd(0.4, angular_z=0.3) + stop_cmd = Twist(linear=[0.0, 0.0, 0.0], angular=[0.0, 0.0, 0.3]) + policy = SequencePolicy( + [ + _decision(GuardrailState.PASS, upstream_cmd, reason="initial_pass"), + _decision( + GuardrailState.STOP_LATCHED, + stop_cmd, + reason="forced_stop", + publish_immediately=True, + ), + ] + ) + guardrail, image_transport, cmd_transport, outputs = _start_threaded_guardrail(policy) + + try: + cmd_transport.publish(upstream_cmd) + image_transport.publish(_textured_gray_image()) + image_transport.publish(_textured_gray_image(shift_x=2)) + + first_output = _wait_for_output(outputs, lambda twist: twist == upstream_cmd) + assert first_output == upstream_cmd + + image_transport.publish(_textured_gray_image(shift_x=4)) + + autonomous_stop = _wait_for_output(outputs, lambda twist: twist == stop_cmd) + assert autonomous_stop == stop_cmd + finally: + guardrail.stop() + + +def test_fast_upstream_commands_reuse_last_risk_decision() -> None: + policy = CountingPassPolicy() + guardrail, image_transport, cmd_transport, outputs = _start_threaded_guardrail( + policy, + guarded_output_publish_hz=100.0, + risk_evaluation_hz=2.0, + command_timeout_s=1.0, + image_timeout_s=1.0, + risk_timeout_s=1.0, + ) + + first_cmd = _cmd(0.20, angular_z=0.10) + second_cmd = _cmd(0.32, angular_z=0.20) + third_cmd = _cmd(0.44, angular_z=0.30) + + try: + cmd_transport.publish(first_cmd) + image_transport.publish(_textured_gray_image()) + image_transport.publish(_textured_gray_image(shift_x=2)) + + first_output = _wait_for_output(outputs, lambda twist: twist == first_cmd, timeout_s=0.6) + assert first_output == first_cmd + assert policy.call_count == 1 + + cmd_transport.publish(second_cmd) + second_output = _wait_for_output( + outputs, + lambda twist: twist == second_cmd, + timeout_s=0.2, + ) + assert second_output == second_cmd + + cmd_transport.publish(third_cmd) + third_output = _wait_for_output( + outputs, + lambda twist: twist == third_cmd, + timeout_s=0.2, + ) + assert third_output == third_cmd + + assert policy.call_count == 1 + assert policy.call_count < 3 + finally: + guardrail.stop() diff --git a/dimos/control/safety/test_rgb_collision_guardrail_integration.py b/dimos/control/safety/test_rgb_collision_guardrail_integration.py new file mode 100644 index 0000000000..46e969cb84 --- /dev/null +++ b/dimos/control/safety/test_rgb_collision_guardrail_integration.py @@ -0,0 +1,325 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Callable, Iterator +import queue +import threading +import time +from typing import TypeVar + +import numpy as np +import pytest + +from dimos.control.safety.guardrail_policy import ( + GuardrailState, +) +from dimos.control.safety.rgb_collision_guardrail import RGBCollisionGuardrail +from dimos.control.safety.test_utils import ( + FakeTransport, + SequencePolicy, + _cmd, + _decision, + _textured_gray_image, +) +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat + +T = TypeVar("T") + + +def _black_gray_image(*, width: int = 160, height: int = 120) -> Image: + return Image.from_numpy( + np.zeros((height, width), dtype=np.uint8), + format=ImageFormat.GRAY, + ) + + +def _start_guardrail( + **config_overrides: float, +) -> tuple[ + RGBCollisionGuardrail, + FakeTransport[Image], + FakeTransport[Twist], + queue.Queue[tuple[float, Twist]], +]: + config: dict[str, float] = { + "guarded_output_publish_hz": 20.0, + "risk_evaluation_hz": 20.0, + "command_timeout_s": 0.3, + "image_timeout_s": 0.3, + "risk_timeout_s": 0.3, + } + config.update(config_overrides) + + guardrail = RGBCollisionGuardrail(**config) + image_transport: FakeTransport[Image] = FakeTransport() + cmd_transport: FakeTransport[Twist] = FakeTransport() + outputs: queue.Queue[tuple[float, Twist]] = queue.Queue() + + guardrail.color_image.transport = image_transport + guardrail.incoming_cmd_vel.transport = cmd_transport + guardrail.safe_cmd_vel.subscribe(lambda msg: outputs.put((time.monotonic(), msg))) + guardrail.start() + + return guardrail, image_transport, cmd_transport, outputs + + +@pytest.fixture +def started_guardrail() -> Iterator[ + tuple[ + RGBCollisionGuardrail, + FakeTransport[Image], + FakeTransport[Twist], + queue.Queue[tuple[float, Twist]], + ] +]: + guardrail, image_transport, cmd_transport, outputs = _start_guardrail() + + yield guardrail, image_transport, cmd_transport, outputs + + guardrail.stop() + + +def _wait_for_output( + outputs: queue.Queue[tuple[float, Twist]], + predicate: Callable[[Twist], bool], + *, + timeout_s: float = 1.0, +) -> tuple[float, Twist]: + deadline = time.monotonic() + timeout_s + while time.monotonic() < deadline: + remaining = deadline - time.monotonic() + try: + ts, msg = outputs.get(timeout=max(remaining, 0.01)) + except queue.Empty: + continue + if predicate(msg): + return ts, msg + raise AssertionError("Timed out waiting for matching guardrail output") + + +@pytest.mark.slow +def test_stream_wiring_end_to_end_passes_upstream_twist( + started_guardrail: tuple[ + RGBCollisionGuardrail, + FakeTransport[Image], + FakeTransport[Twist], + queue.Queue[tuple[float, Twist]], + ], +) -> None: + guardrail, image_transport, cmd_transport, outputs = started_guardrail + + upstream = _cmd(0.03, angular_z=0.15) + + cmd_transport.publish(upstream) + image_transport.publish(_textured_gray_image()) + image_transport.publish(_textured_gray_image(shift_x=2)) + + _, observed = _wait_for_output(outputs, lambda msg: msg == upstream) + assert observed == upstream + + +@pytest.mark.slow +def test_non_pass_output_is_republished_while_upstream_is_quiet() -> None: + guardrail, image_transport, cmd_transport, outputs = _start_guardrail( + guarded_output_publish_hz=20.0, + risk_evaluation_hz=20.0, + command_timeout_s=0.5, + image_timeout_s=0.5, + risk_timeout_s=0.5, + ) + guarded = Twist(linear=[0.1, 0.0, 0.0], angular=[0.0, 0.0, 0.2]) + + try: + guardrail._policy = SequencePolicy( + [_decision(GuardrailState.CLAMP, guarded, reason="forced_clamp")] + ) + + cmd_transport.publish(_cmd(0.4, angular_z=0.2)) + image_transport.publish(_textured_gray_image()) + image_transport.publish(_textured_gray_image(shift_x=2)) + + _wait_for_output(outputs, lambda msg: msg == guarded, timeout_s=1.0) + _wait_for_output(outputs, lambda msg: msg == guarded, timeout_s=1.0) + + t1, msg1 = _wait_for_output(outputs, lambda msg: msg == guarded, timeout_s=1.0) + t2, msg2 = _wait_for_output(outputs, lambda msg: msg == guarded, timeout_s=1.0) + + assert msg1 == guarded + assert msg2 == guarded + + interval = t2 - t1 + expected_period = 1.0 / guardrail.config.guarded_output_publish_hz + + assert expected_period * 0.5 <= interval <= expected_period * 2.0 + finally: + guardrail.stop() + + +@pytest.mark.slow +def test_forced_stop_never_leaks_positive_linear_x_under_concurrent_updates() -> None: + guardrail, image_transport, cmd_transport, outputs = _start_guardrail( + guarded_output_publish_hz=30.0, + risk_evaluation_hz=30.0, + command_timeout_s=0.5, + image_timeout_s=0.5, + risk_timeout_s=0.5, + ) + stop_cmd = Twist(linear=[0.0, 0.0, 0.0], angular=[0.0, 0.0, 0.2]) + + stop_event = threading.Event() + errors: list[Exception] = [] + image_thread: threading.Thread | None = None + cmd_thread: threading.Thread | None = None + + def publish_images() -> None: + try: + shift = 0 + image_transport.publish(_textured_gray_image(shift_x=shift)) + while not stop_event.is_set(): + shift += 1 + image_transport.publish(_textured_gray_image(shift_x=shift)) + time.sleep(0.01) + except Exception as exc: + errors.append(exc) + + def publish_commands() -> None: + try: + speeds = [0.4, 0.25, 0.5, 0.15, 0.35] + idx = 0 + while not stop_event.is_set(): + cmd_transport.publish(_cmd(speeds[idx % len(speeds)], angular_z=0.2)) + idx += 1 + time.sleep(0.01) + except Exception as exc: + errors.append(exc) + + try: + guardrail._policy = SequencePolicy( + [ + _decision( + GuardrailState.STOP_LATCHED, + stop_cmd, + reason="forced_stop", + publish_immediately=True, + ) + ] + ) + + image_thread = threading.Thread(target=publish_images, daemon=True) + cmd_thread = threading.Thread(target=publish_commands, daemon=True) + + image_thread.start() + cmd_thread.start() + + observed_outputs: list[Twist] = [] + while len(observed_outputs) < 8: + _, observed = _wait_for_output( + outputs, lambda msg: isinstance(msg, Twist), timeout_s=1.0 + ) + observed_outputs.append(observed) + + assert observed_outputs + assert all(twist.linear.x == pytest.approx(0.0) for twist in observed_outputs) + assert all(twist == stop_cmd or twist == Twist.zero() for twist in observed_outputs) + finally: + stop_event.set() + if image_thread is not None: + image_thread.join(timeout=1.0) + if cmd_thread is not None: + cmd_thread.join(timeout=1.0) + guardrail.stop() + + assert not errors + assert image_thread is not None + assert cmd_thread is not None + assert not image_thread.is_alive() + assert not cmd_thread.is_alive() + + +@pytest.mark.slow +def test_black_frame_end_to_end_fail_closes_to_zero( + started_guardrail: tuple[ + RGBCollisionGuardrail, + FakeTransport[Image], + FakeTransport[Twist], + queue.Queue[tuple[float, Twist]], + ], +) -> None: + guardrail, image_transport, cmd_transport, outputs = started_guardrail + + cmd_transport.publish(_cmd(0.3, angular_z=0.1)) + image_transport.publish(_textured_gray_image()) + image_transport.publish(_black_gray_image()) + + _, observed = _wait_for_output(outputs, lambda msg: msg == Twist.zero(), timeout_s=1.0) + assert observed == Twist.zero() + + +@pytest.mark.slow +def test_stale_image_end_to_end_fail_closes_without_new_command() -> None: + guardrail, image_transport, cmd_transport, outputs = _start_guardrail( + guarded_output_publish_hz=25.0, + risk_evaluation_hz=25.0, + command_timeout_s=0.5, + image_timeout_s=0.08, + risk_timeout_s=0.5, + ) + # Keep the initial command below deadband so this test is about + # autonomous stale-image fail-close, not optical-flow threshold tuning. + upstream = _cmd(0.03, angular_z=0.1) + + try: + cmd_transport.publish(upstream) + image_transport.publish(_textured_gray_image()) + image_transport.publish(_textured_gray_image(shift_x=2)) + + _, pass_output = _wait_for_output(outputs, lambda msg: msg == upstream, timeout_s=1.0) + assert pass_output == upstream + + _, degraded_output = _wait_for_output( + outputs, + lambda msg: msg == Twist.zero(), + timeout_s=0.4, + ) + assert degraded_output == Twist.zero() + finally: + guardrail.stop() + + +@pytest.mark.slow +def test_stale_command_end_to_end_fail_closes_to_zero( + started_guardrail: tuple[ + RGBCollisionGuardrail, + FakeTransport[Image], + FakeTransport[Twist], + queue.Queue[tuple[float, Twist]], + ], +) -> None: + guardrail, image_transport, cmd_transport, outputs = started_guardrail + + # Keep the initial command below deadband so this test is about + # stale-command fail-close, not optical-flow clamp behavior. + upstream = _cmd(0.03, angular_z=0.1) + + cmd_transport.publish(upstream) + image_transport.publish(_textured_gray_image()) + image_transport.publish(_textured_gray_image(shift_x=2)) + + _wait_for_output(outputs, lambda msg: msg == upstream, timeout_s=1.0) + _, observed = _wait_for_output(outputs, lambda msg: msg == Twist.zero(), timeout_s=1.0) + + assert observed == Twist.zero() diff --git a/dimos/control/safety/test_utils.py b/dimos/control/safety/test_utils.py new file mode 100644 index 0000000000..7e3326dc2f --- /dev/null +++ b/dimos/control/safety/test_utils.py @@ -0,0 +1,110 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Callable +from typing import Any, TypeVar + +import numpy as np + +from dimos.control.safety.guardrail_policy import ( + GuardrailDecision, + GuardrailHealth, + GuardrailState, +) +from dimos.core.stream import Out, Transport +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat + +T = TypeVar("T") + + +class FakeTransport(Transport[T]): + def __init__(self) -> None: + self._subscribers: list[Callable[[T], Any]] = [] + + def start(self) -> None: + pass + + def stop(self) -> None: + pass + + def broadcast(self, selfstream: Out[T] | None, value: T) -> None: + for callback in list(self._subscribers): + callback(value) + + def subscribe( + self, + callback: Callable[[T], Any], + selfstream=None, # type: ignore[no-untyped-def] + ) -> Callable[[], None]: + self._subscribers.append(callback) + + def unsubscribe() -> None: + self._subscribers.remove(callback) + + return unsubscribe + + +class SequencePolicy: + def __init__(self, decisions: list[GuardrailDecision]) -> None: + self._decisions = decisions + self._index = 0 + + def evaluate( + self, + previous_image: Image, + current_image: Image, + incoming_cmd_vel: Twist, + health: GuardrailHealth, + ) -> GuardrailDecision: + if self._index < len(self._decisions): + decision = self._decisions[self._index] + self._index += 1 + return decision + return self._decisions[-1] + + +def _textured_gray_image(*, width: int = 160, height: int = 120, shift_x: int = 0) -> Image: + yy, xx = np.indices((height, width)) + pattern = ((xx * 5 + yy * 9 + shift_x * 17) % 256).astype(np.uint8) + return Image.from_numpy(pattern, format=ImageFormat.GRAY) + + +def _cmd( + x: float = 0.35, + *, + linear_y: float = 0.0, + angular_z: float = 0.25, +) -> Twist: + return Twist( + linear=[x, linear_y, 0.0], + angular=[0.0, 0.0, angular_z], + ) + + +def _decision( + state: GuardrailState, + cmd_vel: Twist, + *, + reason: str = "test", + publish_immediately: bool = False, +) -> GuardrailDecision: + return GuardrailDecision( + state=state, + cmd_vel=cmd_vel, + reason=reason, + publish_immediately=publish_immediately, + )