diff --git a/pyproject.toml b/pyproject.toml index f3b54956..299fef11 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,8 +15,7 @@ dependencies = [ "tickit==0.2.3", "typing_extensions", "softioc", - "pydantic>1", - "apischema" + "pydantic>1" ] dynamic = ["version"] license.file = "LICENSE" diff --git a/src/tickit_devices/eiger/eiger_adapters.py b/src/tickit_devices/eiger/eiger_adapters.py index 6e7c085c..9504222a 100644 --- a/src/tickit_devices/eiger/eiger_adapters.py +++ b/src/tickit_devices/eiger/eiger_adapters.py @@ -1,14 +1,21 @@ import logging +from typing import Any, Dict, List, Union from aiohttp import web -from apischema import serialize +from pydantic.v1 import BaseModel from tickit.adapters.httpadapter import HttpAdapter from tickit.adapters.interpreters.endpoints.http_endpoint import HttpEndpoint from tickit.adapters.zeromq.push_adapter import ZeroMqPushAdapter from tickit_devices.eiger.eiger import EigerDevice -from tickit_devices.eiger.eiger_schema import SequenceComplete, Value, construct_value +from tickit_devices.eiger.eiger_schema import ( + AccessMode, + SequenceComplete, + Value, + construct_value, +) from tickit_devices.eiger.eiger_status import State +from tickit_devices.utils import serialize API_VERSION = "1.8.0" DETECTOR_API = f"detector/api/{API_VERSION}" @@ -41,9 +48,11 @@ async def get_config(self, request: web.Request) -> web.Response: data = construct_value(self.device.settings, param) else: - data = serialize(Value("None", "string", access_mode="None")) + data = serialize( + Value(value="None", value_type="string", access_mode=AccessMode.NONE) + ) - return web.json_response(data) + return web.json_response(data.dict()) @HttpEndpoint.put(f"/{DETECTOR_API}" + "/config/{parameter_name}") async def put_config(self, request: web.Request) -> web.Response: @@ -59,11 +68,11 @@ async def put_config(self, request: web.Request) -> web.Response: """ param = request.match_info["parameter_name"] - response = await request.json() + response = await request.dict() if self.device.get_state() is not State.IDLE: LOGGER.warning("Eiger not initialized or is currently running.") - return web.json_response(serialize([])) + return web.json_response([]) elif ( hasattr(self.device.settings, param) and self.device.get_state() is State.IDLE @@ -75,10 +84,10 @@ async def put_config(self, request: web.Request) -> web.Response: self.device.settings[param] = attr LOGGER.debug("Set " + str(param) + " to " + str(attr)) - return web.json_response(serialize([param])) + return web.json_response([param]) else: LOGGER.debug("Eiger has no config variable: " + str(param)) - return web.json_response(serialize([])) + return web.json_response([]) @HttpEndpoint.get(f"/{DETECTOR_API}" + "/status/{status_param}") async def get_status(self, request: web.Request) -> web.Response: @@ -97,9 +106,11 @@ async def get_status(self, request: web.Request) -> web.Response: data = construct_value(self.device.status, param) else: - data = serialize(Value("None", "string", access_mode="None")) + data = serialize( + Value(value="None", value_type="string", access_mode=AccessMode.NONE) + ) - return web.json_response(data) + return web.json_response(data.dict()) @HttpEndpoint.get(f"/{DETECTOR_API}" + "/status/board_000/{status_param}") async def get_board_000_status(self, request: web.Request) -> web.Response: @@ -141,7 +152,7 @@ async def initialize_eiger(self, request: web.Request) -> web.Response: await self.device.initialize() LOGGER.debug("Initializing Eiger...") - return web.json_response(serialize(SequenceComplete(1))) + return web.json_response(serialize(SequenceComplete.number(1))) @HttpEndpoint.put(f"/{DETECTOR_API}" + "/command/arm", interrupt=True) async def arm_eiger(self, request: web.Request) -> web.Response: @@ -157,7 +168,7 @@ async def arm_eiger(self, request: web.Request) -> web.Response: await self.device.arm() LOGGER.debug("Arming Eiger...") - return web.json_response(serialize(SequenceComplete(2))) + return web.json_response(serialize(SequenceComplete.number(2))) @HttpEndpoint.put(f"/{DETECTOR_API}" + "/command/disarm", interrupt=True) async def disarm_eiger(self, request: web.Request) -> web.Response: @@ -173,7 +184,7 @@ async def disarm_eiger(self, request: web.Request) -> web.Response: await self.device.disarm() LOGGER.debug("Disarming Eiger...") - return web.json_response(serialize(SequenceComplete(3))) + return web.json_response(serialize(SequenceComplete.number(3))) @HttpEndpoint.put(f"/{DETECTOR_API}" + "/command/trigger", interrupt=False) async def trigger_eiger(self, request: web.Request) -> web.Response: @@ -192,7 +203,7 @@ async def trigger_eiger(self, request: web.Request) -> web.Response: await self.raise_interrupt() await self.device.finished_aquisition.wait() - return web.json_response(serialize(SequenceComplete(4))) + return web.json_response(serialize(SequenceComplete.number(4))) @HttpEndpoint.put(f"/{DETECTOR_API}" + "/command/cancel", interrupt=True) async def cancel_eiger(self, request: web.Request) -> web.Response: @@ -208,7 +219,7 @@ async def cancel_eiger(self, request: web.Request) -> web.Response: await self.device.cancel() LOGGER.debug("Cancelling Eiger...") - return web.json_response(serialize(SequenceComplete(5))) + return web.json_response(serialize(SequenceComplete.number(5))) @HttpEndpoint.put(f"/{DETECTOR_API}" + "/command/abort", interrupt=True) async def abort_eiger(self, request: web.Request) -> web.Response: @@ -224,7 +235,7 @@ async def abort_eiger(self, request: web.Request) -> web.Response: await self.device.abort() LOGGER.debug("Aborting Eiger...") - return web.json_response(serialize(SequenceComplete(6))) + return web.json_response(serialize(SequenceComplete.number(6))) @HttpEndpoint.get(f"/{STREAM_API}" + "/status/{param}") async def get_stream_status(self, request: web.Request) -> web.Response: @@ -274,7 +285,7 @@ async def put_stream_config(self, request: web.Request) -> web.Response: """ param = request.match_info["param"] - response = await request.json() + response = await request.dict() if hasattr(self.device.stream.config, param): attr = response["value"] @@ -284,10 +295,10 @@ async def put_stream_config(self, request: web.Request) -> web.Response: self.device.stream.config[param] = attr LOGGER.debug("Set " + str(param) + " to " + str(attr)) - return web.json_response(serialize([param])) + return web.json_response([param]) else: LOGGER.debug("Eiger has no config variable: " + str(param)) - return web.json_response(serialize([])) + return web.json_response([]) @HttpEndpoint.get(f"/{MONITOR_API}" + "/config/{param}") async def get_monitor_config(self, request: web.Request) -> web.Response: @@ -320,7 +331,7 @@ async def put_monitor_config(self, request: web.Request) -> web.Response: """ param = request.match_info["param"] - response = await request.json() + response = await request.dict() if hasattr(self.device.monitor_config, param): attr = response["value"] @@ -330,10 +341,10 @@ async def put_monitor_config(self, request: web.Request) -> web.Response: self.device.monitor_config[param] = attr LOGGER.debug("Set " + str(param) + " to " + str(attr)) - return web.json_response(serialize([param])) + return web.json_response([param]) else: LOGGER.debug("Eiger has no config variable: " + str(param)) - return web.json_response(serialize([])) + return web.json_response([]) @HttpEndpoint.get(f"/{MONITOR_API}" + "/status/{param}") async def get_monitor_status(self, request: web.Request) -> web.Response: @@ -383,7 +394,7 @@ async def put_filewriter_config(self, request: web.Request) -> web.Response: """ param = request.match_info["param"] - response = await request.json() + response = await request.dict() if hasattr(self.device.filewriter_config, param): attr = response["value"] @@ -393,10 +404,10 @@ async def put_filewriter_config(self, request: web.Request) -> web.Response: self.device.filewriter_config[param] = attr LOGGER.debug("Set " + str(param) + " to " + str(attr)) - return web.json_response(serialize([param])) + return web.json_response([param]) else: LOGGER.debug("Eiger has no config variable: " + str(param)) - return web.json_response(serialize([])) + return web.json_response([]) @HttpEndpoint.get(f"/{FILEWRITER_API}" + "/status/{param}") async def get_filewriter_status(self, request: web.Request) -> web.Response: diff --git a/src/tickit_devices/eiger/eiger_schema.py b/src/tickit_devices/eiger/eiger_schema.py index 1e59833a..9346a68e 100644 --- a/src/tickit_devices/eiger/eiger_schema.py +++ b/src/tickit_devices/eiger/eiger_schema.py @@ -1,13 +1,11 @@ import logging -from dataclasses import dataclass, field from enum import Enum from functools import partial from typing import Any, Generic, List, Mapping, Optional, TypeVar -from apischema import serialized -from apischema.fields import with_fields_set -from apischema.metadata import skip -from apischema.serialization import serialize +from pydantic.v1 import BaseModel, Field + +from tickit_devices.utils import serialize T = TypeVar("T") @@ -29,15 +27,16 @@ def field_config(**kwargs) -> Mapping[str, Any]: return dict(**kwargs) -class AccessMode(Enum): +class AccessMode(str, Enum): """Possible access modes for field metadata.""" READ_ONLY: str = "r" WRITE_ONLY: str = "w" READ_WRITE: str = "rw" + NONE: str = "None" -class ValueType(Enum): +class ValueType(str, Enum): """Possible value types for field metadata.""" FLOAT: str = "float" @@ -105,53 +104,60 @@ class ValueType(Enum): ) -@with_fields_set -@dataclass -class Value(Generic[T]): +class Value(BaseModel, Generic[T]): """Schema for a value to be returned by the API. Most fields are optional.""" value: T value_type: str - access_mode: Optional[str] = None + access_mode: Optional[AccessMode] = None unit: Optional[str] = None min: Optional[T] = None max: Optional[T] = None allowed_values: Optional[List[str]] = None -def construct_value(obj, param): # noqa: D103 +def construct_value(obj, param) -> Value: # noqa: D103 value = obj[param]["value"] meta = obj[param]["metadata"] if "allowed_values" in meta: - data = serialize( - Value( - value, - meta["value_type"].value, - access_mode=meta["access_mode"].value, - allowed_values=meta["allowed_values"], - ) + data = Value( + value=value, + value_type=meta["value_type"].value, + access_mode=meta["access_mode"].value, + allowed_values=meta["allowed_values"], ) else: - data = serialize( - Value( - value, - meta["value_type"].value, - access_mode=meta["access_mode"].value, - ) + data = Value( + value=value, + value_type=meta["value_type"].value, + access_mode=meta["access_mode"].value, ) return data -@dataclass -class SequenceComplete: +class SequenceComplete(BaseModel): """Schema for confirmation returned by operations that do not return values.""" - _sequence_id: int = field(default=1, metadata=skip, init=True, repr=False) + sequence_id: int = Field(default=1, alias="sequence id") + + @classmethod + def number(cls, number: int) -> "SequenceComplete": + """Create a new completion document with the given ID. + + This function exists as a workaround for mypy ignoring aliases. + See https://github.com/pydantic/pydantic/discussions/2889 + + Args: + number: The sequence ID + + Returns: + SequenceComplete: Document describing a completed sequence of operations + """ + return SequenceComplete(sequence_id=number) # type: ignore - @serialized("sequence id") # type: ignore - @property - def sequence_id(self) -> int: # noqa: D102 - return self._sequence_id + class Config: + allow_population_by_field_name = True + fields = {"sequence_id": "sequence id"} diff --git a/src/tickit_devices/eiger/eiger_settings.py b/src/tickit_devices/eiger/eiger_settings.py index fec60af9..697a0114 100644 --- a/src/tickit_devices/eiger/eiger_settings.py +++ b/src/tickit_devices/eiger/eiger_settings.py @@ -112,6 +112,7 @@ class EigerSettings: trigger_mode: str = field( default="exts", metadata=rw_str(allowed_values=["exts", "ints", "exte", "inte"]) ) + trigger_start_delay: float = field(default=0.0, metadata=rw_float()) two_theta_increment: float = field(default=0.0, metadata=rw_float()) two_theta_start: float = field(default=0.0, metadata=rw_float()) wavelength: float = field(default=1.0, metadata=rw_float()) diff --git a/src/tickit_devices/eiger/eiger_status.py b/src/tickit_devices/eiger/eiger_status.py index 74a96ec4..932e0553 100644 --- a/src/tickit_devices/eiger/eiger_status.py +++ b/src/tickit_devices/eiger/eiger_status.py @@ -6,7 +6,7 @@ from .eiger_schema import ro_str_list, rw_datetime, rw_float, rw_state -class State(Enum): +class State(str, Enum): """Possible states of the Eiger detector.""" NA = "na" diff --git a/src/tickit_devices/eiger/stream/eiger_stream.py b/src/tickit_devices/eiger/stream/eiger_stream.py index acdc477d..ec166d02 100644 --- a/src/tickit_devices/eiger/stream/eiger_stream.py +++ b/src/tickit_devices/eiger/stream/eiger_stream.py @@ -24,6 +24,9 @@ _Message = Union[BaseModel, Mapping[str, Any], bytes] +_Sendable = Union[bytes, Frame, memoryview] +_Message = Union[_Sendable, str, Mapping[str, Any], BaseModel] + class EigerStream: """Simulation of an Eiger stream.""" @@ -34,6 +37,8 @@ class EigerStream: _message_buffer: Queue[_Message] + _message_buffer: Queue[_Sendable] + #: An empty typed mapping of input values Inputs: TypedDict = TypedDict("Inputs", {}) #: A typed mapping containing the 'value' output value diff --git a/src/tickit_devices/utils.py b/src/tickit_devices/utils.py new file mode 100644 index 00000000..075a0dee --- /dev/null +++ b/src/tickit_devices/utils.py @@ -0,0 +1,34 @@ +from typing import Any, List, Mapping, Union + +from pydantic.v1 import BaseModel + +_Serialized = Union[str, float, int, bool, List[Any], Mapping[str, Any]] +_Serializable = Union[_Serialized, BaseModel] + + +def serialize(document: _Serializable) -> _Serialized: + """Helper to serialize using pydantic base models + + Args: + document: A JSON-serializable document or base model + + Raises: + TypeError: If the document cannot be serialized + + Returns: + _Serialized: A JSON-serializable document + """ + + if ( + isinstance(document, str) + or isinstance(document, float) + or isinstance(document, int) + or isinstance(document, bool) + or isinstance(document, list) + or isinstance(document, dict) + ): + return document + elif isinstance(document, BaseModel): + return document.dict(by_alias=True) + else: + raise TypeError(f"Document {document} is of unrecognized type {type(document)}") diff --git a/tests/eiger/test_eiger.py b/tests/eiger/test_eiger.py index c4e3eea2..4924c93c 100644 --- a/tests/eiger/test_eiger.py +++ b/tests/eiger/test_eiger.py @@ -24,6 +24,10 @@ def test_starting_state_is_na(eiger: EigerDevice): assert_in_state(eiger, State.NA) +def test_starting_state_is_na(eiger: EigerDevice): + assert_in_state(eiger, State.NA) + + @pytest.mark.asyncio async def test_initialize(eiger: EigerDevice): await eiger.initialize() @@ -115,6 +119,9 @@ async def test_armed_eiger_starts_series(eiger: EigerDevice, mock_stream: Mock): await eiger.arm() mock_stream.begin_series.assert_called_once_with(eiger.settings, 1) + eiger.update(SimTime(0.0), {}) + eiger.update(SimTime(0.0), {}) + @pytest.mark.asyncio async def test_disarmed_eiger_starts_and_ends_series( diff --git a/tests/eiger/test_eiger_schema.py b/tests/eiger/test_eiger_schema.py new file mode 100644 index 00000000..6955074d --- /dev/null +++ b/tests/eiger/test_eiger_schema.py @@ -0,0 +1,25 @@ +import pytest + +from tickit_devices.eiger.eiger_schema import SequenceComplete +from tickit_devices.utils import serialize + + +@pytest.mark.parametrize("sequence_id", [1, 2]) +def test_sequence_complete_uses_alternative_constructor(sequence_id: int) -> None: + complete = SequenceComplete.number(sequence_id) + assert complete.sequence_id == sequence_id + + +@pytest.mark.parametrize("sequence_id", [1, 2]) +def test_sequence_complete_uses_space_in_field_name(sequence_id: int) -> None: + complete = SequenceComplete.number(sequence_id) + serialized = serialize(complete) + assert isinstance(serialized, dict) + assert serialized["sequence id"] == sequence_id + + +def test_sequence_complete_uses_alias_only() -> None: + complete = SequenceComplete.number(1) + serialized = serialize(complete) + assert isinstance(serialized, dict) + assert "sequence_id" not in serialized.keys() diff --git a/tests/eiger/test_eiger_system.py b/tests/eiger/test_eiger_system.py index 5233ad29..23fecbdb 100644 --- a/tests/eiger/test_eiger_system.py +++ b/tests/eiger/test_eiger_system.py @@ -1,5 +1,8 @@ +from datetime import datetime + import aiohttp import pytest +from pydantic.v1 import parse_obj_as DETECTOR_URL = "http://localhost:8081/detector/api/1.8.0/" FILE_WRITER_URL = "http://localhost:8081/filewriter/api/1.8.0/" @@ -47,7 +50,7 @@ async def get_status(status, expected): DETECTOR_URL + f"command/{key}", timeout=REQUEST_TIMEOUT, ) as response: - assert value == (await response.json()) + assert value == (await response.json()), key # Check status await get_status(status="doesnt_exist", expected="None") @@ -55,6 +58,14 @@ async def get_status(status, expected): await get_status(status="board_000/doesnt_exist", expected="None") await get_status(status="builder/dcu_buffer_free", expected=0.5) await get_status(status="builder/doesnt_exist", expected="None") + async with session.get( + DETECTOR_URL + "status/time", + timeout=REQUEST_TIMEOUT, + ) as response: + assert response.status == 200 + value = (await response.json())["value"] + eiger_time = parse_obj_as(datetime, value) + assert isinstance(eiger_time, datetime) # Test Eiger in IDLE state await get_status(status="state", expected="idle") @@ -88,6 +99,24 @@ async def get_status(status, expected): ) as response: assert (await response.json()) == ["element"] + async with session.get( + DETECTOR_URL + "config/frame_time", + timeout=REQUEST_TIMEOUT, + ) as response: + assert response.status == 200 + data = await response.json() + assert data["value"] == 0.12 + assert data["access_mode"] == "rw" + + async with session.put( + DETECTOR_URL + "config/frame_time", + headers=headers, + json={"value": 0.1}, + timeout=REQUEST_TIMEOUT, + ) as response: + assert response.status == 200 + assert (await response.json()) == ["frame_time"] + async with session.get( DETECTOR_URL + "config/photon_energy", timeout=REQUEST_TIMEOUT, @@ -116,6 +145,24 @@ async def get_status(status, expected): ) as response: assert [] == (await response.json()) + async with session.get( + DETECTOR_URL + "config/trigger_start_delay", + timeout=REQUEST_TIMEOUT, + ) as response: + assert response.status == 200 + data = await response.json() + assert data["value"] == 0.0 + assert data["access_mode"] == "rw" + + async with session.put( + DETECTOR_URL + "config/trigger_start_delay", + headers=headers, + json={"value": 0.1}, + timeout=REQUEST_TIMEOUT, + ) as response: + assert response.status == 200 + assert (await response.json()) == ["trigger_start_delay"] + # Test filewriter, monitor and stream endpoints async with session.get( FILE_WRITER_URL + "status/state",