diff --git a/bec_ipython_client/tests/client_tests/test_live_table.py b/bec_ipython_client/tests/client_tests/test_live_table.py index 891803831..53c8ccf56 100644 --- a/bec_ipython_client/tests/client_tests/test_live_table.py +++ b/bec_ipython_client/tests/client_tests/test_live_table.py @@ -303,8 +303,11 @@ def test_print_table_data_hinted_value_with_precision( @pytest.mark.parametrize( "value,expected", [ - (np.int32(1), "1.00"), - (np.float64(1.00000), "1.00"), + # Commented out cases are not supported in unstructured serialized data, because msgpack doesn't distinguish + # lists, tuples, or sets. To support this, ScanMessage must be refactored to support the type information directly + # except for numpy arrays, which are currently special-cased but will be removed in a future refactor. + # (np.int32(1), "1.00"), + # (np.float64(1.00000), "1.00"), (0, "0.00"), (1, "1.00"), (0.000, "0.00"), @@ -314,10 +317,10 @@ def test_print_table_data_hinted_value_with_precision( ("False", "False"), ("0", "0"), ("1", "1"), - ((0, 1), "(0, 1)"), + # ((0, 1), "(0, 1)"), ({"value": 0}, "{'value': 0}"), (np.array([0, 1]), "[0 1]"), - ({1, 2}, "{1, 2}"), + # ({1, 2}, "{1, 2}"), ], ) def test_print_table_data_variants(self, client_with_grid_scan, value, expected): diff --git a/bec_lib/bec_lib/bec_service.py b/bec_lib/bec_lib/bec_service.py index 9659effbb..b0d1e7fb7 100644 --- a/bec_lib/bec_lib/bec_service.py +++ b/bec_lib/bec_lib/bec_service.py @@ -259,6 +259,7 @@ def _update_existing_services(self) -> None: msgs = [ self.connector.get(MessageEndpoints.service_status(service)) for service in services ] + print(msgs) self._services_info = {msg.content["name"]: msg for msg in msgs if msg is not None} msgs = [self.connector.get(MessageEndpoints.metrics(service)) for service in services] self._services_metric = {msg.content["name"]: msg for msg in msgs if msg is not None} diff --git a/bec_lib/bec_lib/messages.py b/bec_lib/bec_lib/messages.py index e09751510..c1109a397 100644 --- a/bec_lib/bec_lib/messages.py +++ b/bec_lib/bec_lib/messages.py @@ -9,13 +9,92 @@ from enum import Enum, auto from importlib.metadata import PackageNotFoundError from importlib.metadata import version as importlib_version -from typing import Annotated, Any, ClassVar, Literal, Self, Union +from types import NoneType +from typing import Annotated, Any, ClassVar, Literal, Self from uuid import uuid4 +import msgpack import numpy as np -from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator, model_validator +from pydantic import ( + BaseModel, + BeforeValidator, + ConfigDict, + Field, + ValidationError, + WithJsonSchema, + field_validator, + model_validator, +) +from typing_extensions import TypeAliasType from bec_lib.metadata_schema import get_metadata_schema_for_scan +from bec_lib.one_way_registry import OneWaySerializationRegistry + +_one_way_registry = OneWaySerializationRegistry() + + +def _sanitize_one_way(data: Any) -> Any: + # TODO: Temporary fix for standardizing message structure, will be replaced + # by encoders in a future iteration + if isinstance(data, np.ndarray): + return data + if isinstance(data, np.bool_): + return bool(data) + if isinstance(data, (np.float16, np.float32, np.float64)): + return float(data) + if isinstance(data, (np.int16, np.int32, np.int64, np.uint16, np.uint32, np.uint64)): + return int(data) + if isinstance(data, (list, tuple, set)): + return [_sanitize_one_way(x) for x in data] + if isinstance(data, dict): + return {_sanitize_one_way(k): _sanitize_one_way(v) for k, v in data.items()} + return _one_way_registry.encode(data) + + +def _ignore_ndarray(data: Any) -> Any: + if isinstance(data, np.ndarray): + return [] + raise ValueError(f"Cannot serialize unknown type for {data}: {type(data)}") + + +def _test_packable(data: Any): + try: + msgpack.dumps(data, default=_ignore_ndarray) + except Exception as e: + raise ValueError(f"Non-JSONable/msgpackable data in {data}!") from e + + +def _validate_packable(data: Any) -> Any: + # Skip sanitization if the data is already valid + if isinstance(data, int | float | str | bool | NoneType): + return data + if isinstance(data, np.bool_): + return bool(data) + try: + _test_packable(data) + return data + # Recursively check if we should replace anything which is not supposed to be decoded to a custom + # type on the other end + except ValueError: + data = _sanitize_one_way(data) + _test_packable(data) + return data + + +Jsonable = TypeAliasType( + "Jsonable", + Annotated[ + int | float | str | bool | None | list["Jsonable"] | dict[str, "Jsonable"] | np.ndarray, + BeforeValidator(_validate_packable), + ], +) + +JsonableDict = TypeAliasType( + "JsonableDict", + Annotated[ + dict[str, Jsonable], BeforeValidator(_validate_packable), WithJsonSchema({"type": "object"}) + ], +) class ProcedureWorkerStatus(Enum): @@ -44,8 +123,9 @@ class BECMessage(BaseModel): """ + model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") msg_type: ClassVar[str] - metadata: dict = Field(default_factory=dict) + metadata: JsonableDict = Field(default_factory=dict) @field_validator("metadata") @classmethod @@ -142,7 +222,7 @@ class ScanQueueMessage(BECMessage): msg_type: ClassVar[str] = "scan_queue_message" scan_type: str - parameter: dict + parameter: JsonableDict queue: str = Field(default="primary") allow_restart: bool = Field( default=True, @@ -226,18 +306,18 @@ class ScanStatusMessage(BECMessage): scan_type: Literal["step", "fly"] | None = Field(default=None, description="Type of scan") dataset_number: int | None = None scan_report_devices: list[str] | None = None - user_metadata: dict | None = None + user_metadata: JsonableDict | None = None readout_priority: ( dict[Literal["monitored", "baseline", "async", "continuous", "on_request"], list[str]] | None ) = None scan_parameters: dict[ - Literal["exp_time", "frames_per_trigger", "settling_time", "readout_time"] | str, Any + Literal["exp_time", "frames_per_trigger", "settling_time", "readout_time"] | str, Jsonable ] = Field(default_factory=dict) - request_inputs: dict[Literal["arg_bundle", "inputs", "kwargs"], Any] = Field( + request_inputs: dict[Literal["arg_bundle", "inputs", "kwargs"], Jsonable] = Field( default_factory=dict ) - info: dict + info: JsonableDict timestamp: float = Field(default_factory=time.time) def __str__(self): @@ -303,7 +383,7 @@ class ScanQueueModificationMessage(BECMessage): "release_lock", "user_completed", ] - parameter: dict + parameter: JsonableDict queue: str = Field(default="primary") @@ -551,7 +631,7 @@ class DeviceInstructionMessage(BECMessage): "publish_data_as_read", "close_scan_group", ] - parameter: dict + parameter: JsonableDict class ErrorInfo(BaseModel): @@ -749,7 +829,7 @@ class DeviceInfoMessage(BECMessage): msg_type: ClassVar[str] = "device_info_message" device: str - info: dict + info: JsonableDict class DeviceMonitor2DMessage(BECMessage): @@ -769,8 +849,6 @@ class DeviceMonitor2DMessage(BECMessage): data: np.ndarray timestamp: float = Field(default_factory=time.time) - metadata: dict | None = Field(default_factory=dict) - # Needed for pydantic to accept numpy arrays model_config = ConfigDict(arbitrary_types_allowed=True) @@ -810,8 +888,6 @@ class DeviceMonitor1DMessage(BECMessage): data: np.ndarray timestamp: float = Field(default_factory=time.time) - metadata: dict | None = Field(default_factory=dict) - # Needed for pydantic to accept numpy arrays model_config = ConfigDict(arbitrary_types_allowed=True) @@ -869,7 +945,7 @@ class DeviceUserROIMessage(BECMessage): device: str signal: str roi_type: str = Field(description="Type of the ROI, e.g. 'rectangle', 'circle', 'polygon'") - roi: dict = Field( + roi: JsonableDict = Field( description="Dictionary containing the ROI information, e.g. {'x': 100, 'y': 200, 'width': 50, 'height': 50}" ) timestamp: float = Field(default_factory=time.time) @@ -889,7 +965,7 @@ class ScanMessage(BECMessage): msg_type: ClassVar[str] = "scan_message" point_id: int scan_id: str - data: dict + data: JsonableDict class ScanHistoryMessage(BECMessage): @@ -923,7 +999,7 @@ class ScanHistoryMessage(BECMessage): end_time: float scan_name: str num_points: int - request_inputs: dict | None = None + request_inputs: JsonableDict | None = None stored_data_info: dict[str, dict[str, _StoredDataInfo]] | None = None @@ -951,7 +1027,7 @@ class ScanBaselineMessage(BECMessage): msg_type: ClassVar[str] = "scan_baseline_message" scan_id: str - data: dict + data: JsonableDict ConfigAction = Literal["add", "set", "update", "reload", "remove", "reset", "cancel"] @@ -969,7 +1045,7 @@ class DeviceConfigMessage(BECMessage): msg_type: ClassVar[str] = "device_config_message" action: ConfigAction | None = Field(default=None, validate_default=True) - config: dict | None = Field(default=None) + config: JsonableDict | None = Field(default=None) @model_validator(mode="after") @classmethod @@ -1015,7 +1091,7 @@ class LogMessage(BECMessage): log_type: Literal[ "trace", "debug", "info", "success", "warning", "error", "critical", "console_log" ] - log_msg: dict | str + log_msg: JsonableDict | str class AlarmMessage(BECMessage): @@ -1137,8 +1213,8 @@ class FileContentMessage(BECMessage): msg_type: ClassVar[str] = "file_content_message" file_path: str - data: dict - scan_info: dict + data: JsonableDict + scan_info: JsonableDict class VariableMessage(BECMessage): @@ -1179,7 +1255,7 @@ class ServiceMetricMessage(BECMessage): msg_type: ClassVar[str] = "service_metric_message" name: str - metrics: dict + metrics: JsonableDict class _StrDynamicMetricValue(BaseModel): @@ -1241,7 +1317,7 @@ class ProcessedDataMessage(BECMessage): """ msg_type: ClassVar[str] = "processed_data_message" - data: dict | list[dict] + data: JsonableDict | list[JsonableDict] class DAPConfigMessage(BECMessage): @@ -1253,7 +1329,7 @@ class DAPConfigMessage(BECMessage): """ msg_type: ClassVar[str] = "dap_config_message" - config: dict + config: JsonableDict class DAPRequestMessage(BECMessage): @@ -1269,7 +1345,7 @@ class DAPRequestMessage(BECMessage): msg_type: ClassVar[str] = "dap_request_message" dap_cls: str dap_type: Literal["continuous", "on_demand"] - config: dict + config: JsonableDict class DAPResponseMessage(BECMessage): @@ -1299,7 +1375,7 @@ class AvailableResourceMessage(BECMessage): """ msg_type: ClassVar[str] = "available_resource_message" - resource: dict | list[dict] | BaseModel | list[BaseModel] + resource: JsonableDict | list[JsonableDict] | BECMessage | list[BECMessage] class ProgressMessage(BECMessage): @@ -1327,7 +1403,7 @@ class GUIConfigMessage(BECMessage): """ msg_type: ClassVar[str] = "gui_config_message" - config: dict + config: JsonableDict class GUIDataMessage(BECMessage): @@ -1339,7 +1415,7 @@ class GUIDataMessage(BECMessage): """ msg_type: ClassVar[str] = "gui_data_message" - data: dict + data: JsonableDict class GUIInstructionMessage(BECMessage): @@ -1352,7 +1428,7 @@ class GUIInstructionMessage(BECMessage): msg_type: ClassVar[str] = "gui_instruction_message" action: str - parameter: dict + parameter: JsonableDict class GUIAutoUpdateConfigMessage(BECMessage): @@ -1389,7 +1465,7 @@ class GUIRegistryStateMessage(BECMessage): "container_proxy", "skip_rpc_namespace", ], - str | bool | dict | None, + str | bool | JsonableDict | None, ], ] @@ -1403,7 +1479,7 @@ class ServiceResponseMessage(BECMessage): """ msg_type: ClassVar[str] = "service_response_message" - response: dict + response: JsonableDict class CredentialsMessage(BECMessage): @@ -1415,7 +1491,7 @@ class CredentialsMessage(BECMessage): """ msg_type: ClassVar[str] = "credentials_message" - credentials: dict + credentials: JsonableDict class RawMessage(BECMessage): @@ -1428,7 +1504,7 @@ class RawMessage(BECMessage): """ msg_type: ClassVar[str] = "raw_message" - data: Any + data: Jsonable model_config = ConfigDict(arbitrary_types_allowed=True) @@ -1695,8 +1771,7 @@ class MessagingConfig(BaseModel): AvailableMessagingServices = Annotated[ - Union[SignalServiceInfo, SciLogServiceInfo, TeamsServiceInfo], - Field(discriminator="service_type"), + SignalServiceInfo | SciLogServiceInfo | TeamsServiceInfo, Field(discriminator="service_type") ] @@ -1764,7 +1839,6 @@ class EndpointInfoMessage(BECMessage): msg_type: ClassVar[str] = "endpoint_info_message" endpoint: str - metadata: dict | None = Field(default_factory=dict) class ScriptExecutionInfoMessage(BECMessage): @@ -1797,8 +1871,6 @@ class MacroUpdateMessage(BECMessage): macro_name: str | None = None file_path: str | None = None - metadata: dict | None = Field(default_factory=dict) - @model_validator(mode="after") @classmethod def check_macro(cls, values): @@ -1873,13 +1945,13 @@ class MessagingServiceGiphyContent(BaseModel): giphy_url: str -MessagingServiceContent = Union[ - MessagingServiceTextContent, - MessagingServiceFileContent, - MessagingServiceTagsContent, - MessagingServiceStickerContent, - MessagingServiceGiphyContent, -] +MessagingServiceContent = ( + MessagingServiceTextContent + | MessagingServiceFileContent + | MessagingServiceTagsContent + | MessagingServiceStickerContent + | MessagingServiceGiphyContent +) class MessagingServiceMessage(BECMessage): @@ -1897,7 +1969,6 @@ class MessagingServiceMessage(BECMessage): service_name: Literal["signal", "teams", "scilog"] message: list[MessagingServiceContent] scope: str | list[str] | None = None - metadata: dict | None = Field(default_factory=dict) class FeedbackMessage(BECMessage): diff --git a/bec_lib/bec_lib/one_way_registry.py b/bec_lib/bec_lib/one_way_registry.py new file mode 100644 index 000000000..d1f92a90e --- /dev/null +++ b/bec_lib/bec_lib/one_way_registry.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from functools import lru_cache +from typing import Any, Callable, Type + +from bec_lib.device import DeviceBase + + +class OneWayBECCodec(ABC): + """Abstract base class for custom encoders""" + + obj_type: Type | list[Type] + + @staticmethod + @abstractmethod + def encode(obj: Any) -> Any: + """Encode an object into a serializable format.""" + + +class BECDeviceEncoder(OneWayBECCodec): + obj_type = DeviceBase + + @staticmethod + def encode(obj: DeviceBase) -> str: + if hasattr(obj, "_compile_function_path"): + # pylint: disable=protected-access + return obj._compile_function_path() + return obj.name + + +class OneWaySerializationRegistry: + """Registry for serialization codecs""" + + def __init__(self): + self._registry: dict[str, tuple[Type, Callable]] = {} + + self.register_codec(BECDeviceEncoder) + + def register_codec(self, codec: Type[OneWayBECCodec]): + """ + Register a codec for a specific BECCodec subclass. + This method allows for easy registration of custom encoders and decoders + for BECMessage and other types. + + Args: + codec: A subclass of BECCodec that implements encode and decode methods. + Raises: + ValueError: If a codec for the specified type is already registered. + """ + if isinstance(codec.obj_type, list): + for cls in codec.obj_type: + self.register(cls, codec.encode) + else: + self.register(codec.obj_type, codec.encode) + + def register(self, cls: Type, encoder: Callable): + """Register a codec for a specific type.""" + + if cls.__name__ in self._registry: + raise ValueError(f"Codec for {cls} already registered.") + self._registry[cls.__name__] = (cls, encoder) + self.get_codec.cache_clear() # Clear the cache when a new codec is registered + + @lru_cache(maxsize=2000) + def get_codec(self, cls: Type) -> tuple[Type, Callable] | None: + """Get the codec for a specific type.""" + codec = self._registry.get(cls.__name__) + if codec: + return codec + for _, (registered_cls, encoder) in self._registry.items(): + if issubclass(cls, registered_cls): + return registered_cls, encoder + return None + + def is_registered(self, cls: Type) -> bool: + """ + Check if a codec is registered for a specific type. + Args: + cls: The class type to check for a registered codec. + Returns: + bool: True if a codec is registered for the type, False otherwise. + """ + return self.get_codec(cls) is not None + + def encode(self, obj): + """Encode an object using the registered codec.""" + codec = self.get_codec(type(obj)) + if not codec: + return obj # No codec registered for this type + _, encoder = codec + try: + return encoder(obj) + except Exception as e: + raise ValueError( + f"Serialization failed: Failed to encode {obj.__class__.__name__} with codec {encoder}: {e}" + ) from e diff --git a/bec_lib/bec_lib/serialization_registry.py b/bec_lib/bec_lib/serialization_registry.py index e9aa923c8..7bcbfefcc 100644 --- a/bec_lib/bec_lib/serialization_registry.py +++ b/bec_lib/bec_lib/serialization_registry.py @@ -19,7 +19,6 @@ def __init__(self): self._legacy_codecs = [] # can be removed in future versions, see issue #516 self.register_codec(bec_codecs.BECMessageEncoder) - self.register_codec(bec_codecs.BECDeviceEncoder) self.register_codec(bec_codecs.EndpointInfoEncoder) self.register_codec(bec_codecs.SetEncoder) self.register_codec(bec_codecs.BECTypeEncoder) diff --git a/bec_lib/pyproject.toml b/bec_lib/pyproject.toml index 2e0887dc7..b4023608d 100644 --- a/bec_lib/pyproject.toml +++ b/bec_lib/pyproject.toml @@ -1,61 +1,43 @@ -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - [project] name = "bec_lib" version = "3.113.1" description = "BEC library" requires-python = ">=3.11" classifiers = [ - "Development Status :: 3 - Alpha", - "Programming Language :: Python :: 3", - "Topic :: Scientific/Engineering", + "Development Status :: 3 - Alpha", + "Programming Language :: Python :: 3", + "Topic :: Scientific/Engineering", ] dependencies = [ - "fastjsonschema~=2.19", - "fpdf2~=2.7, >=2.7.7", - "hiredis~=3.0", - "lmfit~=1.3", - "loguru~=0.7", - "louie~=2.0", - "msgpack~=1.0, >1.0.4", - "numpy>=1.24, <3.0", - "psutil~=5.9", - "pydantic~=2.8, <2.12.0", - "pylint~=3.0", - "pyyaml~=6.0", - "redis~=6.2,>=6.2.0", - "requests~=2.31", - "rich~=13.7", - "scipy~=1.12", - "tomli~=2.0, >=2.0.1", - "toolz~=0.12", - "typeguard ~= 4.1, >=4.1.5", - "prettytable~=3.9", - "h5py~=3.10", - "hdf5plugin >=4.3, < 6.0", - "python-dotenv~=1.0", - "python-slugify~=8.0", + "fastjsonschema~=2.19", + "fpdf2~=2.7, >=2.7.7", + "h5py~=3.10", + "hdf5plugin >=4.3, < 6.0", + "hiredis~=3.0", + "lmfit~=1.3", + "loguru~=0.7", + "louie~=2.0", + "msgpack~=1.0, >1.0.4", + "numpy>=1.24, <3.0", + "prettytable~=3.9", + "psutil~=5.9", + "pydantic~=2.8, <2.12.0", + "pylint~=3.0", + "python-dotenv~=1.0", + "python-slugify~=8.0", + "pyyaml~=6.0", + "redis~=6.2,>=6.2.0", + "requests~=2.31", + "rich~=13.7", + "scipy~=1.12", + "tomli~=2.0, >=2.0.1", + "toolz~=0.12", + "typeguard ~= 4.1, >=4.1.5", ] - -[project.optional-dependencies] -dev = [ - "black~=26.0", - "coverage~=7.0", - "fakeredis~=2.23, >=2.23.2", - "isort~=5.13, >=5.13.2", - "pandas~=2.0", - "pytest~=8.0", - "pytest-random-order~=1.1", - "pytest-timeout~=2.2", - "pytest-redis~=3.0", - "Jinja2~=3.1", - "copier~=9.7", - "typer~=0.15", -] -ci = ["bec-testing-plugin"] +[project.urls] +"Bug Tracker" = "https://github.com/bec-project/bec/issues" +Homepage = "https://github.com/bec-project/bec" [project.scripts] bec-channel-monitor = "bec_lib.channel_monitor:channel_monitor_launch" @@ -65,9 +47,30 @@ bec-plugin-manager = "bec_lib.utils.plugin_manager.main:main" [project.entry-points.pytest11] bec_lib_fixtures = "bec_lib.tests.fixtures" -[project.urls] -"Bug Tracker" = "https://github.com/bec-project/bec/issues" -Homepage = "https://github.com/bec-project/bec" +[project.optional-dependencies] +ci = ["bec-testing-plugin"] +dev = [ + "black~=26.0", + "coverage~=7.0", + "fakeredis~=2.23, >=2.23.2", + "isort~=5.13, >=5.13.2", + "pandas~=2.0", + "pytest~=8.0", + "pytest-random-order~=1.1", + "pytest-timeout~=2.2", + "pytest-redis~=3.0", + "Jinja2~=3.1", + "copier~=9.7", + "typer~=0.15", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.black] +line-length = 100 +skip-magic-trailing-comma = true [tool.hatch.build.targets.wheel] include = ["*"] @@ -75,13 +78,15 @@ include = ["*"] [tool.hatch.metadata] allow-direct-references = true -[tool.black] -line-length = 100 -skip-magic-trailing-comma = true - [tool.isort] profile = "black" line_length = 100 multi_line_output = 3 include_trailing_comma = true known_first_party = ["bec_lib", "bec_server", "bec_ipython_client"] + +[tool.ruff] +line-length = 100 + +[tool.ruff.format] +skip-magic-trailing-comma = true diff --git a/bec_lib/tests/test_bec_messages.py b/bec_lib/tests/test_bec_messages.py index ee700fcdc..3af5721b0 100644 --- a/bec_lib/tests/test_bec_messages.py +++ b/bec_lib/tests/test_bec_messages.py @@ -431,15 +431,13 @@ def test_DeviceInstructionMessage(): def test_DeviceMonitor2DMessage(): # Test 2D data - msg = messages.DeviceMonitor2DMessage( - device="eiger", data=np.random.rand(2, 100), metadata=None - ) + msg = messages.DeviceMonitor2DMessage(device="eiger", data=np.random.rand(2, 100)) res = MsgpackSerialization.dumps(msg) res_loaded = MsgpackSerialization.loads(res) assert res_loaded == msg assert res_loaded.metadata == {} # Test rgb image, i.e. image with 3 channels - msg = messages.DeviceMonitor2DMessage(device="eiger", data=np.random.rand(3, 3), metadata=None) + msg = messages.DeviceMonitor2DMessage(device="eiger", data=np.random.rand(3, 3)) res = MsgpackSerialization.dumps(msg) res_loaded = MsgpackSerialization.loads(res) assert res_loaded == msg @@ -456,7 +454,7 @@ def test_DeviceMonitor2DMessage(): def test_DeviceMonitor1DMessage(): # Test 2D data - msg = messages.DeviceMonitor1DMessage(device="eiger", data=np.random.rand(100), metadata=None) + msg = messages.DeviceMonitor1DMessage(device="eiger", data=np.random.rand(100)) res = MsgpackSerialization.dumps(msg) res_loaded = MsgpackSerialization.loads(res) assert res_loaded == msg @@ -704,3 +702,9 @@ def test_feedback_message(): assert res_loaded == msg assert res_loaded.username == getpass.getuser() assert res_loaded.versions == messages.ServiceVersions._get_version_numbers() + + +def test_message_with_np_array_in_dict(): + arr = np.zeros(5) + msg = messages.ScanMessage(point_id=0, scan_id="", data={"device": {"value": arr}}, metadata={}) + assert isinstance(msg.data["device"]["value"], np.ndarray) diff --git a/bec_lib/tests/test_config_helper.py b/bec_lib/tests/test_config_helper.py index 72a9431e9..8cf830815 100644 --- a/bec_lib/tests/test_config_helper.py +++ b/bec_lib/tests/test_config_helper.py @@ -419,7 +419,8 @@ def test_config_helper_get_config_conflicts( config.update(dev_cfg) config_in_redis.append(config) with mock.patch.object(config_helper._device_manager.connector, "get") as mock_get: - mock_get.return_value = messages.AvailableResourceMessage(resource=config_in_redis) + available_resource_message = messages.AvailableResourceMessage(resource=config_in_redis) + mock_get.return_value = available_resource_message conflicts = config_helper._get_config_conflicts(new_config) assert conflicts == expected_conflicts diff --git a/bec_lib/tests/test_serializer.py b/bec_lib/tests/test_serializer.py index 93f5990de..e1a254e7b 100644 --- a/bec_lib/tests/test_serializer.py +++ b/bec_lib/tests/test_serializer.py @@ -10,6 +10,7 @@ from bec_lib.device import DeviceBase from bec_lib.devicemanager import DeviceManagerBase from bec_lib.endpoints import MessageEndpoints +from bec_lib.one_way_registry import OneWaySerializationRegistry from bec_lib.serialization import MsgpackSerialization, json_ext, msgpack @@ -81,10 +82,11 @@ class DummyModel(BaseModel): assert data.model_dump() == converted_data -def test_device_serializer(serializer): +def test_device_serializer(): + serializer = OneWaySerializationRegistry() device_manager = mock.MagicMock(spec=DeviceManagerBase) dummy = DeviceBase(name="dummy", parent=device_manager) - assert serializer.loads(serializer.dumps(dummy)) == "dummy" + assert serializer.encode(dummy) == "dummy" def test_enum_serializer(serializer): diff --git a/bec_server/tests/tests_scan_server/test_scan_stubs.py b/bec_server/tests/tests_scan_server/test_scan_stubs.py index 4ab51e9bd..f1c532f98 100644 --- a/bec_server/tests/tests_scan_server/test_scan_stubs.py +++ b/bec_server/tests/tests_scan_server/test_scan_stubs.py @@ -131,6 +131,7 @@ def test_rpc_call_returns_status(stubs): fake_status = mock.MagicMock() fake_status._result_is_status = True fake_status.wait = mock.MagicMock() + fake_status._device_instr_id = "fake device instruction ID" with mock.patch.object(stubs, "_create_status", return_value=fake_status): result = stubs._rpc_call("samx", "velocity.set", 10) @@ -146,6 +147,7 @@ def test_rpc_call_returns_dict(stubs): fake_status._result_is_status = False fake_status.result = expected fake_status.wait = mock.MagicMock() + fake_status._device_instr_id = "fake device instruction ID" with mock.patch.object(stubs, "_create_status", return_value=fake_status): result = stubs._rpc_call("samx", "velocity.set", 10) diff --git a/bec_server/tests/tests_scan_server/test_scan_worker.py b/bec_server/tests/tests_scan_server/test_scan_worker.py index 46cf59370..6de30a457 100644 --- a/bec_server/tests/tests_scan_server/test_scan_worker.py +++ b/bec_server/tests/tests_scan_server/test_scan_worker.py @@ -293,7 +293,7 @@ def test_initialize_scan_info(scan_worker_mock, msg): assert worker.current_scan_info["scan_msgs"] == [] assert worker.current_scan_info["monitor_sync"] == "bec" assert worker.current_scan_info["frames_per_trigger"] == 1 - assert worker.current_scan_info["args"] == {"samx": (-5, 5, 5), "samy": (-1, 1, 2)} + assert worker.current_scan_info["args"] == {"samx": [-5, 5, 5], "samy": [-1, 1, 2]} assert worker.current_scan_info["kwargs"] == msg.parameter["kwargs"] assert "samx" in worker.current_scan_info["readout_priority"]["monitored"] assert "samy" in worker.current_scan_info["readout_priority"]["baseline"] diff --git a/bec_server/tests/tests_scan_server/test_scans.py b/bec_server/tests/tests_scan_server/test_scans.py index 50c293983..900103311 100644 --- a/bec_server/tests/tests_scan_server/test_scans.py +++ b/bec_server/tests/tests_scan_server/test_scans.py @@ -167,7 +167,7 @@ def offset_mock(): "RID": "0bab7ee3-b384-4571-b...0fff984c05", "devices": ["samx", "samy"], "start": [0, 0], - "end": np.array([1.0, 2.0]), + "end": [1.0, 2.0], } }, metadata={ @@ -211,7 +211,7 @@ def offset_mock(): "RID": "0bab7ee3-b384-4571-b...0fff984c05", "devices": ["samx", "samy", "samz"], "start": [0, 0, 0], - "end": np.array([1.0, 2.0, 3.0]), + "end": [1.0, 2.0, 3.0], } }, metadata={ @@ -264,7 +264,7 @@ def offset_mock(): "RID": "0bab7ee3-b384-4571-b...0fff984c05", "devices": ["samx"], "start": [0], - "end": np.array([1.0]), + "end": [1.0], } }, metadata={