diff --git a/.pylintrc b/.pylintrc index c6f676501..806d3bc3d 100644 --- a/.pylintrc +++ b/.pylintrc @@ -54,10 +54,6 @@ persistent=yes # the version used to run pylint. py-version=3.11 -# When enabled, pylint would attempt to guess common misconfiguration and emit -# user-friendly hints instead of false-positive error messages. -suggestion-mode=yes - # Allow loading of arbitrary C extensions. Extensions are imported into the # active Python interpreter and may run arbitrary code. unsafe-load-any-extension=no diff --git a/bec_lib/bec_lib/bl_state_machine.py b/bec_lib/bec_lib/bl_state_machine.py new file mode 100644 index 000000000..57fc834d3 --- /dev/null +++ b/bec_lib/bec_lib/bl_state_machine.py @@ -0,0 +1,108 @@ +""" +Module for managing aggregated beamline states based on configuration files. + +Example of the YAML configuration file: +``` yaml +alignment: # AggregatedStateConfig -> can have different labels and for each label, different devices + transition_metadata: # optional field for metadata for each label + field: value + devices: + samx: + value: 0 + abs_tol: 0.1 + low_limit: + value: -20 + abs_tol: 0.1 + high_limit: + value: 20 + abs_tol: 0.1 + signals: + velocity: + value: 5 + abs_tol: 0.1 + bpm4i: + value: 100 + abs_tol: 10 +measurement: + devices: + samx: + value: 19 + abs_tol: 0.1 + signals: + velocity: + value: 20 + samy: + value: 0 + abs_tol: 0.1 +test: + devices: + samy: + value: 0 + abs_tol: 0.1 + bpm4i: + value: 100 + abs_tol: 10 +``` + +""" + +from __future__ import annotations + +import yaml + +from bec_lib.bl_state_manager import BeamlineStateManager +from bec_lib.bl_states import AggregatedStateConfig + + +class BeamlineStateMachine: + + def __init__(self, manager: BeamlineStateManager) -> None: + self._manager = manager + self._configs: dict[str, AggregatedStateConfig] = {} + + def load_from_config( + self, name: str, config_path: str | None = None, config_dict: dict | None = None + ) -> None: + """ + Load a state configuration from a YAML file or a dictionary. If None or both are provided, + an error will be raised. Config must be states for an AggregatedStateConfig or a dictionary/YAML file that + can be parsed into one. Please check AggregatedStateConfig state field for the expected format of the configuration. + + Args: + name (str): The name of the aggregated state to load. + config_path (str | None): The path to the YAML configuration file. + config_dict (dict | None): A dictionary containing the configuration. If provided, this will be used instead of loading from a file. + """ + self._check_inputs(config_path=config_path, config_dict=config_dict) + if config_path: + with open(config_path, "r", encoding="utf-8") as f: + config_dict = yaml.safe_load(f) + config = AggregatedStateConfig(name=name, states=config_dict) + self._manager.add(config) + + def update_config( + self, name: str, config_path: str | None = None, config_dict: dict | None = None + ) -> None: + """ + Update a state configuration from a YAML file or a dictionary. If None or both are provided, + an error will be raised. Config must be states for an AggregatedStateConfig or a dictionary/YAML file that + can be parsed into one. Please check AggregatedStateConfig state field for the expected format of the configuration. + + Args: + name (str): The name of the aggregated state to update. + config_path (str | None): The path to the YAML configuration file. + config_dict (dict | None): A dictionary containing the configuration. If provided, this will be used instead of loading from a file. + """ + self._check_inputs(config_path=config_path, config_dict=config_dict) + if config_path: + with open(config_path, "r", encoding="utf-8") as f: + config_dict = yaml.safe_load(f) + # Load the new state + config = AggregatedStateConfig(name=name, states=config_dict) + self._manager._update_state(config) + + def _check_inputs(self, config_path: str | None, config_dict: dict | None) -> None: + if (config_path is None and config_dict is None) or ( + config_path is not None and config_dict is not None + ): + raise ValueError("Either config_path or config_dict must be provided, but not both.") diff --git a/bec_lib/bec_lib/bl_states.py b/bec_lib/bec_lib/bl_states.py index 8d32959c2..d7b61d54f 100644 --- a/bec_lib/bec_lib/bl_states.py +++ b/bec_lib/bec_lib/bl_states.py @@ -1,11 +1,15 @@ +"""Module defining beamline states and their evaluation logic.""" + from __future__ import annotations import functools import keyword import traceback from abc import ABC, abstractmethod -from typing import Callable, ClassVar, Generic, Type, TypeVar, cast +from dataclasses import dataclass +from typing import Any, Callable, Generic, Literal, Type, TypeVar, cast +import yaml from pydantic import BaseModel, field_validator, model_validator from bec_lib import messages @@ -54,7 +58,7 @@ class BeamlineStateConfig(BaseModel): Base Configuration for a beamline state. """ - state_type: ClassVar[str] = "BeamlineState" + state_type: str | None = "BeamlineState" name: str title: str | None = None @@ -81,7 +85,7 @@ class DeviceStateConfig(BeamlineStateConfig): Configuration for a device-based beamline state. """ - state_type: ClassVar[str] = "DeviceBeamlineState" + state_type: str | None = "DeviceBeamlineState" device: DeviceBase | str signal: DeviceBase | str | None = None @@ -114,13 +118,68 @@ class DeviceWithinLimitsStateConfig(DeviceStateConfig): Configuration for a device within limits beamline state. """ - state_type: ClassVar[str] = "DeviceWithinLimitsState" + state_type: str | None = "DeviceWithinLimitsState" low_limit: float | None = None high_limit: float | None = None tolerance: float = 0.1 +class SignalConfig(BaseModel): + """Target value for a signal inside a named machine state.""" + + value: float | int | str | bool + abs_tol: float = 0.0 + + +class DeviceConfig(BaseModel): + """Configuration for a device inside a named machine state.""" + + abs_tol: float = 0.0 + value: float | int | str | bool | None = None + low_limit: SignalConfig | None = None + high_limit: SignalConfig | None = None + signals: dict[str, SignalConfig] | None = None + + @model_validator(mode="after") + def validate_config(self) -> DeviceConfig: + """ + Validate that either value, low_limit, high_limit, or signals are provided. + """ + if ( + self.value is None + and self.low_limit is None + and self.high_limit is None + and self.signals is None + ): + raise ValueError( + "At least one of value, low_limit, high_limit, or signals must be provided." + ) + return self + + +class SubDeviceStateConfig(BaseModel): + """ + Configuration for a sub-state with a specific label. + This is a device/signal mappping to either a DeviceConfig or SignalConfig. + """ + + devices: dict[str, DeviceConfig | SignalConfig] + transition_metadata: dict[str, Any] | None = None + + +class AggregatedStateConfig(BeamlineStateConfig): + """ + Configuration for a state machine driven by multiple device signals. + + Keys of the states dictionary are the labels of the different states. + """ + + state_type: str | None = "AggregatedState" + + states: dict[str, SubDeviceStateConfig] + + C = TypeVar("C", bound=BeamlineStateConfig) D = TypeVar("D", bound=DeviceStateConfig) @@ -322,6 +381,446 @@ def _update_device_state(self, msg_obj: MessageObject) -> messages.BeamlineState return self.evaluate(msg) +SignalSource = Literal["readback", "configuration", "limits"] + + +@dataclass(frozen=True) +class ResolvedStateSignal: + label: str + device_name: str + signal_name: str + expected_value: float | int | str | bool + abs_tolerance: float | int + source: SignalSource + + +class AggregatedState(BeamlineState[AggregatedStateConfig]): + """Beamline state that infers the current named state from multiple device signals.""" + + CONFIG_CLASS = AggregatedStateConfig + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + # Mapping from signal updates to affected state labels, used for efficient evaluation when a signal update is received + self._signal_info_to_labels: dict[tuple[str, SignalSource, str], set[str]] = {} + # Mapping from state labels to the list of signal requirements that define that state + self._requirements_for_label: dict[str, list[ResolvedStateSignal]] = {} + # Set of subscriptions to signal updates + self._subscriptions: set[tuple[str, SignalSource]] = set() + # Cache of the latest signal values + self._signal_value_cache: dict[tuple[str, SignalSource, str], Any] = {} + # List of currently active state labels + self._current_labels: list[str] = [] + + @staticmethod + def _endpoint(device: str, source: SignalSource): + """Static method to get the appropriate message endpoint based on the signal source.""" + if source == "readback": + return MessageEndpoints.device_readback(device) + if source == "configuration": + return MessageEndpoints.device_read_configuration(device) + if source == "limits": + return MessageEndpoints.device_limits(device) + raise ValueError( + f"Invalid signal source '{source}', please use 'readback', 'configuration', or 'limits'." + ) + + def _get_device_manager(self): + """Get the device manager.""" + if self.device_manager is None: + # pylint: disable=import-outside-toplevel + from bec_lib.client import BECClient + + bec = BECClient() + return bec.device_manager + return self.device_manager + + @staticmethod + def _get_signal_source(signal_info: dict[str, Any], error_prefix: str) -> SignalSource: + """ + Determine the signal source (readback, configuration, or limits) based on the signal information. + + Args: + signal_info (dict[str, Any]): The signal information dictionary containing at least the "kind_str" key. + error_prefix (str): A prefix to use in error messages for better context. + + Returns: + SignalSource: A string literal indicating the signal source, one of "readback", "configuration", or "limits". + """ + kind_str = str(signal_info.get("kind_str", "")).lower() + if "hinted" in kind_str or "normal" in kind_str: + return "readback" + if "config" in kind_str: + return "configuration" + raise ValueError( + f"{error_prefix} Unsupported kind: '{kind_str}' for signal : \n {yaml.dump(signal_info, indent=4)}" + ) + + @staticmethod + def _resolve_signal( + device_name: str, signal_name: str, device_manager: DeviceManagerBase, error_prefix: str + ) -> tuple[str, SignalSource]: + """ + Resolve the signal information for a given device and signal name. + + Args: + device_name (str): The name of the device. + signal_name (str): The name of the signal. + device_manager (DeviceManagerBase): The device manager instance. + error_prefix (str): A prefix to use in error messages for better context. + + Returns: + tuple[str, SignalSource]: A tuple containing the object name and the signal source. + """ + devices = device_manager.devices + try: + if not isinstance(device_name, str): + raise ValueError( + f"{error_prefix} Device name must be a string, got {type(device_name)}" + ) + device_obj: DeviceBase = devices[device_name] + except KeyError: + raise ValueError(f"{error_prefix} Device '{device_name}' not found.") from None + + # Special handling for limits, as they are not regular signals. + if signal_name in ["low_limit", "low_limit_travel"]: + return "low", "limits" + if signal_name in ["high_limit", "high_limit_travel"]: + return "high", "limits" + + signal_info = None + # This case is relevant if we are looking at a Signal directly + if device_name == signal_name and len(device_obj.root._info["signals"]) == 0: + signal_info = {"obj_name": signal_name, "kind_str": "hinted"} + # Case where we have a signal specified as a dotted name, e.g. + elif "." in signal_name: + try: + signal_obj = devices[signal_name] + except AttributeError: + raise ValueError( + f"{error_prefix} Signal '{signal_name}' not found for device '{device_name}'." + ) from None + if signal_obj.parent != device_obj: + raise ValueError( + f"{error_prefix} Signal '{signal_name}' does not belong to device '{device_name}'." + ) + signal_component = ".".join(signal_name.split(".")[1:]) + signal_info = device_obj.root._info["signals"].get(signal_component) + # Case where the signal is specified as the signal + else: + signal_info = device_obj.root._info["signals"].get(signal_name) + if signal_info is None: + for candidate in device_obj.root._info["signals"].values(): + if candidate.get("obj_name") == signal_name: + signal_info = candidate + break + + if signal_info is None: + raise ValueError( + f"{error_prefix} Signal '{signal_name}' not found for device '{device_name}'." + ) + + obj_name = signal_info.get("obj_name") + signal_source = AggregatedState._get_signal_source(signal_info, error_prefix) + return obj_name, signal_source + + @staticmethod + def get_state_requirements( + label: str, + state_config: SubDeviceStateConfig, + device_manager: DeviceManagerBase, + error_prefix: str, + ) -> list[ResolvedStateSignal]: + """ + Get the state requirements for a given label and state configuration. + + Args: + label (str): The label for the state. + state_config (SubDeviceStateConfig): The state configuration. + device_manager (DeviceManagerBase): The device manager instance. + error_prefix (str): A prefix to use in error messages for better context. + + Returns: + list[ResolvedStateSignal]: A list of resolved state signals. + """ + state_requirements: list[ResolvedStateSignal] = [] + for device_name, config in state_config.devices.items(): + if isinstance(config, SignalConfig): + state_requirements.append( + AggregatedState._build_requirement_for_signal( + device_name, + device_name, + config.value, + config.abs_tol, + label, + device_manager, + error_prefix, + ) + ) + elif isinstance(config, DeviceConfig): + # If a value is specified for the device, add it as a requirement + if config.value is not None: + state_requirements.append( + AggregatedState._build_requirement_for_signal( + device_name, + device_name, + config.value, + config.abs_tol, + label, + device_manager, + error_prefix, + ) + ) + if config.low_limit is not None: + state_requirements.append( + AggregatedState._build_requirement_for_signal( + device_name, + "low_limit", + config.low_limit.value, + config.low_limit.abs_tol, + label, + device_manager, + error_prefix, + ) + ) + if config.high_limit is not None: + state_requirements.append( + AggregatedState._build_requirement_for_signal( + device_name, + "high_limit", + config.high_limit.value, + config.high_limit.abs_tol, + label, + device_manager, + error_prefix, + ) + ) + for signal_name, signal_config in (config.signals or {}).items(): + state_requirements.append( + AggregatedState._build_requirement_for_signal( + device_name, + signal_name, + signal_config.value, + signal_config.abs_tol, + label, + device_manager, + error_prefix, + ) + ) + return state_requirements + + def _build_rules(self) -> None: + """Build the internal rules and mappings for state evaluation based on the configuration.""" + self._signal_info_to_labels.clear() + self._requirements_for_label.clear() + self._subscriptions.clear() + for label, device_configs in self.config.states.items(): + state_requirements: list[ResolvedStateSignal] = AggregatedState.get_state_requirements( + label, device_configs, self._get_device_manager(), self._error_prefix + ) + for requirement in state_requirements: + device_name = requirement.device_name + signal_name = requirement.signal_name + source = requirement.source + self._subscriptions.add((device_name, source)) + self._signal_info_to_labels.setdefault( + (device_name, source, signal_name), set() + ).add(label) + self._requirements_for_label[label] = state_requirements + + @staticmethod + def _build_requirement_for_signal( + device_name: str, + signal_name: str, + value: Any, + abs_tol: float, + label: str, + device_manager: DeviceManagerBase, + error_prefix: str, + ) -> ResolvedStateSignal: + """ + Build a ResolvedStateSignal for a given device, signal, and expected value. + + Args: + device_name (str): The name of the device. + signal_name (str): The name of the signal. + value (Any): The expected value for the signal. + abs_tol (float): The absolute tolerance for comparing the signal value. + label (str): The label of the state that this requirement belongs to. + device_manager (DeviceManagerBase): The device manager instance. + error_prefix (str): A prefix to use in error messages for better context. + + Returns: + ResolvedStateSignal: The resolved state signal requirement. + """ + resolved_signal_name, source = AggregatedState._resolve_signal( + device_name, signal_name, device_manager, error_prefix + ) + + return ResolvedStateSignal( + label=label, + device_name=device_name, + signal_name=resolved_signal_name, + expected_value=value, + abs_tolerance=abs_tol, + source=source, + ) + + def start(self) -> None: + if self.started: + return + + if self.connector is None: + raise RuntimeError("Redis connector is not set.") + msg = None + try: + self._build_rules() + affected_labels = self._fill_cache() + msg = self.evaluate(affected_labels=affected_labels) + except Exception as exc: + self._handle_state_exception(exc) + + if msg is not None: + self._emit_state(msg) + for device, source in self._subscriptions: + self.connector.register( + self._endpoint(device, source), + cb=self._update_aggregated_state, + device=device, + source=source, + ) + super().start() + + def _fill_cache(self) -> set[str]: + """Fill the signal value cache with the current values and return the set of affected state labels.""" + affected_labels: set[str] = set() + for device, source in self._subscriptions: + endpoint = self._endpoint(device, source) + msg = self.connector.get(endpoint) + if msg is not None: + affected_labels.update(self._cache_message(device, source, msg)) + return affected_labels + + def _cache_message( + self, device: str, source: SignalSource, msg: messages.DeviceMessage + ) -> set[str]: + """Cache the signal values from a device message and return the set of affected state labels.""" + affected_labels: set[str] = set() + for signal_name, signal_data in msg.signals.items(): + key = (device, source, signal_name) + labels = self._signal_info_to_labels.get(key) + if labels is None: # signal not relevant for any state + continue + self._signal_value_cache[key] = signal_data.get("value") + affected_labels.update(labels) + return affected_labels + + def stop(self) -> None: + """Stop the state manager and unregister all subscriptions.""" + if not self.started: + return + if self.connector is not None: + for device, source in self._subscriptions: + self.connector.unregister( + self._endpoint(device, source), cb=self._update_aggregated_state + ) + super().stop() + + def _update_aggregated_state( + self, msg_obj: MessageObject, device: str, source: SignalSource, **_kwargs + ) -> None: + """Update the aggregated state based on a new device message.""" + try: + msg: messages.DeviceMessage = msg_obj.value # type: ignore ; we know it's a DeviceMessage + affected_labels = self._cache_message(device, source, msg) + if affected_labels: + state_msg = self.evaluate(affected_labels=affected_labels) + if state_msg is not None: + self._emit_state(state_msg) + except Exception as exc: + self._handle_state_exception(exc) + + def evaluate( + self, affected_labels: set[str] | None = None + ) -> messages.BeamlineStateMessage | None: + """ + Evaluate the current state based on the cached signal values and return a BeamlineStateMessage. + + Args: + affected_labels (set[str] | None): The set of state labels that are affected by + the latest signal update. If None, all states will be evaluated. + + Returns: + messages.BeamlineStateMessage | None: The resulting state message after evaluation, or None + if no state could be evaluated. + """ + if affected_labels is None: + return None + # We need to always extend the affected labels with the current labels, + # as the signal that updated might be not relevant for the currently active state, + # but the state should still be checked for validity. + affected_labels.update(self._current_labels) + matching_labels = [label for label in affected_labels if self._label_matches(label)] + if matching_labels: + self._current_labels = sorted(matching_labels) + state_msg = messages.BeamlineStateMessage( + name=self.config.name, status="valid", label="|".join(matching_labels) + ) + return state_msg + + self._current_labels = [] + state_msg = messages.BeamlineStateMessage( + name=self.config.name, status="invalid", label="No matching state" + ) + return state_msg + + def _label_matches(self, label: str) -> bool: + """Check if the given label matches the current signal values based on the defined requirements.""" + requirements = self._requirements_for_label.get(label, []) + return bool(requirements) and all( + self._requirement_matches(requirement) for requirement in requirements + ) + + def _requirement_matches(self, requirement: ResolvedStateSignal) -> bool: + """Check if the given requirement matches the current signal values.""" + key = (requirement.device_name, requirement.source, requirement.signal_name) + cached_value = self._signal_value_cache.get(key, None) + if cached_value is None: + return False + + expected_value = requirement.expected_value + # If expected value is a user parameter, fetch the lates value from the device manager + if isinstance(expected_value, str) and expected_value.startswith("user_parameter:"): + # In this case, we fetch the latest user_parameter value from the device manager + dev_obj = self._get_device_manager().devices.get(requirement.device_name, None) + if dev_obj is None: + return False + expected_value = dev_obj.user_parameter.get( + expected_value.split("user_parameter:")[1], None + ) + if expected_value is None: + return False + + try: + # Cast to float to make sure comparison with abs works as expected. + value = float(cached_value) + comparison_value = float(expected_value) + return abs(value - comparison_value) <= requirement.abs_tolerance + # Catch TypeError and ValueError in case the value is not a number or cannot be cast to float, + # in that case we fall back to exact equality. + except (TypeError, ValueError): + try: + result = cached_value == expected_value + except (TypeError, ValueError): + return False + # In case this comparison runs on comparing two arrays. + # We do not consider this comparsion as valid currently. + try: + return bool(result) + except (TypeError, ValueError): + return False + + class ShutterState(DeviceBeamlineState[DeviceStateConfig]): """ A state that checks if the shutter is open. diff --git a/bec_lib/bec_lib/client.py b/bec_lib/bec_lib/client.py index 689f41f1b..e3cbc952b 100644 --- a/bec_lib/bec_lib/client.py +++ b/bec_lib/bec_lib/client.py @@ -20,6 +20,7 @@ from bec_lib.alarm_handler import AlarmHandler, Alarms from bec_lib.bec_service import BECService +from bec_lib.bl_state_machine import BeamlineStateMachine from bec_lib.bl_state_manager import BeamlineStateManager from bec_lib.callback_handler import CallbackHandler, EventType from bec_lib.config_helper import ConfigHelperUser @@ -162,6 +163,7 @@ def __init__( self._username = "" self._system_user = "" self.beamline_states = None + self.state_machine = None self.messaging: MessagingContainer = None # type: ignore def __new__(cls, *args, forced=False, **kwargs): @@ -241,6 +243,7 @@ def _start_services(self): self.device_monitor = DeviceMonitorPlugin(self.connector) self._update_username() self.beamline_states = BeamlineStateManager(client=self) + self.state_machine = BeamlineStateMachine(manager=self.beamline_states) def alarms(self, severity=Alarms.WARNING): """get the next alarm with at least the specified severity""" diff --git a/bec_lib/bec_lib/tests/utils.py b/bec_lib/bec_lib/tests/utils.py index a69190030..f193e5c01 100644 --- a/bec_lib/bec_lib/tests/utils.py +++ b/bec_lib/bec_lib/tests/utils.py @@ -430,8 +430,10 @@ def get_device_info_mock(device_name, device_class) -> messages.DeviceInfoMessag return messages.DeviceInfoMessage( device="rt_controller", info=positioner_info_mock_with_user_access(device_name) ) - elif device_name == "samx": - return messages.DeviceInfoMessage(device="samx", info=positioner_info_mock(device_name)) + elif device_name in ["samx", "samy"]: + return messages.DeviceInfoMessage( + device=device_name, info=positioner_info_mock(device_name) + ) elif device_name == "dyn_signals": return DYN_SIGNALS_MSG elif device_name == "eiger": @@ -440,19 +442,7 @@ def get_device_info_mock(device_name, device_class) -> messages.DeviceInfoMessag if device_base_class == "positioner": signals = positioner_info_mock(device_name)["device_info"]["signals"] elif device_base_class == "signal": - signals = { - device_name: { - "metadata": { - "connected": True, - "read_access": True, - "write_access": False, - "timestamp": 0, - "status": None, - "severity": None, - "precision": None, - } - } - } + signals = {} else: signals = {} dev_info = { diff --git a/bec_lib/tests/test_beamline_states.py b/bec_lib/tests/test_beamline_states.py index e6a936e8e..2a930bc8a 100644 --- a/bec_lib/tests/test_beamline_states.py +++ b/bec_lib/tests/test_beamline_states.py @@ -3,10 +3,13 @@ import inspect from unittest import mock +import numpy as np import pytest +import yaml from pydantic import BaseModel from bec_lib import bl_states, messages +from bec_lib.bl_state_machine import BeamlineStateMachine from bec_lib.bl_state_manager import ( BeamlineStateClientBase, BeamlineStateManager, @@ -196,6 +199,505 @@ def test_device_within_limits_state(self, connected_connector, dm_with_devices): assert state.evaluate(invalid).status == "invalid" assert state.evaluate(missing).status == "invalid" + @pytest.fixture(scope="function") + def aggregated_state_config(self): + """Fixture for an test aggregated state configuration.""" + return bl_states.AggregatedStateConfig( + name="alignment", + states={ + "alignment": { + "devices": { + "samx": { + "value": 0, + "abs_tol": 0.1, + "low_limit": {"value": -20, "abs_tol": 0.1}, + "high_limit": {"value": 20, "abs_tol": 0.1}, + }, + "bpm4i": {"value": 0, "abs_tol": 0.1}, + } + }, + "measurement": { + "devices": { + "samx": { + "value": 19, + "abs_tol": 0.1, + "low_limit": {"value": -20, "abs_tol": 0.1}, + "high_limit": {"value": 20, "abs_tol": 0.1}, + "signals": {"velocity": {"value": 5, "abs_tol": 0.1}}, + }, + "bpm4i": {"value": 2, "abs_tol": 0.1}, + } + }, + "test": {"devices": {"bpm4i": {"value": 0, "abs_tol": 0.1}}}, + "string_state": {"devices": {"bpm3i": {"value": "ok"}}}, + "state_with_user_param": {"devices": {"samx": {"value": "user_parameter:test"}}}, + }, + ) + + def test_aggregated_state_init_and_start( + self, connected_connector, dm_with_devices, aggregated_state_config + ): + """ + Test the initialization of the AggregatedState. + + Based on the provided configuration, we expect certain callbacks to be registered with the + Redis connector. This test checks this which essentially checks the proper functionality + of the 'start' method. + """ + + state = bl_states.AggregatedState( + name=aggregated_state_config.name, + config=aggregated_state_config, + redis_connector=connected_connector, + device_manager=dm_with_devices, + ) + state.start() + # We should now have subscriptions on samx limits, readback and read_configuration, and bpm4i & bpm4i + info = [ + MessageEndpoints.device_readback("samx"), + MessageEndpoints.device_read_configuration("samx"), + MessageEndpoints.device_limits("samx"), + MessageEndpoints.device_readback("bpm4i"), + MessageEndpoints.device_readback("bpm3i"), + ] + for endpoint in info: + assert endpoint.endpoint in state.connector._topics_cb + + def test_aggregated_state_evaluation( + self, connected_connector, dm_with_devices, aggregated_state_config + ): + """ + Test the evaluation of the AggregatedState when receiving message updates. This should trigger a state evaluation for + the affected labels and the current state, and if the state changes, a new state should be published. + """ + state = bl_states.AggregatedState( + name=aggregated_state_config.name, + config=aggregated_state_config, + redis_connector=connected_connector, + device_manager=dm_with_devices, + ) + state.start() + + with ( + mock.patch.object(state, "evaluate", return_value=None) as evaluate, + mock.patch.object(state, "_emit_state") as emit_state, + ): + # Test triggering evaluation for multiple labels + # samx affects alignment and measurement, so both should be evaluated. + msg_with_2_states = messages.DeviceMessage( + signals={"samx": {"value": 5.0, "timestamp": 1.0}} + ) + msg_obj = MessageObject( + value=msg_with_2_states, topic=MessageEndpoints.device_readback("samx").endpoint + ) + state._update_aggregated_state(msg_obj, device="samx", source="readback") + evaluate.assert_called_once_with( + affected_labels=set(["state_with_user_param", "alignment", "measurement"]) + ) + emit_state.assert_not_called() # As evaluate is mocked to return None, _emit_state should not be called + + def test_aggregated_state_evaluate( + self, connected_connector, dm_with_devices, aggregated_state_config + ): + """ + Test the evaluate method. + We manually cache the relevant messages and then call evaluate with the affected label. + We then check if the output message has the expected status and label, and if the current labels are updated correctly. + """ + state = bl_states.AggregatedState( + name=aggregated_state_config.name, + config=aggregated_state_config, + redis_connector=connected_connector, + device_manager=dm_with_devices, + ) + state._build_rules() + # Assume that we are currently in test + state._current_labels = ["test"] + state._cache_message( + "samx", + "readback", + messages.DeviceMessage( + signals={"samx": {"value": 0, "timestamp": 1.0}}, metadata={"stream": "primary"} + ), + ) + state._cache_message( + "samx", + "configuration", + messages.DeviceMessage( + signals={"samx_velocity": {"value": 5, "timestamp": 1.0}}, + metadata={"stream": "baseline"}, + ), + ) + state._cache_message( + "samx", + "limits", + messages.DeviceMessage( + signals={ + "low": {"value": -20, "timestamp": 1.0}, + "high": {"value": 20, "timestamp": 1.0}, + }, + metadata={"stream": "baseline"}, + ), + ) + state._cache_message( + "bpm4i", + "readback", + messages.DeviceMessage( + signals={"bpm4i": {"value": 0, "timestamp": 1.0}}, metadata={"stream": "primary"} + ), + ) + + msg = state.evaluate(affected_labels={"alignment"}) + + assert msg.status == "valid" + # The order of the labels is not guaranteed + assert msg.label in ["alignment|test", "test|alignment"] + assert set(state._current_labels) == set(["alignment", "test"]) + dm_with_devices.devices["samx"].user_parameter["test"] = 0 + msg = state.evaluate(affected_labels={"alignment", "state_with_user_param"}) + assert msg.status == "valid" + assert set(msg.label.split("|")) == set(["alignment", "state_with_user_param", "test"]) + assert set(state._current_labels) == set(["alignment", "state_with_user_param", "test"]) + + state._cache_message( + "samx", + "readback", + messages.DeviceMessage( + signals={"samx": {"value": 3, "timestamp": 2.0}}, metadata={"stream": "primary"} + ), + ) + + msg = state.evaluate(affected_labels={"alignment"}) + + assert msg.status == "valid" + assert msg.label == "test" + assert state._current_labels == ["test"] + + state._cache_message( + "bpm4i", + "readback", + messages.DeviceMessage( + signals={"bpm4i": {"value": 2, "timestamp": 2.0}}, metadata={"stream": "primary"} + ), + ) + + msg = state.evaluate(affected_labels={"alignment", "test", "measurement"}) + + assert msg.status == "invalid" + assert msg.label == "No matching state" + assert state._current_labels == [] + + def test_aggregated_state_exception_handling( + self, connected_connector, dm_with_devices, aggregated_state_config + ): + """ + Test that if an exception is raised during the evaluation of the state, this is properly handled and an alarm is raised. + We check that the evaluate method is called and that if it raises an exception, the raise_alarm method of the connector + is called, and a state with status "unknown" and label "broken state" is published. + """ + state = bl_states.AggregatedState( + name=aggregated_state_config.name, + config=aggregated_state_config, + redis_connector=connected_connector, + device_manager=dm_with_devices, + ) + state.start() + msg = messages.DeviceMessage( + signals={"samx": {"value": 0, "timestamp": 1.0}}, metadata={"stream": "primary"} + ) + msg_obj = MessageObject(value=msg, topic=MessageEndpoints.device_readback("samx").endpoint) + + with ( + mock.patch.object( + state, "evaluate", side_effect=RuntimeError("broken state") + ) as evaluate, + mock.patch.object(connected_connector, "raise_alarm") as raise_alarm, + ): + state._update_aggregated_state(msg_obj, device="samx", source="readback") + + evaluate.assert_called_once_with( + affected_labels={"state_with_user_param", "alignment", "measurement"} + ) + raise_alarm.assert_called_once() + out = connected_connector.xread( + MessageEndpoints.beamline_state("alignment"), from_start=True + ) + assert out[-1]["data"].status == "unknown" + assert out[-1]["data"].label == "broken state" + assert state.raised_warning is True + + def test_aggregated_state_transitions_between_labels( + self, connected_connector, dm_with_devices, aggregated_state_config + ): + """ + Test the transitions between different labels of the aggregated state. We simulate the messages that would trigger + the transitions and check that the output message has the expected status and label, and that the current labels are updated correctly. + """ + state = bl_states.AggregatedState( + name=aggregated_state_config.name, + config=aggregated_state_config, + redis_connector=connected_connector, + device_manager=dm_with_devices, + ) + state.start() + + def update(device, source, signals): + msg = messages.DeviceMessage(signals=signals, metadata={"stream": "primary"}) + msg_obj = MessageObject(value=msg, topic=state._endpoint(device, source).endpoint) + state._update_aggregated_state(msg_obj, device=device, source=source) + out = connected_connector.xread( + MessageEndpoints.beamline_state("alignment"), from_start=True + ) + return out[-1]["data"] + + msg = update("samx", "configuration", {"samx_velocity": {"value": 5, "timestamp": 1.0}}) + assert msg.status == "invalid" + + update( + "samx", + "limits", + {"low": {"value": -20, "timestamp": 1.0}, "high": {"value": 20, "timestamp": 1.0}}, + ) + update("samx", "readback", {"samx": {"value": 0, "timestamp": 1.0}}) + msg = update("bpm4i", "readback", {"bpm4i": {"value": 0, "timestamp": 1.0}}) + assert msg.status == "valid" + assert set(msg.label.split("|")) == {"alignment", "test"} + + msg = update("samx", "readback", {"samx": {"value": 19, "timestamp": 2.0}}) + assert msg.status == "valid" + assert msg.label == "test" + + msg = update("bpm4i", "readback", {"bpm4i": {"value": 2, "timestamp": 2.0}}) + assert msg.status == "valid" + assert msg.label == "measurement" + + @pytest.mark.parametrize( + ("cached_value", "expected_value", "abs_tolerance", "matches"), + [ + (1.05, 1.0, 0.1, True), + (1.2, 1.0, 0.1, False), + (5, 5, 0.0, True), + (np.int64(5), 5, 0.0, True), + (np.float64(1.05), 1.0, 0.1, True), + ("ok", "ok", 0.0, True), + ("not-ok", "ok", 0.0, False), + ([1, 2], 1, 0.0, False), + (np.array([1.0, 2.0]), 1.0, 0.1, False), + (np.array([1.0, 2.0]), np.array([1.0, 2.0]), 0.0, False), + ], + ) + def test_aggregated_state_requirement_matches( + self, + connected_connector, + dm_with_devices, + aggregated_state_config, + cached_value, + expected_value, + abs_tolerance, + matches, + ): + """ + Test the evaluation of requirements in the aggregated state. We manually set the signal value + cache and then call the _requirement_matches method with a requirement, and check if the output is as expected. + """ + state = bl_states.AggregatedState( + name=aggregated_state_config.name, + config=aggregated_state_config, + redis_connector=connected_connector, + device_manager=dm_with_devices, + ) + requirement = bl_states.ResolvedStateSignal( + label="alignment", + device_name="bpm4i", + signal_name="bpm4i", + expected_value=expected_value, + abs_tolerance=abs_tolerance, + source="readback", + ) + state._signal_value_cache[("bpm4i", "readback", "bpm4i")] = cached_value + + assert state._requirement_matches(requirement) is matches + + def test_device_config_requires_at_least_one_target(self): + with pytest.raises(ValueError, match="At least one of value"): + bl_states.DeviceConfig() + + def test_aggregated_state_endpoint_rejects_unknown_source(self): + with pytest.raises(ValueError, match="Invalid signal source"): + bl_states.AggregatedState._endpoint("samx", "unknown") + + def test_aggregated_state_get_device_manager_falls_back_to_client(self): + state = bl_states.AggregatedState( + name="alignment", states={"label": {"devices": {"samx": {"value": 0}}}} + ) + client = mock.MagicMock() + + with mock.patch("bec_lib.client.BECClient", return_value=client): + assert state._get_device_manager() is client.device_manager + + def test_aggregated_state_get_signal_source_rejects_unsupported_kind(self): + with pytest.raises(ValueError, match="Unsupported kind"): + bl_states.AggregatedState._get_signal_source( + {"kind_str": "omitted", "obj_name": "samx_unused"}, "test" + ) + + def test_aggregated_state_resolve_signal_edge_cases(self, dm_with_devices): + assert bl_states.AggregatedState._resolve_signal( + "samx", "low_limit_travel", dm_with_devices, "test" + ) == ("low", "limits") + assert bl_states.AggregatedState._resolve_signal( + "samx", "high_limit_travel", dm_with_devices, "test" + ) == ("high", "limits") + assert bl_states.AggregatedState._resolve_signal( + "samx", "samx_velocity", dm_with_devices, "test" + ) == ("samx_velocity", "configuration") + + with pytest.raises(ValueError, match="Device 'missing' not found"): + bl_states.AggregatedState._resolve_signal("missing", "missing", dm_with_devices, "test") + with pytest.raises(ValueError, match="Device name must be a string"): + bl_states.AggregatedState._resolve_signal(1, "samx", dm_with_devices, "test") + with pytest.raises(ValueError, match="Signal 'missing_signal' not found"): + bl_states.AggregatedState._resolve_signal( + "samx", "missing_signal", dm_with_devices, "test" + ) + with pytest.raises(ValueError, match="Unsupported kind"): + bl_states.AggregatedState._resolve_signal("samx", "unused", dm_with_devices, "test") + + def test_aggregated_state_resolve_dotted_signal_edge_cases(self, dm_with_devices): + assert bl_states.AggregatedState._resolve_signal( + "samx", "samx.velocity", dm_with_devices, "test" + ) == ("samx_velocity", "configuration") + + with pytest.raises(ValueError, match="does not belong"): + bl_states.AggregatedState._resolve_signal( + "samx", "samy.velocity", dm_with_devices, "test" + ) + + devices = mock.MagicMock() + devices.__getitem__.side_effect = [dm_with_devices.devices["samx"], AttributeError] + manager = mock.MagicMock(devices=devices) + with pytest.raises(ValueError, match="Signal 'samx.missing' not found"): + bl_states.AggregatedState._resolve_signal("samx", "samx.missing", manager, "test") + + def test_aggregated_state_start_edge_cases( + self, connected_connector, dm_with_devices, aggregated_state_config + ): + state = bl_states.AggregatedState( + config=aggregated_state_config, + redis_connector=connected_connector, + device_manager=dm_with_devices, + ) + state.started = True + with mock.patch.object(state, "_build_rules") as build_rules: + state.start() + build_rules.assert_not_called() + + state = bl_states.AggregatedState( + config=aggregated_state_config, redis_connector=None, device_manager=dm_with_devices + ) + with pytest.raises(RuntimeError, match="Redis connector is not set"): + state.start() + + def test_aggregated_state_start_handles_rule_build_error( + self, connected_connector, dm_with_devices, aggregated_state_config + ): + state = bl_states.AggregatedState( + config=aggregated_state_config, + redis_connector=connected_connector, + device_manager=dm_with_devices, + ) + + with ( + mock.patch.object(state, "_build_rules", side_effect=RuntimeError("bad rules")), + mock.patch.object(state, "_handle_state_exception") as handle_exception, + ): + state.start() + + handle_exception.assert_called_once() + assert state.started is True + + def test_aggregated_state_fill_cache_uses_existing_messages( + self, connected_connector, dm_with_devices, aggregated_state_config + ): + state = bl_states.AggregatedState( + config=aggregated_state_config, + redis_connector=connected_connector, + device_manager=dm_with_devices, + ) + state._build_rules() + connected_connector.set_and_publish( + MessageEndpoints.device_readback("samx"), + messages.DeviceMessage(signals={"samx": {"value": 0, "timestamp": 1.0}}), + ) + + affected_labels = state._fill_cache() + + assert affected_labels == {"alignment", "measurement", "state_with_user_param"} + assert state._signal_value_cache[("samx", "readback", "samx")] == 0 + + def test_aggregated_state_cache_ignores_irrelevant_signals( + self, connected_connector, dm_with_devices, aggregated_state_config + ): + state = bl_states.AggregatedState( + config=aggregated_state_config, + redis_connector=connected_connector, + device_manager=dm_with_devices, + ) + state._build_rules() + + affected_labels = state._cache_message( + "samx", + "readback", + messages.DeviceMessage( + signals={"samx_unused": {"value": 1, "timestamp": 1.0}}, + metadata={"stream": "primary"}, + ), + ) + + assert affected_labels == set() + assert ("samx", "readback", "samx_unused") not in state._signal_value_cache + + def test_aggregated_state_stop_unregisters_subscriptions( + self, connected_connector, dm_with_devices, aggregated_state_config + ): + state = bl_states.AggregatedState( + config=aggregated_state_config, + redis_connector=connected_connector, + device_manager=dm_with_devices, + ) + state.start() + + with mock.patch.object(connected_connector, "unregister") as unregister: + state.stop() + + assert unregister.call_count == len(state._subscriptions) + assert state.started is False + + def test_aggregated_state_stop_is_noop_before_start( + self, connected_connector, dm_with_devices, aggregated_state_config + ): + state = bl_states.AggregatedState( + config=aggregated_state_config, + redis_connector=connected_connector, + device_manager=dm_with_devices, + ) + + with mock.patch.object(connected_connector, "unregister") as unregister: + state.stop() + + unregister.assert_not_called() + + def test_aggregated_state_evaluate_without_affected_labels( + self, connected_connector, dm_with_devices, aggregated_state_config + ): + state = bl_states.AggregatedState( + config=aggregated_state_config, + redis_connector=connected_connector, + device_manager=dm_with_devices, + ) + + assert state.evaluate() is None + class TestBeamlineStateManager: def test_manager_registers_for_state_updates(self, connected_connector): @@ -321,3 +823,83 @@ def test_show_all_prints_table(self, state_manager, capsys): captured = capsys.readouterr() assert "shutter_open" in (captured.out + captured.err) + + +class TestStateMachine: + + @pytest.fixture() + def state_machine(self, state_manager): + state_machine = BeamlineStateMachine(manager=state_manager) + return state_machine + + @pytest.fixture() + def config_dict(self): + return { + "alignment": { + "devices": { + "samx": { + "value": 0, + "abs_tol": 0.1, + "signals": {"velocity": {"value": 5, "abs_tol": 0.1}}, + } + } + } + } + + def test_load_from_config_with_dict( + self, state_machine: BeamlineStateMachine, tmp_path, config_dict + ): + """Test loading configuration from a dictionary or file.""" + + # Load valid configuration from dictionary + with mock.patch.object(state_machine._manager, "add") as manager_add: + state_machine.load_from_config( + name="alignment", config_path=None, config_dict=config_dict + ) + manager_add.assert_called_once_with( + bl_states.AggregatedStateConfig(name="alignment", states=config_dict) + ) + # Loading with both config_path and config_dict should raise an error + with pytest.raises(ValueError): + state_machine.load_from_config( + name="alignment", config_path="path/to/config.yaml", config_dict=config_dict + ) + # Loading with neither config_path nor config_dict should raise an error + with pytest.raises(ValueError): + state_machine.load_from_config(name="alignment", config_path=None, config_dict=None) + + # Loading from file should work. + config_path = tmp_path / "config.yaml" + with open(config_path, "w", encoding="utf-8") as f: + yaml.dump(config_dict, f) + state_machine.load_from_config(name="alignment", config_path=str(config_path)) + manager_add.assert_called_with( + bl_states.AggregatedStateConfig(name="alignment", states=config_dict) + ) + + def test_update_config(self, state_machine: BeamlineStateMachine, config_dict, tmp_path): + """Test update method of state machine.""" + with mock.patch.object(state_machine._manager, "_update_state") as manager_update: + config = bl_states.AggregatedStateConfig(name="alignment", states=config_dict) + state_machine.update_config(name="alignment", config_dict=config_dict) + manager_update.assert_called_once_with(config) + + manager_update.reset_mock() + + # Invalid updates should raise an error + with pytest.raises(ValueError): + state_machine.update_config(name="alignment", config_dict=None) + manager_update.assert_not_called() + + with pytest.raises(ValueError): + state_machine.update_config( + name="alignment", config_path="path/to/config.yaml", config_dict=config_dict + ) + manager_update.assert_not_called() + manager_update.reset_mock() + # Updating from file should work. + config_path = tmp_path / "config.yaml" + with open(config_path, "w", encoding="utf-8") as f: + yaml.dump(config_dict, f) + state_machine.update_config(name="alignment", config_path=str(config_path)) + manager_update.assert_called_once_with(config) diff --git a/bec_server/bec_server/device_server/devices/device_serializer.py b/bec_server/bec_server/device_server/devices/device_serializer.py index 63bdc1e4e..1caf254b7 100644 --- a/bec_server/bec_server/device_server/devices/device_serializer.py +++ b/bec_server/bec_server/device_server/devices/device_serializer.py @@ -202,6 +202,8 @@ def get_device_info( "kind_str": signal_obj.kind.name, "doc": doc, "describe": signal_obj.describe().get(signal_obj.name, {}), + "read_access": getattr(signal_obj, "read_access", None), + "write_access": getattr(signal_obj, "write_access", None), # pylint: disable=protected-access "metadata": signal_obj._metadata, "labels": sorted(signal_obj._ophyd_labels_), diff --git a/bec_server/bec_server/scan_server/scans/state_transition_scan.py b/bec_server/bec_server/scan_server/scans/state_transition_scan.py new file mode 100644 index 000000000..08af25fa9 --- /dev/null +++ b/bec_server/bec_server/scan_server/scans/state_transition_scan.py @@ -0,0 +1,356 @@ +""" +Updated move scan implementation for coordinated motor repositioning commands. + +Scan procedure: + - prepare_scan + - open_scan + - stage + - pre_scan + - scan_core + - at_each_point (optionally called by scan_core) + - post_scan + - unstage + - close_scan + - on_exception (called if any exception is raised during the scan) +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Tuple + +from bec_lib.alarm_handler import AlarmBase, Alarms +from bec_lib.bl_states import AggregatedState, SubDeviceStateConfig +from bec_lib.device import DeviceBase, Positioner, Signal +from bec_lib.endpoints import MessageEndpoints +from bec_lib.logger import bec_logger +from bec_lib.messages import AlarmMessage, ErrorInfo +from bec_server.scan_server.scans.scan_base import ScanBase +from bec_server.scan_server.scans.scan_modifier import scan_hook + +if TYPE_CHECKING: + from bec_lib.bl_states import AggregatedStateConfig, ResolvedStateSignal + from bec_lib.messages import AvailableBeamlineStatesMessage + +logger = bec_logger.logger + + +class StateTransitionScanError(AlarmBase): + """Exception raised when an RPC call fails.""" + + def __init__(self, exc_type: str, message: str, compact_message: str) -> None: + alarm = AlarmMessage( + severity=Alarms.MAJOR, + info=ErrorInfo( + exception_type=exc_type, + error_message=message, + compact_error_message=compact_message, + ), + ) + super().__init__(alarm, Alarms.MAJOR, handled=False) + + +class StateTransitionScan(ScanBase): + + # Scan Type: Hardware triggered or software triggered? + # If the main trigger and readout logic is done within the at_each_point method in scan_core, choose SOFTWARE_TRIGGERED. + # If the main trigger and readout logic is implemented on a device that is simply kicked off in this scan, choose HARDWARE_TRIGGERED. + # This primarily serves as information for devices: The device may need to react differently if a software trigger is expected + # for every point. + scan_type = None + + # Scan name: This is the name of the scan, e.g. "line_scan". This is used for display purposes and to identify the scan type in user interfaces. + # Choose a descriptive name that does not conflict with existing scan names. + scan_name = "_v4_state_transition" + + # We set is_scan to False to separate this class from the other scans in the user interface + is_scan = False + + def __init__(self, *args, state_name: str, target_label: str, **kwargs): + """ + State transition scan that moves a motor in between two states. + The main purpose of this scan is to be used in conjunction with state + management in BEC, and transitioning the beamline in-between different aggregated states. + """ + super().__init__(**kwargs) + self.state_name = state_name + self.target_label = target_label + + # We need to sort the devices and signals in the config, and identify which of them are motor setpoint/readback pairs + # and which of them are just readouts and thereby can not be set within the transition. + self._signals_to_set: list[Tuple[Signal, Any]] = [] + self._limits_to_set: dict[str, Tuple[Positioner, float, float]] = {} + self._devices_to_set: list[Tuple[Positioner, float]] = [] + + # pylint: disable=protected-access + @scan_hook + def prepare_scan(self): + """ + Prepare the scan. This can include any steps that need to be executed + before the scan is opened, such as preparing the positions (if not done already) + or setting up the devices. + """ + # Check if the state and the target label exists, if yes, fetch the configuration for the target state + self.config_for_label = self._fetch_config_for_label(self.state_name, self.target_label) + requirements: list[ResolvedStateSignal] = AggregatedState.get_state_requirements( + self.target_label, self.config_for_label, self.device_manager, "StateTransitionScan" + ) + self._signals_to_set, self._limits_to_set, self._devices_to_set = ( + self._fetch_devices_signals_and_limits_to_set(requirements) + ) + + self.update_scan_info(scan_report_devices=[dev for dev, _ in self._devices_to_set]) + + @scan_hook + def open_scan(self): + """ + Open the scan. + This step must call self.actions.open_scan() to ensure that a new scan is + opened. Make sure to prepare the scan metadata before, either in + prepare_scan() or in open_scan() itself and call self.update_scan_info(...) + to update the scan metadata if needed. + """ + self.actions.open_scan() + + @scan_hook + def stage(self): + """ + Stage the devices for the upcoming scan. The stage logic is typically + implemented on the device itself (i.e. by the device's stage method). + However, if there are any additional steps that need to be executed before + staging the devices, they can be implemented here. + """ + self.actions.stage_all_devices() + + @scan_hook + def pre_scan(self): + """ + Pre-scan steps to be executed before the main scan logic. + This is typically the last chance to prepare the devices before the core scan + logic is executed. For example, this is a good place to initialize time-criticial + devices, e.g. devices that have a short timeout. + The pre-scan logic is typically implemented on the device itself. + """ + self.actions.pre_scan_all_devices() + + @scan_hook + def scan_core(self): + """ + Core scan logic to be executed during the scan. + This is where the main scan logic should be implemented. + """ + # Set the signals first... because otherwise there can be an issue with the live updates if + # TODO we set the scan_report_instruction_readback for the motors and one of the signal is also a motor. + self._set_signals() + # Motors + motors = [element[0] for element in self._devices_to_set] + target_positions = [element[1] for element in self._devices_to_set] + current_positions = self.components.get_start_positions(motors) + # TODO Check how this can be managed in view of the live updates. If we move the signal section further down, + # We get issues with the DeviceProgressBar live updates, and in this ordering, we have an issue that multiple + # Live displays seem to be triggered. This has to be investigated with care. + # self.actions.add_scan_report_instruction_readback( + # devices=motors, start=current_positions, stop=target_positions + # ) + self.components.move_and_wait(motors, target_positions) + # Limits + self._set_limits() + + @scan_hook + def at_each_point(self): + """ + Logic to be executed at each point during the scan. This is called by the step_scan method at each point. + + Args: + motors (list[str | DeviceBase]): List of motor names or device instances being moved. + positions (np.ndarray): Current positions of the motors, shape (len(motors),). + last_positions (np.ndarray | None): Previous positions of the motors, shape (len(motors),) or None if this is the first point. + """ + + @scan_hook + def post_scan(self): + """ + Post-scan steps to be executed after the main scan logic. + """ + self.actions.complete_all_devices() + + @scan_hook + def unstage(self): + """Unstage the scan by executing post-scan steps.""" + self.actions.unstage_all_devices() + + @scan_hook + def close_scan(self): + """Close the scan.""" + if self._baseline_readout_status is not None: + self._baseline_readout_status.wait() + self.actions.close_scan() + self.actions.check_for_unchecked_statuses() + + @scan_hook + def on_exception(self, exception: Exception): + """ + Handle exceptions that occur during the scan. + This is a good place to implement any cleanup logic that needs to be executed in case of an exception, + such as returning the devices to a safe state or moving the motors back to their starting position. + """ + + ################# + ## Custom Methods + ################# + + def _set_signals(self): + """Method to set signals in the transition.""" + for signal_obj, target_value in self._signals_to_set: + # Check if signal is writable before setting it, if not skip. + if self._check_if_signal_has_write_access(signal_obj): + signal_obj.set(target_value).wait() + + def _set_limits(self): + """Method to set limits for devices in the transition.""" + for dev_name, (dev_obj, low_limit, high_limit) in self._limits_to_set.items(): + dev_obj.limits = [low_limit, high_limit] + + def _check_if_signal_has_write_access(self, signal_obj: Signal) -> bool: + """ + Check if a signal has write access based on its signal information. The issue here is that signals of a + device follow a slightly different pattern. Therefore, we have to first check "_info" for signals + and if that is empty, check '_signal_info' for sub-signals of devices. + + Args: + signal_obj (Signal): Signal object to check. + Returns: + bool: True if the signal has write access, False otherwise. + """ + write_access = signal_obj._info.get("write_access", None) + if write_access is None: + write_access = signal_obj._signal_info.get("write_access", False) + return write_access + + def _fetch_devices_signals_and_limits_to_set( + self, requirements: list[ResolvedStateSignal] + ) -> Tuple[dict, dict, dict]: + """ + This method fetches the device signals, limits and devices to set based on a list of state requirements. + It returns a tuple containing three elements: + - signals_to_set (list[Tuple[Signal, Any]]): List of signals to set with their target values. + - limits_to_set (dict[str, Tuple[Positioner, float, float]]): Dictionary of devices and their limits to set. + - devices_to_set (list[Tuple[Positioner, float]]): List of devices to set with their target positions. + + Args: + requirements (list[ResolvedStateSignal]): List of state requirements to fetch the device signals and limits for. + + Returns: + Tuple containing: + - signals_to_set (list[Tuple[Signal, Any]]): List of signals to set with their target values. + - limits_to_set (dict[str, Tuple[Positioner, float, float]]): Dictionary of devices and their limits to set. + - devices_to_set (list[Tuple[Positioner, float]]): List of devices to set with their target positions. + """ + _signals_to_set: list[Tuple[Signal, Any]] = [] + _limits_to_set: dict[str, Tuple[Positioner, float, float]] = {} + _devices_to_set: list[Tuple[Positioner, float]] = [] + for req in requirements: + dev_obj: DeviceBase = self.device_manager.devices.get(req.device_name) + # Device not found + if dev_obj is None: + raise StateTransitionScanError( + exc_type="DeviceNotFound", + message=f"Device {req.device_name} not found in device manager.", + compact_message=f"Device {req.device_name} not found.", + ) + expected_value = self._get_expected_value(req) + # First we handle Signals logic + if isinstance(dev_obj, Signal): + _signals_to_set.append((dev_obj, expected_value)) + continue + # Positioner and Device logic. Devices must implement .set for this to work, otherwise we can not set them and we raise an error + if isinstance(dev_obj, DeviceBase): + # Handle motor-specific logic here + # First we handle logic for motions of the motor. Device_name and signal_name will be equivalent here + if req.signal_name == req.device_name: + _devices_to_set.append((dev_obj, expected_value)) + continue + if req.source == "limits": + if req.device_name not in _limits_to_set: + _limits_to_set[req.device_name] = ( + dev_obj, + dev_obj.low_limit, + dev_obj.high_limit, + ) + if req.signal_name == "low": + _limits_to_set[req.device_name] = ( + dev_obj, + expected_value, + _limits_to_set[req.device_name][2], + ) + elif req.signal_name == "high": + _limits_to_set[req.device_name] = ( + dev_obj, + _limits_to_set[req.device_name][1], + expected_value, + ) + continue + signal_obj = self._get_signal_object(dev_obj, req.signal_name) + if signal_obj is None: + raise StateTransitionScanError( + exc_type="SignalNotFound", + message=f"Signal {req.signal_name} for device {req.device_name} not found in device manager.", + compact_message=f"Signal {req.signal_name} for device {req.device_name} not found.", + ) + _signals_to_set.append((signal_obj, expected_value)) + continue + # Return the collected signals, limits and devices to set + return _signals_to_set, _limits_to_set, _devices_to_set + + def _get_expected_value(self, requirement: ResolvedStateSignal) -> Any: + expected_value = requirement.expected_value + # If expected value is a user parameter, fetch the lates value from the device manager + if isinstance(expected_value, str) and expected_value.startswith("user_parameter:"): + dev_obj = self.device_manager.devices.get(requirement.device_name, None) + if dev_obj is None: + raise StateTransitionScanError( + exc_type="DeviceNotFound", + message=f"Device {requirement.device_name} not found in device manager.", + compact_message=f"Device {requirement.device_name} not found.", + ) + expected_value = dev_obj.user_parameter.get( + expected_value.split("user_parameter:")[1], None + ) + if expected_value is None: + raise StateTransitionScanError( + exc_type="UserParameterNotFound", + message=f"User parameter {expected_value.split('user_parameter:')[1]} for device {requirement.device_name} not found in device manager.", + compact_message=f"User parameter {expected_value.split('user_parameter:')[1]} for device {requirement.device_name} not found.", + ) + return expected_value + + def _get_signal_object(self, device_obj: DeviceBase, signal_name: str) -> Signal: + for component_name, info in device_obj._info["signals"].items(): + if info["obj_name"] == signal_name: + return getattr(device_obj, component_name) + + def _fetch_config_for_label(self, state_name: str, target_label: str) -> SubDeviceStateConfig: + available_states_msg: AvailableBeamlineStatesMessage = self.redis_connector.get_last( + MessageEndpoints.available_beamline_states() + ) + if available_states_msg is None: + raise ValueError( + "No available beamline states found in Redis. Cannot fetch configuration for state transition scan." + ) + configs = [ + state for state in available_states_msg["data"].states if state.name == state_name + ] + if len(configs) == 0: + raise ValueError(f"State {state_name} not found in available states.") + elif len(configs) > 1: # Should not be possible, but just in case + raise ValueError(f"Multiple states with name {state_name} found in available states.") + config: AggregatedStateConfig = configs[0] + if config.state_type != "AggregatedState": + raise ValueError( + f"State {state_name} is not an aggregated state. Transitions are only supported for aggregated states." + ) + available_labels = list(config.parameters["states"].keys()) + if target_label not in available_labels: + raise ValueError( + f"Target label {target_label} not found in state {state_name}. Available labels: {available_labels}" + ) + return SubDeviceStateConfig.model_validate(config.parameters["states"][target_label]) diff --git a/bec_server/bec_server/scan_server/tests/scan_fixtures.py b/bec_server/bec_server/scan_server/tests/scan_fixtures.py index e7f3b23e9..2c51f6aea 100644 --- a/bec_server/bec_server/scan_server/tests/scan_fixtures.py +++ b/bec_server/bec_server/scan_server/tests/scan_fixtures.py @@ -123,6 +123,18 @@ def full_name(self): def limits(self): return self._limits + @limits.setter + def limits(self, value): + self._limits = tuple(value) + + @property + def low_limit(self): + return self._limits[0] + + @property + def high_limit(self): + return self._limits[1] + @property def enabled(self): return self._enabled @@ -226,6 +238,18 @@ def full_name(self): def limits(self): return self._limits + @limits.setter + def limits(self, value): + self._limits = tuple(value) + + @property + def low_limit(self): + return self._limits[0] + + @property + def high_limit(self): + return self._limits[1] + @property def enabled(self): return self._enabled @@ -360,6 +384,7 @@ def v4_scan_assembler(readout_priority: ReadoutPriorityContainer, device_manager def _assemble_scan(scan_type, *scan_args, **scan_kwargs): scan_id = scan_kwargs.pop("scan_id", "scan-id-test") + connector = scan_kwargs.pop("connector", None) or ConnectorMock("") try: scan_cls = scan_classes[scan_type] @@ -367,7 +392,6 @@ def _assemble_scan(scan_type, *scan_args, **scan_kwargs): available = ", ".join(sorted(scan_classes)) raise KeyError(f"Unknown scan type '{scan_type}'. Available: {available}") from exc - connector = ConnectorMock("") instruction_handler = InstructionHandler(connector) device_names = sorted( set(_infer_v4_device_names(scan_cls, scan_args, scan_kwargs)) diff --git a/bec_server/tests/tests_scan_server/scans_v4/test_state_transition_scan.py b/bec_server/tests/tests_scan_server/scans_v4/test_state_transition_scan.py new file mode 100644 index 000000000..39bdb40e4 --- /dev/null +++ b/bec_server/tests/tests_scan_server/scans_v4/test_state_transition_scan.py @@ -0,0 +1,163 @@ +from unittest import mock + +import pytest +from ophyd_devices.sim.sim_positioner import SimPositioner + +from bec_lib import messages +from bec_lib.device import Positioner, _PermissiveDeviceModel +from bec_lib.endpoints import MessageEndpoints +from bec_server.device_server.devices.device_serializer import get_device_info +from bec_server.scan_server.tests.scan_hook_tests import ( + assert_close_scan_waits_for_baseline_and_closes, + assert_pre_scan_called, + assert_prepare_scan_reads_baseline_devices, + assert_scan_open_called, + assert_stage_all_devices_called, + assert_unstage_all_devices_called, + run_scan_tests, +) + +ACQUIRE_DEFAULT_HOOK_TESTS = [ + ("open_scan", [assert_scan_open_called]), + ("stage", [assert_stage_all_devices_called]), + ("pre_scan", [assert_pre_scan_called]), + ("unstage", [assert_unstage_all_devices_called]), + ("close_scan", [assert_close_scan_waits_for_baseline_and_closes]), +] + + +@pytest.fixture +def state_transition_connector(connected_connector): + connected_connector.xadd( + MessageEndpoints.available_beamline_states(), + { + "data": messages.AvailableBeamlineStatesMessage( + states=[ + messages.BeamlineStateConfig( + name="test", + title="Test state", + state_type="AggregatedState", + parameters={ + "states": { + "alignment": { + "devices": { + "samx": { + "value": 1.5, + "low_limit": {"value": -2}, + "high_limit": {"value": 2}, + "signals": {"velocity": {"value": 0.5}}, + }, + "samy": { + "value": 0.5, + "low_limit": {"value": -1}, + "high_limit": {"value": 1}, + }, + } + } + } + }, + ) + ] + ) + }, + ) + return connected_connector + + +@pytest.fixture +def simulated_samx(device_manager): + # dev_obj = SimPositioner(name="samx") + name = "samx" + dev = SimPositioner(name=name) + config = _PermissiveDeviceModel( + enabled=True, + deviceClass="ophyd_devices.sim.sim_positioner.SimPositioner", + readoutPriority="baseline", + ) + info = get_device_info(dev, connect=True) + dev_man_obj = Positioner( + name=name, info=info, config=config, class_name=config.deviceClass, parent=device_manager + ) + return dev_man_obj + + +@pytest.mark.timeout(20) +@pytest.mark.parametrize(("hook_name", "hook_tests"), ACQUIRE_DEFAULT_HOOK_TESTS) +def test_state_transition_default_hooks( + v4_scan_assembler, state_transition_connector, nth_done_status_mock, hook_name, hook_tests +): + """Test default hooks open_scan, stage, pre_scan, unstage, and close_scan for the StateTransitionScan.""" + scan = v4_scan_assembler( + "_v4_state_transition", + state_name="test", + target_label="alignment", + connector=state_transition_connector, + ) + + run_scan_tests(scan, [(hook_name, hook_tests)], nth_done_status_mock=nth_done_status_mock) + + +@pytest.mark.timeout(20) +def test_state_transition_prepare_scan( + v4_scan_assembler, state_transition_connector, device_manager, simulated_samx +): + """Test prepare scan hook for the StateTransitionScan.""" + device_manager.add_device(simulated_samx, replace=True) + scan = v4_scan_assembler( + "_v4_state_transition", + state_name="test", + target_label="alignment", + connector=state_transition_connector, + ) + + scan.prepare_scan() + + devices_to_set = {(device.name, value) for device, value in scan._devices_to_set} + limits_to_set = { + device_name: (device.name, low_limit, high_limit) + for device_name, (device, low_limit, high_limit) in scan._limits_to_set.items() + } + signals_to_set = {(signal.full_name, value) for signal, value in scan._signals_to_set} + + assert devices_to_set == {("samx", 1.5), ("samy", 0.5)} + assert limits_to_set == {"samx": ("samx", -2, 2), "samy": ("samy", -1, 1)} + assert signals_to_set == {("samx_velocity", 0.5)} + + +@pytest.mark.timeout(20) +def test_state_transition_scan_core( + v4_scan_assembler, state_transition_connector, device_manager, simulated_samx +): + device_manager.add_device(simulated_samx, replace=True) + scan = v4_scan_assembler( + "_v4_state_transition", + state_name="test", + target_label="alignment", + connector=state_transition_connector, + ) + scan.prepare_scan() + signal_by_name = {signal.full_name: signal for signal, _ in scan._signals_to_set} + velocity_set_status = mock.MagicMock() + signal_by_name["samx_velocity"].set = mock.MagicMock(return_value=velocity_set_status) + scan.components.get_start_positions = mock.MagicMock(return_value=[0, 0]) + with ( + mock.patch.object( + scan.components, "get_start_positions", return_value=[0, 0] + ) as mock_get_start_positions, + mock.patch.object(scan.components, "move_and_wait") as mock_move_and_wait, + mock.patch.object( + scan.actions, "add_scan_report_instruction_readback" + ) as mock_add_scan_report_instruction_readback, + mock.patch.object(scan, "_set_limits") as mock_set_limits, + ): + scan.scan_core() + mock_get_start_positions.assert_called_once() + # mock_add_scan_report_instruction_readback.assert_called_once_with( + # devices=[scan.device_manager.devices["samx"], scan.device_manager.devices["samy"]], + # start=[0, 0], + # stop=[1.5, 0.5], + # ) + mock_move_and_wait.assert_called_once_with( + [scan.device_manager.devices["samx"], scan.device_manager.devices["samy"]], [1.5, 0.5] + ) + mock_set_limits.assert_called_once()