From 73d407188fd5958d4e8c118650f2a62ccf1adc2f Mon Sep 17 00:00:00 2001 From: perl_d Date: Tue, 10 Feb 2026 15:09:13 +0100 Subject: [PATCH 1/8] refactor: mark all message dicts as jsonable --- .../tests/client_tests/test_live_table.py | 11 +- .../tests/end-2-end/test_scans_lib_e2e.py | 100 ++++++++--- bec_lib/bec_lib/bec_service.py | 1 + bec_lib/bec_lib/messages.py | 167 +++++++++++++----- bec_lib/bec_lib/one_way_registry.py | 97 ++++++++++ bec_lib/bec_lib/serialization_registry.py | 1 - bec_lib/tests/test_bec_messages.py | 14 +- bec_lib/tests/test_config_helper.py | 3 +- bec_lib/tests/test_serializer.py | 6 +- .../tests_scan_server/test_scan_stubs.py | 37 +++- .../tests_scan_server/test_scan_worker.py | 2 +- .../tests/tests_scan_server/test_scans.py | 6 +- 12 files changed, 349 insertions(+), 96 deletions(-) create mode 100644 bec_lib/bec_lib/one_way_registry.py diff --git a/bec_ipython_client/tests/client_tests/test_live_table.py b/bec_ipython_client/tests/client_tests/test_live_table.py index 891803831..53c8ccf56 100644 --- a/bec_ipython_client/tests/client_tests/test_live_table.py +++ b/bec_ipython_client/tests/client_tests/test_live_table.py @@ -303,8 +303,11 @@ def test_print_table_data_hinted_value_with_precision( @pytest.mark.parametrize( "value,expected", [ - (np.int32(1), "1.00"), - (np.float64(1.00000), "1.00"), + # Commented out cases are not supported in unstructured serialized data, because msgpack doesn't distinguish + # lists, tuples, or sets. To support this, ScanMessage must be refactored to support the type information directly + # except for numpy arrays, which are currently special-cased but will be removed in a future refactor. + # (np.int32(1), "1.00"), + # (np.float64(1.00000), "1.00"), (0, "0.00"), (1, "1.00"), (0.000, "0.00"), @@ -314,10 +317,10 @@ def test_print_table_data_hinted_value_with_precision( ("False", "False"), ("0", "0"), ("1", "1"), - ((0, 1), "(0, 1)"), + # ((0, 1), "(0, 1)"), ({"value": 0}, "{'value': 0}"), (np.array([0, 1]), "[0 1]"), - ({1, 2}, "{1, 2}"), + # ({1, 2}, "{1, 2}"), ], ) def test_print_table_data_variants(self, client_with_grid_scan, value, expected): diff --git a/bec_ipython_client/tests/end-2-end/test_scans_lib_e2e.py b/bec_ipython_client/tests/end-2-end/test_scans_lib_e2e.py index c81f8f3df..90ff9572e 100644 --- a/bec_ipython_client/tests/end-2-end/test_scans_lib_e2e.py +++ b/bec_ipython_client/tests/end-2-end/test_scans_lib_e2e.py @@ -22,7 +22,9 @@ def test_grid_scan_lib(bec_client_lib): bec.metadata.update({"unit_test": "test_grid_scan_bec_client_lib"}) dev = bec.device_manager.devices scans.umv(dev.samx, 0, dev.samy, 0, relative=False) - status = scans.grid_scan(dev.samx, -5, 5, 10, dev.samy, -5, 5, 10, exp_time=0.01, relative=True) + status = scans.grid_scan( + dev.samx, -5, 5, 10, dev.samy, -5, 5, 10, exp_time=0.01, relative=True + ) status.wait(num_points=True, file_written=True) assert len(status.scan.live_data) == 100 assert status.scan.num_points == 100 @@ -34,7 +36,9 @@ def test_grid_scan_lib_cancel(bec_client_lib): scans = bec.scans bec.metadata.update({"unit_test": "test_grid_scan_bec_client_lib"}) dev = bec.device_manager.devices - status = scans.grid_scan(dev.samx, -5, 5, 10, dev.samy, -5, 5, 10, exp_time=1, relative=False) + status = scans.grid_scan( + dev.samx, -5, 5, 10, dev.samy, -5, 5, 10, exp_time=1, relative=False + ) time.sleep(0.5) status.cancel() @@ -52,10 +56,14 @@ def test_mv_scan_lib(bec_client_lib): current_pos_samx = dev.samx.read()["samx"]["value"] current_pos_samy = dev.samy.read()["samy"]["value"] assert np.isclose( - current_pos_samx, 10, atol=dev.samx._config["deviceConfig"].get("tolerance", 0.05) + current_pos_samx, + 10, + atol=dev.samx._config["deviceConfig"].get("tolerance", 0.05), ) assert np.isclose( - current_pos_samy, 20, atol=dev.samy._config["deviceConfig"].get("tolerance", 0.05) + current_pos_samy, + 20, + atol=dev.samy._config["deviceConfig"].get("tolerance", 0.05), ) @@ -105,7 +113,9 @@ def dummy_callback(data, metadata): reference_container["metadata"] = metadata reference_container["data"].append(data) - s = scans.line_scan(dev.samx, 0, 1, steps=10, relative=False, async_callback=dummy_callback) + s = scans.line_scan( + dev.samx, 0, 1, steps=10, relative=False, async_callback=dummy_callback + ) s.wait() while len(reference_container["data"]) < 10: time.sleep(0.1) @@ -128,7 +138,9 @@ def scan_status_update(msg): pos = yield dev.samx.position cb_executed.set() - bec_client_lib.connector.register(MessageEndpoints.scan_status(), cb=scan_status_update) + bec_client_lib.connector.register( + MessageEndpoints.scan_status(), cb=scan_status_update + ) s = scans.line_scan(dev.samx, 0, 1, steps=10, exp_time=0.2, relative=False) s.wait() cb_executed.wait() @@ -145,10 +157,17 @@ def test_config_updates(bec_client_lib): assert dev.rt_controller.limits == [-50, 50] dev.rt_controller.velocity.set(10).wait() - assert dev.rt_controller.velocity.read(cached=True)["rt_controller_velocity"]["value"] == 10 + assert ( + dev.rt_controller.velocity.read(cached=True)["rt_controller_velocity"]["value"] + == 10 + ) assert dev.rt_controller.velocity.read()["rt_controller_velocity"]["value"] == 10 - assert dev.rt_controller.read_configuration()["rt_controller_velocity"]["value"] == 10 - assert dev.rt_controller.read_configuration()["rt_controller_velocity"]["value"] == 10 + assert ( + dev.rt_controller.read_configuration()["rt_controller_velocity"]["value"] == 10 + ) + assert ( + dev.rt_controller.read_configuration()["rt_controller_velocity"]["value"] == 10 + ) dev.rt_controller.velocity.put(5) assert dev.rt_controller.velocity.get() == 5 @@ -179,7 +198,13 @@ def test_dap_fit(bec_client_lib): dev.bpm4i.sim.select_model("GaussianModel") params = dev.bpm4i.sim.params params.update( - {"noise": "uniform", "noise_multiplier": 10, "center": 5, "sigma": 1, "amplitude": 200} + { + "noise": "uniform", + "noise_multiplier": 10, + "center": 5, + "sigma": 1, + "amplitude": 200, + } ) dev.bpm4i.sim.params = params time.sleep(1) @@ -347,11 +372,18 @@ def test_dap_fit(bec_client_lib): ], ) def test_config_reload( - bec_test_config_file_path, bec_client_lib, config, raises_error, deletes_config, disabled_device + bec_test_config_file_path, + bec_client_lib, + config, + raises_error, + deletes_config, + disabled_device, ): bec = bec_client_lib bec.metadata.update({"unit_test": "test_config_reload"}) - runtime_config_file_path = bec_test_config_file_path.parent / "e2e_runtime_config_test.yaml" + runtime_config_file_path = ( + bec_test_config_file_path.parent / "e2e_runtime_config_test.yaml" + ) # write new config to disk with open(runtime_config_file_path, "w") as f: @@ -369,7 +401,9 @@ def test_config_reload( else: assert len(bec.device_manager.devices) == num_devices else: - bec.config.update_session_with_file(runtime_config_file_path, force=True, validate=False) + bec.config.update_session_with_file( + runtime_config_file_path, force=True, validate=False + ) assert len(bec.device_manager.devices) == 2 for dev in disabled_device: assert bec.device_manager.devices[dev].enabled is False @@ -378,7 +412,9 @@ def test_config_reload( def test_config_reload_with_describe_failure(bec_test_config_file_path, bec_client_lib): bec = bec_client_lib bec.metadata.update({"unit_test": "test_config_reload"}) - runtime_config_file_path = bec_test_config_file_path.parent / "e2e_runtime_config_test.yaml" + runtime_config_file_path = ( + bec_test_config_file_path.parent / "e2e_runtime_config_test.yaml" + ) config = { "hexapod": { @@ -406,7 +442,8 @@ def test_config_reload_with_describe_failure(bec_test_config_file_path, bec_clie # set hexapod to fail bec.connector.set( - f"e2e_test_hexapod_fail", messages.DeviceStatusMessage(device="hexapod", status=1) + f"e2e_test_hexapod_fail", + messages.DeviceStatusMessage(device="hexapod", status=1), ) # write new config to disk @@ -414,7 +451,9 @@ def test_config_reload_with_describe_failure(bec_test_config_file_path, bec_clie f.write(yaml.dump(config)) with pytest.raises(DeviceConfigError): - bec.config.update_session_with_file(runtime_config_file_path, force=True, validate=False) + bec.config.update_session_with_file( + runtime_config_file_path, force=True, validate=False + ) assert len(bec.device_manager.devices) == 2 assert bec.device_manager.devices["eyefoc"].enabled is True @@ -422,7 +461,8 @@ def test_config_reload_with_describe_failure(bec_test_config_file_path, bec_clie # set hexapod to pass bec.connector.set( - f"e2e_test_hexapod_fail", messages.DeviceStatusMessage(device="hexapod", status=0) + f"e2e_test_hexapod_fail", + messages.DeviceStatusMessage(device="hexapod", status=0), ) bec.config.update_session_with_file(runtime_config_file_path, force=True) @@ -453,11 +493,15 @@ def test_config_add_remove_device(bec_client_lib): } bec.device_manager.config_helper.send_config_request(action="add", config=config) with pytest.raises(DeviceConfigError) as config_error: - bec.device_manager.config_helper.send_config_request(action="add", config=config) + bec.device_manager.config_helper.send_config_request( + action="add", config=config + ) assert config_error.match("Device new_device already exists") assert "new_device" in dev - bec.device_manager.config_helper.send_config_request(action="remove", config={"new_device": {}}) + bec.device_manager.config_helper.send_config_request( + action="remove", config={"new_device": {}} + ) assert "new_device" not in dev device_config_msg = bec.connector.get(MessageEndpoints.device_config()) @@ -468,7 +512,9 @@ def test_config_add_remove_device(bec_client_lib): config["new_device"]["deviceClass"] = "ophyd_devices.doesnt_exist" with pytest.raises(DeviceConfigError) as config_error: - bec.device_manager.config_helper.send_config_request(action="add", config=config) + bec.device_manager.config_helper.send_config_request( + action="add", config=config + ) assert config_error.match("module 'ophyd_devices' has no attribute 'doesnt_exist'") assert "new_device" not in dev assert "samx" in dev @@ -579,7 +625,9 @@ def test_image_analysis(bec_client_lib): assert (fit_res[1]["stats"]["min"] == 0.0).all() assert (np.isclose(fit_res[1]["stats"]["mean"], 3.3, atol=0.5)).all() # Center of mass is not in the middle due to hot (fluctuating) pixels - assert (np.isclose(fit_res[1]["stats"]["center_of_mass"], [49.5, 40.8], atol=2)).all() + assert ( + np.isclose(fit_res[1]["stats"]["center_of_mass"], [49.5, 40.8], atol=2) + ).all() @pytest.mark.timeout(100) @@ -597,7 +645,11 @@ def test_bl_state(bec_client_lib): tolerance=1, ) samx_config = DeviceWithinLimitsStateConfig( - name="samx_within_limits", device="samx", low_limit=-10, high_limit=10, tolerance=1 + name="samx_within_limits", + device="samx", + low_limit=-10, + high_limit=10, + tolerance=1, ) bec.beamline_states.add(hexapod_config) @@ -632,7 +684,9 @@ def test_bl_state(bec_client_lib): bec.beamline_states.delete("hexapod_x_within_limits") assert not hasattr(bec.beamline_states, "hexapod_x_within_limits") - bec.beamline_states.samx_within_limits.update_parameters(low_limit=-5, high_limit=25) + bec.beamline_states.samx_within_limits.update_parameters( + low_limit=-5, high_limit=25 + ) bec.beamline_states.show_all() while bec.beamline_states.samx_within_limits.get()["status"] != "valid": diff --git a/bec_lib/bec_lib/bec_service.py b/bec_lib/bec_lib/bec_service.py index 9659effbb..b0d1e7fb7 100644 --- a/bec_lib/bec_lib/bec_service.py +++ b/bec_lib/bec_lib/bec_service.py @@ -259,6 +259,7 @@ def _update_existing_services(self) -> None: msgs = [ self.connector.get(MessageEndpoints.service_status(service)) for service in services ] + print(msgs) self._services_info = {msg.content["name"]: msg for msg in msgs if msg is not None} msgs = [self.connector.get(MessageEndpoints.metrics(service)) for service in services] self._services_metric = {msg.content["name"]: msg for msg in msgs if msg is not None} diff --git a/bec_lib/bec_lib/messages.py b/bec_lib/bec_lib/messages.py index e09751510..c1109a397 100644 --- a/bec_lib/bec_lib/messages.py +++ b/bec_lib/bec_lib/messages.py @@ -9,13 +9,92 @@ from enum import Enum, auto from importlib.metadata import PackageNotFoundError from importlib.metadata import version as importlib_version -from typing import Annotated, Any, ClassVar, Literal, Self, Union +from types import NoneType +from typing import Annotated, Any, ClassVar, Literal, Self from uuid import uuid4 +import msgpack import numpy as np -from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator, model_validator +from pydantic import ( + BaseModel, + BeforeValidator, + ConfigDict, + Field, + ValidationError, + WithJsonSchema, + field_validator, + model_validator, +) +from typing_extensions import TypeAliasType from bec_lib.metadata_schema import get_metadata_schema_for_scan +from bec_lib.one_way_registry import OneWaySerializationRegistry + +_one_way_registry = OneWaySerializationRegistry() + + +def _sanitize_one_way(data: Any) -> Any: + # TODO: Temporary fix for standardizing message structure, will be replaced + # by encoders in a future iteration + if isinstance(data, np.ndarray): + return data + if isinstance(data, np.bool_): + return bool(data) + if isinstance(data, (np.float16, np.float32, np.float64)): + return float(data) + if isinstance(data, (np.int16, np.int32, np.int64, np.uint16, np.uint32, np.uint64)): + return int(data) + if isinstance(data, (list, tuple, set)): + return [_sanitize_one_way(x) for x in data] + if isinstance(data, dict): + return {_sanitize_one_way(k): _sanitize_one_way(v) for k, v in data.items()} + return _one_way_registry.encode(data) + + +def _ignore_ndarray(data: Any) -> Any: + if isinstance(data, np.ndarray): + return [] + raise ValueError(f"Cannot serialize unknown type for {data}: {type(data)}") + + +def _test_packable(data: Any): + try: + msgpack.dumps(data, default=_ignore_ndarray) + except Exception as e: + raise ValueError(f"Non-JSONable/msgpackable data in {data}!") from e + + +def _validate_packable(data: Any) -> Any: + # Skip sanitization if the data is already valid + if isinstance(data, int | float | str | bool | NoneType): + return data + if isinstance(data, np.bool_): + return bool(data) + try: + _test_packable(data) + return data + # Recursively check if we should replace anything which is not supposed to be decoded to a custom + # type on the other end + except ValueError: + data = _sanitize_one_way(data) + _test_packable(data) + return data + + +Jsonable = TypeAliasType( + "Jsonable", + Annotated[ + int | float | str | bool | None | list["Jsonable"] | dict[str, "Jsonable"] | np.ndarray, + BeforeValidator(_validate_packable), + ], +) + +JsonableDict = TypeAliasType( + "JsonableDict", + Annotated[ + dict[str, Jsonable], BeforeValidator(_validate_packable), WithJsonSchema({"type": "object"}) + ], +) class ProcedureWorkerStatus(Enum): @@ -44,8 +123,9 @@ class BECMessage(BaseModel): """ + model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") msg_type: ClassVar[str] - metadata: dict = Field(default_factory=dict) + metadata: JsonableDict = Field(default_factory=dict) @field_validator("metadata") @classmethod @@ -142,7 +222,7 @@ class ScanQueueMessage(BECMessage): msg_type: ClassVar[str] = "scan_queue_message" scan_type: str - parameter: dict + parameter: JsonableDict queue: str = Field(default="primary") allow_restart: bool = Field( default=True, @@ -226,18 +306,18 @@ class ScanStatusMessage(BECMessage): scan_type: Literal["step", "fly"] | None = Field(default=None, description="Type of scan") dataset_number: int | None = None scan_report_devices: list[str] | None = None - user_metadata: dict | None = None + user_metadata: JsonableDict | None = None readout_priority: ( dict[Literal["monitored", "baseline", "async", "continuous", "on_request"], list[str]] | None ) = None scan_parameters: dict[ - Literal["exp_time", "frames_per_trigger", "settling_time", "readout_time"] | str, Any + Literal["exp_time", "frames_per_trigger", "settling_time", "readout_time"] | str, Jsonable ] = Field(default_factory=dict) - request_inputs: dict[Literal["arg_bundle", "inputs", "kwargs"], Any] = Field( + request_inputs: dict[Literal["arg_bundle", "inputs", "kwargs"], Jsonable] = Field( default_factory=dict ) - info: dict + info: JsonableDict timestamp: float = Field(default_factory=time.time) def __str__(self): @@ -303,7 +383,7 @@ class ScanQueueModificationMessage(BECMessage): "release_lock", "user_completed", ] - parameter: dict + parameter: JsonableDict queue: str = Field(default="primary") @@ -551,7 +631,7 @@ class DeviceInstructionMessage(BECMessage): "publish_data_as_read", "close_scan_group", ] - parameter: dict + parameter: JsonableDict class ErrorInfo(BaseModel): @@ -749,7 +829,7 @@ class DeviceInfoMessage(BECMessage): msg_type: ClassVar[str] = "device_info_message" device: str - info: dict + info: JsonableDict class DeviceMonitor2DMessage(BECMessage): @@ -769,8 +849,6 @@ class DeviceMonitor2DMessage(BECMessage): data: np.ndarray timestamp: float = Field(default_factory=time.time) - metadata: dict | None = Field(default_factory=dict) - # Needed for pydantic to accept numpy arrays model_config = ConfigDict(arbitrary_types_allowed=True) @@ -810,8 +888,6 @@ class DeviceMonitor1DMessage(BECMessage): data: np.ndarray timestamp: float = Field(default_factory=time.time) - metadata: dict | None = Field(default_factory=dict) - # Needed for pydantic to accept numpy arrays model_config = ConfigDict(arbitrary_types_allowed=True) @@ -869,7 +945,7 @@ class DeviceUserROIMessage(BECMessage): device: str signal: str roi_type: str = Field(description="Type of the ROI, e.g. 'rectangle', 'circle', 'polygon'") - roi: dict = Field( + roi: JsonableDict = Field( description="Dictionary containing the ROI information, e.g. {'x': 100, 'y': 200, 'width': 50, 'height': 50}" ) timestamp: float = Field(default_factory=time.time) @@ -889,7 +965,7 @@ class ScanMessage(BECMessage): msg_type: ClassVar[str] = "scan_message" point_id: int scan_id: str - data: dict + data: JsonableDict class ScanHistoryMessage(BECMessage): @@ -923,7 +999,7 @@ class ScanHistoryMessage(BECMessage): end_time: float scan_name: str num_points: int - request_inputs: dict | None = None + request_inputs: JsonableDict | None = None stored_data_info: dict[str, dict[str, _StoredDataInfo]] | None = None @@ -951,7 +1027,7 @@ class ScanBaselineMessage(BECMessage): msg_type: ClassVar[str] = "scan_baseline_message" scan_id: str - data: dict + data: JsonableDict ConfigAction = Literal["add", "set", "update", "reload", "remove", "reset", "cancel"] @@ -969,7 +1045,7 @@ class DeviceConfigMessage(BECMessage): msg_type: ClassVar[str] = "device_config_message" action: ConfigAction | None = Field(default=None, validate_default=True) - config: dict | None = Field(default=None) + config: JsonableDict | None = Field(default=None) @model_validator(mode="after") @classmethod @@ -1015,7 +1091,7 @@ class LogMessage(BECMessage): log_type: Literal[ "trace", "debug", "info", "success", "warning", "error", "critical", "console_log" ] - log_msg: dict | str + log_msg: JsonableDict | str class AlarmMessage(BECMessage): @@ -1137,8 +1213,8 @@ class FileContentMessage(BECMessage): msg_type: ClassVar[str] = "file_content_message" file_path: str - data: dict - scan_info: dict + data: JsonableDict + scan_info: JsonableDict class VariableMessage(BECMessage): @@ -1179,7 +1255,7 @@ class ServiceMetricMessage(BECMessage): msg_type: ClassVar[str] = "service_metric_message" name: str - metrics: dict + metrics: JsonableDict class _StrDynamicMetricValue(BaseModel): @@ -1241,7 +1317,7 @@ class ProcessedDataMessage(BECMessage): """ msg_type: ClassVar[str] = "processed_data_message" - data: dict | list[dict] + data: JsonableDict | list[JsonableDict] class DAPConfigMessage(BECMessage): @@ -1253,7 +1329,7 @@ class DAPConfigMessage(BECMessage): """ msg_type: ClassVar[str] = "dap_config_message" - config: dict + config: JsonableDict class DAPRequestMessage(BECMessage): @@ -1269,7 +1345,7 @@ class DAPRequestMessage(BECMessage): msg_type: ClassVar[str] = "dap_request_message" dap_cls: str dap_type: Literal["continuous", "on_demand"] - config: dict + config: JsonableDict class DAPResponseMessage(BECMessage): @@ -1299,7 +1375,7 @@ class AvailableResourceMessage(BECMessage): """ msg_type: ClassVar[str] = "available_resource_message" - resource: dict | list[dict] | BaseModel | list[BaseModel] + resource: JsonableDict | list[JsonableDict] | BECMessage | list[BECMessage] class ProgressMessage(BECMessage): @@ -1327,7 +1403,7 @@ class GUIConfigMessage(BECMessage): """ msg_type: ClassVar[str] = "gui_config_message" - config: dict + config: JsonableDict class GUIDataMessage(BECMessage): @@ -1339,7 +1415,7 @@ class GUIDataMessage(BECMessage): """ msg_type: ClassVar[str] = "gui_data_message" - data: dict + data: JsonableDict class GUIInstructionMessage(BECMessage): @@ -1352,7 +1428,7 @@ class GUIInstructionMessage(BECMessage): msg_type: ClassVar[str] = "gui_instruction_message" action: str - parameter: dict + parameter: JsonableDict class GUIAutoUpdateConfigMessage(BECMessage): @@ -1389,7 +1465,7 @@ class GUIRegistryStateMessage(BECMessage): "container_proxy", "skip_rpc_namespace", ], - str | bool | dict | None, + str | bool | JsonableDict | None, ], ] @@ -1403,7 +1479,7 @@ class ServiceResponseMessage(BECMessage): """ msg_type: ClassVar[str] = "service_response_message" - response: dict + response: JsonableDict class CredentialsMessage(BECMessage): @@ -1415,7 +1491,7 @@ class CredentialsMessage(BECMessage): """ msg_type: ClassVar[str] = "credentials_message" - credentials: dict + credentials: JsonableDict class RawMessage(BECMessage): @@ -1428,7 +1504,7 @@ class RawMessage(BECMessage): """ msg_type: ClassVar[str] = "raw_message" - data: Any + data: Jsonable model_config = ConfigDict(arbitrary_types_allowed=True) @@ -1695,8 +1771,7 @@ class MessagingConfig(BaseModel): AvailableMessagingServices = Annotated[ - Union[SignalServiceInfo, SciLogServiceInfo, TeamsServiceInfo], - Field(discriminator="service_type"), + SignalServiceInfo | SciLogServiceInfo | TeamsServiceInfo, Field(discriminator="service_type") ] @@ -1764,7 +1839,6 @@ class EndpointInfoMessage(BECMessage): msg_type: ClassVar[str] = "endpoint_info_message" endpoint: str - metadata: dict | None = Field(default_factory=dict) class ScriptExecutionInfoMessage(BECMessage): @@ -1797,8 +1871,6 @@ class MacroUpdateMessage(BECMessage): macro_name: str | None = None file_path: str | None = None - metadata: dict | None = Field(default_factory=dict) - @model_validator(mode="after") @classmethod def check_macro(cls, values): @@ -1873,13 +1945,13 @@ class MessagingServiceGiphyContent(BaseModel): giphy_url: str -MessagingServiceContent = Union[ - MessagingServiceTextContent, - MessagingServiceFileContent, - MessagingServiceTagsContent, - MessagingServiceStickerContent, - MessagingServiceGiphyContent, -] +MessagingServiceContent = ( + MessagingServiceTextContent + | MessagingServiceFileContent + | MessagingServiceTagsContent + | MessagingServiceStickerContent + | MessagingServiceGiphyContent +) class MessagingServiceMessage(BECMessage): @@ -1897,7 +1969,6 @@ class MessagingServiceMessage(BECMessage): service_name: Literal["signal", "teams", "scilog"] message: list[MessagingServiceContent] scope: str | list[str] | None = None - metadata: dict | None = Field(default_factory=dict) class FeedbackMessage(BECMessage): diff --git a/bec_lib/bec_lib/one_way_registry.py b/bec_lib/bec_lib/one_way_registry.py new file mode 100644 index 000000000..d1f92a90e --- /dev/null +++ b/bec_lib/bec_lib/one_way_registry.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from functools import lru_cache +from typing import Any, Callable, Type + +from bec_lib.device import DeviceBase + + +class OneWayBECCodec(ABC): + """Abstract base class for custom encoders""" + + obj_type: Type | list[Type] + + @staticmethod + @abstractmethod + def encode(obj: Any) -> Any: + """Encode an object into a serializable format.""" + + +class BECDeviceEncoder(OneWayBECCodec): + obj_type = DeviceBase + + @staticmethod + def encode(obj: DeviceBase) -> str: + if hasattr(obj, "_compile_function_path"): + # pylint: disable=protected-access + return obj._compile_function_path() + return obj.name + + +class OneWaySerializationRegistry: + """Registry for serialization codecs""" + + def __init__(self): + self._registry: dict[str, tuple[Type, Callable]] = {} + + self.register_codec(BECDeviceEncoder) + + def register_codec(self, codec: Type[OneWayBECCodec]): + """ + Register a codec for a specific BECCodec subclass. + This method allows for easy registration of custom encoders and decoders + for BECMessage and other types. + + Args: + codec: A subclass of BECCodec that implements encode and decode methods. + Raises: + ValueError: If a codec for the specified type is already registered. + """ + if isinstance(codec.obj_type, list): + for cls in codec.obj_type: + self.register(cls, codec.encode) + else: + self.register(codec.obj_type, codec.encode) + + def register(self, cls: Type, encoder: Callable): + """Register a codec for a specific type.""" + + if cls.__name__ in self._registry: + raise ValueError(f"Codec for {cls} already registered.") + self._registry[cls.__name__] = (cls, encoder) + self.get_codec.cache_clear() # Clear the cache when a new codec is registered + + @lru_cache(maxsize=2000) + def get_codec(self, cls: Type) -> tuple[Type, Callable] | None: + """Get the codec for a specific type.""" + codec = self._registry.get(cls.__name__) + if codec: + return codec + for _, (registered_cls, encoder) in self._registry.items(): + if issubclass(cls, registered_cls): + return registered_cls, encoder + return None + + def is_registered(self, cls: Type) -> bool: + """ + Check if a codec is registered for a specific type. + Args: + cls: The class type to check for a registered codec. + Returns: + bool: True if a codec is registered for the type, False otherwise. + """ + return self.get_codec(cls) is not None + + def encode(self, obj): + """Encode an object using the registered codec.""" + codec = self.get_codec(type(obj)) + if not codec: + return obj # No codec registered for this type + _, encoder = codec + try: + return encoder(obj) + except Exception as e: + raise ValueError( + f"Serialization failed: Failed to encode {obj.__class__.__name__} with codec {encoder}: {e}" + ) from e diff --git a/bec_lib/bec_lib/serialization_registry.py b/bec_lib/bec_lib/serialization_registry.py index e9aa923c8..7bcbfefcc 100644 --- a/bec_lib/bec_lib/serialization_registry.py +++ b/bec_lib/bec_lib/serialization_registry.py @@ -19,7 +19,6 @@ def __init__(self): self._legacy_codecs = [] # can be removed in future versions, see issue #516 self.register_codec(bec_codecs.BECMessageEncoder) - self.register_codec(bec_codecs.BECDeviceEncoder) self.register_codec(bec_codecs.EndpointInfoEncoder) self.register_codec(bec_codecs.SetEncoder) self.register_codec(bec_codecs.BECTypeEncoder) diff --git a/bec_lib/tests/test_bec_messages.py b/bec_lib/tests/test_bec_messages.py index ee700fcdc..3af5721b0 100644 --- a/bec_lib/tests/test_bec_messages.py +++ b/bec_lib/tests/test_bec_messages.py @@ -431,15 +431,13 @@ def test_DeviceInstructionMessage(): def test_DeviceMonitor2DMessage(): # Test 2D data - msg = messages.DeviceMonitor2DMessage( - device="eiger", data=np.random.rand(2, 100), metadata=None - ) + msg = messages.DeviceMonitor2DMessage(device="eiger", data=np.random.rand(2, 100)) res = MsgpackSerialization.dumps(msg) res_loaded = MsgpackSerialization.loads(res) assert res_loaded == msg assert res_loaded.metadata == {} # Test rgb image, i.e. image with 3 channels - msg = messages.DeviceMonitor2DMessage(device="eiger", data=np.random.rand(3, 3), metadata=None) + msg = messages.DeviceMonitor2DMessage(device="eiger", data=np.random.rand(3, 3)) res = MsgpackSerialization.dumps(msg) res_loaded = MsgpackSerialization.loads(res) assert res_loaded == msg @@ -456,7 +454,7 @@ def test_DeviceMonitor2DMessage(): def test_DeviceMonitor1DMessage(): # Test 2D data - msg = messages.DeviceMonitor1DMessage(device="eiger", data=np.random.rand(100), metadata=None) + msg = messages.DeviceMonitor1DMessage(device="eiger", data=np.random.rand(100)) res = MsgpackSerialization.dumps(msg) res_loaded = MsgpackSerialization.loads(res) assert res_loaded == msg @@ -704,3 +702,9 @@ def test_feedback_message(): assert res_loaded == msg assert res_loaded.username == getpass.getuser() assert res_loaded.versions == messages.ServiceVersions._get_version_numbers() + + +def test_message_with_np_array_in_dict(): + arr = np.zeros(5) + msg = messages.ScanMessage(point_id=0, scan_id="", data={"device": {"value": arr}}, metadata={}) + assert isinstance(msg.data["device"]["value"], np.ndarray) diff --git a/bec_lib/tests/test_config_helper.py b/bec_lib/tests/test_config_helper.py index 72a9431e9..8cf830815 100644 --- a/bec_lib/tests/test_config_helper.py +++ b/bec_lib/tests/test_config_helper.py @@ -419,7 +419,8 @@ def test_config_helper_get_config_conflicts( config.update(dev_cfg) config_in_redis.append(config) with mock.patch.object(config_helper._device_manager.connector, "get") as mock_get: - mock_get.return_value = messages.AvailableResourceMessage(resource=config_in_redis) + available_resource_message = messages.AvailableResourceMessage(resource=config_in_redis) + mock_get.return_value = available_resource_message conflicts = config_helper._get_config_conflicts(new_config) assert conflicts == expected_conflicts diff --git a/bec_lib/tests/test_serializer.py b/bec_lib/tests/test_serializer.py index 93f5990de..e1a254e7b 100644 --- a/bec_lib/tests/test_serializer.py +++ b/bec_lib/tests/test_serializer.py @@ -10,6 +10,7 @@ from bec_lib.device import DeviceBase from bec_lib.devicemanager import DeviceManagerBase from bec_lib.endpoints import MessageEndpoints +from bec_lib.one_way_registry import OneWaySerializationRegistry from bec_lib.serialization import MsgpackSerialization, json_ext, msgpack @@ -81,10 +82,11 @@ class DummyModel(BaseModel): assert data.model_dump() == converted_data -def test_device_serializer(serializer): +def test_device_serializer(): + serializer = OneWaySerializationRegistry() device_manager = mock.MagicMock(spec=DeviceManagerBase) dummy = DeviceBase(name="dummy", parent=device_manager) - assert serializer.loads(serializer.dumps(dummy)) == "dummy" + assert serializer.encode(dummy) == "dummy" def test_enum_serializer(serializer): diff --git a/bec_server/tests/tests_scan_server/test_scan_stubs.py b/bec_server/tests/tests_scan_server/test_scan_stubs.py index 4ab51e9bd..98ab65bfc 100644 --- a/bec_server/tests/tests_scan_server/test_scan_stubs.py +++ b/bec_server/tests/tests_scan_server/test_scan_stubs.py @@ -49,7 +49,11 @@ def stubs(): device="rtx", action="kickoff", parameter={ - "configure": {"num_pos": 5, "positions": [1, 2, 3, 4, 5], "exp_time": 2} + "configure": { + "num_pos": 5, + "positions": [1, 2, 3, 4, 5], + "exp_time": 2, + } }, metadata={}, ), @@ -57,7 +61,9 @@ def stubs(): ], ) def test_kickoff(stubs, device, parameter, metadata, reference_msg): - msg = list(stubs.kickoff(device=device, parameter=parameter, metadata=metadata, wait=False)) + msg = list( + stubs.kickoff(device=device, parameter=parameter, metadata=metadata, wait=False) + ) reference_msg.metadata["device_instr_id"] = msg[0].metadata["device_instr_id"] assert msg[0] == reference_msg @@ -74,12 +80,16 @@ def test_kickoff(stubs, device, parameter, metadata, reference_msg): False, ), ( - messages.ProgressMessage(value=10, max_value=100, done=False, metadata={"RID": "rid"}), + messages.ProgressMessage( + value=10, max_value=100, done=False, metadata={"RID": "rid"} + ), 10, False, ), ( - messages.DeviceStatusMessage(device="samx", status=0, metadata={"RID": "rid"}), + messages.DeviceStatusMessage( + device="samx", status=0, metadata={"RID": "rid"} + ), None, True, ), @@ -96,7 +106,9 @@ def test_device_progress(stubs, msg, ret_value, raised_error): def test_send_rpc_and_wait(stubs, ScanStubStatusMock): - with mock.patch.object(stubs, "_get_result_from_status", return_value="msg") as get_rpc: + with mock.patch.object( + stubs, "_get_result_from_status", return_value="msg" + ) as get_rpc: original_rpc = stubs.send_rpc with mock.patch.object(stubs, "send_rpc") as mock_rpc: @@ -106,12 +118,19 @@ def mock_rpc_func(*args, **kwargs): mock_rpc.side_effect = mock_rpc_func - instructions = list(stubs.send_rpc_and_wait("sim_profile", "readback_profile")) + instructions = list( + stubs.send_rpc_and_wait("sim_profile", "readback_profile") + ) rpc_call_1 = instructions[0] - instructions = list(stubs.send_rpc_and_wait("sim_profile", "readback_profile")) + instructions = list( + stubs.send_rpc_and_wait("sim_profile", "readback_profile") + ) rpc_call_2 = instructions[0] assert rpc_call_1 != rpc_call_2 - assert rpc_call_1.metadata["device_instr_id"] != rpc_call_2.metadata["device_instr_id"] + assert ( + rpc_call_1.metadata["device_instr_id"] + != rpc_call_2.metadata["device_instr_id"] + ) def test_stage(stubs): @@ -131,6 +150,7 @@ def test_rpc_call_returns_status(stubs): fake_status = mock.MagicMock() fake_status._result_is_status = True fake_status.wait = mock.MagicMock() + fake_status._device_instr_id = "fake device instruction ID" with mock.patch.object(stubs, "_create_status", return_value=fake_status): result = stubs._rpc_call("samx", "velocity.set", 10) @@ -146,6 +166,7 @@ def test_rpc_call_returns_dict(stubs): fake_status._result_is_status = False fake_status.result = expected fake_status.wait = mock.MagicMock() + fake_status._device_instr_id = "fake device instruction ID" with mock.patch.object(stubs, "_create_status", return_value=fake_status): result = stubs._rpc_call("samx", "velocity.set", 10) diff --git a/bec_server/tests/tests_scan_server/test_scan_worker.py b/bec_server/tests/tests_scan_server/test_scan_worker.py index 46cf59370..6de30a457 100644 --- a/bec_server/tests/tests_scan_server/test_scan_worker.py +++ b/bec_server/tests/tests_scan_server/test_scan_worker.py @@ -293,7 +293,7 @@ def test_initialize_scan_info(scan_worker_mock, msg): assert worker.current_scan_info["scan_msgs"] == [] assert worker.current_scan_info["monitor_sync"] == "bec" assert worker.current_scan_info["frames_per_trigger"] == 1 - assert worker.current_scan_info["args"] == {"samx": (-5, 5, 5), "samy": (-1, 1, 2)} + assert worker.current_scan_info["args"] == {"samx": [-5, 5, 5], "samy": [-1, 1, 2]} assert worker.current_scan_info["kwargs"] == msg.parameter["kwargs"] assert "samx" in worker.current_scan_info["readout_priority"]["monitored"] assert "samy" in worker.current_scan_info["readout_priority"]["baseline"] diff --git a/bec_server/tests/tests_scan_server/test_scans.py b/bec_server/tests/tests_scan_server/test_scans.py index 50c293983..900103311 100644 --- a/bec_server/tests/tests_scan_server/test_scans.py +++ b/bec_server/tests/tests_scan_server/test_scans.py @@ -167,7 +167,7 @@ def offset_mock(): "RID": "0bab7ee3-b384-4571-b...0fff984c05", "devices": ["samx", "samy"], "start": [0, 0], - "end": np.array([1.0, 2.0]), + "end": [1.0, 2.0], } }, metadata={ @@ -211,7 +211,7 @@ def offset_mock(): "RID": "0bab7ee3-b384-4571-b...0fff984c05", "devices": ["samx", "samy", "samz"], "start": [0, 0, 0], - "end": np.array([1.0, 2.0, 3.0]), + "end": [1.0, 2.0, 3.0], } }, metadata={ @@ -264,7 +264,7 @@ def offset_mock(): "RID": "0bab7ee3-b384-4571-b...0fff984c05", "devices": ["samx"], "start": [0], - "end": np.array([1.0]), + "end": [1.0], } }, metadata={ From 99fce2891e0fefebd0bceac8b3a35e3a3e0affbb Mon Sep 17 00:00:00 2001 From: David Perl Date: Tue, 10 Feb 2026 23:34:06 +0100 Subject: [PATCH 2/8] refactor: separate message serialization and encoding --- bec_lib/bec_lib/bec_serializable.py | 23 ++++++++++++++++++++ bec_lib/bec_lib/devicemanager.py | 8 ++++++- bec_lib/bec_lib/messages.py | 24 +++++++++++++++------ bec_lib/bec_lib/serialization.py | 2 ++ bec_lib/bec_lib/serialization_registry.py | 7 +++++- bec_lib/tests/test_bec_messages.py | 13 ++++++++++- bec_server/bec_server/procedures/manager.py | 1 + 7 files changed, 69 insertions(+), 9 deletions(-) create mode 100644 bec_lib/bec_lib/bec_serializable.py diff --git a/bec_lib/bec_lib/bec_serializable.py b/bec_lib/bec_lib/bec_serializable.py new file mode 100644 index 000000000..1fa68442e --- /dev/null +++ b/bec_lib/bec_lib/bec_serializable.py @@ -0,0 +1,23 @@ +from pydantic import BaseModel, ConfigDict, computed_field + + +class BecCodecInfo(BaseModel): + type_name: str + + +class BECSerializable(BaseModel): + """A base class for serializable BEC objects, especially BEC messages. + Fields in subclasses which use non-primitive types must be in structured, + type-hinted objects, and their encoders and JSON schema should be defined in + this class.""" + + model_config = ConfigDict( + json_schema_serialization_defaults_required=True, + arbitrary_types_allowed=True, + extra="forbid", + ) + + @computed_field() + @property + def bec_codec(self) -> BecCodecInfo: + return BecCodecInfo(type_name=self.__class__.__name__) diff --git a/bec_lib/bec_lib/devicemanager.py b/bec_lib/bec_lib/devicemanager.py index e77dd87d5..7fbbcb0b1 100644 --- a/bec_lib/bec_lib/devicemanager.py +++ b/bec_lib/bec_lib/devicemanager.py @@ -693,15 +693,21 @@ def _get_redis_device_config(self) -> list: return devices.content["resource"] def _add_multiple_devices_with_log(self, devices: Iterable[tuple[dict, DeviceInfoMessage]]): + override = self._allow_override try: - override = self._allow_override self._allow_override = True logs = (self._add_device(*conf_msg) for conf_msg in devices if conf_msg is not None) + if set(logs) == {None}: + logger.warning("No devices added!") + return logger.info(f"Adding new devices:\n" + ", ".join(f"{name}: {t}" for name, t in logs)) # type: ignore # filtered finally: self._allow_override = override def _add_device(self, dev: dict, msg: DeviceInfoMessage) -> tuple[str, str] | None: + if msg is None: + logger.error(f"No device info in Redis for: {dev}") + return None name = msg.content["device"] info = msg.content["info"] diff --git a/bec_lib/bec_lib/messages.py b/bec_lib/bec_lib/messages.py index c1109a397..8789cd36a 100644 --- a/bec_lib/bec_lib/messages.py +++ b/bec_lib/bec_lib/messages.py @@ -27,6 +27,7 @@ ) from typing_extensions import TypeAliasType +from bec_lib.bec_serializable import BECSerializable from bec_lib.metadata_schema import get_metadata_schema_for_scan from bec_lib.one_way_registry import OneWaySerializationRegistry @@ -114,7 +115,7 @@ class BECStatus(Enum): ERROR = -1 -class BECMessage(BaseModel): +class BECMessage(BECSerializable): """Base Model class for BEC Messages Args: @@ -123,7 +124,6 @@ class BECMessage(BaseModel): """ - model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") msg_type: ClassVar[str] metadata: JsonableDict = Field(default_factory=dict) @@ -137,6 +137,13 @@ def check_metadata(cls, v): """ return v or {} + @model_validator(mode="before") + @classmethod + def _strip_codec_info(cls, data: Any): + if isinstance(data, dict): + data.pop("bec_codec", None) + return data + @property def content(self): """Return the content of the message""" @@ -174,6 +181,11 @@ def __hash__(self) -> int: return self.model_dump_json().__hash__() +# To correctly encode a message in another message, pydantic should know it is to be dumped +# as the concrete type it is, and not only the fields from BECMessage +SpecificMessageType = TypeVar("MessageType", bound=BECMessage) + + class BundleMessage(BECMessage): """Message type to send a bundle of BECMessages. @@ -189,7 +201,7 @@ class BundleMessage(BECMessage): """ msg_type: ClassVar[str] = "bundle_message" - messages: list = Field(default_factory=list[BECMessage]) + messages: list[SpecificMessageType] = Field(default_factory=list) def append(self, msg: BECMessage): """Append a new BECMessage to the bundle""" @@ -1363,19 +1375,19 @@ class DAPResponseMessage(BECMessage): success: bool data: tuple | None = Field(default_factory=lambda: ({}, None)) error: str | None = None - dap_request: BECMessage | None = Field(default=None) + dap_request: SpecificMessageType | None = Field(default=None) class AvailableResourceMessage(BECMessage): """Message for available resources such as scans, data processing plugins etc Args: - resource (dict, list[dict], BECMessage, list[BECMessage]): Resource description + resource (dict, list[dict], BECMessage, list[BECMessage]): Resource description - may contain only one type of BECMessage metadata (dict, optional): Metadata. Defaults to None. """ msg_type: ClassVar[str] = "available_resource_message" - resource: JsonableDict | list[JsonableDict] | BECMessage | list[BECMessage] + resource: JsonableDict | list[JsonableDict] | SpecificMessageType | list[SpecificMessageType] class ProgressMessage(BECMessage): diff --git a/bec_lib/bec_lib/serialization.py b/bec_lib/bec_lib/serialization.py index f512c279e..fe3ffce1a 100644 --- a/bec_lib/bec_lib/serialization.py +++ b/bec_lib/bec_lib/serialization.py @@ -36,6 +36,8 @@ class BECMessagePack(SerializationRegistry): def dumps(self, obj): """Pack object `obj` and return packed bytes.""" + if isinstance(obj, BECMessage): + obj = obj.model_dump(mode="python", fallback=self.encode) return msgpack_module.packb(obj, default=self.encode) def loads(self, raw_bytes): diff --git a/bec_lib/bec_lib/serialization_registry.py b/bec_lib/bec_lib/serialization_registry.py index 7bcbfefcc..e2e6c8d81 100644 --- a/bec_lib/bec_lib/serialization_registry.py +++ b/bec_lib/bec_lib/serialization_registry.py @@ -4,6 +4,7 @@ from typing import Callable, Type from bec_lib import codecs as bec_codecs +from bec_lib import messages from bec_lib.logger import bec_logger logger = bec_logger.logger @@ -18,7 +19,6 @@ def __init__(self): self._registry: dict[str, tuple[Type, Callable, Callable]] = {} self._legacy_codecs = [] # can be removed in future versions, see issue #516 - self.register_codec(bec_codecs.BECMessageEncoder) self.register_codec(bec_codecs.EndpointInfoEncoder) self.register_codec(bec_codecs.SetEncoder) self.register_codec(bec_codecs.BECTypeEncoder) @@ -97,6 +97,11 @@ def encode(self, obj): def decode(self, data): """Decode an object using the registered codec.""" + if isinstance(data, dict) and "bec_codec" in data: + codec_info = data.pop("bec_codec") + msg_cls = messages.__dict__.get(codec_info.get("type_name")) + if msg_cls is not None: + return msg_cls.model_validate(data) if not isinstance(data, dict) or "__bec_codec__" not in data: return data codec_info = data["__bec_codec__"] diff --git a/bec_lib/tests/test_bec_messages.py b/bec_lib/tests/test_bec_messages.py index 3af5721b0..1b3c6c0aa 100644 --- a/bec_lib/tests/test_bec_messages.py +++ b/bec_lib/tests/test_bec_messages.py @@ -19,7 +19,7 @@ def test_bec_message_msgpack_serialization_version(version): assert "Unsupported BECMessage version" in str(exception.value) else: res = MsgpackSerialization.dumps(msg) - res_expected = b"\x81\xad__bec_codec__\x83\xacencoder_name\xaaBECMessage\xa9type_name\xb8DeviceInstructionMessage\xa4data\x84\xa8metadata\x81\xa3RID\xa41234\xa6device\xa4samx\xa6action\xa3set\xa9parameter\x81\xa3set\xcb?\xe0\x00\x00\x00\x00\x00\x00" + res_expected = b"\x85\xa8metadata\x81\xa3RID\xa41234\xa6device\xa4samx\xa6action\xa3set\xa9parameter\x81\xa3set\xcb?\xe0\x00\x00\x00\x00\x00\x00\xa9bec_codec\x81\xa9type_name\xb8DeviceInstructionMessage" assert res == res_expected res_loaded = MsgpackSerialization.loads(res) assert res_loaded == msg @@ -708,3 +708,14 @@ def test_message_with_np_array_in_dict(): arr = np.zeros(5) msg = messages.ScanMessage(point_id=0, scan_id="", data={"device": {"value": arr}}, metadata={}) assert isinstance(msg.data["device"]["value"], np.ndarray) + + +def test_message_service_config(): + msg = messages.MessagingServiceConfig( + metadata={}, service_name="signal", scopes=["*"], enabled=True + ) + dump = msg.model_dump(mode="python") + assert dump["service_name"] == "signal" + resource_msg = messages.AvailableResourceMessage(resource=[msg]) + resource_msg_dump = resource_msg.model_dump(mode="python") + assert resource_msg_dump["resource"][0]["service_name"] == "signal" diff --git a/bec_server/bec_server/procedures/manager.py b/bec_server/bec_server/procedures/manager.py index 9f6d5487d..e9e6e4773 100644 --- a/bec_server/bec_server/procedures/manager.py +++ b/bec_server/bec_server/procedures/manager.py @@ -51,6 +51,7 @@ def _log_on_end(future: Future): def _resolve_dict(msg: dict[str, Any] | _T, MsgType: type[_T]) -> _T: if isinstance(msg, dict): + msg.pop("bec_codec", None) return MsgType.model_validate(msg) return msg From 53f8eeeae5ffce0bd6430c198549758f76c6d855 Mon Sep 17 00:00:00 2001 From: perl_d Date: Wed, 11 Feb 2026 12:24:59 +0100 Subject: [PATCH 3/8] refactor: replace custom data in msgs with struct - Replaces available scans dict with a new model - Replaces numpy usage with structured data or casting to list depending on what seems appropriate --- .../client_tests/test_ipython_live_updates.py | 6 +- .../tests/client_tests/test_live_table.py | 11 +- .../tests/end-2-end/test_scans_e2e.py | 2 +- .../tests/end-2-end/test_scans_lib_e2e.py | 26 +-- bec_lib/bec_lib/bec_serializable.py | 10 + bec_lib/bec_lib/bec_service.py | 1 - bec_lib/bec_lib/device.py | 6 +- bec_lib/bec_lib/messages.py | 175 ++++++++++-------- bec_lib/bec_lib/scans.py | 9 +- bec_lib/bec_lib/signature_serializer.py | 2 +- bec_lib/tests/test_bec_messages.py | 6 +- bec_lib/tests/test_config_helper.py | 28 +-- bec_lib/tests/test_devices.py | 81 +++++--- bec_lib/tests/test_file_utils.py | 8 +- bec_lib/tests/test_scan_context.py | 2 +- bec_lib/tests/test_signature_serializer.py | 12 +- ...signature_serializer_with_future_import.py | 10 +- .../bec_server/device_server/device_server.py | 4 +- .../devices/device_serializer.py | 9 +- .../device_server/devices/devicemanager.py | 3 +- bec_server/bec_server/procedures/manager.py | 2 +- .../bec_server/scan_server/scan_assembler.py | 8 +- .../bec_server/scan_server/scan_gui_models.py | 13 +- .../bec_server/scan_server/scan_manager.py | 36 ++-- .../scan_server/scan_plugins/otf_scan.py | 3 +- .../bec_server/scan_server/scan_stubs.py | 4 +- bec_server/bec_server/scan_server/scans.py | 14 +- .../test_config_handler.py | 2 +- .../tests_device_server/test_rpc_handler.py | 4 +- .../test_async_file_writer.py | 2 +- .../tests_scan_server/test_scan_assembler.py | 23 ++- .../tests_scan_server/test_scan_guard.py | 18 +- .../tests_scan_server/test_scan_gui_models.py | 5 +- .../test_scan_server_queue.py | 74 ++++---- .../test_scan_server_scan_manager.py | 2 +- .../tests_scan_server/test_scan_worker.py | 8 +- .../tests/tests_scan_server/test_scans.py | 62 +++---- 37 files changed, 385 insertions(+), 306 deletions(-) diff --git a/bec_ipython_client/tests/client_tests/test_ipython_live_updates.py b/bec_ipython_client/tests/client_tests/test_ipython_live_updates.py index f26808044..f74224009 100644 --- a/bec_ipython_client/tests/client_tests/test_ipython_live_updates.py +++ b/bec_ipython_client/tests/client_tests/test_ipython_live_updates.py @@ -23,7 +23,7 @@ def queue_elements(bec_client_mock): client = bec_client_mock request_msg = messages.ScanQueueMessage( scan_type="grid_scan", - parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}}, + parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -52,7 +52,7 @@ def queue_elements(bec_client_mock): def sample_request_msg(): return messages.ScanQueueMessage( scan_type="grid_scan", - parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}}, + parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -232,7 +232,7 @@ def test_available_req_blocks_multiple_blocks(bec_client_mock): request_msg = messages.ScanQueueMessage( scan_type="grid_scan", - parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}}, + parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}}, queue="primary", metadata={"RID": "test_rid"}, ) diff --git a/bec_ipython_client/tests/client_tests/test_live_table.py b/bec_ipython_client/tests/client_tests/test_live_table.py index 53c8ccf56..d66e51228 100644 --- a/bec_ipython_client/tests/client_tests/test_live_table.py +++ b/bec_ipython_client/tests/client_tests/test_live_table.py @@ -50,7 +50,7 @@ def client_with_grid_scan(bec_client_mock): client = bec_client_mock request_msg = messages.ScanQueueMessage( scan_type="grid_scan", - parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}}, + parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -88,7 +88,7 @@ def test_sort_devices(self): ( messages.ScanQueueMessage( scan_type="grid_scan", - parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}}, + parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ), @@ -134,7 +134,7 @@ def test_wait_for_request_acceptance(self, client_with_grid_scan): def test_run_update(self, bec_client_mock, scan_item): request_msg = messages.ScanQueueMessage( scan_type="grid_scan", - parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}}, + parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -161,7 +161,7 @@ def test_run_update(self, bec_client_mock, scan_item): def test_run_update_without_monitored_devices(self, bec_client_mock, scan_item): request_msg = messages.ScanQueueMessage( scan_type="grid_scan", - parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}}, + parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -305,7 +305,6 @@ def test_print_table_data_hinted_value_with_precision( [ # Commented out cases are not supported in unstructured serialized data, because msgpack doesn't distinguish # lists, tuples, or sets. To support this, ScanMessage must be refactored to support the type information directly - # except for numpy arrays, which are currently special-cased but will be removed in a future refactor. # (np.int32(1), "1.00"), # (np.float64(1.00000), "1.00"), (0, "0.00"), @@ -319,7 +318,7 @@ def test_print_table_data_hinted_value_with_precision( ("1", "1"), # ((0, 1), "(0, 1)"), ({"value": 0}, "{'value': 0}"), - (np.array([0, 1]), "[0 1]"), + # (np.array([0, 1]), "[0 1]"), # ({1, 2}, "{1, 2}"), ], ) diff --git a/bec_ipython_client/tests/end-2-end/test_scans_e2e.py b/bec_ipython_client/tests/end-2-end/test_scans_e2e.py index 238883c11..0961c2667 100644 --- a/bec_ipython_client/tests/end-2-end/test_scans_e2e.py +++ b/bec_ipython_client/tests/end-2-end/test_scans_e2e.py @@ -918,7 +918,7 @@ def test_scan_repeat_decorator(bec_ipython_client_fixture): "update_frequency": 400, }, "readoutPriority": "baseline", - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "enabled": True, "readOnly": False, } diff --git a/bec_ipython_client/tests/end-2-end/test_scans_lib_e2e.py b/bec_ipython_client/tests/end-2-end/test_scans_lib_e2e.py index 90ff9572e..4aa7d5dd2 100644 --- a/bec_ipython_client/tests/end-2-end/test_scans_lib_e2e.py +++ b/bec_ipython_client/tests/end-2-end/test_scans_lib_e2e.py @@ -243,7 +243,7 @@ def test_dap_fit(bec_client_lib): "hexapod": { "deviceClass": "ophyd_devices.SynDeviceOPAAS", "deviceConfig": {}, - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "readoutPriority": "baseline", "enabled": True, "readOnly": False, @@ -256,7 +256,7 @@ def test_dap_fit(bec_client_lib): "tolerance": 0.01, "update_frequency": 400, }, - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "enabled": True, "readOnly": False, }, @@ -270,7 +270,7 @@ def test_dap_fit(bec_client_lib): "hexapod": { "deviceClass": "ophyd_devices.SynDeviceOPAAS", "deviceConfig": {}, - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "readoutPriority": "baseline", "enabled": True, "readOnly": False, @@ -284,7 +284,7 @@ def test_dap_fit(bec_client_lib): "update_frequency": 400, }, "readoutPriority": "baseline", - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "enabled": True, "readOnly": False, }, @@ -298,7 +298,7 @@ def test_dap_fit(bec_client_lib): "hexapod": { "deviceClass": "ophyd_devices.SynDeviceOPAAS", "deviceConfig": {}, - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "readoutPriority": "baseline", "enabled": True, "readOnly": False, @@ -307,7 +307,7 @@ def test_dap_fit(bec_client_lib): "deviceClass": "ophyd_devices.utils.bec_utils.DeviceClassConnectionError", "deviceConfig": {}, "readoutPriority": "baseline", - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "enabled": True, "readOnly": False, }, @@ -321,7 +321,7 @@ def test_dap_fit(bec_client_lib): "hexapod": { "deviceClass": "SynDeviceOPAAS", "deviceConfig": {}, - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "readoutPriority": "baseline", "enabled": True, "readOnly": False, @@ -330,7 +330,7 @@ def test_dap_fit(bec_client_lib): "deviceClass": "ophyd_devices.utils.bec_utils.DeviceClassInitError", "deviceConfig": {}, "readoutPriority": "baseline", - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "enabled": True, "readOnly": False, }, @@ -344,7 +344,7 @@ def test_dap_fit(bec_client_lib): "hexapod": { "deviceClass": "SynDeviceOPAAS", "deviceConfig": {}, - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "readoutPriority": "baseline", "enabled": True, "readOnly": False, @@ -353,7 +353,7 @@ def test_dap_fit(bec_client_lib): "deviceClass": "ophyd_devices.WrongDeviceClass", "deviceConfig": {}, "readoutPriority": "baseline", - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "enabled": True, "readOnly": False, }, @@ -420,7 +420,7 @@ def test_config_reload_with_describe_failure(bec_test_config_file_path, bec_clie "hexapod": { "deviceClass": "ophyd_devices.sim.sim_test_devices.SimPositionerWithDescribeFailure", "deviceConfig": {}, - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "readoutPriority": "baseline", "enabled": True, "readOnly": False, @@ -434,7 +434,7 @@ def test_config_reload_with_describe_failure(bec_test_config_file_path, bec_clie "update_frequency": 400, }, "readoutPriority": "baseline", - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "enabled": True, "readOnly": False, }, @@ -486,7 +486,7 @@ def test_config_add_remove_device(bec_client_lib): "update_frequency": 400, }, "readoutPriority": "baseline", - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "enabled": True, "readOnly": False, } diff --git a/bec_lib/bec_lib/bec_serializable.py b/bec_lib/bec_lib/bec_serializable.py index 1fa68442e..a3180a61d 100644 --- a/bec_lib/bec_lib/bec_serializable.py +++ b/bec_lib/bec_lib/bec_serializable.py @@ -1,3 +1,4 @@ +import numpy as np from pydantic import BaseModel, ConfigDict, computed_field @@ -21,3 +22,12 @@ class BECSerializable(BaseModel): @property def bec_codec(self) -> BecCodecInfo: return BecCodecInfo(type_name=self.__class__.__name__) + + def __eq__(self, other): + if type(other) is not type(self): + return False + try: + np.testing.assert_equal(self.model_dump(), other.model_dump()) + return True + except AssertionError: + return False diff --git a/bec_lib/bec_lib/bec_service.py b/bec_lib/bec_lib/bec_service.py index b0d1e7fb7..9659effbb 100644 --- a/bec_lib/bec_lib/bec_service.py +++ b/bec_lib/bec_lib/bec_service.py @@ -259,7 +259,6 @@ def _update_existing_services(self) -> None: msgs = [ self.connector.get(MessageEndpoints.service_status(service)) for service in services ] - print(msgs) self._services_info = {msg.content["name"]: msg for msg in msgs if msg is not None} msgs = [self.connector.get(MessageEndpoints.metrics(service)) for service in services] self._services_metric = {msg.content["name"]: msg for msg in msgs if msg is not None} diff --git a/bec_lib/bec_lib/device.py b/bec_lib/bec_lib/device.py index 2ffbe1d52..c2f48d39d 100644 --- a/bec_lib/bec_lib/device.py +++ b/bec_lib/bec_lib/device.py @@ -346,7 +346,7 @@ def _prepare_rpc_msg( client: BECClient = self.root.parent.parent msg = messages.ScanQueueMessage( scan_type="device_rpc", - parameter=params, + parameter=messages.sanitize_one_way_encodable(params), queue=client.queue.get_default_scan_queue(), # type: ignore metadata={"RID": request_id, "response": True}, ) @@ -1164,8 +1164,8 @@ def limits(self): if not limit_msg: return [0, 0] limits = [ - limit_msg.content["signals"].get("low", {}).get("value", 0), - limit_msg.content["signals"].get("high", {}).get("value", 0), + limit_msg.signals.get("low", {}).get("value", 0), + limit_msg.signals.get("high", {}).get("value", 0), ] return limits diff --git a/bec_lib/bec_lib/messages.py b/bec_lib/bec_lib/messages.py index 8789cd36a..b8ced733d 100644 --- a/bec_lib/bec_lib/messages.py +++ b/bec_lib/bec_lib/messages.py @@ -6,20 +6,24 @@ import uuid import warnings from copy import deepcopy -from enum import Enum, auto +from enum import Enum, StrEnum, auto from importlib.metadata import PackageNotFoundError from importlib.metadata import version as importlib_version -from types import NoneType -from typing import Annotated, Any, ClassVar, Literal, Self +from typing import Annotated, Any, ClassVar, Literal, Mapping, Self, TypeVar, Union from uuid import uuid4 -import msgpack import numpy as np from pydantic import ( BaseModel, BeforeValidator, ConfigDict, + FailFast, Field, + Strict, + StrictBool, + StrictFloat, + StrictInt, + StrictStr, ValidationError, WithJsonSchema, field_validator, @@ -34,67 +38,31 @@ _one_way_registry = OneWaySerializationRegistry() -def _sanitize_one_way(data: Any) -> Any: - # TODO: Temporary fix for standardizing message structure, will be replaced - # by encoders in a future iteration - if isinstance(data, np.ndarray): - return data - if isinstance(data, np.bool_): - return bool(data) - if isinstance(data, (np.float16, np.float32, np.float64)): - return float(data) - if isinstance(data, (np.int16, np.int32, np.int64, np.uint16, np.uint32, np.uint64)): - return int(data) +def sanitize_one_way_encodable(data: Any) -> Any: + """Sanitize any data which can be serialized in a json-compatible format and is not supposed to be decoded, + for example, a parameter dict containing devices""" if isinstance(data, (list, tuple, set)): - return [_sanitize_one_way(x) for x in data] - if isinstance(data, dict): - return {_sanitize_one_way(k): _sanitize_one_way(v) for k, v in data.items()} + return [sanitize_one_way_encodable(x) for x in data] + if isinstance(data, Mapping): + return { + sanitize_one_way_encodable(k): sanitize_one_way_encodable(v) for k, v in data.items() + } return _one_way_registry.encode(data) -def _ignore_ndarray(data: Any) -> Any: - if isinstance(data, np.ndarray): - return [] - raise ValueError(f"Cannot serialize unknown type for {data}: {type(data)}") - - -def _test_packable(data: Any): - try: - msgpack.dumps(data, default=_ignore_ndarray) - except Exception as e: - raise ValueError(f"Non-JSONable/msgpackable data in {data}!") from e - - -def _validate_packable(data: Any) -> Any: - # Skip sanitization if the data is already valid - if isinstance(data, int | float | str | bool | NoneType): - return data - if isinstance(data, np.bool_): - return bool(data) - try: - _test_packable(data) - return data - # Recursively check if we should replace anything which is not supposed to be decoded to a custom - # type on the other end - except ValueError: - data = _sanitize_one_way(data) - _test_packable(data) - return data - +JsonableScalar = TypeAliasType("JsonableScalar", StrictInt | StrictFloat | StrictStr | StrictBool) Jsonable = TypeAliasType( "Jsonable", - Annotated[ - int | float | str | bool | None | list["Jsonable"] | dict[str, "Jsonable"] | np.ndarray, - BeforeValidator(_validate_packable), - ], + JsonableScalar + | None + | Annotated[list["Jsonable"], Strict(), FailFast()] + | Annotated[dict[StrictStr, "Jsonable"], Strict()], ) JsonableDict = TypeAliasType( "JsonableDict", - Annotated[ - dict[str, Jsonable], BeforeValidator(_validate_packable), WithJsonSchema({"type": "object"}) - ], + Annotated[dict[StrictStr, Jsonable], WithJsonSchema({"type": "object"}), Strict()], ) @@ -127,16 +95,6 @@ class BECMessage(BECSerializable): msg_type: ClassVar[str] metadata: JsonableDict = Field(default_factory=dict) - @field_validator("metadata") - @classmethod - def check_metadata(cls, v): - """Validate the metadata, return empty dict if None - - Args: - v (dict, None): Metadata dictionary - """ - return v or {} - @model_validator(mode="before") @classmethod def _strip_codec_info(cls, data: Any): @@ -151,18 +109,6 @@ def content(self): content.pop("metadata", None) return content - def __eq__(self, other): - if not isinstance(other, BECMessage): - # don't attempt to compare against unrelated types - return False - - try: - np.testing.assert_equal(self.model_dump(), other.model_dump()) - except AssertionError: - return False - - return self.msg_type == other.msg_type and self.metadata == other.metadata - def loads(self): warnings.warn( "BECMessage.loads() is deprecated and should not be used anymore. When calling Connector methods, it can be omitted. When a message needs to be deserialized call the appropriate function from bec_lib.serialization", @@ -183,7 +129,7 @@ def __hash__(self) -> int: # To correctly encode a message in another message, pydantic should know it is to be dumped # as the concrete type it is, and not only the fields from BECMessage -SpecificMessageType = TypeVar("MessageType", bound=BECMessage) +SpecificMessageType = TypeVar("SpecificMessageType", bound=BECMessage) class BundleMessage(BECMessage): @@ -201,7 +147,7 @@ class BundleMessage(BECMessage): """ msg_type: ClassVar[str] = "bundle_message" - messages: list[SpecificMessageType] = Field(default_factory=list) + messages: Annotated[list[SpecificMessageType], Field(default_factory=list)] def append(self, msg: BECMessage): """Append a new BECMessage to the bundle""" @@ -671,6 +617,46 @@ def _ensure_error_info_if_error(self): return self +# TODO: remove when deprecated usages of SignalReading are cleaned up +logger = None + + +def lazy_ensure_logger(): + global logger + if logger is None: + from bec_lib.logger import bec_logger + + logger = bec_logger.logger + + +class SignalReading(BECSerializable): + value: int | float | list[int] | list[float] | np.ndarray | None | str + timestamp: float | list[float] | None = None + + def keys(self): + lazy_ensure_logger() + logger.warning( + "Dictionary usage of SignalReading is deprecated; please replace it with a different access pattern." + ) + return ["value", "timestamp"] + + def get(self, item: Literal["value", "timestamp"], default=Any): + """Allow dictionary-style access for legacy reasons.""" + lazy_ensure_logger() + logger.warning( + "Get-access on SignalReading is deprecated; Just access the model.value field." + ) + if item not in ["value", "timestamp"]: + raise KeyError('SignalReading only has "value" and "timestamp" fields!') + return getattr(self, item) + + def __getitem__(self, item: str): + return self.get(item) + + def items(self): + return dict(self).items() + + class DeviceMessage(BECMessage): """Message type for sending device readings from the device server @@ -683,7 +669,7 @@ class DeviceMessage(BECMessage): """ msg_type: ClassVar[str] = "device_message" - signals: dict[str, dict[Literal["value", "timestamp"], Any]] + signals: dict[str, SignalReading] @field_validator("metadata") @classmethod @@ -1378,6 +1364,29 @@ class DAPResponseMessage(BECMessage): dap_request: SpecificMessageType | None = Field(default=None) +class ScanArgType(StrEnum): + DEVICE = "device" + FLOAT = "float" + INT = "int" + BOOL = "boolean" + STR = "str" + LIST = "list" + DICT = "dict" + + +class AvailableScan(BECMessage): + """Information about an available scan""" + + class_name: str + base_class: str + arg_input: dict[str, Jsonable | ScanArgType] + gui_config: JsonableDict + required_kwargs: list[str] | dict[str, ScanArgType] + arg_bundle_size: JsonableDict + doc: str | None = None + signature: list[JsonableDict] + + class AvailableResourceMessage(BECMessage): """Message for available resources such as scans, data processing plugins etc @@ -1387,7 +1396,13 @@ class AvailableResourceMessage(BECMessage): """ msg_type: ClassVar[str] = "available_resource_message" - resource: JsonableDict | list[JsonableDict] | SpecificMessageType | list[SpecificMessageType] + resource: ( + JsonableDict + | list[JsonableDict] + | SpecificMessageType + | list[SpecificMessageType] + | dict[str, SpecificMessageType] + ) class ProgressMessage(BECMessage): diff --git a/bec_lib/bec_lib/scans.py b/bec_lib/bec_lib/scans.py index 041d8d72f..5dccfa8b1 100644 --- a/bec_lib/bec_lib/scans.py +++ b/bec_lib/bec_lib/scans.py @@ -21,6 +21,7 @@ from bec_lib.device import DeviceBase from bec_lib.endpoints import MessageEndpoints from bec_lib.logger import bec_logger +from bec_lib.messages import ScanArgType # moved from here to messages - for compat with plugins from bec_lib.scan_repeat import _scan_repeat_depth from bec_lib.scan_report import ScanReport from bec_lib.signature_serializer import deserialize_dtype, dict_to_signature @@ -365,14 +366,14 @@ def prepare_scan_request( return messages.ScanQueueMessage( scan_type=scan_name, - parameter=params, + parameter=messages.sanitize_one_way_encodable(params), queue=scan_queue, metadata=metadata, allow_restart=allow_restart, ) @staticmethod - def _parameter_bundler(args: tuple, bundle_size: int) -> tuple | dict: + def _parameter_bundler(args: tuple, bundle_size: int) -> list | dict: """ Bundle the arguments into the correct format for the scan server. If the bundle size is 0, return the arguments as is. @@ -383,11 +384,11 @@ def _parameter_bundler(args: tuple, bundle_size: int) -> tuple | dict: bundle_size: number of parameters per bundle Returns: - tuple | dict: bundled arguments + list | dict: bundled arguments """ if not bundle_size: - return args + return list(args) params = {} for cmds in partition(bundle_size, args): params[cmds[0]] = list(cmds[1:]) diff --git a/bec_lib/bec_lib/signature_serializer.py b/bec_lib/bec_lib/signature_serializer.py index 577978e51..dc09d649a 100644 --- a/bec_lib/bec_lib/signature_serializer.py +++ b/bec_lib/bec_lib/signature_serializer.py @@ -124,7 +124,7 @@ def _merge_literals(vals: Generator[str | dict, None, None]) -> Generator[str | if _literal_args == [None]: yield "NoneType" elif _literal_args: - yield {"Literal": tuple(_literal_args)} + yield {"Literal": list(_literal_args)} def serialize_dtype(dtype: object) -> list[str | dict] | str | dict: diff --git a/bec_lib/tests/test_bec_messages.py b/bec_lib/tests/test_bec_messages.py index 1b3c6c0aa..3b2cf6f44 100644 --- a/bec_lib/tests/test_bec_messages.py +++ b/bec_lib/tests/test_bec_messages.py @@ -706,8 +706,10 @@ def test_feedback_message(): def test_message_with_np_array_in_dict(): arr = np.zeros(5) - msg = messages.ScanMessage(point_id=0, scan_id="", data={"device": {"value": arr}}, metadata={}) - assert isinstance(msg.data["device"]["value"], np.ndarray) + with pytest.raises(pydantic.ValidationError) as e: + msg = messages.BECMessage(metadata={"value": arr}) + assert e.match("metadata.value") + assert e.match("should be a valid") def test_message_service_config(): diff --git a/bec_lib/tests/test_config_helper.py b/bec_lib/tests/test_config_helper.py index 8cf830815..fc0b10014 100644 --- a/bec_lib/tests/test_config_helper.py +++ b/bec_lib/tests/test_config_helper.py @@ -72,7 +72,7 @@ def test_config_helper_save_current_session(config_helper): "enabled": True, "readOnly": False, "deviceClass": "SimPositioner", - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "deviceConfig": { "delay": 1, "labels": "pinz", @@ -93,7 +93,7 @@ def test_config_helper_save_current_session(config_helper): "enabled": True, "readOnly": False, "deviceClass": "SimMonitor", - "deviceTags": {"beamline"}, + "deviceTags": ["beamline"], "deviceConfig": {"labels": "transd", "name": "transd", "tolerance": 0.5}, "readoutPriority": "monitored", "onFailure": "retry", @@ -238,7 +238,7 @@ def test_update_base_path_recovery(config_helper_plain): { "pinz": { "deviceClass": "SimPositioner", - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "enabled": True, "readOnly": False, "deviceConfig": { @@ -254,7 +254,7 @@ def test_update_base_path_recovery(config_helper_plain): }, "transd": { "deviceClass": "SimMonitor", - "deviceTags": {"beamline"}, + "deviceTags": ["beamline"], "enabled": True, "readOnly": False, "deviceConfig": {"labels": "transd", "name": "transd", "tolerance": 0.5}, @@ -265,7 +265,7 @@ def test_update_base_path_recovery(config_helper_plain): { "pinz": { "deviceClass": "SimPositioner", - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "enabled": True, "readOnly": True, "deviceConfig": { @@ -281,7 +281,7 @@ def test_update_base_path_recovery(config_helper_plain): }, "transd": { "deviceClass": "SimMonitor", - "deviceTags": {"beamline"}, + "deviceTags": ["beamline"], "enabled": True, "readOnly": False, "deviceConfig": {"labels": "transd", "name": "transd", "tolerance": 0.5}, @@ -295,7 +295,7 @@ def test_update_base_path_recovery(config_helper_plain): { "pinz": { "deviceClass": "SimPositioner", - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "enabled": True, "readOnly": False, "deviceConfig": { @@ -311,7 +311,7 @@ def test_update_base_path_recovery(config_helper_plain): }, "transd": { "deviceClass": "SimMonitor", - "deviceTags": {"beamline"}, + "deviceTags": ["beamline"], "enabled": True, "readOnly": False, "deviceConfig": {"labels": "transd", "name": "transd", "tolerance": 0.5}, @@ -322,7 +322,7 @@ def test_update_base_path_recovery(config_helper_plain): { "pinz": { "deviceClass": "SimPositioner", - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "enabled": True, "readOnly": False, "deviceConfig": { @@ -338,7 +338,7 @@ def test_update_base_path_recovery(config_helper_plain): }, "transd": { "deviceClass": "SimMonitor", - "deviceTags": {"beamline"}, + "deviceTags": ["beamline"], "enabled": True, "readOnly": False, "deviceConfig": {"labels": "transd", "name": "transd", "tolerance": 0.5}, @@ -352,7 +352,7 @@ def test_update_base_path_recovery(config_helper_plain): { "pinz": { "deviceClass": "SimPositioner", - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "enabled": True, "readOnly": False, "deviceConfig": { @@ -368,7 +368,7 @@ def test_update_base_path_recovery(config_helper_plain): }, "transd": { "deviceClass": "SimMonitor", - "deviceTags": {"beamline"}, + "deviceTags": ["beamline"], "enabled": True, "readOnly": False, "deviceConfig": {"labels": "transd", "name": "transd", "tolerance": 0.5}, @@ -379,7 +379,7 @@ def test_update_base_path_recovery(config_helper_plain): { "pinz": { "deviceClass": "SimPositioner", - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "enabled": True, "readOnly": False, "deviceConfig": { @@ -395,7 +395,7 @@ def test_update_base_path_recovery(config_helper_plain): }, "transd": { "deviceClass": "SimMonitor", - "deviceTags": {"beamline"}, + "deviceTags": ["beamline"], "enabled": True, "readOnly": False, "deviceConfig": {"labels": "transd", "name": "transd", "tolerance": 0.5}, diff --git a/bec_lib/tests/test_devices.py b/bec_lib/tests/test_devices.py index 75bf0e920..4b69ef78c 100644 --- a/bec_lib/tests/test_devices.py +++ b/bec_lib/tests/test_devices.py @@ -61,9 +61,15 @@ def test_read(dev: Any): res = dev.samx.read(cached=True) mock_get.assert_called_once_with(MessageEndpoints.device_readback("samx")) assert res == { - "samx": {"value": 0, "timestamp": 1701105880.1711318}, - "samx_setpoint": {"value": 0, "timestamp": 1701105880.1693492}, - "samx_motor_is_moving": {"value": 0, "timestamp": 1701105880.16935}, + "samx": messages.SignalReading.model_validate( + {"value": 0, "timestamp": 1701105880.1711318} + ), + "samx_setpoint": messages.SignalReading.model_validate( + {"value": 0, "timestamp": 1701105880.1693492} + ), + "samx_motor_is_moving": messages.SignalReading.model_validate( + {"value": 0, "timestamp": 1701105880.16935} + ), } @@ -79,15 +85,25 @@ def test_read_filtered_hints(dev: Any): ) res = dev.samx.read(cached=True, filter_to_hints=True) mock_get.assert_called_once_with(MessageEndpoints.device_readback("samx")) - assert res == {"samx": {"value": 0, "timestamp": 1701105880.1711318}} + assert res == { + "samx": messages.SignalReading.model_validate( + {"value": 0, "timestamp": 1701105880.1711318} + ) + } def test_read_use_read(dev: Any): with mock.patch.object(dev.samx.root.parent.connector, "get") as mock_get: data = { - "samx": {"value": 0, "timestamp": 1701105880.1711318}, - "samx_setpoint": {"value": 0, "timestamp": 1701105880.1693492}, - "samx_motor_is_moving": {"value": 0, "timestamp": 1701105880.16935}, + "samx": messages.SignalReading.model_validate( + {"value": 0, "timestamp": 1701105880.1711318} + ), + "samx_setpoint": messages.SignalReading.model_validate( + {"value": 0, "timestamp": 1701105880.1693492} + ), + "samx_motor_is_moving": messages.SignalReading.model_validate( + {"value": 0, "timestamp": 1701105880.16935} + ), } mock_get.return_value = messages.DeviceMessage( signals=data, metadata={"scan_id": "scan_id", "scan_type": "scan_type"} @@ -100,11 +116,21 @@ def test_read_use_read(dev: Any): def test_read_nested_device(dev: Any): with mock.patch.object(dev.dyn_signals.root.parent.connector, "get") as mock_get: data = { - "dyn_signals_messages_message1": {"value": 0, "timestamp": 1701105880.0716832}, - "dyn_signals_messages_message2": {"value": 0, "timestamp": 1701105880.071722}, - "dyn_signals_messages_message3": {"value": 0, "timestamp": 1701105880.071739}, - "dyn_signals_messages_message4": {"value": 0, "timestamp": 1701105880.071753}, - "dyn_signals_messages_message5": {"value": 0, "timestamp": 1701105880.071766}, + "dyn_signals_messages_message1": messages.SignalReading.model_validate( + {"value": 0, "timestamp": 1701105880.0716832} + ), + "dyn_signals_messages_message2": messages.SignalReading.model_validate( + {"value": 0, "timestamp": 1701105880.071722} + ), + "dyn_signals_messages_message3": messages.SignalReading.model_validate( + {"value": 0, "timestamp": 1701105880.071739} + ), + "dyn_signals_messages_message4": messages.SignalReading.model_validate( + {"value": 0, "timestamp": 1701105880.071753} + ), + "dyn_signals_messages_message5": messages.SignalReading.model_validate( + {"value": 0, "timestamp": 1701105880.071766} + ), } mock_get.return_value = messages.DeviceMessage( signals=data, metadata={"scan_id": "scan_id", "scan_type": "scan_type"} @@ -139,7 +165,11 @@ def test_read_kind_hinted( if cached: mock_get.assert_called_once_with(MessageEndpoints.device_readback("samx")) mock_run.assert_not_called() - assert res == {"samx": {"value": 0, "timestamp": 1701105880.1711318}} + assert res == { + "samx": messages.SignalReading.model_validate( + {"value": 0, "timestamp": 1701105880.1711318} + ) + } else: mock_run.assert_called_once_with(cached=False, fcn=dev.samx.readback.read) mock_get.assert_not_called() @@ -207,10 +237,10 @@ def test_read_configuration_cached( @pytest.mark.parametrize( ["mock_rpc", "method", "args", "kwargs", "expected_call"], [ - ("_get_rpc_response", "set", (1,), {}, (mock.ANY, mock.ANY)), - ("_run_rpc_call", "set", (1,), {}, ("samx", "setpoint.set", 1)), - ("_run_rpc_call", "put", (1,), {"wait": True}, ("samx", "setpoint.set", 1)), - ("_run_rpc_call", "put", (1,), {}, ("samx", "setpoint.put", 1)), + ("_get_rpc_response", "set", [1], {}, (mock.ANY, mock.ANY)), + ("_run_rpc_call", "set", [1], {}, ("samx", "setpoint.set", 1)), + ("_run_rpc_call", "put", [1], {"wait": True}, ("samx", "setpoint.set", 1)), + ("_run_rpc_call", "put", [1], {}, ("samx", "setpoint.put", 1)), ], ) def test_run_rpc_call(dev: Any, mock_rpc, method, args, kwargs, expected_call): @@ -334,7 +364,7 @@ def device_config(): "readoutPriority": "monitored", "deviceClass": "SimCamera", "deviceConfig": {"device_access": True, "labels": "eiger", "name": "eiger"}, - "deviceTags": {"detector"}, + "deviceTags": ["detector"], } @@ -368,7 +398,12 @@ def device_obj(device_config: dict[str, Any]): def test_create_device_saves_config( device_obj: DeviceBaseWithConfig, device_config: dict[str, Any] ): - assert {k: v for k, v in device_obj._config.items() if k in device_config} == device_config + assert ( + messages.sanitize_one_way_encodable( + {k: v for k, v in device_obj._config.items() if k in device_config} + ) + == device_config + ) def test_device_enabled(device_obj: DeviceBaseWithConfig, device_config: dict[str, Any]): @@ -462,7 +497,7 @@ def test_status_wait(): @pytest.fixture def device_w_tags(dev_w_config: Callable[..., DeviceBaseWithConfig]): - yield dev_w_config({"deviceTags": {"tag1", "tag2"}}) + yield dev_w_config({"deviceTags": ["tag1", "tag2"]}) @pytest.mark.parametrize( @@ -500,7 +535,7 @@ def test_properties(dev_w_config: Callable[..., DeviceBaseWithConfig], config, a @pytest.mark.parametrize( ["config", "method", "value"], - [({"deviceTags": {"tag1", "tag2"}}, "get_device_tags", {"tag1", "tag2"})], + [({"deviceTags": ["tag1", "tag2"]}, "get_device_tags", {"tag1", "tag2"})], ) def test_methods(dev_w_config: Callable[..., DeviceBaseWithConfig], config, method, value): assert getattr(dev_w_config(config), method)() == value @@ -627,7 +662,7 @@ def test_show_all(dm_with_override): "readOnly": False, "deviceClass": "Class1", "readoutPriority": "monitored", - "deviceTags": {"tag1", "tag2"}, + "deviceTags": ["tag1", "tag2"], }, parent=parent, ) @@ -639,7 +674,7 @@ def test_show_all(dm_with_override): "readOnly": True, "deviceClass": "Class2", "readoutPriority": "baseline", - "deviceTags": {"tag3", "tag4"}, + "deviceTags": ["tag3", "tag4"], }, parent=parent, ) diff --git a/bec_lib/tests/test_file_utils.py b/bec_lib/tests/test_file_utils.py index cbb8b908e..047deeaf8 100644 --- a/bec_lib/tests/test_file_utils.py +++ b/bec_lib/tests/test_file_utils.py @@ -40,7 +40,7 @@ def scan_msg(): yield ScanStatusMessage( scan_id="1111", scan_parameters={"system_config": {"file_directory": None, "file_suffix": None}}, - info={"scan_number": 5, "file_components": ("S00000-00999/S00005/S00005", "h5")}, + info={"scan_number": 5, "file_components": ["S00000-00999/S00005/S00005", "h5"]}, status="closed", ) @@ -266,7 +266,7 @@ def test_compile_file_components_valid_paths(kwargs, expected_path, description) ScanStatusMessage( scan_id="1111", scan_parameters={"system_config": {"file_directory": None, "file_suffix": None}}, - info={"scan_number": 5, "file_components": ("S00000-00999/S00005/S00005", "h5")}, + info={"scan_number": 5, "file_components": ["S00000-00999/S00005/S00005", "h5"]}, status="closed", ) ), @@ -276,7 +276,7 @@ def test_compile_file_components_valid_paths(kwargs, expected_path, description) scan_parameters={ "system_config": {"file_directory": "/my_dir", "file_suffix": None} }, - info={"scan_number": 5, "file_components": ("/my_dir/S00005", "h5")}, + info={"scan_number": 5, "file_components": ["/my_dir/S00005", "h5"]}, status="closed", ) ), @@ -288,7 +288,7 @@ def test_compile_file_components_valid_paths(kwargs, expected_path, description) }, info={ "scan_number": 5, - "file_components": ("S00000-00999/S00005_sampleA/S00005", "h5"), + "file_components": ["S00000-00999/S00005_sampleA/S00005", "h5"], }, status="closed", ) diff --git a/bec_lib/tests/test_scan_context.py b/bec_lib/tests/test_scan_context.py index b8fe2b037..105f3a3f1 100644 --- a/bec_lib/tests/test_scan_context.py +++ b/bec_lib/tests/test_scan_context.py @@ -140,7 +140,7 @@ def test_parameter_bundler(bec_client_mock): assert res == {dev.samx: [-5, 5, 5]} res = client.scans._parameter_bundler((-5, 5, 5), 0) - assert res == (-5, 5, 5) + assert res == [-5, 5, 5] @pytest.mark.parametrize( diff --git a/bec_lib/tests/test_signature_serializer.py b/bec_lib/tests/test_signature_serializer.py index d585cb15a..28c37eebb 100644 --- a/bec_lib/tests/test_signature_serializer.py +++ b/bec_lib/tests/test_signature_serializer.py @@ -43,7 +43,7 @@ def test_func(a: Literal[1, 2, 3] | None = None): "name": "a", "kind": "POSITIONAL_OR_KEYWORD", "default": None, - "annotation": {"Literal": (1, 2, 3, None)}, + "annotation": {"Literal": [1, 2, 3, None]}, } ] @@ -59,7 +59,7 @@ def test_func(a, b: Literal["test", None], *args, **kwargs): "name": "b", "kind": "POSITIONAL_OR_KEYWORD", "default": "_empty", - "annotation": {"Literal": ("test", None)}, + "annotation": {"Literal": ["test", None]}, }, {"name": "args", "kind": "VAR_POSITIONAL", "default": "_empty", "annotation": "_empty"}, {"name": "kwargs", "kind": "VAR_KEYWORD", "default": "_empty", "annotation": "_empty"}, @@ -83,13 +83,13 @@ def test_func( "name": "b", "kind": "POSITIONAL_OR_KEYWORD", "default": "_empty", - "annotation": {"Literal": ("test", None)}, + "annotation": {"Literal": ["test", None]}, }, { "name": "c", "kind": "POSITIONAL_OR_KEYWORD", "default": 1, - "annotation": {"Literal": (1, 2, 3)}, + "annotation": {"Literal": [1, 2, 3]}, }, { "name": "d", @@ -265,7 +265,7 @@ def test_func(step: Annotated[float, scan_argument] | None = None): (float, "float"), (bool, "bool"), (inspect._empty, "_empty"), - (Literal[1, 2, 3], {"Literal": (1, 2, 3)}), + (Literal[1, 2, 3], {"Literal": [1, 2, 3]}), (Union[int, str], ["int", "str"]), (Optional[str], ["str", "NoneType"]), (DeviceBase, "DeviceBase"), @@ -336,7 +336,7 @@ def test_serialize_dtype(dtype_in, dtype_out): ("float", float), ("bool", bool), ("_empty", inspect._empty), - ({"Literal": (1, 2, 3)}, Literal[1, 2, 3]), + ({"Literal": [1, 2, 3]}, Literal[1, 2, 3]), (["int", "str"], Union[int, str]), (["str", "NoneType"], Optional[str]), ("NoneType", None), diff --git a/bec_lib/tests/test_signature_serializer_with_future_import.py b/bec_lib/tests/test_signature_serializer_with_future_import.py index d895aac22..038245cfd 100644 --- a/bec_lib/tests/test_signature_serializer_with_future_import.py +++ b/bec_lib/tests/test_signature_serializer_with_future_import.py @@ -19,7 +19,7 @@ def test_func(a: Literal[1, 2, 3] | None = None): "name": "a", "kind": "POSITIONAL_OR_KEYWORD", "default": None, - "annotation": {"Literal": (1, 2, 3, None)}, + "annotation": {"Literal": [1, 2, 3, None]}, } ] @@ -34,7 +34,7 @@ def test_func(a: Literal[1, 2, 3] | None | Literal["a", "b", "c"]): "name": "a", "kind": "POSITIONAL_OR_KEYWORD", "default": "_empty", - "annotation": {"Literal": (1, 2, 3, None, "a", "b", "c")}, + "annotation": {"Literal": [1, 2, 3, None, "a", "b", "c"]}, } ] @@ -52,7 +52,7 @@ def test_func(a: Literal[1, 2, 3] | "SomeUnknownType" | Literal["a", "b", "c"]): "name": "a", "kind": "POSITIONAL_OR_KEYWORD", "default": "_empty", - "annotation": {"Literal": (1, 2, 3, "a", "b", "c")}, + "annotation": {"Literal": [1, 2, 3, "a", "b", "c"]}, } ] @@ -61,7 +61,7 @@ def test_serialize_dtype_imported_imported_func_arg(): sig = inspect.signature(literal_union_test_func) anno = sig.parameters["a"].annotation assert serialize_dtype(anno) == serialize_dtype(Union[Literal["a", "b", "c"], EnumTest]) - assert serialize_dtype(anno) == {"Literal": ("a", "b", "c")} + assert serialize_dtype(anno) == {"Literal": ["a", "b", "c"]} def test_signature_serializer_parses_untion_on_imported_func(): @@ -71,7 +71,7 @@ def test_signature_serializer_parses_untion_on_imported_func(): "name": "a", "kind": "POSITIONAL_OR_KEYWORD", "default": "_empty", - "annotation": {"Literal": ("a", "b", "c")}, + "annotation": {"Literal": ["a", "b", "c"]}, } ] diff --git a/bec_server/bec_server/device_server/device_server.py b/bec_server/bec_server/device_server/device_server.py index 3651eddd1..da8b82d11 100644 --- a/bec_server/bec_server/device_server/device_server.py +++ b/bec_server/bec_server/device_server/device_server.py @@ -19,7 +19,7 @@ from bec_lib.device import OnFailure from bec_lib.endpoints import MessageEndpoints from bec_lib.logger import bec_logger -from bec_lib.messages import BECStatus +from bec_lib.messages import BECStatus, sanitize_one_way_encodable from bec_lib.serialization import json_ext from bec_lib.utils.rpc_utils import rgetattr from bec_server.device_server.devices.devicemanager import DeviceManagerDS @@ -790,7 +790,7 @@ def status_callback(self, status): def _update_read_configuration(self, obj: OphydObject, metadata: dict, pipe) -> None: dev_config_msg = messages.DeviceMessage( - signals=obj.root.read_configuration(), metadata=metadata + signals=sanitize_one_way_encodable(obj.root.read_configuration()), metadata=metadata ) self.connector.set_and_publish( MessageEndpoints.device_read_configuration(obj.root.name), dev_config_msg, pipe diff --git a/bec_server/bec_server/device_server/devices/device_serializer.py b/bec_server/bec_server/device_server/devices/device_serializer.py index a7d168601..f0a93cbd7 100644 --- a/bec_server/bec_server/device_server/devices/device_serializer.py +++ b/bec_server/bec_server/device_server/devices/device_serializer.py @@ -12,6 +12,7 @@ from ophyd_devices import BECDeviceBase, ComputedSignal from ophyd_devices.utils.bec_signals import BECMessageSignal +from bec_lib import messages from bec_lib.bec_errors import DeviceConfigError from bec_lib.device import DeviceBaseWithConfig from bec_lib.logger import bec_logger @@ -183,7 +184,9 @@ def get_device_info( "kind_int": kind, "kind_str": Kind(kind).name, "doc": doc, - "describe": signal_obj.describe().get(signal_obj.name, {}), + "describe": messages.sanitize_one_way_encodable( + signal_obj.describe().get(signal_obj.name, {}) + ), # pylint: disable=protected-access "metadata": signal_obj._metadata, "labels": sorted(signal_obj._ophyd_labels_), @@ -201,7 +204,9 @@ def get_device_info( "kind_int": signal_obj.kind.value, "kind_str": signal_obj.kind.name, "doc": doc, - "describe": signal_obj.describe().get(signal_obj.name, {}), + "describe": messages.sanitize_one_way_encodable( + signal_obj.describe().get(signal_obj.name, {}) + ), # pylint: disable=protected-access "metadata": signal_obj._metadata, "labels": sorted(signal_obj._ophyd_labels_), diff --git a/bec_server/bec_server/device_server/devices/devicemanager.py b/bec_server/bec_server/device_server/devices/devicemanager.py index 8857aaa7c..f22c6081f 100644 --- a/bec_server/bec_server/device_server/devices/devicemanager.py +++ b/bec_server/bec_server/device_server/devices/devicemanager.py @@ -117,7 +117,8 @@ def initialize_device_buffer(self, connector): if not isinstance(self.obj, ophyd.Signal): # signals have the same read and read_configuration values; no need to publish twice dev_config_msg = messages.DeviceMessage( - signals=self.obj.read_configuration(), metadata={} + signals=messages.sanitize_one_way_encodable(self.obj.read_configuration()), + metadata={}, ) connector.set_and_publish( MessageEndpoints.device_read_configuration(self.name), dev_config_msg, pipe=pipe diff --git a/bec_server/bec_server/procedures/manager.py b/bec_server/bec_server/procedures/manager.py index e9e6e4773..5f6015fff 100644 --- a/bec_server/bec_server/procedures/manager.py +++ b/bec_server/bec_server/procedures/manager.py @@ -98,7 +98,7 @@ def __init__(self, redis: str, worker_type: type[ProcedureWorker]): MessageEndpoints.available_procedures(), AvailableResourceMessage( resource={ - name: procedure_registry.get_info(name) + name: list(procedure_registry.get_info(name)) for name in procedure_registry.available() } ), diff --git a/bec_server/bec_server/scan_server/scan_assembler.py b/bec_server/bec_server/scan_server/scan_assembler.py index 57b0eb1e2..9b1bc525d 100644 --- a/bec_server/bec_server/scan_server/scan_assembler.py +++ b/bec_server/bec_server/scan_server/scan_assembler.py @@ -34,8 +34,8 @@ def is_scan_message(self, msg: messages.ScanQueueMessage) -> bool: Returns: bool: True if the message is a scan message, False otherwise """ - scan = msg.content.get("scan_type") - cls_name = self.scan_manager.available_scans[scan]["class"] + scan = msg.scan_type + cls_name = self.scan_manager.available_scans[scan].class_name scan_cls = self.scan_manager.scan_dict[cls_name] return issubclass(scan_cls, ScanBase) @@ -55,8 +55,8 @@ def assemble_device_instructions( Returns: RequestBase: Scan instance of the initialized scan class """ - scan = msg.content.get("scan_type") - cls_name = self.scan_manager.available_scans[scan]["class"] + scan = msg.scan_type + cls_name = self.scan_manager.available_scans[scan].class_name scan_cls = self.scan_manager.scan_dict[cls_name] logger.info(f"Preparing instructions of request of type {scan} / {scan_cls.__name__}") diff --git a/bec_server/bec_server/scan_server/scan_gui_models.py b/bec_server/bec_server/scan_server/scan_gui_models.py index 17a398b30..a6afe4280 100644 --- a/bec_server/bec_server/scan_server/scan_gui_models.py +++ b/bec_server/bec_server/scan_server/scan_gui_models.py @@ -9,6 +9,7 @@ from pydantic_core import PydanticCustomError from bec_lib.device import DeviceBase +from bec_lib.messages import Jsonable, sanitize_one_way_encodable from bec_lib.signature_serializer import signature_to_dict from bec_server.scan_server.scans import ScanArgType, ScanBase @@ -26,9 +27,9 @@ class GUIInput(BaseModel): arg: bool = Field(False) name: str = Field(None, validate_default=True) - type: Optional[ - Literal["DeviceBase", "device", "float", "int", "bool", "str", "list", "dict"] - ] = Field(None, validate_default=True) + type: ( + Literal["DeviceBase", "device", "float", "int", "bool", "str", "list", "dict"] | Jsonable + ) = Field(None, validate_default=True) display_name: Optional[str] = Field(None, validate_default=True) tooltip: Optional[str] = Field(None, validate_default=True) default: Optional[Any] = Field(None, validate_default=True) @@ -83,7 +84,7 @@ def validate_name(cls, v, values): def validate_field(cls, v, values): # args cannot be validated with the current implementation of signature of scans if values.data["arg"]: - return v + return sanitize_one_way_encodable(v) signature = context_signature.get() if v is None: name = values.data.get("name", None) @@ -96,7 +97,7 @@ def validate_field(cls, v, values): for entry in signature: if entry["name"] == name: v = entry["annotation"] - return v + return sanitize_one_way_encodable(v) @field_validator("tooltip") @classmethod @@ -225,7 +226,7 @@ class GUIConfig(BaseModel): scan_class_name: str arg_group: Optional[GUIArgGroup] = Field(None) - kwarg_groups: list[GUIGroup] = Field(None) + kwarg_groups: list[GUIGroup] | None = Field(None) signature: list[dict] = Field(..., exclude=True) docstring: str = Field(..., exclude=True) diff --git a/bec_server/bec_server/scan_server/scan_manager.py b/bec_server/bec_server/scan_server/scan_manager.py index a383bad52..53ccf18b5 100644 --- a/bec_server/bec_server/scan_server/scan_manager.py +++ b/bec_server/bec_server/scan_server/scan_manager.py @@ -4,12 +4,13 @@ import inspect -from bec_lib import plugin_helper +from bec_lib import messages, plugin_helper from bec_lib.device import DeviceBase from bec_lib.endpoints import MessageEndpoints from bec_lib.logger import bec_logger from bec_lib.messages import AvailableResourceMessage from bec_lib.signature_serializer import serialize_dtype, signature_to_dict + from bec_server.scan_server.scan_gui_models import GUIConfig from . import scans as scans_module @@ -37,7 +38,7 @@ def __init__(self, *, parent): Scan Manager loads and manages the available scans. """ self.parent = parent - self.available_scans = {} + self.available_scans: dict[str, messages.AvailableScan] = {} self.scan_dict: dict[str, type[scans_module.RequestBase]] = {} self._plugins = {} self.load_plugins() @@ -63,7 +64,9 @@ def update_available_scans(self): members: list[tuple[str, type]] = inspect.getmembers( scans_module, predicate=inspect.isclass ) - members.extend((name, cls) for name, cls in self._plugins.items() if inspect.isclass(cls)) + members.extend( + (name, cls) for name, cls in self._plugins.items() if inspect.isclass(cls) + ) for name, scan_cls in members: is_scan = issubclass(scan_cls, scans_module.RequestBase) @@ -93,17 +96,20 @@ def update_available_scans(self): elif hasattr(scan_cls, "gui_config"): # type: ignore gui_visibility = scan_cls.gui_config # type: ignore - self.available_scans[scan_cls.scan_name] = { - "class": scan_cls.__name__, - "base_class": base_cls, - "arg_input": self.convert_arg_input(scan_cls.arg_input), - "required_kwargs": scan_cls.required_kwargs, - "arg_bundle_size": scan_cls.arg_bundle_size, - "doc": scan_cls.__doc__ or scan_cls.__init__.__doc__, - "signature": signature_to_dict(scan_cls.__init__), - "gui_visibility": gui_visibility, - "gui_config": gui_config, # deprecated! - should be removed - } + self.available_scans[scan_cls.scan_name] = ( + messages.AvailableScan.model_validate( + { + "class_name": scan_cls.__name__, + "base_class": base_cls, + "arg_input": self.convert_arg_input(scan_cls.arg_input), + "gui_config": gui_config, + "required_kwargs": scan_cls.required_kwargs, + "arg_bundle_size": scan_cls.arg_bundle_size, + "doc": scan_cls.__doc__ or scan_cls.__init__.__doc__, + "signature": signature_to_dict(scan_cls.__init__), + } + ) + ) def validate_gui_config(self, scan_cls) -> dict: """ @@ -154,5 +160,5 @@ def publish_available_scans(self): """send all available scans to the broker""" self.parent.connector.set( MessageEndpoints.available_scans(), - AvailableResourceMessage(resource=self.available_scans), + messages.AvailableResourceMessage(resource=self.available_scans), ) diff --git a/bec_server/bec_server/scan_server/scan_plugins/otf_scan.py b/bec_server/bec_server/scan_server/scan_plugins/otf_scan.py index 93f3b4e75..faac453af 100644 --- a/bec_server/bec_server/scan_server/scan_plugins/otf_scan.py +++ b/bec_server/bec_server/scan_server/scan_plugins/otf_scan.py @@ -1,7 +1,8 @@ import time from bec_lib.logger import bec_logger -from bec_server.scan_server.scans import ScanArgType, ScanBase, SyncFlyScanBase +from bec_lib.messages import ScanArgType +from bec_server.scan_server.scans import ScanBase, SyncFlyScanBase logger = bec_logger.logger diff --git a/bec_server/bec_server/scan_server/scan_stubs.py b/bec_server/bec_server/scan_server/scan_stubs.py index e39877870..46fe5d0f4 100644 --- a/bec_server/bec_server/scan_server/scan_stubs.py +++ b/bec_server/bec_server/scan_server/scan_stubs.py @@ -337,7 +337,9 @@ def _exclude_nones(input_dict: dict): def _device_msg(self, **kwargs) -> messages.DeviceInstructionMessage: """""" - msg = messages.DeviceInstructionMessage(**kwargs) + msg = messages.DeviceInstructionMessage.model_validate( + messages.sanitize_one_way_encodable(kwargs) + ) msg.metadata = {**self.device_msg_metadata(), **msg.metadata} return msg diff --git a/bec_server/bec_server/scan_server/scans.py b/bec_server/bec_server/scan_server/scans.py index dff71a199..5dae168de 100644 --- a/bec_server/bec_server/scan_server/scans.py +++ b/bec_server/bec_server/scan_server/scans.py @@ -1,7 +1,6 @@ from __future__ import annotations import ast -import enum import threading import time import uuid @@ -16,6 +15,7 @@ from bec_lib.devicemanager import DeviceManagerBase from bec_lib.endpoints import MessageEndpoints from bec_lib.logger import bec_logger +from bec_lib.messages import ScanArgType from bec_server.scan_server.instruction_handler import InstructionHandler from .errors import LimitError, ScanAbortion @@ -25,16 +25,6 @@ logger = bec_logger.logger -class ScanArgType(str, enum.Enum): - DEVICE = "device" - FLOAT = "float" - INT = "int" - BOOL = "boolean" - STR = "str" - LIST = "list" - DICT = "dict" - - def unpack_scan_args(scan_args: dict[str, Any]) -> list: """unpack_scan_args unpacks the scan arguments and returns them as a tuple. @@ -941,7 +931,7 @@ def scan_report_instructions(self): "RID": self.metadata["RID"], "devices": self.scan_motors, "start": self.start_pos, - "end": self.positions[0], + "end": list(self.positions[0]), } } ) diff --git a/bec_server/tests/tests_device_server/test_config_handler.py b/bec_server/tests/tests_device_server/test_config_handler.py index 192892141..8cddb920c 100644 --- a/bec_server/tests/tests_device_server/test_config_handler.py +++ b/bec_server/tests/tests_device_server/test_config_handler.py @@ -128,7 +128,7 @@ def test_parse_config_request_add_remove(dm_with_devices): "tolerance": 0.01, "update_frequency": 400, }, - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "enabled": True, "readOnly": False, "name": "new_device", diff --git a/bec_server/tests/tests_device_server/test_rpc_handler.py b/bec_server/tests/tests_device_server/test_rpc_handler.py index 172d0ce86..b1d2948e5 100644 --- a/bec_server/tests/tests_device_server/test_rpc_handler.py +++ b/bec_server/tests/tests_device_server/test_rpc_handler.py @@ -58,7 +58,7 @@ def test_execute_rpc_call(rpc_cls: RPCHandler, instr_params): msg = messages.DeviceInstructionMessage( device="device", action="rpc", - parameter=instr_params, + parameter=messages.sanitize_one_way_encodable(instr_params), metadata={"RID": "RID", "device_instr_id": "diid"}, ) out = rpc_cls._execute_rpc_call(rpc_var=rpc_var, instr=msg) @@ -80,7 +80,7 @@ def test_execute_rpc_call_var(rpc_cls: RPCHandler, instr_params: dict): msg = messages.DeviceInstructionMessage( device="device", action="rpc", - parameter=instr_params, + parameter=messages.sanitize_one_way_encodable(instr_params), metadata={"RID": "RID", "device_instr_id": "diid"}, ) out = rpc_cls._execute_rpc_call(rpc_var=rpc_var, instr=msg) diff --git a/bec_server/tests/tests_file_writer/test_async_file_writer.py b/bec_server/tests/tests_file_writer/test_async_file_writer.py index 31b354ffd..a79e77bf8 100644 --- a/bec_server/tests/tests_file_writer/test_async_file_writer.py +++ b/bec_server/tests/tests_file_writer/test_async_file_writer.py @@ -549,7 +549,7 @@ def test_async_writer_raises_on_wrong_data_type(async_writer): # Send invalid data (not a DeviceMessage) invalid_data = messages.DeviceMessage( - signals={"monitor_async": {"value": {"data": None}, "timestamp": 1}}, + signals={"monitor_async": {"value": None, "timestamp": 1}}, metadata={"async_update": {"type": "add", "max_shape": [None]}}, ) diff --git a/bec_server/tests/tests_scan_server/test_scan_assembler.py b/bec_server/tests/tests_scan_server/test_scan_assembler.py index 0156fd57e..b0a7d4f41 100644 --- a/bec_server/tests/tests_scan_server/test_scan_assembler.py +++ b/bec_server/tests/tests_scan_server/test_scan_assembler.py @@ -39,7 +39,7 @@ def run(self): # Fermat scan with args and kwargs, matching the FermatSpiralScan signature messages.ScanQueueMessage( scan_type="fermat_scan", - parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"steps": 3}}, + parameter={"args": {"samx": [-5, 5], "samy": [-5, 5]}, "kwargs": {"steps": 3}}, queue="primary", ), { @@ -120,7 +120,7 @@ def run(self): # Line scan with arg bundle messages.ScanQueueMessage( scan_type="line_scan", - parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"steps": 3}}, + parameter={"args": {"samx": [-5, 5], "samy": [-5, 5]}, "kwargs": {"steps": 3}}, queue="primary", ), {"arg_bundle": ["samx", -5, 5, "samy", -5, 5], "inputs": {}, "kwargs": {"steps": 3}}, @@ -147,12 +147,23 @@ def run(self): ) def test_scan_assembler_request_inputs(msg, request_inputs_expected, scan_assembler): + def _available_scan(clss: str): + return messages.AvailableScan( + class_name=clss, + base_class="", + arg_input={}, + gui_config={}, + required_kwargs=[], + arg_bundle_size={}, + signature=[], + ) + class MockScanManager: available_scans = { - "fermat_scan": {"class": "FermatSpiralScan"}, - "line_scan": {"class": "LineScan"}, - "custom_scan": {"class": "CustomScan"}, - "custom_scan2": {"class": "CustomScan2"}, + "fermat_scan": _available_scan("FermatSpiralScan"), + "line_scan": _available_scan("LineScan"), + "custom_scan": _available_scan("CustomScan"), + "custom_scan2": _available_scan("CustomScan2"), } scan_dict = { "FermatSpiralScan": FermatSpiralScan, diff --git a/bec_server/tests/tests_scan_server/test_scan_guard.py b/bec_server/tests/tests_scan_server/test_scan_guard.py index ce3670108..8709f1910 100644 --- a/bec_server/tests/tests_scan_server/test_scan_guard.py +++ b/bec_server/tests/tests_scan_server/test_scan_guard.py @@ -24,7 +24,7 @@ def scan_guard_mock(scan_server_mock): ( messages.ScanQueueMessage( scan_type="fermat_scan", - parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"step": 3}}, + parameter={"args": {"samx": [-5, 5], "samy": [-5, 5]}, "kwargs": {"step": 3}}, queue="primary", ) ), @@ -73,7 +73,7 @@ def test_device_rpc_is_valid(scan_guard_mock, device, func, is_valid): ( messages.ScanQueueMessage( scan_type="fermat_scan", - parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"step": 3}}, + parameter={"args": {"samx": [-5, 5], "samy": [-5, 5]}, "kwargs": {"step": 3}}, queue="primary", metadata={"client_info": {"acl_user": "default"}}, ), @@ -125,7 +125,7 @@ def test_check_valid_scan_raises_for_unknown_scan(scan_guard_mock): request = messages.ScanQueueMessage( scan_type="unknown_scan", - parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"step": 3}}, + parameter={"args": {"samx": [-5, 5], "samy": [-5, 5]}, "kwargs": {"step": 3}}, queue="primary", ) @@ -142,7 +142,7 @@ def test_check_valid_scan_accepts_known_scan(scan_guard_mock): request = messages.ScanQueueMessage( scan_type="fermat_scan", - parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"step": 3}}, + parameter={"args": {"samx": [-5, 5], "samy": [-5, 5]}, "kwargs": {"step": 3}}, queue="primary", ) @@ -209,7 +209,7 @@ def test_append_to_scan_queue(scan_guard_mock): sg = scan_guard_mock msg = messages.ScanQueueMessage( scan_type="fermat_scan", - parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"step": 3}}, + parameter={"args": {"samx": [-5, 5], "samy": [-5, 5]}, "kwargs": {"step": 3}}, queue="primary", ) with mock.patch.object(sg.device_manager.connector, "send") as send: @@ -221,7 +221,7 @@ def test_scan_queue_request_callback(scan_guard_mock): sg = scan_guard_mock msg = messages.ScanQueueMessage( scan_type="fermat_scan", - parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"step": 3}}, + parameter={"args": {"samx": [-5, 5], "samy": [-5, 5]}, "kwargs": {"step": 3}}, queue="primary", ) msg_obj = MessageObject(MessageEndpoints.scan_queue_request("default").endpoint, msg) @@ -255,7 +255,7 @@ def test_handle_scan_request(scan_guard_mock): sg = scan_guard_mock msg = messages.ScanQueueMessage( scan_type="fermat_scan", - parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"step": 3}}, + parameter={"args": {"samx": [-5, 5], "samy": [-5, 5]}, "kwargs": {"step": 3}}, queue="primary", ) with mock.patch.object(sg, "_is_valid_scan_request") as valid: @@ -336,7 +336,7 @@ def test_handle_scan_request_rejected(scan_guard_mock): sg = scan_guard_mock msg = messages.ScanQueueMessage( scan_type="fermat_scan", - parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"step": 3}}, + parameter={"args": {"samx": [-5, 5], "samy": [-5, 5]}, "kwargs": {"step": 3}}, queue="primary", ) with mock.patch.object(sg, "_is_valid_scan_request") as valid: @@ -350,7 +350,7 @@ def test_is_valid_scan_request_returns_scan_status_on_error(scan_guard_mock): sg = scan_guard_mock msg = messages.ScanQueueMessage( scan_type="fermat_scan", - parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"step": 3}}, + parameter={"args": {"samx": [-5, 5], "samy": [-5, 5]}, "kwargs": {"step": 3}}, queue="primary", metadata={"client_info": {"acl_user": "default"}}, ) diff --git a/bec_server/tests/tests_scan_server/test_scan_gui_models.py b/bec_server/tests/tests_scan_server/test_scan_gui_models.py index 346231e10..ff312261a 100644 --- a/bec_server/tests/tests_scan_server/test_scan_gui_models.py +++ b/bec_server/tests/tests_scan_server/test_scan_gui_models.py @@ -3,8 +3,9 @@ import pytest from pydantic import ValidationError +from bec_lib.messages import ScanArgType from bec_server.scan_server.scan_gui_models import GUIConfig -from bec_server.scan_server.scans import ScanArgType, ScanBase +from bec_server.scan_server.scans import ScanBase class GoodScan(ScanBase): # pragma: no cover @@ -210,7 +211,7 @@ def test_gui_config_good_scan_dump(): "expert": False, "name": "optim_trajectory", "tooltip": None, - "type": {"Literal": ("path", None)}, + "type": {"Literal": ["path", None]}, }, ], } diff --git a/bec_server/tests/tests_scan_server/test_scan_server_queue.py b/bec_server/tests/tests_scan_server/test_scan_server_queue.py index 0d0033d51..0a607165a 100644 --- a/bec_server/tests/tests_scan_server/test_scan_server_queue.py +++ b/bec_server/tests/tests_scan_server/test_scan_server_queue.py @@ -80,7 +80,7 @@ def test_queuemanager_add_to_queue(queuemanager_mock, queue): queue_manager = queuemanager_mock() msg = messages.ScanQueueMessage( scan_type="mv", - parameter={"args": {"samx": (1,)}, "kwargs": {}}, + parameter={"args": {"samx": [1]}, "kwargs": {}}, queue=queue, metadata={"RID": "something"}, ) @@ -122,7 +122,7 @@ def test_queuemanager_add_to_queue_restarts_queue_if_worker_is_dead(queuemanager msg = messages.ScanQueueMessage( scan_type="mv", - parameter={"args": {"samx": (1,)}, "kwargs": {}}, + parameter={"args": {"samx": [1]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -137,7 +137,7 @@ def test_queuemanager_add_to_queue_error_send_alarm(queuemanager_mock): queue_manager = queuemanager_mock() msg = messages.ScanQueueMessage( scan_type="mv", - parameter={"args": {"samx": (1,)}, "kwargs": {}}, + parameter={"args": {"samx": [1]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -153,7 +153,7 @@ def test_queuemanager_scan_queue_callback(queuemanager_mock): queue_manager = queuemanager_mock() msg = messages.ScanQueueMessage( scan_type="mv", - parameter={"args": {"samx": (1,)}, "kwargs": {}}, + parameter={"args": {"samx": [1]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -229,7 +229,7 @@ def test_set_pause(queuemanager_mock): # Add a queue item so worker_status has something to operate on msg = messages.ScanQueueMessage( scan_type="mv", - parameter={"args": {"samx": (1,)}, "kwargs": {}}, + parameter={"args": {"samx": [1]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -253,7 +253,7 @@ def test_set_pause_does_not_change_non_running_worker(queuemanager_mock): # Add a queue item msg = messages.ScanQueueMessage( scan_type="mv", - parameter={"args": {"samx": (1,)}, "kwargs": {}}, + parameter={"args": {"samx": [1]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -300,7 +300,7 @@ def test_set_abort(queuemanager_mock): queue_manager.connector.message_sent = [] msg = messages.ScanQueueMessage( scan_type="mv", - parameter={"args": {"samx": (1,)}, "kwargs": {}}, + parameter={"args": {"samx": [1]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -329,7 +329,7 @@ def test_set_abort_with_scan_id(queuemanager_mock): queue_manager.connector.message_sent = [] msg = messages.ScanQueueMessage( scan_type="line_scan", - parameter={"args": {"samx": (-1, 1)}, "kwargs": {"steps": 10, "relative": False}}, + parameter={"args": {"samx": [-1, 1]}, "kwargs": {"steps": 10, "relative": False}}, queue="primary", metadata={"RID": "something"}, ) @@ -358,7 +358,7 @@ def test_set_abort_with_scan_id_not_active(queuemanager_mock): queue_manager.connector.message_sent = [] msg = messages.ScanQueueMessage( scan_type="line_scan", - parameter={"args": {"samx": (-1, 1)}, "kwargs": {"steps": 10, "relative": False}}, + parameter={"args": {"samx": [-1, 1]}, "kwargs": {"steps": 10, "relative": False}}, queue="primary", metadata={"RID": "something"}, ) @@ -381,7 +381,7 @@ def test_set_abort_with_wrong_scan_id(queuemanager_mock): queue_manager.connector.message_sent = [] msg = messages.ScanQueueMessage( scan_type="line_scan", - parameter={"args": {"samx": (-1, 1)}, "kwargs": {"steps": 10, "relative": False}}, + parameter={"args": {"samx": [-1, 1]}, "kwargs": {"steps": 10, "relative": False}}, queue="primary", metadata={"RID": "something"}, ) @@ -439,7 +439,7 @@ def test_set_restart(queuemanager_mock): queue_manager.queues["primary"] = ScanQueue(queue_manager, queue_name="primary") msg = messages.ScanQueueMessage( scan_type="grid_scan", - parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}}, + parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -476,7 +476,7 @@ def test_set_restart_no_active_scan(queuemanager_mock): queue_manager.queues["primary"] = ScanQueue(queue_manager, queue_name="primary") msg = messages.ScanQueueMessage( scan_type="grid_scan", - parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}}, + parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -502,7 +502,7 @@ def test_set_user_completed(queuemanager_mock): queue_manager = queuemanager_mock() msg = messages.ScanQueueMessage( scan_type="line_scan", - parameter={"args": {"samx": (-1, 1)}, "kwargs": {"steps": 10, "relative": False}}, + parameter={"args": {"samx": [-1, 1]}, "kwargs": {"steps": 10, "relative": False}}, queue="primary", metadata={"RID": "something"}, ) @@ -523,7 +523,7 @@ def test_request_block(scan_server_mock): scan_server = scan_server_mock msg = messages.ScanQueueMessage( scan_type="mv", - parameter={"args": {"samx": (1,)}, "kwargs": {}}, + parameter={"args": {"samx": [1]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -538,7 +538,7 @@ def test_request_block(scan_server_mock): ( messages.ScanQueueMessage( scan_type="mv", - parameter={"args": {"samx": (1,)}, "kwargs": {}}, + parameter={"args": {"samx": [1]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -546,7 +546,7 @@ def test_request_block(scan_server_mock): ( messages.ScanQueueMessage( scan_type="grid_scan", - parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}}, + parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -572,7 +572,7 @@ def test_remove_queue_item(queuemanager_mock): queue_manager = queuemanager_mock() msg = messages.ScanQueueMessage( scan_type="mv", - parameter={"args": {"samx": (1,)}, "kwargs": {}}, + parameter={"args": {"samx": [1]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -586,7 +586,7 @@ def test_invalid_scan_specified_in_message(queuemanager_mock): queue_manager = queuemanager_mock() msg = messages.ScanQueueMessage( scan_type="fake test scan which does not exist!", - parameter={"args": {"samx": (1,)}, "kwargs": {}}, + parameter={"args": {"samx": [1]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -601,7 +601,7 @@ def test_set_clear(queuemanager_mock): queue_manager = queuemanager_mock() msg = messages.ScanQueueMessage( scan_type="mv", - parameter={"args": {"samx": (1,)}, "kwargs": {}}, + parameter={"args": {"samx": [1]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -693,7 +693,7 @@ def test_request_block_queue_append(): req_block_queue = RequestBlockQueue(mock.MagicMock(), mock.MagicMock()) msg = messages.ScanQueueMessage( scan_type="mv", - parameter={"args": {"samx": (1,)}, "kwargs": {}}, + parameter={"args": {"samx": [1]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -711,7 +711,7 @@ def test_request_block_queue_append(): ( messages.ScanQueueMessage( scan_type="mv", - parameter={"args": {"samx": (1,)}, "kwargs": {}}, + parameter={"args": {"samx": [1]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ), @@ -720,7 +720,7 @@ def test_request_block_queue_append(): ( messages.ScanQueueMessage( scan_type="grid_scan", - parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}}, + parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}}, queue="primary", metadata={"RID": "something", "scan_def_id": "something"}, ), @@ -729,7 +729,7 @@ def test_request_block_queue_append(): ( messages.ScanQueueMessage( scan_type="grid_scan", - parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}}, + parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}}, queue="primary", metadata={"RID": "something", "scan_def_id": "existing_scan_def_id"}, ), @@ -766,7 +766,7 @@ def test_append_request_block(): ( messages.ScanQueueMessage( scan_type="grid_scan", - parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}}, + parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}}, queue="primary", metadata={"RID": "something", "scan_def_id": "something"}, ), @@ -775,7 +775,7 @@ def test_append_request_block(): ( messages.ScanQueueMessage( scan_type="grid_scan", - parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}}, + parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}}, queue="primary", metadata={"RID": "something", "scan_def_id": "existing_scan_def_id"}, ), @@ -806,7 +806,7 @@ def test_update_point_id(scan_queue_msg, scan_id): ( messages.ScanQueueMessage( scan_type="grid_scan", - parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}}, + parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}}, queue="primary", metadata={"RID": "something", "scan_def_id": "existing_scan_def_id"}, ), @@ -833,7 +833,7 @@ def test_update_point_id_takes_max(scan_queue_msg, scan_id): ( messages.ScanQueueMessage( scan_type="mv", - parameter={"args": {"samx": (1,)}, "kwargs": {}}, + parameter={"args": {"samx": [1]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ), @@ -842,7 +842,7 @@ def test_update_point_id_takes_max(scan_queue_msg, scan_id): ( messages.ScanQueueMessage( scan_type="grid_scan", - parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}}, + parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ), @@ -851,7 +851,7 @@ def test_update_point_id_takes_max(scan_queue_msg, scan_id): ( messages.ScanQueueMessage( scan_type="grid_scan", - parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}}, + parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}}, queue="primary", metadata={"RID": "something", "scan_def_id": "existing_scan_def_id"}, ), @@ -860,7 +860,7 @@ def test_update_point_id_takes_max(scan_queue_msg, scan_id): ( messages.ScanQueueMessage( scan_type="grid_scan", - parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}}, + parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}}, queue="primary", metadata={"RID": "something", "dataset_id_on_hold": True}, ), @@ -891,7 +891,7 @@ def test_pull_request_block_non_empty_rb(): req_block_queue = RequestBlockQueue(mock.MagicMock(), mock.MagicMock()) scan_queue_msg = messages.ScanQueueMessage( scan_type="grid_scan", - parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}}, + parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -948,7 +948,7 @@ def test_queue_manager_get_active_scan_id(queuemanager_mock): queue_manager = queuemanager_mock() msg = messages.ScanQueueMessage( scan_type="mv", - parameter={"args": {"samx": (1,)}, "kwargs": {}}, + parameter={"args": {"samx": [1]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -967,7 +967,7 @@ def test_queue_manager_get_active_scan_id_wo_rbl_returns_None(queuemanager_mock) queue_manager = queuemanager_mock() msg = messages.ScanQueueMessage( scan_type="mv", - parameter={"args": {"samx": (1,)}, "kwargs": {}}, + parameter={"args": {"samx": [1]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -979,7 +979,7 @@ def test_request_block_queue_next(): req_block_queue = RequestBlockQueue(mock.MagicMock(), mock.MagicMock()) msg = messages.ScanQueueMessage( scan_type="mv", - parameter={"args": {"samx": (1,)}, "kwargs": {}}, + parameter={"args": {"samx": [1]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -995,7 +995,7 @@ def test_request_block_queue_next_raises_stopiteration(): req_block_queue = RequestBlockQueue(mock.MagicMock(), mock.MagicMock()) msg = messages.ScanQueueMessage( scan_type="mv", - parameter={"args": {"samx": (1,)}, "kwargs": {}}, + parameter={"args": {"samx": [1]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -1012,7 +1012,7 @@ def test_request_block_queue_next_updates_point_id(): req_block_queue = RequestBlockQueue(mock.MagicMock(), mock.MagicMock()) msg = messages.ScanQueueMessage( scan_type="mv", - parameter={"args": {"samx": (1,)}, "kwargs": {}}, + parameter={"args": {"samx": [1]}, "kwargs": {}}, queue="primary", metadata={"RID": "something", "scan_def_id": "scan_def_id"}, ) @@ -1073,7 +1073,7 @@ def test_queue_order_change(queuemanager_mock, order_msg, position): queue_manager = queuemanager_mock() msg = messages.ScanQueueMessage( scan_type="line_scan", - parameter={"args": {"samx": (-5, 5)}, "kwargs": {"steps": 3}}, + parameter={"args": {"samx": [-5, 5]}, "kwargs": {"steps": 3}}, queue="primary", metadata={"RID": "something"}, ) diff --git a/bec_server/tests/tests_scan_server/test_scan_server_scan_manager.py b/bec_server/tests/tests_scan_server/test_scan_server_scan_manager.py index 293ac4110..391cdeb9d 100644 --- a/bec_server/tests/tests_scan_server/test_scan_server_scan_manager.py +++ b/bec_server/tests/tests_scan_server/test_scan_server_scan_manager.py @@ -3,8 +3,8 @@ import pytest from bec_lib.device import Device, DeviceBase, Positioner +from bec_lib.messages import ScanArgType from bec_server.scan_server.scan_manager import ScanManager -from bec_server.scan_server.scans import ScanArgType @pytest.fixture diff --git a/bec_server/tests/tests_scan_server/test_scan_worker.py b/bec_server/tests/tests_scan_server/test_scan_worker.py index 6de30a457..b7c37a772 100644 --- a/bec_server/tests/tests_scan_server/test_scan_worker.py +++ b/bec_server/tests/tests_scan_server/test_scan_worker.py @@ -101,7 +101,7 @@ def test_publish_data_as_read(scan_worker_mock): def test_publish_data_as_read_multiple(scan_worker_mock): worker = scan_worker_mock - data = [{"samx": {}}, {"samy": {}}] + data = [{"samx": {"value": None}}, {"samy": {"value": None}}] devices = ["samx", "samy"] instr = messages.DeviceInstructionMessage( device=devices, @@ -217,7 +217,7 @@ def test_open_scan(scan_worker_mock, instr, corr_num_points, scan_id): messages.ScanQueueMessage( scan_type="grid_scan", parameter={ - "args": {"samx": (-5, 5, 5), "samy": (-1, 1, 2)}, + "args": {"samx": [-5, 5, 5], "samy": [-1, 1, 2]}, "kwargs": { "exp_time": 1, "relative": True, @@ -234,7 +234,7 @@ def test_open_scan(scan_worker_mock, instr, corr_num_points, scan_id): messages.ScanQueueMessage( scan_type="grid_scan", parameter={ - "args": {"samx": (-5, 5, 5), "samy": (-1, 1, 2)}, + "args": {"samx": [-5, 5, 5], "samy": [-1, 1, 2]}, "kwargs": { "exp_time": 1, "relative": True, @@ -251,7 +251,7 @@ def test_open_scan(scan_worker_mock, instr, corr_num_points, scan_id): messages.ScanQueueMessage( scan_type="grid_scan", parameter={ - "args": {"samx": (-5, 5, 5), "samy": (-1, 1, 2)}, + "args": {"samx": [-5, 5, 5], "samy": [-1, 1, 2]}, "kwargs": { "exp_time": 1, "relative": True, diff --git a/bec_server/tests/tests_scan_server/test_scans.py b/bec_server/tests/tests_scan_server/test_scans.py index 900103311..16d3beeaf 100644 --- a/bec_server/tests/tests_scan_server/test_scans.py +++ b/bec_server/tests/tests_scan_server/test_scans.py @@ -69,7 +69,7 @@ def test_unpack_scan_args_valid_input(): ( messages.ScanQueueMessage( scan_type="mv", - parameter={"args": {"samx": (1,), "samy": (2,)}, "kwargs": {}}, + parameter={"args": {"samx": [1], "samy": [2]}, "kwargs": {}}, queue="primary", ), [ @@ -90,7 +90,7 @@ def test_unpack_scan_args_valid_input(): ( messages.ScanQueueMessage( scan_type="mv", - parameter={"args": {"samx": (1,), "samy": (2,), "samz": (3,)}, "kwargs": {}}, + parameter={"args": {"samx": [1], "samy": [2], "samz": [3]}, "kwargs": {}}, queue="primary", ), [ @@ -116,7 +116,7 @@ def test_unpack_scan_args_valid_input(): ), ( messages.ScanQueueMessage( - scan_type="mv", parameter={"args": {"samx": (1,)}, "kwargs": {}}, queue="primary" + scan_type="mv", parameter={"args": {"samx": [1]}, "kwargs": {}}, queue="primary" ), [ messages.DeviceInstructionMessage( @@ -154,7 +154,7 @@ def offset_mock(): ( messages.ScanQueueMessage( scan_type="umv", - parameter={"args": {"samx": (1,), "samy": (2,)}, "kwargs": {}}, + parameter={"args": {"samx": [1], "samy": [2]}, "kwargs": {}}, queue="primary", metadata={"RID": "0bab7ee3-b384-4571-b...0fff984c05"}, ), @@ -198,7 +198,7 @@ def offset_mock(): ( messages.ScanQueueMessage( scan_type="umv", - parameter={"args": {"samx": (1,), "samy": (2,), "samz": (3,)}, "kwargs": {}}, + parameter={"args": {"samx": [1], "samy": [2], "samz": [3]}, "kwargs": {}}, queue="primary", metadata={"RID": "0bab7ee3-b384-4571-b...0fff984c05"}, ), @@ -251,7 +251,7 @@ def offset_mock(): ( messages.ScanQueueMessage( scan_type="umv", - parameter={"args": {"samx": (1,)}, "kwargs": {}}, + parameter={"args": {"samx": [1]}, "kwargs": {}}, queue="primary", metadata={"RID": "0bab7ee3-b384-4571-b...0fff984c05"}, ), @@ -297,8 +297,8 @@ def test_scan_updated_move(mv_msg, reference_msg_list, scan_assembler, ScanStubS mock_get_from_rpc.return_value = { dev: {"value": value} for dev, value in zip( - reference_msg_list[0].content["parameter"]["readback"]["devices"], - reference_msg_list[0].content["parameter"]["readback"]["start"], + reference_msg_list[0].parameter["readback"]["devices"], + reference_msg_list[0].parameter["readback"]["start"], ) } @@ -322,7 +322,7 @@ def mock_rpc_func(*args, **kwargs): ( messages.ScanQueueMessage( scan_type="grid_scan", - parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}}, + parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}}, queue="primary", ), [ @@ -473,7 +473,7 @@ def offset_mock(): ( messages.ScanQueueMessage( scan_type="grid_scan", - parameter={"args": {"samx": (-5, 5, 2), "samy": (-5, 5, 2)}, "kwargs": {}}, + parameter={"args": {"samx": [-5, 5, 2], "samy": [-5, 5, 2]}, "kwargs": {}}, queue="primary", ), [ @@ -648,7 +648,7 @@ def offset_mock(): ( messages.ScanQueueMessage( scan_type="fermat_scan", - parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"step": 3}}, + parameter={"args": {"samx": [-5, 5], "samy": [-5, 5]}, "kwargs": {"step": 3}}, queue="primary", ), [ @@ -668,7 +668,7 @@ def offset_mock(): messages.ScanQueueMessage( scan_type="fermat_scan", parameter={ - "args": {"samx": (-5, 5), "samy": (-5, 5)}, + "args": {"samx": [-5, 5], "samy": [-5, 5]}, "kwargs": {"step": 3, "spiral_type": 1}, }, queue="primary", @@ -721,7 +721,7 @@ def offset_mock(): }, scan_type="cont_line_scan", parameter={ - "args": ("samx", -1, 1), + "args": ["samx", -1, 1], "kwargs": { "steps": 3, "exp_time": 0.1, @@ -744,7 +744,7 @@ def offset_mock(): metadata={"readout_priority": "monitored"}, device="samx", action="rpc", - parameter={"device": "samx", "func": "velocity.get", "args": (), "kwargs": {}}, + parameter={"device": "samx", "func": "velocity.get", "args": [], "kwargs": {}}, ), messages.DeviceInstructionMessage( metadata={"readout_priority": "monitored"}, @@ -753,7 +753,7 @@ def offset_mock(): parameter={ "device": "samx", "func": "acceleration.get", - "args": (), + "args": [], "kwargs": {}, }, ), @@ -761,7 +761,7 @@ def offset_mock(): metadata={"readout_priority": "monitored"}, device="samx", action="rpc", - parameter={"device": "samx", "func": "read", "args": (), "kwargs": {}}, + parameter={"device": "samx", "func": "read", "args": [], "kwargs": {}}, ), messages.DeviceInstructionMessage( metadata={"readout_priority": "monitored"}, @@ -1053,7 +1053,7 @@ def pre_scan_macro(devices: dict, request: RequestBase): macros = inspect.getsource(pre_scan_macro).encode() scan_msg = messages.ScanQueueMessage( scan_type="fermat_scan", - parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"step": 3}}, + parameter={"args": {"samx": [-5, 5], "samy": [-5, 5]}, "kwargs": {"step": 3}}, queue="primary", ) args = unpack_scan_args(scan_msg.content.get("parameter", {}).get("args", [])) @@ -1077,7 +1077,7 @@ def pre_scan_macro(devices: dict, request: RequestBase): # device_manager = DMMock() # device_manager.add_device("samx") # parameter = { -# "args": {"samx": (-5, 5), "samy": (-5, 5)}, +# "args": {"samx": [-5, 5], "samy": [-5, 5]}, # "kwargs": {"step": 3}, # } # request = RequestBase(device_manager=device_manager, parameter=parameter) @@ -1099,7 +1099,7 @@ def test_round_roi_scan(): scan_msg = messages.ScanQueueMessage( scan_type="round_roi_scan", parameter={ - "args": {"samx": (10,), "samy": (10,)}, + "args": {"samx": [10], "samy": [10]}, "kwargs": {"dr": 2, "nth": 4, "exp_time": 2, "relative": True}, }, queue="primary", @@ -1211,7 +1211,7 @@ def pre_scan_macro(devices: dict, request: RequestBase): macros = [inspect.getsource(pre_scan_macro).encode()] scan_msg = messages.ScanQueueMessage( scan_type="fermat_scan", - parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"step": 3}}, + parameter={"args": {"samx": [-5, 5], "samy": [-5, 5]}, "kwargs": {"step": 3}}, queue="primary", ) args = unpack_scan_args(scan_msg.content.get("parameter", {}).get("args", [])) @@ -1228,7 +1228,7 @@ def test_scan_report_devices(): device_manager.add_device("samy") scan_msg = messages.ScanQueueMessage( scan_type="fermat_scan", - parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"step": 3}}, + parameter={"args": {"samx": [-5, 5], "samy": [-5, 5]}, "kwargs": {"step": 3}}, queue="primary", ) args = unpack_scan_args(scan_msg.content.get("parameter", {}).get("args", [])) @@ -1253,7 +1253,7 @@ def run(self): device_manager.add_device("samy") scan_msg = messages.ScanQueueMessage( scan_type="fermat_scan", - parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"step": 3}}, + parameter={"args": {"samx": [-5, 5], "samy": [-5, 5]}, "kwargs": {"step": 3}}, queue="primary", ) request = RequestBaseMock( @@ -1305,7 +1305,7 @@ def run(self): device_manager.add_device("samz") scan_msg = messages.ScanQueueMessage( scan_type="fermat_scan", - parameter={"args": {"samx": (-5, 5)}, "kwargs": {"step": 3}}, + parameter={"args": {"samx": [-5, 5]}, "kwargs": {"step": 3}}, queue="primary", ) request = RequestBaseMock( @@ -1318,7 +1318,7 @@ def run(self): assert request.scan_motors == ["samx"] request.arg_bundle_size = {"bundle": 2, "min": None, "max": None} - request.caller_args = {"samz": (-2, 2), "samy": (-1, 2)} + request.caller_args = {"samz": [-2, 2], "samy": [-1, 2]} request.update_scan_motors() assert request.scan_motors == ["samz", "samy"] @@ -1340,7 +1340,7 @@ def _calculate_positions(self): scan_msg = messages.ScanQueueMessage( scan_type="", - parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"step": 3}}, + parameter={"args": {"samx": [-5, 5], "samy": [-5, 5]}, "kwargs": {"step": 3}}, queue="primary", ) with pytest.raises(ValueError) as exc_info: @@ -1358,7 +1358,7 @@ def test_scan_base_set_position_offset(): scan_msg = messages.ScanQueueMessage( scan_type="fermat_scan", parameter={ - "args": {"samx": (-5, 5), "samy": (-5, 5)}, + "args": {"samx": [-5, 5], "samy": [-5, 5]}, "kwargs": {"step": 3, "relative": False}, }, queue="primary", @@ -1385,7 +1385,7 @@ def test_round_scan_fly_simupdate_scan_motors(): device_manager.add_device("flyer_sim") scan_msg = messages.ScanQueueMessage( scan_type="round_scan_fly", - parameter={"args": {"flyer_sim": (0, 50, 5, 3)}, "kwargs": {"realtive": True}}, + parameter={"args": {"flyer_sim": [0, 50, 5, 3]}, "kwargs": {"realtive": True}}, queue="primary", ) request = RoundScanFlySim( @@ -1408,7 +1408,7 @@ def test_round_scan_fly_sim_prepare_positions(): device_manager.add_device("flyer_sim") scan_msg = messages.ScanQueueMessage( scan_type="round_scan_fly", - parameter={"args": {"flyer_sim": (0, 50, 5, 3)}, "kwargs": {"realtive": True}}, + parameter={"args": {"flyer_sim": [0, 50, 5, 3]}, "kwargs": {"realtive": True}}, queue="primary", ) request = RoundScanFlySim( @@ -1433,7 +1433,7 @@ def test_round_scan_fly_sim_prepare_positions(): @pytest.mark.parametrize( - "in_args,reference_positions", [((1, 5, 1, 1), [[0, -3], [0, -7], [0, 7]])] + "in_args,reference_positions", [([1, 5, 1, 1], [[0, -3], [0, -7], [0, 7]])] ) def test_round_scan_fly_sim_calculate_positions(in_args, reference_positions): device_manager = DMMock() @@ -1458,7 +1458,7 @@ def test_round_scan_fly_sim_calculate_positions(in_args, reference_positions): @pytest.mark.parametrize( - "in_args,reference_positions", [((1, 5, 1, 1), [[0, -3], [0, -7], [0, 7]])] + "in_args,reference_positions", [([1, 5, 1, 1], [[0, -3], [0, -7], [0, 7]])] ) def test_round_scan_fly_sim_scan_core(in_args, reference_positions, scan_assembler): scan_msg = messages.ScanQueueMessage( @@ -2139,7 +2139,7 @@ def fake_set(*args, **kwargs): "device": "samx", "func": "read", "rpc_id": "rpc_id", - "args": (), + "args": [], "kwargs": {}, }, ), From ea61f2d4d526be6270a0617a316d9c3813ea5a94 Mon Sep 17 00:00:00 2001 From: perl_d Date: Tue, 24 Feb 2026 11:30:44 +0100 Subject: [PATCH 4/8] refactor: relax jsonable strictness --- bec_lib/bec_lib/messages.py | 21 +++++++-------------- bec_lib/bec_lib/scans.py | 31 +++++++++++++++++-------------- 2 files changed, 24 insertions(+), 28 deletions(-) diff --git a/bec_lib/bec_lib/messages.py b/bec_lib/bec_lib/messages.py index b8ced733d..838b0f140 100644 --- a/bec_lib/bec_lib/messages.py +++ b/bec_lib/bec_lib/messages.py @@ -12,17 +12,14 @@ from typing import Annotated, Any, ClassVar, Literal, Mapping, Self, TypeVar, Union from uuid import uuid4 +import msgpack import numpy as np from pydantic import ( BaseModel, BeforeValidator, ConfigDict, - FailFast, Field, Strict, - StrictBool, - StrictFloat, - StrictInt, StrictStr, ValidationError, WithJsonSchema, @@ -50,19 +47,15 @@ def sanitize_one_way_encodable(data: Any) -> Any: return _one_way_registry.encode(data) -JsonableScalar = TypeAliasType("JsonableScalar", StrictInt | StrictFloat | StrictStr | StrictBool) +def _try_dump(v): + msgpack.dumps(v) + return v -Jsonable = TypeAliasType( - "Jsonable", - JsonableScalar - | None - | Annotated[list["Jsonable"], Strict(), FailFast()] - | Annotated[dict[StrictStr, "Jsonable"], Strict()], -) + +Jsonable = TypeAliasType("Jsonable", Annotated[Any, BeforeValidator(_try_dump)]) JsonableDict = TypeAliasType( - "JsonableDict", - Annotated[dict[StrictStr, Jsonable], WithJsonSchema({"type": "object"}), Strict()], + "JsonableDict", Annotated[dict[str, Jsonable], WithJsonSchema({"type": "object"})] ) diff --git a/bec_lib/bec_lib/scans.py b/bec_lib/bec_lib/scans.py index 5dccfa8b1..b7b103df5 100644 --- a/bec_lib/bec_lib/scans.py +++ b/bec_lib/bec_lib/scans.py @@ -22,6 +22,7 @@ from bec_lib.endpoints import MessageEndpoints from bec_lib.logger import bec_logger from bec_lib.messages import ScanArgType # moved from here to messages - for compat with plugins +from bec_lib.messages import AvailableResourceMessage, AvailableScan from bec_lib.scan_repeat import _scan_repeat_depth from bec_lib.scan_report import ScanReport from bec_lib.signature_serializer import deserialize_dtype, dict_to_signature @@ -41,7 +42,7 @@ class ScanObject: """ScanObject is a class for scans""" - def __init__(self, scan_name: str, scan_info: dict, client: BECClient = None) -> None: + def __init__(self, scan_name: str, scan_info: AvailableScan, client: BECClient = None) -> None: self.scan_name = scan_name self.scan_info = scan_info self.client = client @@ -187,14 +188,17 @@ def __init__(self, parent): def _import_scans(self): """Import scans from the scan server""" - available_scans = self.parent.connector.get(MessageEndpoints.available_scans()) - if available_scans is None: + available_scans_msg: AvailableResourceMessage | None = self.parent.connector.get( + MessageEndpoints.available_scans() + ) + if available_scans_msg is None: logger.warning("No scans available. Are redis and the BEC server running?") return - for scan_name, scan_info in available_scans.resource.items(): + available_scans: dict[str, AvailableScan] = available_scans_msg.resource + for scan_name, scan_info in available_scans.items(): self._available_scans[scan_name] = ScanObject(scan_name, scan_info, client=self.parent) setattr(self, scan_name, self._available_scans[scan_name].run) - setattr(getattr(self, scan_name), "__doc__", scan_info.get("doc")) + setattr(getattr(self, scan_name), "__doc__", scan_info.doc) setattr( getattr(self, scan_name), "__signature__", @@ -289,7 +293,7 @@ def get_arg_type(in_type: str | dict | list): @staticmethod def prepare_scan_request( - scan_name: str, scan_info: dict, *args, **kwargs + scan_name: str, scan_info: AvailableScan, *args, **kwargs ) -> messages.ScanQueueMessage: """Prepare scan request message with given scan arguments @@ -307,20 +311,20 @@ def prepare_scan_request( """ scan_queue = kwargs.pop("scan_queue", "primary") # check that all required keyword arguments have been specified - if not all(req_kwarg in kwargs for req_kwarg in scan_info.get("required_kwargs")): + if not all(req_kwarg in kwargs for req_kwarg in scan_info.required_kwargs): raise TypeError( - f"{scan_info.get('doc')}\n Not all required keyword arguments have been" - f" specified. The required arguments are: {scan_info.get('required_kwargs')}" + f"{scan_info.doc}\n Not all required keyword arguments have been" + f" specified. The required arguments are: {scan_info.required_kwargs}" ) # check that all required arguments have been specified - arg_input = list(scan_info.get("arg_input", {}).values()) - arg_bundle_size = scan_info.get("arg_bundle_size", {}) + arg_input = list(scan_info.arg_input.values()) + arg_bundle_size = scan_info.arg_bundle_size bundle_size = arg_bundle_size.get("bundle") if len(arg_input) > 0: if len(args) % len(arg_input) != 0: raise TypeError( - f"{scan_info.get('doc')}\n {scan_name} takes multiples of" + f"{scan_info.doc}\n {scan_name} takes multiples of" f" {len(arg_input)} arguments ({len(args)} given)." ) # check that all specified devices in args are different objects @@ -329,8 +333,7 @@ def prepare_scan_request( continue if args.count(arg) > 1: raise TypeError( - f"{scan_info.get('doc')}\n All specified devices must be different" - f" objects." + f"{scan_info.doc}\n All specified devices must be different objects." ) # check that all arguments are of the correct type From e0581fdcb71a834984776125f4a214e2da7a2e49 Mon Sep 17 00:00:00 2001 From: David Perl Date: Mon, 2 Mar 2026 08:43:53 +0100 Subject: [PATCH 5/8] refactor: adjust logic which uses now structured objects --- bec_lib/bec_lib/devicemanager.py | 2 +- bec_lib/bec_lib/messages.py | 67 ++++++++++-------------------- bec_lib/bec_lib/scans.py | 6 +-- bec_lib/tests/test_bec_messages.py | 13 +----- 4 files changed, 28 insertions(+), 60 deletions(-) diff --git a/bec_lib/bec_lib/devicemanager.py b/bec_lib/bec_lib/devicemanager.py index 7fbbcb0b1..f16ff40f1 100644 --- a/bec_lib/bec_lib/devicemanager.py +++ b/bec_lib/bec_lib/devicemanager.py @@ -77,7 +77,7 @@ def _rgetattr_safe(obj, attr, *args): return None -class DeviceContainer(dict): +class DeviceContainer(dict[str, DeviceBase]): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) for arg in args: diff --git a/bec_lib/bec_lib/messages.py b/bec_lib/bec_lib/messages.py index 838b0f140..4f59ff304 100644 --- a/bec_lib/bec_lib/messages.py +++ b/bec_lib/bec_lib/messages.py @@ -2,6 +2,7 @@ from __future__ import annotations import getpass +import sys import time import uuid import warnings @@ -9,7 +10,7 @@ from enum import Enum, StrEnum, auto from importlib.metadata import PackageNotFoundError from importlib.metadata import version as importlib_version -from typing import Annotated, Any, ClassVar, Literal, Mapping, Self, TypeVar, Union +from typing import Annotated, Any, ClassVar, Literal, Mapping, NotRequired, Self, TypeVar from uuid import uuid4 import msgpack @@ -18,9 +19,8 @@ BaseModel, BeforeValidator, ConfigDict, + FailFast, Field, - Strict, - StrictStr, ValidationError, WithJsonSchema, field_validator, @@ -32,6 +32,11 @@ from bec_lib.metadata_schema import get_metadata_schema_for_scan from bec_lib.one_way_registry import OneWaySerializationRegistry +if sys.version_info >= (3, 12): + from typing import TypedDict +else: + from typing_extensions import TypedDict # Pydantic needs the typing_extensions version on 3.11 + _one_way_registry = OneWaySerializationRegistry() @@ -47,8 +52,17 @@ def sanitize_one_way_encodable(data: Any) -> Any: return _one_way_registry.encode(data) +# Temporary enforcement of only primitive values which can be serialized in otherwise uptyped dicts +# This can be removed when the refactor is complete + + def _try_dump(v): - msgpack.dumps(v) + if isinstance(v, BECMessage): + return v # ignore where we have mixed types + try: + msgpack.dumps(v) + except TypeError as e: + raise ValueError(str(v)) from e return v @@ -610,44 +624,9 @@ def _ensure_error_info_if_error(self): return self -# TODO: remove when deprecated usages of SignalReading are cleaned up -logger = None - - -def lazy_ensure_logger(): - global logger - if logger is None: - from bec_lib.logger import bec_logger - - logger = bec_logger.logger - - -class SignalReading(BECSerializable): - value: int | float | list[int] | list[float] | np.ndarray | None | str - timestamp: float | list[float] | None = None - - def keys(self): - lazy_ensure_logger() - logger.warning( - "Dictionary usage of SignalReading is deprecated; please replace it with a different access pattern." - ) - return ["value", "timestamp"] - - def get(self, item: Literal["value", "timestamp"], default=Any): - """Allow dictionary-style access for legacy reasons.""" - lazy_ensure_logger() - logger.warning( - "Get-access on SignalReading is deprecated; Just access the model.value field." - ) - if item not in ["value", "timestamp"]: - raise KeyError('SignalReading only has "value" and "timestamp" fields!') - return getattr(self, item) - - def __getitem__(self, item: str): - return self.get(item) - - def items(self): - return dict(self).items() +class SignalReading(TypedDict): + value: NotRequired[int | float | list[int] | list[float] | np.ndarray | None | str] + timestamp: NotRequired[float | list[float]] class DeviceMessage(BECMessage): @@ -1391,9 +1370,9 @@ class AvailableResourceMessage(BECMessage): msg_type: ClassVar[str] = "available_resource_message" resource: ( JsonableDict - | list[JsonableDict] + | Annotated[list[JsonableDict], FailFast] | SpecificMessageType - | list[SpecificMessageType] + | Annotated[list[SpecificMessageType], FailFast] | dict[str, SpecificMessageType] ) diff --git a/bec_lib/bec_lib/scans.py b/bec_lib/bec_lib/scans.py index b7b103df5..5980cad12 100644 --- a/bec_lib/bec_lib/scans.py +++ b/bec_lib/bec_lib/scans.py @@ -340,7 +340,7 @@ def prepare_scan_request( for ii, arg in enumerate(args): if not isinstance(arg, Scans.get_arg_type(arg_input[ii % len(arg_input)])): raise TypeError( - f"{scan_info.get('doc')}\n Argument {ii} must be of type" + f"{scan_info.doc}\n Argument {ii} must be of type" f" {arg_input[ii%len(arg_input)]}, not {type(arg).__name__}." ) @@ -356,12 +356,12 @@ def prepare_scan_request( max_bundles = arg_bundle_size.get("max") if min_bundles and num_bundles < min_bundles: raise TypeError( - f"{scan_info.get('doc')}\n {scan_name} requires at least {min_bundles} bundles" + f"{scan_info.doc}\n {scan_name} requires at least {min_bundles} bundles" f" of arguments ({num_bundles} given)." ) if max_bundles and num_bundles > max_bundles: raise TypeError( - f"{scan_info.get('doc')}\n {scan_name} requires at most {max_bundles} bundles" + f"{scan_info.doc}\n {scan_name} requires at most {max_bundles} bundles" f" of arguments ({num_bundles} given)." ) # Check if we are in a "restart" decorator context diff --git a/bec_lib/tests/test_bec_messages.py b/bec_lib/tests/test_bec_messages.py index 3b2cf6f44..0c9a4c0ef 100644 --- a/bec_lib/tests/test_bec_messages.py +++ b/bec_lib/tests/test_bec_messages.py @@ -709,15 +709,4 @@ def test_message_with_np_array_in_dict(): with pytest.raises(pydantic.ValidationError) as e: msg = messages.BECMessage(metadata={"value": arr}) assert e.match("metadata.value") - assert e.match("should be a valid") - - -def test_message_service_config(): - msg = messages.MessagingServiceConfig( - metadata={}, service_name="signal", scopes=["*"], enabled=True - ) - dump = msg.model_dump(mode="python") - assert dump["service_name"] == "signal" - resource_msg = messages.AvailableResourceMessage(resource=[msg]) - resource_msg_dump = resource_msg.model_dump(mode="python") - assert resource_msg_dump["resource"][0]["service_name"] == "signal" + assert e.match("input_type=ndarray") From 02f14a70cebd5f1750e7d3490c045b6d82454c1c Mon Sep 17 00:00:00 2001 From: David Perl Date: Mon, 2 Mar 2026 11:26:07 +0100 Subject: [PATCH 6/8] fix: revert test objects --- bec_lib/tests/test_devices.py | 56 +++++----------------- bec_lib/tests/test_metadata_schema.py | 17 +++++-- bec_lib/tests/test_scan_object.py | 68 +++++++++++++++++---------- 3 files changed, 69 insertions(+), 72 deletions(-) diff --git a/bec_lib/tests/test_devices.py b/bec_lib/tests/test_devices.py index 4b69ef78c..9f99b33c8 100644 --- a/bec_lib/tests/test_devices.py +++ b/bec_lib/tests/test_devices.py @@ -61,15 +61,9 @@ def test_read(dev: Any): res = dev.samx.read(cached=True) mock_get.assert_called_once_with(MessageEndpoints.device_readback("samx")) assert res == { - "samx": messages.SignalReading.model_validate( - {"value": 0, "timestamp": 1701105880.1711318} - ), - "samx_setpoint": messages.SignalReading.model_validate( - {"value": 0, "timestamp": 1701105880.1693492} - ), - "samx_motor_is_moving": messages.SignalReading.model_validate( - {"value": 0, "timestamp": 1701105880.16935} - ), + "samx": {"value": 0, "timestamp": 1701105880.1711318}, + "samx_setpoint": {"value": 0, "timestamp": 1701105880.1693492}, + "samx_motor_is_moving": {"value": 0, "timestamp": 1701105880.16935}, } @@ -85,25 +79,15 @@ def test_read_filtered_hints(dev: Any): ) res = dev.samx.read(cached=True, filter_to_hints=True) mock_get.assert_called_once_with(MessageEndpoints.device_readback("samx")) - assert res == { - "samx": messages.SignalReading.model_validate( - {"value": 0, "timestamp": 1701105880.1711318} - ) - } + assert res == {"samx": {"value": 0, "timestamp": 1701105880.1711318}} def test_read_use_read(dev: Any): with mock.patch.object(dev.samx.root.parent.connector, "get") as mock_get: data = { - "samx": messages.SignalReading.model_validate( - {"value": 0, "timestamp": 1701105880.1711318} - ), - "samx_setpoint": messages.SignalReading.model_validate( - {"value": 0, "timestamp": 1701105880.1693492} - ), - "samx_motor_is_moving": messages.SignalReading.model_validate( - {"value": 0, "timestamp": 1701105880.16935} - ), + "samx": {"value": 0, "timestamp": 1701105880.1711318}, + "samx_setpoint": {"value": 0, "timestamp": 1701105880.1693492}, + "samx_motor_is_moving": {"value": 0, "timestamp": 1701105880.16935}, } mock_get.return_value = messages.DeviceMessage( signals=data, metadata={"scan_id": "scan_id", "scan_type": "scan_type"} @@ -116,21 +100,11 @@ def test_read_use_read(dev: Any): def test_read_nested_device(dev: Any): with mock.patch.object(dev.dyn_signals.root.parent.connector, "get") as mock_get: data = { - "dyn_signals_messages_message1": messages.SignalReading.model_validate( - {"value": 0, "timestamp": 1701105880.0716832} - ), - "dyn_signals_messages_message2": messages.SignalReading.model_validate( - {"value": 0, "timestamp": 1701105880.071722} - ), - "dyn_signals_messages_message3": messages.SignalReading.model_validate( - {"value": 0, "timestamp": 1701105880.071739} - ), - "dyn_signals_messages_message4": messages.SignalReading.model_validate( - {"value": 0, "timestamp": 1701105880.071753} - ), - "dyn_signals_messages_message5": messages.SignalReading.model_validate( - {"value": 0, "timestamp": 1701105880.071766} - ), + "dyn_signals_messages_message1": {"value": 0, "timestamp": 1701105880.0716832}, + "dyn_signals_messages_message2": {"value": 0, "timestamp": 1701105880.071722}, + "dyn_signals_messages_message3": {"value": 0, "timestamp": 1701105880.071739}, + "dyn_signals_messages_message4": {"value": 0, "timestamp": 1701105880.071753}, + "dyn_signals_messages_message5": {"value": 0, "timestamp": 1701105880.071766}, } mock_get.return_value = messages.DeviceMessage( signals=data, metadata={"scan_id": "scan_id", "scan_type": "scan_type"} @@ -165,11 +139,7 @@ def test_read_kind_hinted( if cached: mock_get.assert_called_once_with(MessageEndpoints.device_readback("samx")) mock_run.assert_not_called() - assert res == { - "samx": messages.SignalReading.model_validate( - {"value": 0, "timestamp": 1701105880.1711318} - ) - } + assert res == {"samx": {"value": 0, "timestamp": 1701105880.1711318}} else: mock_run.assert_called_once_with(cached=False, fcn=dev.samx.readback.read) mock_get.assert_not_called() diff --git a/bec_lib/tests/test_metadata_schema.py b/bec_lib/tests/test_metadata_schema.py index 5b10e71f4..04d1a131e 100644 --- a/bec_lib/tests/test_metadata_schema.py +++ b/bec_lib/tests/test_metadata_schema.py @@ -4,7 +4,7 @@ from pydantic import ValidationError from bec_lib import metadata_schema -from bec_lib.messages import ScanQueueMessage +from bec_lib.messages import AvailableScan, ScanQueueMessage from bec_lib.metadata_schema import BasicScanMetadata from bec_lib.scans import Scans @@ -99,23 +99,32 @@ def test_default_schema_is_used_as_fallback(): def test_prepare_scan_request_produces_conforming_message(): + available_scan = AvailableScan( + class_name="test", + base_class="", + arg_input={}, + gui_config={}, + required_kwargs=[], + arg_bundle_size={}, + signature=[], + ) with patch.dict(metadata_schema._METADATA_SCHEMA_REGISTRY, TEST_REGISTRY, clear=True): with pytest.raises(ValidationError): Scans.prepare_scan_request( scan_name="fake_scan_with_extra_metadata", - scan_info={"required_kwargs": []}, + scan_info=available_scan, system_config={}, ) with pytest.raises(ValidationError): Scans.prepare_scan_request( scan_name="fake_scan_with_extra_metadata", - scan_info={"required_kwargs": []}, + scan_info=available_scan, system_config={}, user_metadata={"number_field": "string"}, ) msg = Scans.prepare_scan_request( scan_name="fake_scan_with_extra_metadata", - scan_info={"required_kwargs": []}, + scan_info=available_scan, system_config={}, user_metadata={"number_field": 123}, ) diff --git a/bec_lib/tests/test_scan_object.py b/bec_lib/tests/test_scan_object.py index e51d63da8..0f3898ab4 100644 --- a/bec_lib/tests/test_scan_object.py +++ b/bec_lib/tests/test_scan_object.py @@ -2,26 +2,43 @@ import pytest +from bec_lib.messages import AvailableScan from bec_lib.scans import ScanObject +def _mock_scan_info(info: dict): + defaults = { + "class_name": "test", + "base_class": "", + "arg_input": {}, + "gui_config": {}, + "required_kwargs": [], + "arg_bundle_size": {}, + "signature": [], + } + defaults.update(info) + return AvailableScan.model_validate(defaults) + + @pytest.fixture def scan_obj(bec_client_mock): - scan_info = { - "class": "FermatSpiralScan", - "arg_input": {"device": "device", "start": "float", "stop": "float"}, - "required_kwargs": ["step", "relative"], - "arg_bundle_size": {"bundle": 3, "min": 2, "max": 2}, - "doc": ( - "\n A scan following Fermat's spiral.\n\n Args:\n *args: pairs" - " of device / start position / end position / steps arguments\n relative:" - " Start from an absolute or relative position\n burst: number of acquisition" - " per point\n optim_trajectory: routine used for the trajectory" - " optimization, e.g. 'corridor'. Default: None\n\n Returns:\n\n " - " Examples:\n >>> scans.fermat_scan(dev.motor1, -5, 5, dev.motor2, -5, 5," - ' step=0.5, exp_time=0.1, relative=True, optim_trajectory="corridor")\n\n ' - ), - } + scan_info = _mock_scan_info( + { + "class_name": "FermatSpiralScan", + "arg_input": {"device": "device", "start": "float", "stop": "float"}, + "required_kwargs": ["step", "relative"], + "arg_bundle_size": {"bundle": 3, "min": 2, "max": 2}, + "doc": ( + "\n A scan following Fermat's spiral.\n\n Args:\n *args: pairs" + " of device / start position / end position / steps arguments\n relative:" + " Start from an absolute or relative position\n burst: number of acquisition" + " per point\n optim_trajectory: routine used for the trajectory" + " optimization, e.g. 'corridor'. Default: None\n\n Returns:\n\n " + " Examples:\n >>> scans.fermat_scan(dev.motor1, -5, 5, dev.motor2, -5, 5," + ' step=0.5, exp_time=0.1, relative=True, optim_trajectory="corridor")\n\n ' + ), + } + ) scan_name = "fermat_scan" obj = ScanObject(scan_name, scan_info, bec_client_mock) with mock.patch.object(bec_client_mock, "alarm_handler"): @@ -30,16 +47,17 @@ def scan_obj(bec_client_mock): @pytest.fixture def scan_obj_no_args(bec_client_mock): - scan_info = { - "class": "TimeScan", - "base_class": "ScanBase", - "arg_input": {}, - "gui_config": {"scan_class_name": "TimeScan", "arg_group": "", "kwarg_groups": ""}, - "required_kwargs": ["points", "interval"], - "arg_bundle_size": {"bundle": 0, "min": None, "max": None}, - "doc": '\n Trigger and readout devices at a fixed interval.\n Note that the interval time cannot be less than the exposure time.\n The effective "sleep" time between points is\n sleep_time = interval - exp_time\n\n Args:\n points: number of points\n interval: time interval between points\n exp_time: exposure time in s\n burst: number of acquisition per point\n\n Returns:\n ScanReport\n\n Examples:\n >>> scans.time_scan(points=10, interval=1.5, exp_time=0.1, relative=True)\n\n ', - "signature": "", - } + scan_info = _mock_scan_info( + { + "class_name": "TimeScan", + "base_class": "ScanBase", + "arg_input": {}, + "gui_config": {"scan_class_name": "TimeScan", "arg_group": "", "kwarg_groups": ""}, + "required_kwargs": ["points", "interval"], + "arg_bundle_size": {"bundle": 0, "min": None, "max": None}, + "doc": '\n Trigger and readout devices at a fixed interval.\n Note that the interval time cannot be less than the exposure time.\n The effective "sleep" time between points is\n sleep_time = interval - exp_time\n\n Args:\n points: number of points\n interval: time interval between points\n exp_time: exposure time in s\n burst: number of acquisition per point\n\n Returns:\n ScanReport\n\n Examples:\n >>> scans.time_scan(points=10, interval=1.5, exp_time=0.1, relative=True)\n\n ', + } + ) scan_name = "fermat_scan" obj = ScanObject(scan_name, scan_info, bec_client_mock) with mock.patch.object(bec_client_mock, "alarm_handler"): From fd01f74902c885135880ae7d854ff2c52474c9e2 Mon Sep 17 00:00:00 2001 From: David Perl Date: Tue, 3 Mar 2026 08:51:21 +0100 Subject: [PATCH 7/8] chore: tidy no-longer-used code --- .../tests/end-2-end/test_scans_lib_e2e.py | 134 ++++---------- bec_lib/bec_lib/atlas_models.py | 16 +- bec_lib/bec_lib/codecs.py | 24 --- bec_lib/bec_lib/device.py | 161 +++++++++++----- bec_lib/bec_lib/messages.py | 4 +- bec_lib/bec_lib/serialization.py | 3 +- bec_lib/bec_lib/serialization_registry.py | 2 - bec_lib/tests/test_devices.py | 174 +++++++++++++----- .../bec_server/scihub/atlas/config_handler.py | 4 +- .../tests_scan_server/test_scan_stubs.py | 35 +--- 10 files changed, 316 insertions(+), 241 deletions(-) diff --git a/bec_ipython_client/tests/end-2-end/test_scans_lib_e2e.py b/bec_ipython_client/tests/end-2-end/test_scans_lib_e2e.py index 4aa7d5dd2..fdcaf31e2 100644 --- a/bec_ipython_client/tests/end-2-end/test_scans_lib_e2e.py +++ b/bec_ipython_client/tests/end-2-end/test_scans_lib_e2e.py @@ -22,9 +22,7 @@ def test_grid_scan_lib(bec_client_lib): bec.metadata.update({"unit_test": "test_grid_scan_bec_client_lib"}) dev = bec.device_manager.devices scans.umv(dev.samx, 0, dev.samy, 0, relative=False) - status = scans.grid_scan( - dev.samx, -5, 5, 10, dev.samy, -5, 5, 10, exp_time=0.01, relative=True - ) + status = scans.grid_scan(dev.samx, -5, 5, 10, dev.samy, -5, 5, 10, exp_time=0.01, relative=True) status.wait(num_points=True, file_written=True) assert len(status.scan.live_data) == 100 assert status.scan.num_points == 100 @@ -36,9 +34,7 @@ def test_grid_scan_lib_cancel(bec_client_lib): scans = bec.scans bec.metadata.update({"unit_test": "test_grid_scan_bec_client_lib"}) dev = bec.device_manager.devices - status = scans.grid_scan( - dev.samx, -5, 5, 10, dev.samy, -5, 5, 10, exp_time=1, relative=False - ) + status = scans.grid_scan(dev.samx, -5, 5, 10, dev.samy, -5, 5, 10, exp_time=1, relative=False) time.sleep(0.5) status.cancel() @@ -56,14 +52,10 @@ def test_mv_scan_lib(bec_client_lib): current_pos_samx = dev.samx.read()["samx"]["value"] current_pos_samy = dev.samy.read()["samy"]["value"] assert np.isclose( - current_pos_samx, - 10, - atol=dev.samx._config["deviceConfig"].get("tolerance", 0.05), + current_pos_samx, 10, atol=dev.samx._config["deviceConfig"].get("tolerance", 0.05) ) assert np.isclose( - current_pos_samy, - 20, - atol=dev.samy._config["deviceConfig"].get("tolerance", 0.05), + current_pos_samy, 20, atol=dev.samy._config["deviceConfig"].get("tolerance", 0.05) ) @@ -113,9 +105,7 @@ def dummy_callback(data, metadata): reference_container["metadata"] = metadata reference_container["data"].append(data) - s = scans.line_scan( - dev.samx, 0, 1, steps=10, relative=False, async_callback=dummy_callback - ) + s = scans.line_scan(dev.samx, 0, 1, steps=10, relative=False, async_callback=dummy_callback) s.wait() while len(reference_container["data"]) < 10: time.sleep(0.1) @@ -138,9 +128,7 @@ def scan_status_update(msg): pos = yield dev.samx.position cb_executed.set() - bec_client_lib.connector.register( - MessageEndpoints.scan_status(), cb=scan_status_update - ) + bec_client_lib.connector.register(MessageEndpoints.scan_status(), cb=scan_status_update) s = scans.line_scan(dev.samx, 0, 1, steps=10, exp_time=0.2, relative=False) s.wait() cb_executed.wait() @@ -157,17 +145,10 @@ def test_config_updates(bec_client_lib): assert dev.rt_controller.limits == [-50, 50] dev.rt_controller.velocity.set(10).wait() - assert ( - dev.rt_controller.velocity.read(cached=True)["rt_controller_velocity"]["value"] - == 10 - ) + assert dev.rt_controller.velocity.read(cached=True)["rt_controller_velocity"]["value"] == 10 assert dev.rt_controller.velocity.read()["rt_controller_velocity"]["value"] == 10 - assert ( - dev.rt_controller.read_configuration()["rt_controller_velocity"]["value"] == 10 - ) - assert ( - dev.rt_controller.read_configuration()["rt_controller_velocity"]["value"] == 10 - ) + assert dev.rt_controller.read_configuration()["rt_controller_velocity"]["value"] == 10 + assert dev.rt_controller.read_configuration()["rt_controller_velocity"]["value"] == 10 dev.rt_controller.velocity.put(5) assert dev.rt_controller.velocity.get() == 5 @@ -198,13 +179,7 @@ def test_dap_fit(bec_client_lib): dev.bpm4i.sim.select_model("GaussianModel") params = dev.bpm4i.sim.params params.update( - { - "noise": "uniform", - "noise_multiplier": 10, - "center": 5, - "sigma": 1, - "amplitude": 200, - } + {"noise": "uniform", "noise_multiplier": 10, "center": 5, "sigma": 1, "amplitude": 200} ) dev.bpm4i.sim.params = params time.sleep(1) @@ -243,7 +218,7 @@ def test_dap_fit(bec_client_lib): "hexapod": { "deviceClass": "ophyd_devices.SynDeviceOPAAS", "deviceConfig": {}, - "deviceTags": ["user motors"], + "deviceTags": {"user motors"}, "readoutPriority": "baseline", "enabled": True, "readOnly": False, @@ -256,7 +231,7 @@ def test_dap_fit(bec_client_lib): "tolerance": 0.01, "update_frequency": 400, }, - "deviceTags": ["user motors"], + "deviceTags": {"user motors"}, "enabled": True, "readOnly": False, }, @@ -270,7 +245,7 @@ def test_dap_fit(bec_client_lib): "hexapod": { "deviceClass": "ophyd_devices.SynDeviceOPAAS", "deviceConfig": {}, - "deviceTags": ["user motors"], + "deviceTags": {"user motors"}, "readoutPriority": "baseline", "enabled": True, "readOnly": False, @@ -284,7 +259,7 @@ def test_dap_fit(bec_client_lib): "update_frequency": 400, }, "readoutPriority": "baseline", - "deviceTags": ["user motors"], + "deviceTags": {"user motors"}, "enabled": True, "readOnly": False, }, @@ -298,7 +273,7 @@ def test_dap_fit(bec_client_lib): "hexapod": { "deviceClass": "ophyd_devices.SynDeviceOPAAS", "deviceConfig": {}, - "deviceTags": ["user motors"], + "deviceTags": {"user motors"}, "readoutPriority": "baseline", "enabled": True, "readOnly": False, @@ -307,7 +282,7 @@ def test_dap_fit(bec_client_lib): "deviceClass": "ophyd_devices.utils.bec_utils.DeviceClassConnectionError", "deviceConfig": {}, "readoutPriority": "baseline", - "deviceTags": ["user motors"], + "deviceTags": {"user motors"}, "enabled": True, "readOnly": False, }, @@ -321,7 +296,7 @@ def test_dap_fit(bec_client_lib): "hexapod": { "deviceClass": "SynDeviceOPAAS", "deviceConfig": {}, - "deviceTags": ["user motors"], + "deviceTags": {"user motors"}, "readoutPriority": "baseline", "enabled": True, "readOnly": False, @@ -330,7 +305,7 @@ def test_dap_fit(bec_client_lib): "deviceClass": "ophyd_devices.utils.bec_utils.DeviceClassInitError", "deviceConfig": {}, "readoutPriority": "baseline", - "deviceTags": ["user motors"], + "deviceTags": {"user motors"}, "enabled": True, "readOnly": False, }, @@ -344,7 +319,7 @@ def test_dap_fit(bec_client_lib): "hexapod": { "deviceClass": "SynDeviceOPAAS", "deviceConfig": {}, - "deviceTags": ["user motors"], + "deviceTags": {"user motors"}, "readoutPriority": "baseline", "enabled": True, "readOnly": False, @@ -353,7 +328,7 @@ def test_dap_fit(bec_client_lib): "deviceClass": "ophyd_devices.WrongDeviceClass", "deviceConfig": {}, "readoutPriority": "baseline", - "deviceTags": ["user motors"], + "deviceTags": {"user motors"}, "enabled": True, "readOnly": False, }, @@ -372,18 +347,11 @@ def test_dap_fit(bec_client_lib): ], ) def test_config_reload( - bec_test_config_file_path, - bec_client_lib, - config, - raises_error, - deletes_config, - disabled_device, + bec_test_config_file_path, bec_client_lib, config, raises_error, deletes_config, disabled_device ): bec = bec_client_lib bec.metadata.update({"unit_test": "test_config_reload"}) - runtime_config_file_path = ( - bec_test_config_file_path.parent / "e2e_runtime_config_test.yaml" - ) + runtime_config_file_path = bec_test_config_file_path.parent / "e2e_runtime_config_test.yaml" # write new config to disk with open(runtime_config_file_path, "w") as f: @@ -401,9 +369,7 @@ def test_config_reload( else: assert len(bec.device_manager.devices) == num_devices else: - bec.config.update_session_with_file( - runtime_config_file_path, force=True, validate=False - ) + bec.config.update_session_with_file(runtime_config_file_path, force=True, validate=False) assert len(bec.device_manager.devices) == 2 for dev in disabled_device: assert bec.device_manager.devices[dev].enabled is False @@ -412,15 +378,13 @@ def test_config_reload( def test_config_reload_with_describe_failure(bec_test_config_file_path, bec_client_lib): bec = bec_client_lib bec.metadata.update({"unit_test": "test_config_reload"}) - runtime_config_file_path = ( - bec_test_config_file_path.parent / "e2e_runtime_config_test.yaml" - ) + runtime_config_file_path = bec_test_config_file_path.parent / "e2e_runtime_config_test.yaml" config = { "hexapod": { "deviceClass": "ophyd_devices.sim.sim_test_devices.SimPositionerWithDescribeFailure", "deviceConfig": {}, - "deviceTags": ["user motors"], + "deviceTags": {"user motors"}, "readoutPriority": "baseline", "enabled": True, "readOnly": False, @@ -434,7 +398,7 @@ def test_config_reload_with_describe_failure(bec_test_config_file_path, bec_clie "update_frequency": 400, }, "readoutPriority": "baseline", - "deviceTags": ["user motors"], + "deviceTags": {"user motors"}, "enabled": True, "readOnly": False, }, @@ -442,8 +406,7 @@ def test_config_reload_with_describe_failure(bec_test_config_file_path, bec_clie # set hexapod to fail bec.connector.set( - f"e2e_test_hexapod_fail", - messages.DeviceStatusMessage(device="hexapod", status=1), + f"e2e_test_hexapod_fail", messages.DeviceStatusMessage(device="hexapod", status=1) ) # write new config to disk @@ -451,9 +414,7 @@ def test_config_reload_with_describe_failure(bec_test_config_file_path, bec_clie f.write(yaml.dump(config)) with pytest.raises(DeviceConfigError): - bec.config.update_session_with_file( - runtime_config_file_path, force=True, validate=False - ) + bec.config.update_session_with_file(runtime_config_file_path, force=True, validate=False) assert len(bec.device_manager.devices) == 2 assert bec.device_manager.devices["eyefoc"].enabled is True @@ -461,8 +422,7 @@ def test_config_reload_with_describe_failure(bec_test_config_file_path, bec_clie # set hexapod to pass bec.connector.set( - f"e2e_test_hexapod_fail", - messages.DeviceStatusMessage(device="hexapod", status=0), + f"e2e_test_hexapod_fail", messages.DeviceStatusMessage(device="hexapod", status=0) ) bec.config.update_session_with_file(runtime_config_file_path, force=True) @@ -486,22 +446,18 @@ def test_config_add_remove_device(bec_client_lib): "update_frequency": 400, }, "readoutPriority": "baseline", - "deviceTags": ["user motors"], + "deviceTags": {"user motors"}, "enabled": True, "readOnly": False, } } bec.device_manager.config_helper.send_config_request(action="add", config=config) with pytest.raises(DeviceConfigError) as config_error: - bec.device_manager.config_helper.send_config_request( - action="add", config=config - ) + bec.device_manager.config_helper.send_config_request(action="add", config=config) assert config_error.match("Device new_device already exists") assert "new_device" in dev - bec.device_manager.config_helper.send_config_request( - action="remove", config={"new_device": {}} - ) + bec.device_manager.config_helper.send_config_request(action="remove", config={"new_device": {}}) assert "new_device" not in dev device_config_msg = bec.connector.get(MessageEndpoints.device_config()) @@ -512,9 +468,7 @@ def test_config_add_remove_device(bec_client_lib): config["new_device"]["deviceClass"] = "ophyd_devices.doesnt_exist" with pytest.raises(DeviceConfigError) as config_error: - bec.device_manager.config_helper.send_config_request( - action="add", config=config - ) + bec.device_manager.config_helper.send_config_request(action="add", config=config) assert config_error.match("module 'ophyd_devices' has no attribute 'doesnt_exist'") assert "new_device" not in dev assert "samx" in dev @@ -609,13 +563,13 @@ def test_image_analysis(bec_client_lib): dev.eiger.sim.select_model("gaussian") dev.eiger.sim.params = { "amplitude": 100, - "center_offset": np.array([0, 0]), - "covariance": np.array([[1, 0], [0, 1]]), + "center_offset": [0, 0], + "covariance": [[1, 0], [0, 1]], "noise": "uniform", "noise_multiplier": 10, - "hot_pixel_coords": np.array([[24, 24], [50, 20], [4, 40]]), + "hot_pixel_coords": [[24, 24], [50, 20], [4, 40]], "hot_pixel_types": ["fluctuating", "constant", "fluctuating"], - "hot_pixel_values": np.array([1000.0, 10000.0, 1000.0]), + "hot_pixel_values": [1000.0, 10000.0, 1000.0], } res = scans.line_scan(dev.samx, -5, 5, steps=10, relative=False, exp_time=0) @@ -625,9 +579,7 @@ def test_image_analysis(bec_client_lib): assert (fit_res[1]["stats"]["min"] == 0.0).all() assert (np.isclose(fit_res[1]["stats"]["mean"], 3.3, atol=0.5)).all() # Center of mass is not in the middle due to hot (fluctuating) pixels - assert ( - np.isclose(fit_res[1]["stats"]["center_of_mass"], [49.5, 40.8], atol=2) - ).all() + assert (np.isclose(fit_res[1]["stats"]["center_of_mass"], [49.5, 40.8], atol=2)).all() @pytest.mark.timeout(100) @@ -645,11 +597,7 @@ def test_bl_state(bec_client_lib): tolerance=1, ) samx_config = DeviceWithinLimitsStateConfig( - name="samx_within_limits", - device="samx", - low_limit=-10, - high_limit=10, - tolerance=1, + name="samx_within_limits", device="samx", low_limit=-10, high_limit=10, tolerance=1 ) bec.beamline_states.add(hexapod_config) @@ -684,9 +632,7 @@ def test_bl_state(bec_client_lib): bec.beamline_states.delete("hexapod_x_within_limits") assert not hasattr(bec.beamline_states, "hexapod_x_within_limits") - bec.beamline_states.samx_within_limits.update_parameters( - low_limit=-5, high_limit=25 - ) + bec.beamline_states.samx_within_limits.update_parameters(low_limit=-5, high_limit=25) bec.beamline_states.show_all() while bec.beamline_states.samx_within_limits.get()["status"] != "valid": diff --git a/bec_lib/bec_lib/atlas_models.py b/bec_lib/bec_lib/atlas_models.py index 484bfb216..70ac1c5e9 100644 --- a/bec_lib/bec_lib/atlas_models.py +++ b/bec_lib/bec_lib/atlas_models.py @@ -7,9 +7,17 @@ import hashlib import json from enum import Enum -from typing import AbstractSet, Any, Literal, TypeVar - -from pydantic import BaseModel, Field, PrivateAttr, create_model, field_validator, model_validator +from typing import AbstractSet, Annotated, Any, Literal, TypeVar + +from pydantic import ( + BaseModel, + Field, + PlainSerializer, + PrivateAttr, + create_model, + field_validator, + model_validator, +) from pydantic_core import PydanticUndefined from bec_lib.utils.json_extended import ExtendedEncoder @@ -42,7 +50,7 @@ class _DeviceModelCore(BaseModel): deviceConfig: dict | None = None connectionTimeout: float = 5.0 description: str = "" - deviceTags: set[str] = set() + deviceTags: Annotated[set[str], Field(default_factory=set), PlainSerializer(list)] needs: list[str] = [] onFailure: Literal["buffer", "retry", "raise"] = "retry" readOnly: bool = False diff --git a/bec_lib/bec_lib/codecs.py b/bec_lib/bec_lib/codecs.py index 0df6fb31a..2ce4fdee0 100644 --- a/bec_lib/bec_lib/codecs.py +++ b/bec_lib/bec_lib/codecs.py @@ -98,18 +98,6 @@ def decode(type_name: str, data: str) -> str: return data -class PydanticEncoder(BECCodec): - obj_type: Type = BaseModel - - @staticmethod - def encode(obj: BaseModel) -> dict: - return obj.model_dump() - - @staticmethod - def decode(type_name: str, data: dict) -> dict: - return data - - class EndpointInfoEncoder(BECCodec): obj_type: Type = EndpointInfo @@ -130,18 +118,6 @@ def decode(type_name: str, data: dict) -> EndpointInfo: ) -class SetEncoder(BECCodec): - obj_type: Type = set - - @staticmethod - def encode(obj: set) -> list: - return list(obj) - - @staticmethod - def decode(type_name: str, data: list) -> set: - return set(data) - - class BECTypeEncoder(BECCodec): obj_type: Type = type diff --git a/bec_lib/bec_lib/device.py b/bec_lib/bec_lib/device.py index c2f48d39d..a4b035250 100644 --- a/bec_lib/bec_lib/device.py +++ b/bec_lib/bec_lib/device.py @@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Any, Callable, Iterable import numpy as np -from pydantic import ConfigDict +from pydantic import ConfigDict, field_serializer from rich.console import Console from rich.table import Table from typeguard import typechecked @@ -41,7 +41,9 @@ logger = bec_logger.logger _MAX_RECURSION_DEPTH = 100 -rpc_method_context: ContextVar[Callable | None] = ContextVar("rpc_method_context", default=None) +rpc_method_context: ContextVar[Callable | None] = ContextVar( + "rpc_method_context", default=None +) class RPCError(AlarmBase): @@ -148,7 +150,9 @@ def __eq__(self, __value: object) -> bool: return False @staticmethod - def _on_status_update(msg: dict[str, messages.DeviceReqStatusMessage], parent: Status): + def _on_status_update( + msg: dict[str, messages.DeviceReqStatusMessage], parent: Status + ): # pylint: disable=protected-access parent._request_status = msg["data"] parent._set_done() @@ -156,7 +160,8 @@ def _on_status_update(msg: dict[str, messages.DeviceReqStatusMessage], parent: S def _set_done(self): self._status_done.set() self._connector.unregister( - MessageEndpoints.device_req_status(self._request_id), cb=self._on_status_update + MessageEndpoints.device_req_status(self._request_id), + cb=self._on_status_update, ) def wait(self, timeout=None, raise_on_failure=True): @@ -169,9 +174,10 @@ def wait(self, timeout=None, raise_on_failure=True): raise_on_failure (bool, optional): If True, an RPCError is raised if the request fails. Defaults to True. """ try: - if not self._status_done.wait(timeout): - raise TimeoutError("The request has not been completed within the specified time.") + raise TimeoutError( + "The request has not been completed within the specified time." + ) finally: self._set_done() @@ -185,11 +191,21 @@ def wait(self, timeout=None, raise_on_failure=True): class _PermissiveDeviceModel(_DeviceModelCore): model_config = ConfigDict(extra="allow") + @field_serializer("deviceTags") + def serialize_devicetags(self, value: set[str], info): + if info.mode == "json": + return list[value] + else: + return value -def set_device_config(device: "DeviceBase", config: dict | _PermissiveDeviceModel | None): - # device._config = config + +def set_device_config( + device: "DeviceBase", config: dict | _PermissiveDeviceModel | None +): device._config = ( # pylint: disable=protected-access - _PermissiveDeviceModel.model_validate(config).model_dump() if config is not None else None + _PermissiveDeviceModel.model_validate(config).model_dump(mode="python") + if config is not None + else None ) @@ -219,7 +235,9 @@ def __init__( class_name (str, optional): The class name of the device. Defaults to None. If None, the class name is inferred from the class of the object. """ self.name = name - self._class_name = class_name or object.__getattribute__(self, "__class__").__name__ + self._class_name = ( + class_name or object.__getattribute__(self, "__class__").__name__ + ) self._signal_info = signal_info set_device_config(self, config) if info is None: @@ -290,7 +308,9 @@ def __setattr__(self, name: str, value: Any) -> None: def _should_prevent_attribute_overwrite(self, name: str) -> bool: # pylint: disable=protected-access # allow override is defined on the device manager - if self.root.parent is None or getattr(self.root.parent, "_allow_override", True): + if self.root.parent is None or getattr( + self.root.parent, "_allow_override", True + ): return False if name.startswith("_"): return False @@ -392,7 +412,9 @@ def _get_rpc_response(self, request_id, rpc_id) -> Any: def _handle_client_info_msg(self): """Handle client messages during RPC calls""" - msgs = self.root.parent.connector.xread(MessageEndpoints.client_info(), block=200) + msgs = self.root.parent.connector.xread( + MessageEndpoints.client_info(), block=200 + ) # The client is the parent.parent of the device client: BECClient = self.root.parent.parent if client.live_updates_config.print_client_messages is False: @@ -402,7 +424,9 @@ def _handle_client_info_msg(self): for msg in msgs: print(QueueItem.format_client_msg(msg["data"])) - def _run_rpc_call(self, device, func_call, *args, wait_for_rpc_response=True, **kwargs) -> Any: + def _run_rpc_call( + self, device, func_call, *args, wait_for_rpc_response=True, **kwargs + ) -> Any: """ Runs an RPC call on the device. This method is used internally by the RPC decorator. If a call is interrupted by the user, the a stop signal is sent to this device. @@ -428,7 +452,9 @@ def _run_rpc_call(self, device, func_call, *args, wait_for_rpc_response=True, ** # prepare RPC message rpc_id = str(uuid.uuid4()) request_id = str(uuid.uuid4()) - msg = self._prepare_rpc_msg(rpc_id, request_id, device, func_call, *args, **kwargs) + msg = self._prepare_rpc_msg( + rpc_id, request_id, device, func_call, *args, **kwargs + ) # pylint: disable=protected-access if client.scans._scan_def_id: @@ -441,7 +467,9 @@ def _run_rpc_call(self, device, func_call, *args, wait_for_rpc_response=True, ** } # send RPC message - client.connector.send(MessageEndpoints.scan_queue_request(client.username), msg) + client.connector.send( + MessageEndpoints.scan_queue_request(client.username), msg + ) # wait for RPC response if not wait_for_rpc_response: @@ -472,7 +500,9 @@ def _validate_rpc_client(self) -> None: ) if client.alarm_handler is None: - raise RPCError("RPC calls require an alarm handler to be set in the BECClient.") + raise RPCError( + "RPC calls require an alarm handler to be set in the BECClient." + ) def _get_rpc_func_name(self, fcn=None, use_parent=False): func_call = [self._compile_function_path(use_parent=use_parent)] @@ -545,9 +575,15 @@ def _parse_info(self): base_class = dev["device_info"].get("device_base_class") attr_name = dev["device_info"].get("device_attr_name") if base_class == "positioner": - setattr(self, attr_name, Positioner(name=attr_name, info=dev, parent=self)) + setattr( + self, + attr_name, + Positioner(name=attr_name, info=dev, parent=self), + ) elif base_class == "device": - setattr(self, attr_name, Device(name=attr_name, info=dev, parent=self)) + setattr( + self, attr_name, Device(name=attr_name, info=dev, parent=self) + ) for user_access_name, descr in self._info.get("custom_user_access", {}).items(): # avoid circular imports as the signature serializer imports the DeviceBase class @@ -559,8 +595,14 @@ def _parse_info(self): self._custom_rpc_methods[user_access_name] = DeviceBase( name=user_access_name, info=descr, parent=self ) - setattr(self, user_access_name, self._custom_rpc_methods[user_access_name].run) - setattr(getattr(self, user_access_name), "__doc__", descr.get("doc")) + setattr( + self, + user_access_name, + self._custom_rpc_methods[user_access_name].run, + ) + setattr( + getattr(self, user_access_name), "__doc__", descr.get("doc") + ) setattr( getattr(self, user_access_name), "__signature__", @@ -579,7 +621,9 @@ def _parse_info(self): parent=self, class_name=descr["device_class"], ) - setattr(self, user_access_name, self._custom_rpc_methods[user_access_name]) + setattr( + self, user_access_name, self._custom_rpc_methods[user_access_name] + ) def __eq__(self, other): if isinstance(other, DeviceBase): @@ -632,7 +676,9 @@ def _repr_pretty_(self, p: PrettyPrinter, cycle: bool): @staticmethod def _compile_device_table(obj: DeviceBase) -> Table: # Create main table - table = Table(title=f"{obj._class_name}: {obj.name}", show_header=False, box=None) + table = Table( + title=f"{obj._class_name}: {obj.name}", show_header=False, box=None + ) table.add_column("Property", style="cyan", no_wrap=True) table.add_column("Value", style="white") @@ -642,7 +688,9 @@ def _compile_device_table(obj: DeviceBase) -> Table: table.add_row("Read only", str(obj.read_only)) table.add_row("Software Trigger", str(obj.root.software_trigger)) table.add_row("Device class", str(obj._config.get("deviceClass", "N/A"))) - table.add_row("Readout Priority", str(obj._config.get("readoutPriority", "N/A"))) + table.add_row( + "Readout Priority", str(obj._config.get("readoutPriority", "N/A")) + ) if obj._config.get("deviceTags"): tags = ", ".join(obj._config.get("deviceTags", [])) @@ -667,7 +715,9 @@ def _compile_current_values(current_values: dict) -> Table: # Format value (handle numpy arrays) if isinstance(value, np.ndarray): with np.printoptions(precision=4, suppress=True, threshold=10): - value_str = f"{str(value)}, shape={value.shape}, dtype={value.dtype}" + value_str = ( + f"{str(value)}, shape={value.shape}, dtype={value.dtype}" + ) else: value_str = str(value) # Format timestamp @@ -696,8 +746,12 @@ def _compile_config_section(device_config: dict) -> Table: return config_table @staticmethod - def _compile_rich_tables(obj: DeviceBase) -> tuple[Table, Table | None, Table | None]: - table = DeviceBase._compile_device_table(obj) # Add current values section if available + def _compile_rich_tables( + obj: DeviceBase, + ) -> tuple[Table, Table | None, Table | None]: + table = DeviceBase._compile_device_table( + obj + ) # Add current values section if available value_table = ( DeviceBase._compile_current_values(current_values) if (current_values := obj.read(cached=True)) @@ -706,7 +760,9 @@ def _compile_rich_tables(obj: DeviceBase) -> tuple[Table, Table | None, Table | # Get the updated device config. We use the cached version to avoid # excessive calls to Redis. device_config = ( - obj.parent.get_device_config_cached().get(obj.name, {}).get("deviceConfig", {}) + obj.parent.get_device_config_cached() + .get(obj.name, {}) + .get("deviceConfig", {}) ) # Filter down to only config signals config_signals = [ @@ -717,7 +773,9 @@ def _compile_rich_tables(obj: DeviceBase) -> tuple[Table, Table | None, Table | device_config = {k: v for k, v in device_config.items() if k in config_signals} # Add config signals section if available config_table = ( - DeviceBase._compile_config_section(device_config) if (device_config) else None + DeviceBase._compile_config_section(device_config) + if (device_config) + else None ) return table, value_table, config_table @@ -742,7 +800,6 @@ def _compile_rich_str(obj: DeviceBase) -> str | None: class DeviceBaseWithConfig(DeviceBase): - @property def full_name(self): """Returns the full name of the device or signal, separated by "_" e.g. samx_velocity""" @@ -777,10 +834,10 @@ def _update_config(self, update: dict) -> None: action="update", config={self.name: update} ) - def get_device_tags(self) -> list: + def get_device_tags(self) -> set[str]: """get the device tags for this device""" # pylint: disable=protected-access - return self.root._config.get("deviceTags", []) + return self.root._config.get("deviceTags", {}) @typechecked def set_device_tags(self, val: Iterable): @@ -788,7 +845,8 @@ def set_device_tags(self, val: Iterable): # pylint: disable=protected-access self.root._config["deviceTags"] = set(val) return self.root.parent.config_helper.send_config_request( - action="update", config={self.name: {"deviceTags": self.root._config["deviceTags"]}} + action="update", + config={self.name: {"deviceTags": self.root._config["deviceTags"]}}, ) @typechecked @@ -797,7 +855,8 @@ def add_device_tag(self, val: str): # pylint: disable=protected-access self.root._config["deviceTags"].add(val) return self.root.parent.config_helper.send_config_request( - action="update", config={self.name: {"deviceTags": self.root._config["deviceTags"]}} + action="update", + config={self.name: {"deviceTags": self.root._config["deviceTags"]}}, ) def remove_device_tag(self, val: str): @@ -805,7 +864,8 @@ def remove_device_tag(self, val: str): # pylint: disable=protected-access self.root._config["deviceTags"].remove(val) return self.root.parent.config_helper.send_config_request( - action="update", config={self.name: {"deviceTags": self.root._config["deviceTags"]}} + action="update", + config={self.name: {"deviceTags": self.root._config["deviceTags"]}}, ) @property @@ -844,7 +904,8 @@ def on_failure(self, val: OnFailure): # pylint: disable=protected-access self.root._config["onFailure"] = val return self.root.parent.config_helper.send_config_request( - action="update", config={self.name: {"onFailure": self.root._config["onFailure"]}} + action="update", + config={self.name: {"onFailure": self.root._config["onFailure"]}}, ) @property @@ -936,7 +997,9 @@ def read( MessageEndpoints.device_readback(self.root.name) ) else: - val = self.root.parent.connector.get(MessageEndpoints.device_read(self.root.name)) + val = self.root.parent.connector.get( + MessageEndpoints.device_read(self.root.name) + ) if not val: return None @@ -956,7 +1019,11 @@ def read_configuration(self, cached=False) -> dict[str, dict[str, Any]] | None: is_signal, is_config_signal, cached = self._get_rpc_signal_info(cached) if not cached: - fcn = self.read_configuration if (not is_signal or is_config_signal) else self.read + fcn = ( + self.read_configuration + if (not is_signal or is_config_signal) + else self.read + ) signals = self._run(cached=False, fcn=fcn) else: if is_signal and not is_config_signal: @@ -975,7 +1042,9 @@ def _filter_rpc_signals(self, signals: dict) -> dict: if self._signal_info: obj_name = self._signal_info.get("obj_name") return {obj_name: signals.get(obj_name, {})} - return {key: val for key, val in signals.items() if key.startswith(self.full_name)} + return { + key: val for key, val in signals.items() if key.startswith(self.full_name) + } def _get_rpc_signal_info(self, cached: bool): is_config_signal = False @@ -1160,7 +1229,9 @@ def limits(self): """ Returns the device limits. """ - limit_msg = self.root.parent.connector.get(MessageEndpoints.device_limits(self.root.name)) + limit_msg = self.root.parent.connector.get( + MessageEndpoints.device_limits(self.root.name) + ) if not limit_msg: return [0, 0] limits = [ @@ -1203,7 +1274,6 @@ class Signal(AdjustableMixin, OphydInterfaceBase): class ComputedSignal(Signal): - def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._num_args_method = None @@ -1230,7 +1300,9 @@ def calculate_readback(signal): # check if it is a bound method if hasattr(method, "__self__") and method.__self__ is not None: - raise ValueError("The compute method must be an unbound function, not a bound method.") + raise ValueError( + "The compute method must be an unbound function, not a bound method." + ) # check if it is a lambda function if method.__name__ == "": @@ -1239,7 +1311,9 @@ def calculate_readback(signal): method_code = inspect.getsource(method) self._num_args_method = len(inspect.signature(method).parameters) - self._update_config({"deviceConfig": {"compute_method": self._header + method_code}}) + self._update_config( + {"deviceConfig": {"compute_method": self._header + method_code}} + ) if self._num_signals is None: return if self._num_args_method != self._num_signals: @@ -1288,7 +1362,8 @@ def show_all(self): table.add_row("Compute Method", compute_method if compute_method else "Not set") table.add_row( - "Input Signals", ", ".join(input_signals) if input_signals else "No input signals set" + "Input Signals", + ", ".join(input_signals) if input_signals else "No input signals set", ) console.print(table) diff --git a/bec_lib/bec_lib/messages.py b/bec_lib/bec_lib/messages.py index 4f59ff304..fcb55f992 100644 --- a/bec_lib/bec_lib/messages.py +++ b/bec_lib/bec_lib/messages.py @@ -28,6 +28,7 @@ ) from typing_extensions import TypeAliasType +from bec_lib.atlas_models import Device from bec_lib.bec_serializable import BECSerializable from bec_lib.metadata_schema import get_metadata_schema_for_scan from bec_lib.one_way_registry import OneWaySerializationRegistry @@ -62,7 +63,7 @@ def _try_dump(v): try: msgpack.dumps(v) except TypeError as e: - raise ValueError(str(v)) from e + raise ValueError("Non-JSONable/msgpackable data in " + str(v)) from e return v @@ -1374,6 +1375,7 @@ class AvailableResourceMessage(BECMessage): | SpecificMessageType | Annotated[list[SpecificMessageType], FailFast] | dict[str, SpecificMessageType] + | Annotated[list[Device], FailFast] ) diff --git a/bec_lib/bec_lib/serialization.py b/bec_lib/bec_lib/serialization.py index fe3ffce1a..cbd223c4e 100644 --- a/bec_lib/bec_lib/serialization.py +++ b/bec_lib/bec_lib/serialization.py @@ -10,6 +10,7 @@ from abc import abstractmethod import msgpack as msgpack_module +from pydantic import BaseModel from bec_lib import messages as messages_module from bec_lib.logger import bec_logger @@ -36,7 +37,7 @@ class BECMessagePack(SerializationRegistry): def dumps(self, obj): """Pack object `obj` and return packed bytes.""" - if isinstance(obj, BECMessage): + if isinstance(obj, (BECMessage, BaseModel)): obj = obj.model_dump(mode="python", fallback=self.encode) return msgpack_module.packb(obj, default=self.encode) diff --git a/bec_lib/bec_lib/serialization_registry.py b/bec_lib/bec_lib/serialization_registry.py index e2e6c8d81..1ed584003 100644 --- a/bec_lib/bec_lib/serialization_registry.py +++ b/bec_lib/bec_lib/serialization_registry.py @@ -20,9 +20,7 @@ def __init__(self): self._legacy_codecs = [] # can be removed in future versions, see issue #516 self.register_codec(bec_codecs.EndpointInfoEncoder) - self.register_codec(bec_codecs.SetEncoder) self.register_codec(bec_codecs.BECTypeEncoder) - self.register_codec(bec_codecs.PydanticEncoder) self.register_codec(bec_codecs.EnumEncoder) if self.use_json: diff --git a/bec_lib/tests/test_devices.py b/bec_lib/tests/test_devices.py index 9f99b33c8..2af8a0269 100644 --- a/bec_lib/tests/test_devices.py +++ b/bec_lib/tests/test_devices.py @@ -100,26 +100,47 @@ def test_read_use_read(dev: Any): def test_read_nested_device(dev: Any): with mock.patch.object(dev.dyn_signals.root.parent.connector, "get") as mock_get: data = { - "dyn_signals_messages_message1": {"value": 0, "timestamp": 1701105880.0716832}, - "dyn_signals_messages_message2": {"value": 0, "timestamp": 1701105880.071722}, - "dyn_signals_messages_message3": {"value": 0, "timestamp": 1701105880.071739}, - "dyn_signals_messages_message4": {"value": 0, "timestamp": 1701105880.071753}, - "dyn_signals_messages_message5": {"value": 0, "timestamp": 1701105880.071766}, + "dyn_signals_messages_message1": { + "value": 0, + "timestamp": 1701105880.0716832, + }, + "dyn_signals_messages_message2": { + "value": 0, + "timestamp": 1701105880.071722, + }, + "dyn_signals_messages_message3": { + "value": 0, + "timestamp": 1701105880.071739, + }, + "dyn_signals_messages_message4": { + "value": 0, + "timestamp": 1701105880.071753, + }, + "dyn_signals_messages_message5": { + "value": 0, + "timestamp": 1701105880.071766, + }, } mock_get.return_value = messages.DeviceMessage( signals=data, metadata={"scan_id": "scan_id", "scan_type": "scan_type"} ) res = dev.dyn_signals.messages.read(cached=True) - mock_get.assert_called_once_with(MessageEndpoints.device_readback("dyn_signals")) + mock_get.assert_called_once_with( + MessageEndpoints.device_readback("dyn_signals") + ) assert res == data @pytest.mark.parametrize( - "kind,cached", [("normal", True), ("hinted", True), ("config", False), ("omitted", False)] + "kind,cached", + [("normal", True), ("hinted", True), ("config", False), ("omitted", False)], ) def test_read_kind_hinted( dev: Any, - kind: Literal["normal"] | Literal["hinted"] | Literal["config"] | Literal["omitted"], + kind: Literal["normal"] + | Literal["hinted"] + | Literal["config"] + | Literal["omitted"], cached: bool, ): with ( @@ -168,7 +189,9 @@ def test_read_configuration_not_cached( mock.patch.object(dev.samx.readback, "_run") as mock_run, ): dev.samx.readback.read_configuration(cached=False) - mock_run.assert_called_once_with(cached=False, fcn=getattr(dev.samx.readback, method)) + mock_run.assert_called_once_with( + cached=False, fcn=getattr(dev.samx.readback, method) + ) @pytest.mark.parametrize( @@ -176,7 +199,10 @@ def test_read_configuration_not_cached( [(True, False, "read"), (False, True, "redis"), (False, False, "redis")], ) def test_read_configuration_cached( - dev: Any, is_signal: bool, is_config_signal: bool, method: Literal["read"] | Literal["redis"] + dev: Any, + is_signal: bool, + is_config_signal: bool, + method: Literal["read"] | Literal["redis"], ): with ( mock.patch.object( @@ -197,7 +223,9 @@ def test_read_configuration_cached( ) dev.samx.readback.read_configuration(cached=True) if method == "redis": - mock_get.assert_called_once_with(MessageEndpoints.device_read_configuration("samx")) + mock_get.assert_called_once_with( + MessageEndpoints.device_read_configuration("samx") + ) mock_read.assert_not_called() else: mock_read.assert_called_once_with(cached=True) @@ -228,11 +256,15 @@ def test_get_rpc_func_name_read(dev: Any): @pytest.mark.parametrize( - "kind,cached", [("normal", True), ("hinted", True), ("config", False), ("omitted", False)] + "kind,cached", + [("normal", True), ("hinted", True), ("config", False), ("omitted", False)], ) def test_get_rpc_func_name_readback_get( dev: Any, - kind: Literal["normal"] | Literal["hinted"] | Literal["config"] | Literal["omitted"], + kind: Literal["normal"] + | Literal["hinted"] + | Literal["config"] + | Literal["omitted"], cached: bool, ): with ( @@ -265,17 +297,26 @@ def test_get_rpc_func_name_nested(dev: Any): "_run_rpc_call", ) as mock_rpc: dev.rt_controller.dummy_controller._func_with_args(1, 2) - mock_rpc.assert_called_once_with("rt_controller", "dummy_controller._func_with_args", 1, 2) + mock_rpc.assert_called_once_with( + "rt_controller", "dummy_controller._func_with_args", 1, 2 + ) def test_handle_rpc_response(dev: Any): - msg = messages.DeviceRPCMessage(device="samx", return_val=1, out="done", success=True) + msg = messages.DeviceRPCMessage( + device="samx", return_val=1, out="done", success=True + ) assert dev.samx._handle_rpc_response(msg) == 1 -def test_handle_rpc_response_returns_status(dev: Any, bec_client_mock: ClientMock | BECClient): +def test_handle_rpc_response_returns_status( + dev: Any, bec_client_mock: ClientMock | BECClient +): msg = messages.DeviceRPCMessage( - device="samx", return_val={"type": "status", "RID": "request_id"}, out="done", success=True + device="samx", + return_val={"type": "status", "RID": "request_id"}, + out="done", + success=True, ) assert dev.samx._handle_rpc_response(msg) == Status( bec_client_mock.device_manager.connector, "request_id" @@ -283,7 +324,9 @@ def test_handle_rpc_response_returns_status(dev: Any, bec_client_mock: ClientMoc def test_rpc_status_raises_error(dev: Any): - msg = messages.DeviceReqStatusMessage(device="samx", success=False, request_id="request_id") + msg = messages.DeviceReqStatusMessage( + device="samx", success=False, request_id="request_id" + ) connector = mock.MagicMock() status = Status(connector, "request_id") status._on_status_update({"data": msg}, parent=status) @@ -298,7 +341,9 @@ def test_handle_rpc_response_raises(dev: Any): device="samx", return_val={"type": "status", "RID": "request_id"}, out=messages.ErrorInfo( - exception_type="RPCError", error_message="An error occurred", compact_error_message=None + exception_type="RPCError", + error_message="An error occurred", + compact_error_message=None, ), success=False, ) @@ -307,7 +352,9 @@ def test_handle_rpc_response_raises(dev: Any): def test_handle_rpc_response_returns_dict(dev: Any): - msg = messages.DeviceRPCMessage(device="samx", return_val={"a": "b"}, out="done", success=True) + msg = messages.DeviceRPCMessage( + device="samx", return_val={"a": "b"}, out="done", success=True + ) assert dev.samx._handle_rpc_response(msg) == {"a": "b"} @@ -334,7 +381,7 @@ def device_config(): "readoutPriority": "monitored", "deviceClass": "SimCamera", "deviceConfig": {"device_access": True, "labels": "eiger", "name": "eiger"}, - "deviceTags": ["detector"], + "deviceTags": {"detector"}, } @@ -350,7 +397,9 @@ def dev_w_config(): def _func(config: dict = {}): dm_base = DeviceManagerBase(mock.MagicMock()) dm_base.config_helper = mock.MagicMock(spec=ConfigHelper) - return DeviceBaseWithConfig(name="test", config=BASIC_CONFIG | config, parent=dm_base) + return DeviceBaseWithConfig( + name="test", config=BASIC_CONFIG | config, parent=dm_base + ) return _func @@ -376,7 +425,9 @@ def test_create_device_saves_config( ) -def test_device_enabled(device_obj: DeviceBaseWithConfig, device_config: dict[str, Any]): +def test_device_enabled( + device_obj: DeviceBaseWithConfig, device_config: dict[str, Any] +): assert device_obj.enabled == device_config["enabled"] device_config["enabled"] = False set_device_config(device_obj, device_config) @@ -384,7 +435,9 @@ def test_device_enabled(device_obj: DeviceBaseWithConfig, device_config: dict[st def test_device_enable(device_obj: DeviceBaseWithConfig): - with mock.patch.object(device_obj.parent.config_helper, "send_config_request") as config_req: + with mock.patch.object( + device_obj.parent.config_helper, "send_config_request" + ) as config_req: device_obj.enabled = True config_req.assert_called_once_with( action="update", config={device_obj.name: {"enabled": True}} @@ -392,7 +445,9 @@ def test_device_enable(device_obj: DeviceBaseWithConfig): def test_device_enable_set(device_obj: DeviceBaseWithConfig): - with mock.patch.object(device_obj.parent.config_helper, "send_config_request") as config_req: + with mock.patch.object( + device_obj.parent.config_helper, "send_config_request" + ) as config_req: device_obj.read_only = False config_req.assert_called_once_with( action="update", config={device_obj.name: {"readOnly": False}} @@ -408,7 +463,9 @@ def test_device_set_user_parameter( val: dict[str, int] | set[str], raised_error: None | TypeCheckError, ): - with mock.patch.object(device_obj.parent.config_helper, "send_config_request") as config_req: + with mock.patch.object( + device_obj.parent.config_helper, "send_config_request" + ) as config_req: if raised_error is None: device_obj.set_user_parameter(val) config_req.assert_called_once_with( @@ -436,7 +493,9 @@ def test_device_update_user_parameter( raised_error: None | TypeCheckError, ): device_obj._config["userParameter"] = user_param - with mock.patch.object(device_obj.parent.config_helper, "send_config_request") as config_req: + with mock.patch.object( + device_obj.parent.config_helper, "send_config_request" + ) as config_req: if raised_error is None: device_obj.update_user_parameter(val) config_req.assert_called_once_with( @@ -467,14 +526,14 @@ def test_status_wait(): @pytest.fixture def device_w_tags(dev_w_config: Callable[..., DeviceBaseWithConfig]): - yield dev_w_config({"deviceTags": ["tag1", "tag2"]}) + yield dev_w_config({"deviceTags": {"tag1", "tag2"}}) @pytest.mark.parametrize( ["method", "args", "result"], [ ("set_device_tags", {"tag3", "tag4"}, {"tag3", "tag4"}), - ("set_device_tags", ["tag3", "tag3", "tag3", "tag4"], {"tag3", "tag4"}), + ("set_device_tags", {"tag3", "tag3", "tag3", "tag4"}, {"tag3", "tag4"}), ("add_device_tag", "tag3", {"tag1", "tag2", "tag3"}), ("add_device_tag", "tag1", {"tag1", "tag2"}), ("remove_device_tag", "tag1", {"tag2"}), @@ -499,7 +558,9 @@ def test_device_wm(device_w_tags): ({"read_only": False}, "read_only", False), ], ) -def test_properties(dev_w_config: Callable[..., DeviceBaseWithConfig], config, attr, value): +def test_properties( + dev_w_config: Callable[..., DeviceBaseWithConfig], config, attr, value +): assert getattr(dev_w_config(config), attr) == value @@ -507,7 +568,9 @@ def test_properties(dev_w_config: Callable[..., DeviceBaseWithConfig], config, a ["config", "method", "value"], [({"deviceTags": ["tag1", "tag2"]}, "get_device_tags", {"tag1", "tag2"})], ) -def test_methods(dev_w_config: Callable[..., DeviceBaseWithConfig], config, method, value): +def test_methods( + dev_w_config: Callable[..., DeviceBaseWithConfig], config, method, value +): assert getattr(dev_w_config(config), method)() == value @@ -563,7 +626,9 @@ def dev_container(dm_with_override): def test_device_container_wm(dev_container, capsys): - with mock.patch.object(dev_container.test, "read", return_value={"test": {"value": 1}}) as read: + with mock.patch.object( + dev_container.test, "read", return_value={"test": {"value": 1}} + ) as read: dev_container.wm("test") dev_container.wm("tes*") captured = capsys.readouterr() @@ -597,7 +662,9 @@ def test_device_container_wm_with_setpoint_names(dev_container, reading): def test_device_has_describe_method( device_cls: Device | Signal | Positioner, dev_container, dm_with_override ): - dev_container["test"] = device_cls(name="test", config=BASIC_CONFIG, parent=dm_with_override) + dev_container["test"] = device_cls( + name="test", config=BASIC_CONFIG, parent=dm_with_override + ) assert hasattr(dev_container.test, "describe") with mock.patch.object(dev_container.test, "_run_rpc_call") as mock_rpc: dev_container.test.describe() @@ -771,9 +838,12 @@ def test_computed_signal_set_signals(dm_with_override): comp_signal = ComputedSignal(name="comp_signal", parent=dm_with_override) with mock.patch.object(comp_signal, "_update_config") as _update_config: comp_signal.set_input_signals( - Signal(name="a", parent=dm_with_override), Signal(name="b", parent=dm_with_override) + Signal(name="a", parent=dm_with_override), + Signal(name="b", parent=dm_with_override), + ) + _update_config.assert_called_once_with( + {"deviceConfig": {"input_signals": ["a", "b"]}} ) - _update_config.assert_called_once_with({"deviceConfig": {"input_signals": ["a", "b"]}}) def test_computed_signal_set_signals_raises_error(dm_with_override): @@ -821,7 +891,9 @@ def test_device_summary_signal_grouping(dev: Any): dev.samx.summary() num_rows = mock_add_row.call_count - assert num_rows == len(dev.samx._info["signals"]) + 3 # 3 extra rows for headers + assert ( + num_rows == len(dev.samx._info["signals"]) + 3 + ) # 3 extra rows for headers assert mock_add_row.call_args_list[0][0] == ( "readback", @@ -842,7 +914,11 @@ def test_device_summary_signal_grouping(dev: Any): "", "setpoint doc string", ) - devs = [row_call[0][0] for row_call in mock_add_row.call_args_list if row_call[0]] + devs = [ + row_call[0][0] + for row_call in mock_add_row.call_args_list + if row_call[0] + ] assert devs == [ "readback", "setpoint", @@ -949,7 +1025,9 @@ def text(self, value): self.text_output = value with mock.patch.object( - dev, "read", return_value={"eiger": {"value": 1, "timestamp": 1701105880.1711318}} + dev, + "read", + return_value={"eiger": {"value": 1, "timestamp": 1701105880.1711318}}, ): p = MockPrinter() dev._repr_pretty_(p, cycle=False) @@ -972,7 +1050,10 @@ def test_device_compile_rich_str_with_values(dm_with_devices): "read", return_value={ "eiger": {"value": 5.0, "timestamp": 1701105880.1711318}, - "eiger_array": {"value": np.array([1, 2, 3, 4, 5]), "timestamp": 1701105880.1711318}, + "eiger_array": { + "value": np.array([1, 2, 3, 4, 5]), + "timestamp": 1701105880.1711318, + }, }, ): result = dev._compile_rich_str(dev) @@ -995,7 +1076,9 @@ def test_device_compile_rich_str_with_config_signals(dev): } with mock.patch.object( - dev.samx, "read", return_value={"samx": {"value": 5.0, "timestamp": 1701105880.1711318}} + dev.samx, + "read", + return_value={"samx": {"value": 5.0, "timestamp": 1701105880.1711318}}, ): result = dev.samx._compile_rich_str(dev.samx) @@ -1011,7 +1094,9 @@ def test_rpc_call_without_client_raises(dm_with_devices): dev = dm_with_devices.devices.eiger dev.parent.parent = None # Remove reference to DeviceManagerBase - with pytest.raises(RPCError, match="RPC calls can only be made from a BECClient instance"): + with pytest.raises( + RPCError, match="RPC calls can only be made from a BECClient instance" + ): dev.read(cached=False) @@ -1028,8 +1113,11 @@ def isinstance_side_effect(obj, classinfo): return original_isinstance(obj, classinfo) with mock.patch.object(dev.samx.root.parent.parent, "alarm_handler", None): - with mock.patch("bec_lib.device.isinstance", side_effect=isinstance_side_effect): + with mock.patch( + "bec_lib.device.isinstance", side_effect=isinstance_side_effect + ): with pytest.raises( - RPCError, match="RPC calls require an alarm handler to be set in the BECClient" + RPCError, + match="RPC calls require an alarm handler to be set in the BECClient", ): dev.samx.read(cached=False) diff --git a/bec_server/bec_server/scihub/atlas/config_handler.py b/bec_server/bec_server/scihub/atlas/config_handler.py index 710ae6086..1bb22d356 100644 --- a/bec_server/bec_server/scihub/atlas/config_handler.py +++ b/bec_server/bec_server/scihub/atlas/config_handler.py @@ -495,7 +495,7 @@ def remove_devices_from_redis(self, dev_configs: dict): config.pop(index) self.set_config_in_redis(config) - def get_config_from_redis(self): + def get_config_from_redis(self) -> list[dict]: """ Get the config from redis @@ -503,7 +503,7 @@ def get_config_from_redis(self): list: List of device configs """ config = self.device_manager.connector.get(MessageEndpoints.device_config()) - return config.content["resource"] + return config.resource def set_config_in_redis(self, config): """ diff --git a/bec_server/tests/tests_scan_server/test_scan_stubs.py b/bec_server/tests/tests_scan_server/test_scan_stubs.py index 98ab65bfc..f1c532f98 100644 --- a/bec_server/tests/tests_scan_server/test_scan_stubs.py +++ b/bec_server/tests/tests_scan_server/test_scan_stubs.py @@ -49,11 +49,7 @@ def stubs(): device="rtx", action="kickoff", parameter={ - "configure": { - "num_pos": 5, - "positions": [1, 2, 3, 4, 5], - "exp_time": 2, - } + "configure": {"num_pos": 5, "positions": [1, 2, 3, 4, 5], "exp_time": 2} }, metadata={}, ), @@ -61,9 +57,7 @@ def stubs(): ], ) def test_kickoff(stubs, device, parameter, metadata, reference_msg): - msg = list( - stubs.kickoff(device=device, parameter=parameter, metadata=metadata, wait=False) - ) + msg = list(stubs.kickoff(device=device, parameter=parameter, metadata=metadata, wait=False)) reference_msg.metadata["device_instr_id"] = msg[0].metadata["device_instr_id"] assert msg[0] == reference_msg @@ -80,16 +74,12 @@ def test_kickoff(stubs, device, parameter, metadata, reference_msg): False, ), ( - messages.ProgressMessage( - value=10, max_value=100, done=False, metadata={"RID": "rid"} - ), + messages.ProgressMessage(value=10, max_value=100, done=False, metadata={"RID": "rid"}), 10, False, ), ( - messages.DeviceStatusMessage( - device="samx", status=0, metadata={"RID": "rid"} - ), + messages.DeviceStatusMessage(device="samx", status=0, metadata={"RID": "rid"}), None, True, ), @@ -106,9 +96,7 @@ def test_device_progress(stubs, msg, ret_value, raised_error): def test_send_rpc_and_wait(stubs, ScanStubStatusMock): - with mock.patch.object( - stubs, "_get_result_from_status", return_value="msg" - ) as get_rpc: + with mock.patch.object(stubs, "_get_result_from_status", return_value="msg") as get_rpc: original_rpc = stubs.send_rpc with mock.patch.object(stubs, "send_rpc") as mock_rpc: @@ -118,19 +106,12 @@ def mock_rpc_func(*args, **kwargs): mock_rpc.side_effect = mock_rpc_func - instructions = list( - stubs.send_rpc_and_wait("sim_profile", "readback_profile") - ) + instructions = list(stubs.send_rpc_and_wait("sim_profile", "readback_profile")) rpc_call_1 = instructions[0] - instructions = list( - stubs.send_rpc_and_wait("sim_profile", "readback_profile") - ) + instructions = list(stubs.send_rpc_and_wait("sim_profile", "readback_profile")) rpc_call_2 = instructions[0] assert rpc_call_1 != rpc_call_2 - assert ( - rpc_call_1.metadata["device_instr_id"] - != rpc_call_2.metadata["device_instr_id"] - ) + assert rpc_call_1.metadata["device_instr_id"] != rpc_call_2.metadata["device_instr_id"] def test_stage(stubs): From 82a22b83253d31ea81989e808afe35c39855eb4e Mon Sep 17 00:00:00 2001 From: David Perl Date: Mon, 13 Apr 2026 11:17:12 +0200 Subject: [PATCH 8/8] fix: issues from rebase --- bec_lib/bec_lib/config_helper.py | 8 +- bec_lib/bec_lib/device.py | 140 ++++---------- bec_lib/bec_lib/messages.py | 12 +- bec_lib/bec_lib/scans.py | 6 +- bec_lib/bec_lib/serialization.py | 6 +- bec_lib/tests/test_devices.py | 177 +++++------------- bec_lib/tests/test_serializer.py | 7 +- .../data_processing/lmfit1d_service.py | 1 + .../bec_server/device_server/device_server.py | 7 +- .../bec_server/scan_server/scan_manager.py | 29 ++- 10 files changed, 119 insertions(+), 274 deletions(-) diff --git a/bec_lib/bec_lib/config_helper.py b/bec_lib/bec_lib/config_helper.py index 3a883ef95..db7a44c74 100644 --- a/bec_lib/bec_lib/config_helper.py +++ b/bec_lib/bec_lib/config_helper.py @@ -29,7 +29,7 @@ from bec_lib.endpoints import MessageEndpoints from bec_lib.file_utils import DeviceConfigWriter from bec_lib.logger import bec_logger -from bec_lib.messages import ConfigAction +from bec_lib.messages import ConfigAction, sanitize_one_way_encodable from bec_lib.utils.import_utils import lazy_import_from from bec_lib.utils.json_extended import ExtendedEncoder @@ -617,7 +617,11 @@ def send_config_request( request_id = str(uuid.uuid4()) self._connector.send( MessageEndpoints.device_config_request(), - DeviceConfigMessage(action=action, config=config, metadata={"RID": request_id}), + DeviceConfigMessage( + action=action, + config=sanitize_one_way_encodable(config), + metadata={"RID": request_id}, + ), ) if wait_for_response: diff --git a/bec_lib/bec_lib/device.py b/bec_lib/bec_lib/device.py index a4b035250..3b84c2fdf 100644 --- a/bec_lib/bec_lib/device.py +++ b/bec_lib/bec_lib/device.py @@ -41,9 +41,7 @@ logger = bec_logger.logger _MAX_RECURSION_DEPTH = 100 -rpc_method_context: ContextVar[Callable | None] = ContextVar( - "rpc_method_context", default=None -) +rpc_method_context: ContextVar[Callable | None] = ContextVar("rpc_method_context", default=None) class RPCError(AlarmBase): @@ -150,9 +148,7 @@ def __eq__(self, __value: object) -> bool: return False @staticmethod - def _on_status_update( - msg: dict[str, messages.DeviceReqStatusMessage], parent: Status - ): + def _on_status_update(msg: dict[str, messages.DeviceReqStatusMessage], parent: Status): # pylint: disable=protected-access parent._request_status = msg["data"] parent._set_done() @@ -160,8 +156,7 @@ def _on_status_update( def _set_done(self): self._status_done.set() self._connector.unregister( - MessageEndpoints.device_req_status(self._request_id), - cb=self._on_status_update, + MessageEndpoints.device_req_status(self._request_id), cb=self._on_status_update ) def wait(self, timeout=None, raise_on_failure=True): @@ -175,9 +170,7 @@ def wait(self, timeout=None, raise_on_failure=True): """ try: if not self._status_done.wait(timeout): - raise TimeoutError( - "The request has not been completed within the specified time." - ) + raise TimeoutError("The request has not been completed within the specified time.") finally: self._set_done() @@ -199,9 +192,7 @@ def serialize_devicetags(self, value: set[str], info): return value -def set_device_config( - device: "DeviceBase", config: dict | _PermissiveDeviceModel | None -): +def set_device_config(device: "DeviceBase", config: dict | _PermissiveDeviceModel | None): device._config = ( # pylint: disable=protected-access _PermissiveDeviceModel.model_validate(config).model_dump(mode="python") if config is not None @@ -235,9 +226,7 @@ def __init__( class_name (str, optional): The class name of the device. Defaults to None. If None, the class name is inferred from the class of the object. """ self.name = name - self._class_name = ( - class_name or object.__getattribute__(self, "__class__").__name__ - ) + self._class_name = class_name or object.__getattribute__(self, "__class__").__name__ self._signal_info = signal_info set_device_config(self, config) if info is None: @@ -308,9 +297,7 @@ def __setattr__(self, name: str, value: Any) -> None: def _should_prevent_attribute_overwrite(self, name: str) -> bool: # pylint: disable=protected-access # allow override is defined on the device manager - if self.root.parent is None or getattr( - self.root.parent, "_allow_override", True - ): + if self.root.parent is None or getattr(self.root.parent, "_allow_override", True): return False if name.startswith("_"): return False @@ -412,9 +399,7 @@ def _get_rpc_response(self, request_id, rpc_id) -> Any: def _handle_client_info_msg(self): """Handle client messages during RPC calls""" - msgs = self.root.parent.connector.xread( - MessageEndpoints.client_info(), block=200 - ) + msgs = self.root.parent.connector.xread(MessageEndpoints.client_info(), block=200) # The client is the parent.parent of the device client: BECClient = self.root.parent.parent if client.live_updates_config.print_client_messages is False: @@ -424,9 +409,7 @@ def _handle_client_info_msg(self): for msg in msgs: print(QueueItem.format_client_msg(msg["data"])) - def _run_rpc_call( - self, device, func_call, *args, wait_for_rpc_response=True, **kwargs - ) -> Any: + def _run_rpc_call(self, device, func_call, *args, wait_for_rpc_response=True, **kwargs) -> Any: """ Runs an RPC call on the device. This method is used internally by the RPC decorator. If a call is interrupted by the user, the a stop signal is sent to this device. @@ -452,9 +435,7 @@ def _run_rpc_call( # prepare RPC message rpc_id = str(uuid.uuid4()) request_id = str(uuid.uuid4()) - msg = self._prepare_rpc_msg( - rpc_id, request_id, device, func_call, *args, **kwargs - ) + msg = self._prepare_rpc_msg(rpc_id, request_id, device, func_call, *args, **kwargs) # pylint: disable=protected-access if client.scans._scan_def_id: @@ -467,9 +448,7 @@ def _run_rpc_call( } # send RPC message - client.connector.send( - MessageEndpoints.scan_queue_request(client.username), msg - ) + client.connector.send(MessageEndpoints.scan_queue_request(client.username), msg) # wait for RPC response if not wait_for_rpc_response: @@ -500,9 +479,7 @@ def _validate_rpc_client(self) -> None: ) if client.alarm_handler is None: - raise RPCError( - "RPC calls require an alarm handler to be set in the BECClient." - ) + raise RPCError("RPC calls require an alarm handler to be set in the BECClient.") def _get_rpc_func_name(self, fcn=None, use_parent=False): func_call = [self._compile_function_path(use_parent=use_parent)] @@ -575,15 +552,9 @@ def _parse_info(self): base_class = dev["device_info"].get("device_base_class") attr_name = dev["device_info"].get("device_attr_name") if base_class == "positioner": - setattr( - self, - attr_name, - Positioner(name=attr_name, info=dev, parent=self), - ) + setattr(self, attr_name, Positioner(name=attr_name, info=dev, parent=self)) elif base_class == "device": - setattr( - self, attr_name, Device(name=attr_name, info=dev, parent=self) - ) + setattr(self, attr_name, Device(name=attr_name, info=dev, parent=self)) for user_access_name, descr in self._info.get("custom_user_access", {}).items(): # avoid circular imports as the signature serializer imports the DeviceBase class @@ -595,14 +566,8 @@ def _parse_info(self): self._custom_rpc_methods[user_access_name] = DeviceBase( name=user_access_name, info=descr, parent=self ) - setattr( - self, - user_access_name, - self._custom_rpc_methods[user_access_name].run, - ) - setattr( - getattr(self, user_access_name), "__doc__", descr.get("doc") - ) + setattr(self, user_access_name, self._custom_rpc_methods[user_access_name].run) + setattr(getattr(self, user_access_name), "__doc__", descr.get("doc")) setattr( getattr(self, user_access_name), "__signature__", @@ -621,9 +586,7 @@ def _parse_info(self): parent=self, class_name=descr["device_class"], ) - setattr( - self, user_access_name, self._custom_rpc_methods[user_access_name] - ) + setattr(self, user_access_name, self._custom_rpc_methods[user_access_name]) def __eq__(self, other): if isinstance(other, DeviceBase): @@ -676,9 +639,7 @@ def _repr_pretty_(self, p: PrettyPrinter, cycle: bool): @staticmethod def _compile_device_table(obj: DeviceBase) -> Table: # Create main table - table = Table( - title=f"{obj._class_name}: {obj.name}", show_header=False, box=None - ) + table = Table(title=f"{obj._class_name}: {obj.name}", show_header=False, box=None) table.add_column("Property", style="cyan", no_wrap=True) table.add_column("Value", style="white") @@ -688,9 +649,7 @@ def _compile_device_table(obj: DeviceBase) -> Table: table.add_row("Read only", str(obj.read_only)) table.add_row("Software Trigger", str(obj.root.software_trigger)) table.add_row("Device class", str(obj._config.get("deviceClass", "N/A"))) - table.add_row( - "Readout Priority", str(obj._config.get("readoutPriority", "N/A")) - ) + table.add_row("Readout Priority", str(obj._config.get("readoutPriority", "N/A"))) if obj._config.get("deviceTags"): tags = ", ".join(obj._config.get("deviceTags", [])) @@ -715,9 +674,7 @@ def _compile_current_values(current_values: dict) -> Table: # Format value (handle numpy arrays) if isinstance(value, np.ndarray): with np.printoptions(precision=4, suppress=True, threshold=10): - value_str = ( - f"{str(value)}, shape={value.shape}, dtype={value.dtype}" - ) + value_str = f"{str(value)}, shape={value.shape}, dtype={value.dtype}" else: value_str = str(value) # Format timestamp @@ -746,12 +703,8 @@ def _compile_config_section(device_config: dict) -> Table: return config_table @staticmethod - def _compile_rich_tables( - obj: DeviceBase, - ) -> tuple[Table, Table | None, Table | None]: - table = DeviceBase._compile_device_table( - obj - ) # Add current values section if available + def _compile_rich_tables(obj: DeviceBase) -> tuple[Table, Table | None, Table | None]: + table = DeviceBase._compile_device_table(obj) # Add current values section if available value_table = ( DeviceBase._compile_current_values(current_values) if (current_values := obj.read(cached=True)) @@ -760,9 +713,7 @@ def _compile_rich_tables( # Get the updated device config. We use the cached version to avoid # excessive calls to Redis. device_config = ( - obj.parent.get_device_config_cached() - .get(obj.name, {}) - .get("deviceConfig", {}) + obj.parent.get_device_config_cached().get(obj.name, {}).get("deviceConfig", {}) ) # Filter down to only config signals config_signals = [ @@ -773,9 +724,7 @@ def _compile_rich_tables( device_config = {k: v for k, v in device_config.items() if k in config_signals} # Add config signals section if available config_table = ( - DeviceBase._compile_config_section(device_config) - if (device_config) - else None + DeviceBase._compile_config_section(device_config) if (device_config) else None ) return table, value_table, config_table @@ -845,8 +794,7 @@ def set_device_tags(self, val: Iterable): # pylint: disable=protected-access self.root._config["deviceTags"] = set(val) return self.root.parent.config_helper.send_config_request( - action="update", - config={self.name: {"deviceTags": self.root._config["deviceTags"]}}, + action="update", config={self.name: {"deviceTags": self.root._config["deviceTags"]}} ) @typechecked @@ -855,8 +803,7 @@ def add_device_tag(self, val: str): # pylint: disable=protected-access self.root._config["deviceTags"].add(val) return self.root.parent.config_helper.send_config_request( - action="update", - config={self.name: {"deviceTags": self.root._config["deviceTags"]}}, + action="update", config={self.name: {"deviceTags": self.root._config["deviceTags"]}} ) def remove_device_tag(self, val: str): @@ -864,8 +811,7 @@ def remove_device_tag(self, val: str): # pylint: disable=protected-access self.root._config["deviceTags"].remove(val) return self.root.parent.config_helper.send_config_request( - action="update", - config={self.name: {"deviceTags": self.root._config["deviceTags"]}}, + action="update", config={self.name: {"deviceTags": self.root._config["deviceTags"]}} ) @property @@ -904,8 +850,7 @@ def on_failure(self, val: OnFailure): # pylint: disable=protected-access self.root._config["onFailure"] = val return self.root.parent.config_helper.send_config_request( - action="update", - config={self.name: {"onFailure": self.root._config["onFailure"]}}, + action="update", config={self.name: {"onFailure": self.root._config["onFailure"]}} ) @property @@ -997,9 +942,7 @@ def read( MessageEndpoints.device_readback(self.root.name) ) else: - val = self.root.parent.connector.get( - MessageEndpoints.device_read(self.root.name) - ) + val = self.root.parent.connector.get(MessageEndpoints.device_read(self.root.name)) if not val: return None @@ -1019,11 +962,7 @@ def read_configuration(self, cached=False) -> dict[str, dict[str, Any]] | None: is_signal, is_config_signal, cached = self._get_rpc_signal_info(cached) if not cached: - fcn = ( - self.read_configuration - if (not is_signal or is_config_signal) - else self.read - ) + fcn = self.read_configuration if (not is_signal or is_config_signal) else self.read signals = self._run(cached=False, fcn=fcn) else: if is_signal and not is_config_signal: @@ -1042,9 +981,7 @@ def _filter_rpc_signals(self, signals: dict) -> dict: if self._signal_info: obj_name = self._signal_info.get("obj_name") return {obj_name: signals.get(obj_name, {})} - return { - key: val for key, val in signals.items() if key.startswith(self.full_name) - } + return {key: val for key, val in signals.items() if key.startswith(self.full_name)} def _get_rpc_signal_info(self, cached: bool): is_config_signal = False @@ -1229,9 +1166,7 @@ def limits(self): """ Returns the device limits. """ - limit_msg = self.root.parent.connector.get( - MessageEndpoints.device_limits(self.root.name) - ) + limit_msg = self.root.parent.connector.get(MessageEndpoints.device_limits(self.root.name)) if not limit_msg: return [0, 0] limits = [ @@ -1300,9 +1235,7 @@ def calculate_readback(signal): # check if it is a bound method if hasattr(method, "__self__") and method.__self__ is not None: - raise ValueError( - "The compute method must be an unbound function, not a bound method." - ) + raise ValueError("The compute method must be an unbound function, not a bound method.") # check if it is a lambda function if method.__name__ == "": @@ -1311,9 +1244,7 @@ def calculate_readback(signal): method_code = inspect.getsource(method) self._num_args_method = len(inspect.signature(method).parameters) - self._update_config( - {"deviceConfig": {"compute_method": self._header + method_code}} - ) + self._update_config({"deviceConfig": {"compute_method": self._header + method_code}}) if self._num_signals is None: return if self._num_args_method != self._num_signals: @@ -1362,8 +1293,7 @@ def show_all(self): table.add_row("Compute Method", compute_method if compute_method else "Not set") table.add_row( - "Input Signals", - ", ".join(input_signals) if input_signals else "No input signals set", + "Input Signals", ", ".join(input_signals) if input_signals else "No input signals set" ) console.print(table) diff --git a/bec_lib/bec_lib/messages.py b/bec_lib/bec_lib/messages.py index fcb55f992..c495b5d00 100644 --- a/bec_lib/bec_lib/messages.py +++ b/bec_lib/bec_lib/messages.py @@ -43,7 +43,7 @@ def sanitize_one_way_encodable(data: Any) -> Any: """Sanitize any data which can be serialized in a json-compatible format and is not supposed to be decoded, - for example, a parameter dict containing devices""" + # for example, a parameter dict containing devices""" if isinstance(data, (list, tuple, set)): return [sanitize_one_way_encodable(x) for x in data] if isinstance(data, Mapping): @@ -63,7 +63,7 @@ def _try_dump(v): try: msgpack.dumps(v) except TypeError as e: - raise ValueError("Non-JSONable/msgpackable data in " + str(v)) from e + raise ValueError(f"Non-JSONable/msgpackable data in {str(v)}\n {e}") from e return v @@ -1279,6 +1279,12 @@ def from_dict(cls, metrics: dict[str, str | int | float | bool]): ) +DictPossibleNumpy = TypeAliasType( + "DictPossibleNumpy", + dict[str, list[int] | list[float] | int | bool | float | str | np.ndarray | None], +) + + class ProcessedDataMessage(BECMessage): """Message for processed data @@ -1288,7 +1294,7 @@ class ProcessedDataMessage(BECMessage): """ msg_type: ClassVar[str] = "processed_data_message" - data: JsonableDict | list[JsonableDict] + data: DictPossibleNumpy | list[DictPossibleNumpy] | JsonableDict | list[JsonableDict] class DAPConfigMessage(BECMessage): diff --git a/bec_lib/bec_lib/scans.py b/bec_lib/bec_lib/scans.py index 5980cad12..a29121788 100644 --- a/bec_lib/bec_lib/scans.py +++ b/bec_lib/bec_lib/scans.py @@ -202,9 +202,7 @@ def _import_scans(self): setattr( getattr(self, scan_name), "__signature__", - dict_to_signature( - self._strip_scan_signature_annotations(scan_info.get("signature")) - ), + dict_to_signature(self._strip_scan_signature_annotations(scan_info.signature)), ) @staticmethod @@ -341,7 +339,7 @@ def prepare_scan_request( if not isinstance(arg, Scans.get_arg_type(arg_input[ii % len(arg_input)])): raise TypeError( f"{scan_info.doc}\n Argument {ii} must be of type" - f" {arg_input[ii%len(arg_input)]}, not {type(arg).__name__}." + f" {arg_input[ii % len(arg_input)]}, not {type(arg).__name__}." ) metadata = {} diff --git a/bec_lib/bec_lib/serialization.py b/bec_lib/bec_lib/serialization.py index cbd223c4e..5262a4d8a 100644 --- a/bec_lib/bec_lib/serialization.py +++ b/bec_lib/bec_lib/serialization.py @@ -14,7 +14,7 @@ from bec_lib import messages as messages_module from bec_lib.logger import bec_logger -from bec_lib.messages import BECMessage +from bec_lib.messages import BECMessage, sanitize_one_way_encodable from bec_lib.serialization_registry import SerializationRegistry logger = bec_logger.logger @@ -39,6 +39,7 @@ def dumps(self, obj): """Pack object `obj` and return packed bytes.""" if isinstance(obj, (BECMessage, BaseModel)): obj = obj.model_dump(mode="python", fallback=self.encode) + obj = sanitize_one_way_encodable(obj) return msgpack_module.packb(obj, default=self.encode) def loads(self, raw_bytes): @@ -56,6 +57,9 @@ class BECJson(SerializationRegistry): def dumps(self, obj, indent: int | None = None) -> str: """Pack object `obj` and return packed bytes.""" + if isinstance(obj, (BECMessage, BaseModel)): + obj = obj.model_dump(mode="python", fallback=self.encode) + obj = sanitize_one_way_encodable(obj) return json.dumps(obj, default=self.encode, indent=indent) def loads(self, raw_bytes): diff --git a/bec_lib/tests/test_devices.py b/bec_lib/tests/test_devices.py index 2af8a0269..ec3e46fc9 100644 --- a/bec_lib/tests/test_devices.py +++ b/bec_lib/tests/test_devices.py @@ -100,47 +100,26 @@ def test_read_use_read(dev: Any): def test_read_nested_device(dev: Any): with mock.patch.object(dev.dyn_signals.root.parent.connector, "get") as mock_get: data = { - "dyn_signals_messages_message1": { - "value": 0, - "timestamp": 1701105880.0716832, - }, - "dyn_signals_messages_message2": { - "value": 0, - "timestamp": 1701105880.071722, - }, - "dyn_signals_messages_message3": { - "value": 0, - "timestamp": 1701105880.071739, - }, - "dyn_signals_messages_message4": { - "value": 0, - "timestamp": 1701105880.071753, - }, - "dyn_signals_messages_message5": { - "value": 0, - "timestamp": 1701105880.071766, - }, + "dyn_signals_messages_message1": {"value": 0, "timestamp": 1701105880.0716832}, + "dyn_signals_messages_message2": {"value": 0, "timestamp": 1701105880.071722}, + "dyn_signals_messages_message3": {"value": 0, "timestamp": 1701105880.071739}, + "dyn_signals_messages_message4": {"value": 0, "timestamp": 1701105880.071753}, + "dyn_signals_messages_message5": {"value": 0, "timestamp": 1701105880.071766}, } mock_get.return_value = messages.DeviceMessage( signals=data, metadata={"scan_id": "scan_id", "scan_type": "scan_type"} ) res = dev.dyn_signals.messages.read(cached=True) - mock_get.assert_called_once_with( - MessageEndpoints.device_readback("dyn_signals") - ) + mock_get.assert_called_once_with(MessageEndpoints.device_readback("dyn_signals")) assert res == data @pytest.mark.parametrize( - "kind,cached", - [("normal", True), ("hinted", True), ("config", False), ("omitted", False)], + "kind,cached", [("normal", True), ("hinted", True), ("config", False), ("omitted", False)] ) def test_read_kind_hinted( dev: Any, - kind: Literal["normal"] - | Literal["hinted"] - | Literal["config"] - | Literal["omitted"], + kind: Literal["normal"] | Literal["hinted"] | Literal["config"] | Literal["omitted"], cached: bool, ): with ( @@ -189,9 +168,7 @@ def test_read_configuration_not_cached( mock.patch.object(dev.samx.readback, "_run") as mock_run, ): dev.samx.readback.read_configuration(cached=False) - mock_run.assert_called_once_with( - cached=False, fcn=getattr(dev.samx.readback, method) - ) + mock_run.assert_called_once_with(cached=False, fcn=getattr(dev.samx.readback, method)) @pytest.mark.parametrize( @@ -199,10 +176,7 @@ def test_read_configuration_not_cached( [(True, False, "read"), (False, True, "redis"), (False, False, "redis")], ) def test_read_configuration_cached( - dev: Any, - is_signal: bool, - is_config_signal: bool, - method: Literal["read"] | Literal["redis"], + dev: Any, is_signal: bool, is_config_signal: bool, method: Literal["read"] | Literal["redis"] ): with ( mock.patch.object( @@ -223,9 +197,7 @@ def test_read_configuration_cached( ) dev.samx.readback.read_configuration(cached=True) if method == "redis": - mock_get.assert_called_once_with( - MessageEndpoints.device_read_configuration("samx") - ) + mock_get.assert_called_once_with(MessageEndpoints.device_read_configuration("samx")) mock_read.assert_not_called() else: mock_read.assert_called_once_with(cached=True) @@ -256,15 +228,11 @@ def test_get_rpc_func_name_read(dev: Any): @pytest.mark.parametrize( - "kind,cached", - [("normal", True), ("hinted", True), ("config", False), ("omitted", False)], + "kind,cached", [("normal", True), ("hinted", True), ("config", False), ("omitted", False)] ) def test_get_rpc_func_name_readback_get( dev: Any, - kind: Literal["normal"] - | Literal["hinted"] - | Literal["config"] - | Literal["omitted"], + kind: Literal["normal"] | Literal["hinted"] | Literal["config"] | Literal["omitted"], cached: bool, ): with ( @@ -297,26 +265,17 @@ def test_get_rpc_func_name_nested(dev: Any): "_run_rpc_call", ) as mock_rpc: dev.rt_controller.dummy_controller._func_with_args(1, 2) - mock_rpc.assert_called_once_with( - "rt_controller", "dummy_controller._func_with_args", 1, 2 - ) + mock_rpc.assert_called_once_with("rt_controller", "dummy_controller._func_with_args", 1, 2) def test_handle_rpc_response(dev: Any): - msg = messages.DeviceRPCMessage( - device="samx", return_val=1, out="done", success=True - ) + msg = messages.DeviceRPCMessage(device="samx", return_val=1, out="done", success=True) assert dev.samx._handle_rpc_response(msg) == 1 -def test_handle_rpc_response_returns_status( - dev: Any, bec_client_mock: ClientMock | BECClient -): +def test_handle_rpc_response_returns_status(dev: Any, bec_client_mock: ClientMock | BECClient): msg = messages.DeviceRPCMessage( - device="samx", - return_val={"type": "status", "RID": "request_id"}, - out="done", - success=True, + device="samx", return_val={"type": "status", "RID": "request_id"}, out="done", success=True ) assert dev.samx._handle_rpc_response(msg) == Status( bec_client_mock.device_manager.connector, "request_id" @@ -324,9 +283,7 @@ def test_handle_rpc_response_returns_status( def test_rpc_status_raises_error(dev: Any): - msg = messages.DeviceReqStatusMessage( - device="samx", success=False, request_id="request_id" - ) + msg = messages.DeviceReqStatusMessage(device="samx", success=False, request_id="request_id") connector = mock.MagicMock() status = Status(connector, "request_id") status._on_status_update({"data": msg}, parent=status) @@ -341,9 +298,7 @@ def test_handle_rpc_response_raises(dev: Any): device="samx", return_val={"type": "status", "RID": "request_id"}, out=messages.ErrorInfo( - exception_type="RPCError", - error_message="An error occurred", - compact_error_message=None, + exception_type="RPCError", error_message="An error occurred", compact_error_message=None ), success=False, ) @@ -352,9 +307,7 @@ def test_handle_rpc_response_raises(dev: Any): def test_handle_rpc_response_returns_dict(dev: Any): - msg = messages.DeviceRPCMessage( - device="samx", return_val={"a": "b"}, out="done", success=True - ) + msg = messages.DeviceRPCMessage(device="samx", return_val={"a": "b"}, out="done", success=True) assert dev.samx._handle_rpc_response(msg) == {"a": "b"} @@ -397,9 +350,7 @@ def dev_w_config(): def _func(config: dict = {}): dm_base = DeviceManagerBase(mock.MagicMock()) dm_base.config_helper = mock.MagicMock(spec=ConfigHelper) - return DeviceBaseWithConfig( - name="test", config=BASIC_CONFIG | config, parent=dm_base - ) + return DeviceBaseWithConfig(name="test", config=BASIC_CONFIG | config, parent=dm_base) return _func @@ -417,17 +368,12 @@ def device_obj(device_config: dict[str, Any]): def test_create_device_saves_config( device_obj: DeviceBaseWithConfig, device_config: dict[str, Any] ): - assert ( - messages.sanitize_one_way_encodable( - {k: v for k, v in device_obj._config.items() if k in device_config} - ) - == device_config - ) + assert messages.sanitize_one_way_encodable( + {k: v for k, v in device_obj._config.items() if k in device_config} + ) == messages.sanitize_one_way_encodable(device_config) -def test_device_enabled( - device_obj: DeviceBaseWithConfig, device_config: dict[str, Any] -): +def test_device_enabled(device_obj: DeviceBaseWithConfig, device_config: dict[str, Any]): assert device_obj.enabled == device_config["enabled"] device_config["enabled"] = False set_device_config(device_obj, device_config) @@ -435,9 +381,7 @@ def test_device_enabled( def test_device_enable(device_obj: DeviceBaseWithConfig): - with mock.patch.object( - device_obj.parent.config_helper, "send_config_request" - ) as config_req: + with mock.patch.object(device_obj.parent.config_helper, "send_config_request") as config_req: device_obj.enabled = True config_req.assert_called_once_with( action="update", config={device_obj.name: {"enabled": True}} @@ -445,9 +389,7 @@ def test_device_enable(device_obj: DeviceBaseWithConfig): def test_device_enable_set(device_obj: DeviceBaseWithConfig): - with mock.patch.object( - device_obj.parent.config_helper, "send_config_request" - ) as config_req: + with mock.patch.object(device_obj.parent.config_helper, "send_config_request") as config_req: device_obj.read_only = False config_req.assert_called_once_with( action="update", config={device_obj.name: {"readOnly": False}} @@ -463,9 +405,7 @@ def test_device_set_user_parameter( val: dict[str, int] | set[str], raised_error: None | TypeCheckError, ): - with mock.patch.object( - device_obj.parent.config_helper, "send_config_request" - ) as config_req: + with mock.patch.object(device_obj.parent.config_helper, "send_config_request") as config_req: if raised_error is None: device_obj.set_user_parameter(val) config_req.assert_called_once_with( @@ -493,9 +433,7 @@ def test_device_update_user_parameter( raised_error: None | TypeCheckError, ): device_obj._config["userParameter"] = user_param - with mock.patch.object( - device_obj.parent.config_helper, "send_config_request" - ) as config_req: + with mock.patch.object(device_obj.parent.config_helper, "send_config_request") as config_req: if raised_error is None: device_obj.update_user_parameter(val) config_req.assert_called_once_with( @@ -558,9 +496,7 @@ def test_device_wm(device_w_tags): ({"read_only": False}, "read_only", False), ], ) -def test_properties( - dev_w_config: Callable[..., DeviceBaseWithConfig], config, attr, value -): +def test_properties(dev_w_config: Callable[..., DeviceBaseWithConfig], config, attr, value): assert getattr(dev_w_config(config), attr) == value @@ -568,9 +504,7 @@ def test_properties( ["config", "method", "value"], [({"deviceTags": ["tag1", "tag2"]}, "get_device_tags", {"tag1", "tag2"})], ) -def test_methods( - dev_w_config: Callable[..., DeviceBaseWithConfig], config, method, value -): +def test_methods(dev_w_config: Callable[..., DeviceBaseWithConfig], config, method, value): assert getattr(dev_w_config(config), method)() == value @@ -626,9 +560,7 @@ def dev_container(dm_with_override): def test_device_container_wm(dev_container, capsys): - with mock.patch.object( - dev_container.test, "read", return_value={"test": {"value": 1}} - ) as read: + with mock.patch.object(dev_container.test, "read", return_value={"test": {"value": 1}}) as read: dev_container.wm("test") dev_container.wm("tes*") captured = capsys.readouterr() @@ -662,9 +594,7 @@ def test_device_container_wm_with_setpoint_names(dev_container, reading): def test_device_has_describe_method( device_cls: Device | Signal | Positioner, dev_container, dm_with_override ): - dev_container["test"] = device_cls( - name="test", config=BASIC_CONFIG, parent=dm_with_override - ) + dev_container["test"] = device_cls(name="test", config=BASIC_CONFIG, parent=dm_with_override) assert hasattr(dev_container.test, "describe") with mock.patch.object(dev_container.test, "_run_rpc_call") as mock_rpc: dev_container.test.describe() @@ -838,12 +768,9 @@ def test_computed_signal_set_signals(dm_with_override): comp_signal = ComputedSignal(name="comp_signal", parent=dm_with_override) with mock.patch.object(comp_signal, "_update_config") as _update_config: comp_signal.set_input_signals( - Signal(name="a", parent=dm_with_override), - Signal(name="b", parent=dm_with_override), - ) - _update_config.assert_called_once_with( - {"deviceConfig": {"input_signals": ["a", "b"]}} + Signal(name="a", parent=dm_with_override), Signal(name="b", parent=dm_with_override) ) + _update_config.assert_called_once_with({"deviceConfig": {"input_signals": ["a", "b"]}}) def test_computed_signal_set_signals_raises_error(dm_with_override): @@ -891,9 +818,7 @@ def test_device_summary_signal_grouping(dev: Any): dev.samx.summary() num_rows = mock_add_row.call_count - assert ( - num_rows == len(dev.samx._info["signals"]) + 3 - ) # 3 extra rows for headers + assert num_rows == len(dev.samx._info["signals"]) + 3 # 3 extra rows for headers assert mock_add_row.call_args_list[0][0] == ( "readback", @@ -914,11 +839,7 @@ def test_device_summary_signal_grouping(dev: Any): "", "setpoint doc string", ) - devs = [ - row_call[0][0] - for row_call in mock_add_row.call_args_list - if row_call[0] - ] + devs = [row_call[0][0] for row_call in mock_add_row.call_args_list if row_call[0]] assert devs == [ "readback", "setpoint", @@ -1025,9 +946,7 @@ def text(self, value): self.text_output = value with mock.patch.object( - dev, - "read", - return_value={"eiger": {"value": 1, "timestamp": 1701105880.1711318}}, + dev, "read", return_value={"eiger": {"value": 1, "timestamp": 1701105880.1711318}} ): p = MockPrinter() dev._repr_pretty_(p, cycle=False) @@ -1050,10 +969,7 @@ def test_device_compile_rich_str_with_values(dm_with_devices): "read", return_value={ "eiger": {"value": 5.0, "timestamp": 1701105880.1711318}, - "eiger_array": { - "value": np.array([1, 2, 3, 4, 5]), - "timestamp": 1701105880.1711318, - }, + "eiger_array": {"value": np.array([1, 2, 3, 4, 5]), "timestamp": 1701105880.1711318}, }, ): result = dev._compile_rich_str(dev) @@ -1076,9 +992,7 @@ def test_device_compile_rich_str_with_config_signals(dev): } with mock.patch.object( - dev.samx, - "read", - return_value={"samx": {"value": 5.0, "timestamp": 1701105880.1711318}}, + dev.samx, "read", return_value={"samx": {"value": 5.0, "timestamp": 1701105880.1711318}} ): result = dev.samx._compile_rich_str(dev.samx) @@ -1094,9 +1008,7 @@ def test_rpc_call_without_client_raises(dm_with_devices): dev = dm_with_devices.devices.eiger dev.parent.parent = None # Remove reference to DeviceManagerBase - with pytest.raises( - RPCError, match="RPC calls can only be made from a BECClient instance" - ): + with pytest.raises(RPCError, match="RPC calls can only be made from a BECClient instance"): dev.read(cached=False) @@ -1113,11 +1025,8 @@ def isinstance_side_effect(obj, classinfo): return original_isinstance(obj, classinfo) with mock.patch.object(dev.samx.root.parent.parent, "alarm_handler", None): - with mock.patch( - "bec_lib.device.isinstance", side_effect=isinstance_side_effect - ): + with mock.patch("bec_lib.device.isinstance", side_effect=isinstance_side_effect): with pytest.raises( - RPCError, - match="RPC calls require an alarm handler to be set in the BECClient", + RPCError, match="RPC calls require an alarm handler to be set in the BECClient" ): dev.samx.read(cached=False) diff --git a/bec_lib/tests/test_serializer.py b/bec_lib/tests/test_serializer.py index e1a254e7b..e73c98fde 100644 --- a/bec_lib/tests/test_serializer.py +++ b/bec_lib/tests/test_serializer.py @@ -11,10 +11,10 @@ from bec_lib.devicemanager import DeviceManagerBase from bec_lib.endpoints import MessageEndpoints from bec_lib.one_way_registry import OneWaySerializationRegistry -from bec_lib.serialization import MsgpackSerialization, json_ext, msgpack +from bec_lib.serialization import json_ext, msgpack -@pytest.fixture(params=[json_ext, msgpack, MsgpackSerialization]) +@pytest.fixture(params=[json_ext, msgpack]) def serializer(request): yield request.param @@ -67,12 +67,11 @@ class CustomEnum(enum.Enum): ], ) def test_serialize(serializer, data): - res = serializer.loads(serializer.dumps(data)) == data + res = serializer.loads(serializer.dumps(data)) == messages.sanitize_one_way_encodable(data) assert all(res) if isinstance(data, np.ndarray) else res def test_serialize_model(serializer): - class DummyModel(BaseModel): a: int b: int diff --git a/bec_server/bec_server/data_processing/lmfit1d_service.py b/bec_server/bec_server/data_processing/lmfit1d_service.py index d6b069d5b..20ff4d336 100644 --- a/bec_server/bec_server/data_processing/lmfit1d_service.py +++ b/bec_server/bec_server/data_processing/lmfit1d_service.py @@ -371,6 +371,7 @@ def _process_and_publish_current_scan(self) -> None: if not out: return stream_output, metadata = out + # TODO: refactor processed data message to allow numpy types in specific structure self.client.connector.xadd( MessageEndpoints.processed_data(self.model.__class__.__name__), msg_dict={"data": messages.ProcessedDataMessage(data=stream_output, metadata=metadata)}, diff --git a/bec_server/bec_server/device_server/device_server.py b/bec_server/bec_server/device_server/device_server.py index da8b82d11..69f54601c 100644 --- a/bec_server/bec_server/device_server/device_server.py +++ b/bec_server/bec_server/device_server/device_server.py @@ -247,7 +247,6 @@ def get_error_info(self, error: Exception, obj: StatusBase) -> messages.ErrorInf device_name = obj.obj.dotted_name or obj.obj.name else: device_name = None - msg = ( f"{error.__class__.__name__}: {error}\n" f"The status {obj.__class__.__name__} from device {device_name} failed during the execution " @@ -840,7 +839,7 @@ def _read_and_update_devices(self, devices: list[str], metadata: dict) -> list: ) pipe.execute() logger.trace( - f"Elapsed time for reading and updating status info: {(time.time()-start)*1000} ms" + f"Elapsed time for reading and updating status info: {(time.time() - start) * 1000} ms" ) return signal_container @@ -865,7 +864,7 @@ def _read_config_and_update_devices(self, devices: list[str], metadata: dict) -> ) pipe.execute() logger.trace( - f"Elapsed time for reading and updating status info: {(time.time()-start)*1000} ms" + f"Elapsed time for reading and updating status info: {(time.time() - start) * 1000} ms" ) return signal_container @@ -938,7 +937,7 @@ def _stage_device( break except ophyd_errors.WaitTimeoutError: logger.warning( - f"Unstaging device {dev} still running, {timeout_on_unstage*(ii+1)} seconds passed." + f"Unstaging device {dev} still running, {timeout_on_unstage * (ii + 1)} seconds passed." ) if status is not None: raise ValueError(f"Unstaging device {dev} failed to finish in 30 seconds") diff --git a/bec_server/bec_server/scan_server/scan_manager.py b/bec_server/bec_server/scan_server/scan_manager.py index 53ccf18b5..89ec8a08a 100644 --- a/bec_server/bec_server/scan_server/scan_manager.py +++ b/bec_server/bec_server/scan_server/scan_manager.py @@ -10,7 +10,6 @@ from bec_lib.logger import bec_logger from bec_lib.messages import AvailableResourceMessage from bec_lib.signature_serializer import serialize_dtype, signature_to_dict - from bec_server.scan_server.scan_gui_models import GUIConfig from . import scans as scans_module @@ -64,9 +63,7 @@ def update_available_scans(self): members: list[tuple[str, type]] = inspect.getmembers( scans_module, predicate=inspect.isclass ) - members.extend( - (name, cls) for name, cls in self._plugins.items() if inspect.isclass(cls) - ) + members.extend((name, cls) for name, cls in self._plugins.items() if inspect.isclass(cls)) for name, scan_cls in members: is_scan = issubclass(scan_cls, scans_module.RequestBase) @@ -96,19 +93,17 @@ def update_available_scans(self): elif hasattr(scan_cls, "gui_config"): # type: ignore gui_visibility = scan_cls.gui_config # type: ignore - self.available_scans[scan_cls.scan_name] = ( - messages.AvailableScan.model_validate( - { - "class_name": scan_cls.__name__, - "base_class": base_cls, - "arg_input": self.convert_arg_input(scan_cls.arg_input), - "gui_config": gui_config, - "required_kwargs": scan_cls.required_kwargs, - "arg_bundle_size": scan_cls.arg_bundle_size, - "doc": scan_cls.__doc__ or scan_cls.__init__.__doc__, - "signature": signature_to_dict(scan_cls.__init__), - } - ) + self.available_scans[scan_cls.scan_name] = messages.AvailableScan.model_validate( + { + "class_name": scan_cls.__name__, + "base_class": base_cls, + "arg_input": self.convert_arg_input(scan_cls.arg_input), + "gui_config": gui_config, + "required_kwargs": scan_cls.required_kwargs, + "arg_bundle_size": scan_cls.arg_bundle_size, + "doc": scan_cls.__doc__ or scan_cls.__init__.__doc__, + "signature": signature_to_dict(scan_cls.__init__), + } ) def validate_gui_config(self, scan_cls) -> dict: