diff --git a/bec_lib/bec_lib/connector.py b/bec_lib/bec_lib/connector.py index 107960f4e..9329809f8 100644 --- a/bec_lib/bec_lib/connector.py +++ b/bec_lib/bec_lib/connector.py @@ -139,7 +139,7 @@ def register(self, topics=None, patterns=None, cb=None, start_thread=True, **kwa """Register a callback for a topic or pattern""" @abc.abstractmethod - def unregister(self, topics=None, patterns=None, cb=None): + def unregister(self, topics=None, patterns=None, cb=None, **kwargs): """Unregister a callback for a topic or pattern""" @abc.abstractmethod diff --git a/bec_lib/bec_lib/redis_connector.py b/bec_lib/bec_lib/redis_connector.py index ae36db08c..f824505b8 100644 --- a/bec_lib/bec_lib/redis_connector.py +++ b/bec_lib/bec_lib/redis_connector.py @@ -1002,14 +1002,23 @@ def _register_stream( self._stream_events_listener_thread.name += f" ({self.name})" self._stream_events_listener_thread.start() - def _filter_topics_cb(self, topics: list, cb: Callable | None): + @staticmethod + def _matches_subscription( + item: tuple[louie.saferef.BoundMethodWeakref, dict], cb: Callable | None, kwargs: dict + ) -> bool: + cb_ref, item_kwargs = item + if cb is not None and cb_ref() != cb: + return False + return all(item_kwargs.get(key) == value for key, value in kwargs.items()) + + def _filter_topics_cb(self, topics: list, cb: Callable | None, kwargs: dict): unsubscribe_list = [] with self._topics_cb_lock: for topic in topics: topics_cb = self._topics_cb[topic] # remove callback from list self._topics_cb[topic] = list( - filter(lambda item: cb and item[0]() is not cb, topics_cb) + filter(lambda item: not self._matches_subscription(item, cb, kwargs), topics_cb) ) if not self._topics_cb[topic]: # no callbacks left, unsubscribe @@ -1019,7 +1028,7 @@ def _filter_topics_cb(self, topics: list, cb: Callable | None): del self._topics_cb[topic] return unsubscribe_list - def unregister(self, topics=None, patterns=None, cb=None): + def unregister(self, topics=None, patterns=None, cb=None, **kwargs): if self._events_listener_thread is None: return if topics and patterns: @@ -1031,20 +1040,20 @@ def unregister(self, topics=None, patterns=None, cb=None): # see if registered streams can be unregistered for pattern in patterns: self._unregister_stream(fnmatch.filter(self._stream_subs.all_topics, pattern), cb) - pubsub_unsubscribe_list = self._filter_topics_cb(patterns, cb) + pubsub_unsubscribe_list = self._filter_topics_cb(patterns, cb, kwargs) if pubsub_unsubscribe_list: self._pubsub_conn.punsubscribe(pubsub_unsubscribe_list) elif topics is not None: topics, _ = self._convert_endpointinfo(topics, check_message_op=False) if not self._unregister_stream(topics, cb): - unsubscribe_list = self._filter_topics_cb(topics, cb) + unsubscribe_list = self._filter_topics_cb(topics, cb, kwargs) if unsubscribe_list: self._pubsub_conn.unsubscribe(unsubscribe_list) else: with self._topics_cb_lock: topics = list(self._topics_cb.keys()) - self.unregister(topics, cb) - self.unregister(self._stream_subs.all_topics, cb) + self.unregister(topics=topics, cb=cb, **kwargs) + self.unregister(topics=self._stream_subs.all_topics, cb=cb, **kwargs) def _unregister_stream(self, topics: list[str], cb: Callable | None = None) -> bool: """Unregister callbacks from a list of topics. Returns true if any were removed""" diff --git a/bec_lib/tests/test_redis_connector_fakeredis.py b/bec_lib/tests/test_redis_connector_fakeredis.py index abc792ed4..8dc2b63cf 100644 --- a/bec_lib/tests/test_redis_connector_fakeredis.py +++ b/bec_lib/tests/test_redis_connector_fakeredis.py @@ -180,6 +180,99 @@ def test_redis_connector_unregister_cb_not_topic(connected_connector): assert received_event2.call_count == 2 +def test_redis_connector_unregister_cb_not_topic_with_kwargs(connected_connector): + connector = connected_connector + + topic1 = EndpointInfo("topic1", TestMessage, MessageOp.SEND) + topic2 = EndpointInfo("topic2", TestMessage, MessageOp.SEND) + + received_event1 = mock.Mock(spec=[]) + received_event2 = mock.Mock(spec=[]) + + connector.register(topics=topic1, cb=received_event1, start_thread=False, a=1) + connector.register(topics=topic1, cb=received_event2, start_thread=False, b=2) + connector.register(topics=topic2, cb=received_event1, start_thread=False, c=3) + + connector.send(topic1, TestMessage(msg="topic1")) + connector.poll_messages(timeout=1) + received_event1.assert_called_once_with(MessageObject("topic1", TestMessage(msg="topic1")), a=1) + received_event2.assert_called_once_with(MessageObject("topic1", TestMessage(msg="topic1")), b=2) + + connector.unregister(topic1, cb=received_event1) + + received_event1.reset_mock() + received_event2.reset_mock() + connector.send(topic1, TestMessage(msg="topic1")) + connector.poll_messages(timeout=1) + received_event1.assert_not_called() + received_event2.assert_called_once_with(MessageObject("topic1", TestMessage(msg="topic1")), b=2) + + connector.send(topic2, TestMessage(msg="topic2")) + connector.poll_messages(timeout=1) + received_event1.assert_called_once_with(MessageObject("topic2", TestMessage(msg="topic2")), c=3) + assert list(connector._topics_cb.keys()) == ["topic1", "topic2"] + + +def test_redis_connector_unregister_only_specified_callback_kwargs(connected_connector): + connector = connected_connector + + topic1 = EndpointInfo("topic1", TestMessage, MessageOp.SEND) + received_event = mock.Mock(spec=[]) + + connector.register(topics=topic1, cb=received_event, start_thread=False, a=1) + connector.register(topics=topic1, cb=received_event, start_thread=False, a=2) + + connector.send(topic1, TestMessage(msg="topic1")) + connector.poll_messages(timeout=1) + received_event.assert_has_calls( + [ + mock.call(MessageObject("topic1", TestMessage(msg="topic1")), a=1), + mock.call(MessageObject("topic1", TestMessage(msg="topic1")), a=2), + ] + ) + + received_event.reset_mock() + connector.unregister(topic1, cb=received_event, a=1) + + connector.send(topic1, TestMessage(msg="topic1")) + connector.poll_messages(timeout=1) + received_event.assert_called_once_with(MessageObject("topic1", TestMessage(msg="topic1")), a=2) + assert len(connector._topics_cb["topic1"]) == 1 + + +def test_redis_connector_unregister_bound_method_with_kwargs(connected_connector): + connector = connected_connector + + topic1 = EndpointInfo("topic1", TestMessage, MessageOp.SEND) + + class Receiver: + def __init__(self): + self.received = mock.Mock(spec=[]) + + def on_message(self, msg_obj, scan_id): + self.received(msg_obj, scan_id=scan_id) + + receiver = Receiver() + + connector.register(topics=topic1, cb=receiver.on_message, start_thread=False, scan_id="scan-1") + + connector.send(topic1, TestMessage(msg="topic1")) + connector.poll_messages(timeout=1) + receiver.received.assert_called_once_with( + MessageObject("topic1", TestMessage(msg="topic1")), scan_id="scan-1" + ) + + receiver.received.reset_mock() + connector.unregister(topic1, cb=receiver.on_message, scan_id="scan-1") + + connector.send(topic1, TestMessage(msg="topic1")) + with pytest.raises(TimeoutError): + connector.poll_messages(timeout=1) + receiver.received.assert_not_called() + assert connector._redis_conn.execute_command("PUBSUB CHANNELS") == [] + assert len(connector._topics_cb) == 0 + + def test_redis_connector_unregister_topic_keeps_others_alive(connected_connector): def send_msgs_and_poll(timeout=None): connector.send(topic1, TestMessage()) @@ -223,6 +316,57 @@ def send_msgs_and_poll(timeout=None): assert received_event2.call_count == 2 +def test_redis_connector_unregister_all_callback_subscriptions_with_kwargs(connected_connector): + connector = connected_connector + + received_event1 = mock.Mock(spec=[]) + received_event2 = mock.Mock(spec=[]) + + connector.register(topics="topic1", cb=received_event1, start_thread=False, a=1) + connector.register(topics="topic2", cb=received_event1, start_thread=False, b=2) + connector.register(topics="topic2", cb=received_event2, start_thread=False, c=3) + + connector.unregister(cb=received_event1) + + connector.send("topic1", TestMessage(msg="topic1")) + connector.send("topic2", TestMessage(msg="topic2")) + connector.poll_messages(timeout=1) + + received_event1.assert_not_called() + received_event2.assert_called_once() + assert list(connector._topics_cb.keys()) == ["topic2"] + + +def test_redis_connector_unregister_same_callback_registered_with_multiple_kwargs( + connected_connector, +): + connector = connected_connector + + received_event = mock.Mock(spec=[]) + + connector.register(topics="topic1", cb=received_event, start_thread=False, a=1) + connector.register(topics="topic1", cb=received_event, start_thread=False, a=2) + + connector.send("topic1", TestMessage(msg="topic1")) + connector.poll_messages(timeout=1) + + received_event.assert_has_calls( + [ + mock.call(MessageObject("topic1", TestMessage(msg="topic1")), a=1), + mock.call(MessageObject("topic1", TestMessage(msg="topic1")), a=2), + ] + ) + assert received_event.call_count == 2 + + received_event.reset_mock() + connector.unregister("topic1", cb=received_event) + connector.send("topic1", TestMessage(msg="topic1")) + + assert received_event.call_count == 0 + assert connector._redis_conn.execute_command("PUBSUB CHANNELS") == [] + assert len(connector._topics_cb) == 0 + + def test_redis_register_poll_messages(connected_connector): connector = connected_connector cb_fcn_has_been_called = False diff --git a/bec_server/bec_server/scan_bundler/bec_emitter.py b/bec_server/bec_server/scan_bundler/bec_emitter.py index 69d7ddb23..212d82b03 100644 --- a/bec_server/bec_server/scan_bundler/bec_emitter.py +++ b/bec_server/bec_server/scan_bundler/bec_emitter.py @@ -3,7 +3,7 @@ import threading import time from queue import Queue -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast from bec_lib import messages from bec_lib.endpoints import MessageEndpoints @@ -14,6 +14,8 @@ logger = bec_logger.logger if TYPE_CHECKING: + from bec_lib.redis_connector import MessageObject + from .scan_bundler import ScanBundler @@ -96,7 +98,8 @@ def _send_bec_scan_point(self, scan_id: str, point_id: int) -> None: MessageEndpoints.scan_segment(), MessageEndpoints.public_scan_segment(scan_id=scan_id, point_id=point_id), ) - self._update_scan_progress(scan_id, point_id) + if not sb.sync_storage[scan_id].get("device_progress_sub"): + self._update_scan_progress(scan_id, point_id) def _update_scan_progress(self, scan_id: str, point_id: int, done=False) -> None: if scan_id not in self.scan_bundler.sync_storage: @@ -107,18 +110,36 @@ def _update_scan_progress(self, scan_id: str, point_id: int, done=False) -> None info = self.scan_bundler.sync_storage[scan_id]["info"] num_monitored_readouts = info.get("num_monitored_readouts", info.get("num_points", 0)) - + value = point_id + 1 + max_value = num_monitored_readouts or point_id + 1 + self.send_scan_progress(scan_id, value=value, max_value=max_value, done=done) + + def send_scan_progress(self, scan_id: str, value: float, max_value: float, done=False) -> None: + """ + Send a scan progress update. + + Args: + scan_id (str): The ID of the scan. + value (float): The current progress value. + max_value (float): The maximum progress value. + done (bool): Whether the scan is done. + """ + storage = self.scan_bundler.sync_storage.get(scan_id) + if not storage: + return + info = storage["info"] msg = messages.ProgressMessage( - value=point_id + 1, - max_value=num_monitored_readouts or point_id + 1, + value=value, + max_value=max_value, done=done, metadata={ "scan_id": scan_id, "RID": info.get("RID", ""), "queue_id": info.get("queue_id", ""), - "status": self.scan_bundler.sync_storage[scan_id]["status"], + "status": storage["status"], }, ) + storage["last_progress_sent"] = msg self.scan_bundler.connector.set_and_publish(MessageEndpoints.scan_progress(), msg) def _send_baseline(self, scan_id: str) -> None: @@ -141,29 +162,88 @@ def _send_baseline(self, scan_id: str) -> None: pipe.execute() def on_scan_status_update(self, status_msg: messages.ScanStatusMessage): + sb = self.scan_bundler + if status_msg.scan_id not in sb.sync_storage: + logger.warning( + f"Cannot update scan progress: Scan {status_msg.scan_id} not found in sync storage." + ) + return + if status_msg.status == "open": - # No need to update progress for an open scan. This is handled by the scan point emit. + # Update progress subscription: + # - If the scan report instruction contains "scan_progress", we simply emit + # progress updates as they come in. + # - If the scan report instruction contains "device_progress", we subscribe + # to the progress of the first device and use that as the progress for the whole scan. + self._update_device_progress_subscription(status_msg.scan_id) return num_points = max(status_msg.info.get("num_points", 0) - 1, 0) num_monitored_readouts = status_msg.info.get("num_monitored_readouts", num_points) if status_msg.status == "closed": - self._update_scan_progress(status_msg.scan_id, num_monitored_readouts, done=True) + storage = sb.sync_storage[status_msg.scan_id] + device_sub = storage.get("device_progress_sub") + if not device_sub: + self._update_scan_progress(status_msg.scan_id, num_monitored_readouts, done=True) + return + + self.connector.unregister(**storage["device_progress_sub"]) + self._emit_last_progress(status_msg.scan_id) return - sb = self.scan_bundler - if status_msg.scan_id not in sb.sync_storage: - logger.warning( - f"Cannot update scan progress: Scan {status_msg.scan_id} not found in sync storage." - ) - return + # Scan is not open or closed but instead in ["aborted", "halted", "user_completed"] storage = sb.sync_storage[status_msg.scan_id] + if storage.get("device_progress_sub"): + self.connector.unregister(**storage["device_progress_sub"]) + self._emit_last_progress(status_msg.scan_id) + return sent_vals = storage.get("sent", {0}) or {0} max_point = max(sent_vals) self._update_scan_progress(status_msg.scan_id, max_point, done=True) + def on_cleanup(self, scan_id: str): + if scan_id in self.scan_bundler.sync_storage: + device_progress_sub = self.scan_bundler.sync_storage[scan_id].get("device_progress_sub") + if device_progress_sub: + self.connector.unregister(**device_progress_sub) + def shutdown(self): if self._buffered_connector_thread: self._buffered_publisher_stop_event.set() self._buffered_connector_thread.join() self._buffered_connector_thread = None + + ############################################################# + ################# Device Progress Helpers ################### + ############################################################# + + def _update_device_progress_subscription(self, scan_id: str): + sb = self.scan_bundler + instructions = sb.scan_report_instructions.get(scan_id, []) + if sb.sync_storage[scan_id].get("device_progress_sub"): + return + for instruction in instructions: + if "device_progress" in instruction: + device = instruction["device_progress"][0] + sub = sb.sync_storage[scan_id]["device_progress_sub"] = { + "topics": MessageEndpoints.device_progress(device=device), + "cb": self._on_device_progress, + "scan_id": scan_id, + } + + self.connector.register(**sub) + + def _emit_last_progress(self, scan_id: str): + storage = self.scan_bundler.sync_storage.get(scan_id, {}) + msg = storage.get("last_progress_sent") + value = msg.value if msg else 0 + max_value = msg.max_value if msg else 0 + self.send_scan_progress(scan_id, value=value, max_value=max_value, done=True) + + def _on_device_progress(self, msg_obj: MessageObject, scan_id: str): + msg = cast(messages.ProgressMessage, msg_obj.value) + if msg.done: + sub_info = self.scan_bundler.sync_storage.get(scan_id, {}).get("device_progress_sub") + if sub_info: + self.connector.unregister(**sub_info) + self.send_scan_progress(scan_id, value=msg.value, max_value=msg.max_value, done=msg.done) diff --git a/bec_server/bec_server/scan_bundler/scan_bundler.py b/bec_server/bec_server/scan_bundler/scan_bundler.py index 78fd7a0be..27ecd75d4 100644 --- a/bec_server/bec_server/scan_bundler/scan_bundler.py +++ b/bec_server/bec_server/scan_bundler/scan_bundler.py @@ -6,7 +6,7 @@ import traceback from collections.abc import Callable from concurrent.futures import ThreadPoolExecutor -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast from bec_lib import messages from bec_lib.bec_service import BECService @@ -17,7 +17,7 @@ from .bec_emitter import BECEmitter if TYPE_CHECKING: - from bec_lib.redis_connector import RedisConnector + from bec_lib.redis_connector import MessageObject, RedisConnector logger = bec_logger.logger @@ -35,12 +35,13 @@ def __init__(self, config, connector_cls: type[RedisConnector]) -> None: name="device_read_register", ) self.connector.register(MessageEndpoints.scan_status(), cb=self._scan_status_callback) - self.sync_storage = {} self.monitored_devices = {} self.baseline_devices = {} self.device_storage = {} self.readout_priority = {} + self.scan_queue: messages.ScanQueueStatusMessage | None = None + self.scan_report_instructions: dict[str, list] = {} self.storage_initialized = set() self.executor = ThreadPoolExecutor(max_workers=4) self.executor_tasks = collections.deque(maxlen=100) @@ -48,6 +49,9 @@ def __init__(self, config, connector_cls: type[RedisConnector]) -> None: self._lock = threading.Lock() self._emitter = [] self._initialize_emitters() + self.connector.register( + MessageEndpoints.scan_queue_status(), cb=self.on_scan_queue_status_update + ) self.status = messages.BECStatus.RUNNING def _initialize_emitters(self): @@ -95,6 +99,35 @@ def handle_scan_status_message(self, msg: messages.ScanStatusMessage) -> None: self._scan_status_modification(msg) self.run_emitter("on_scan_status_update", msg) + def on_scan_queue_status_update(self, msg_obj: MessageObject): + """ + Update the scan_report_instructions based on the active request block + in the scan queue status message. + + Args: + status_msg (messages.ScanQueueStatusMessage): The scan queue status message + containing the active request block. + """ + status_msg = cast(messages.ScanQueueStatusMessage, msg_obj.value) + for scan_queue_status in status_msg.queue.values(): + if not scan_queue_status.info: + continue + info = scan_queue_status.info[0] + active_request_block = info.active_request_block + if not active_request_block: + continue + scan_id = active_request_block.scan_id + if scan_id is None: + continue + report_instructions = active_request_block.report_instructions + if not report_instructions: + continue + + self.scan_report_instructions[scan_id] = report_instructions + logger.debug( + f"Updated report instructions for scan_id {scan_id}: {report_instructions}" + ) + def _scan_status_modification(self, msg: messages.ScanStatusMessage): status = msg.content.get("status") if status not in ["closed", "aborted", "paused", "halted", "user_completed"]: @@ -358,6 +391,7 @@ def cleanup_storage(self): remove_scan_ids.append(scan_id) for scan_id in remove_scan_ids: + self.run_emitter("on_cleanup", scan_id) for storage in [ "sync_storage", "monitored_devices", @@ -368,7 +402,6 @@ def cleanup_storage(self): getattr(self, storage).pop(scan_id) except KeyError: logger.warning(f"Failed to remove {scan_id} from {storage}.") - self.run_emitter("on_cleanup", scan_id) self.storage_initialized.remove(scan_id) def _send_scan_point(self, scan_id, point_id) -> None: diff --git a/bec_server/tests/tests_scan_bundler/test_bec_emitter.py b/bec_server/tests/tests_scan_bundler/test_bec_emitter.py index 5c21852ce..48b1c5a2d 100644 --- a/bec_server/tests/tests_scan_bundler/test_bec_emitter.py +++ b/bec_server/tests/tests_scan_bundler/test_bec_emitter.py @@ -3,6 +3,7 @@ import pytest from bec_lib import messages +from bec_lib.connector import MessageObject from bec_lib.endpoints import MessageEndpoints from bec_server.scan_bundler.bec_emitter import BECEmitter @@ -52,6 +53,27 @@ def test_send_bec_scan_point(bec_emitter_mock): ) +def test_send_bec_scan_point_skips_point_progress_with_device_progress_sub(bec_emitter_mock): + sb = bec_emitter_mock.scan_bundler + scan_id = "lkajsdlkj" + point_id = 2 + sb.sync_storage[scan_id] = { + "info": {}, + "status": "open", + "sent": set(), + "device_progress_sub": {"topics": MessageEndpoints.device_progress("samx")}, + } + sb.sync_storage[scan_id][point_id] = {} + + with ( + mock.patch.object(bec_emitter_mock, "add_message") as send, + mock.patch.object(bec_emitter_mock, "_update_scan_progress") as update_progress, + ): + bec_emitter_mock._send_bec_scan_point(scan_id, point_id) + send.assert_called_once() + update_progress.assert_not_called() + + def test_send_baseline_BEC(bec_emitter_mock): sb = bec_emitter_mock.scan_bundler scan_id = "lkajsdlkj" @@ -161,6 +183,19 @@ def test_add_message(msg, endpoint, public): emitter.shutdown() +def test_bec_emitter_scan_status_update_open_updates_subscription(bec_emitter_mock): + bec_emitter_mock.scan_bundler.sync_storage["lkajsdlkj"] = { + "info": {}, + "status": "open", + "sent": set(), + "baseline": {}, + } + msg = messages.ScanStatusMessage(scan_id="lkajsdlkj", status="open", info={"num_points": 10}) + with mock.patch.object(bec_emitter_mock, "_update_device_progress_subscription") as update_sub: + bec_emitter_mock.on_scan_status_update(msg) + update_sub.assert_called_once_with("lkajsdlkj") + + @pytest.mark.parametrize( "msg, sent, progress, ref_scan_id", [ @@ -175,7 +210,7 @@ def test_add_message(msg, endpoint, public): scan_id="lkajsdlkj", status="closed", info={"num_points": 10} ), {0, 1}, - 9, # 10 points, but sent 0 and 1, so progress is 9 + 9, "lkajsdlkj", ), ( @@ -192,7 +227,7 @@ def test_add_message(msg, endpoint, public): ), {0, 1}, 1, - "lkajsdlkj", # This is a different scan_id, should not update progress + "lkajsdlkj", ), ( messages.ScanStatusMessage( @@ -200,19 +235,173 @@ def test_add_message(msg, endpoint, public): ), {}, 0, - "lkajsdlkj", # This is a different scan_id, should not update progress + "lkajsdlkj", ), ], ) -def test_bec_emitter_scan_status_update(bec_emitter_mock, msg, sent, progress, ref_scan_id): - +def test_bec_emitter_scan_status_update_point_progress_path( + bec_emitter_mock, msg, sent, progress, ref_scan_id +): sb = bec_emitter_mock.scan_bundler - sb.sync_storage[ref_scan_id] = {"info": {}, "status": msg.status, "sent": sent} - sb.sync_storage[ref_scan_id]["baseline"] = {} + sb.sync_storage[ref_scan_id] = {"info": {}, "status": msg.status, "sent": sent, "baseline": {}} - with mock.patch.object(bec_emitter_mock, "_update_scan_progress") as update: + with ( + mock.patch.object(bec_emitter_mock, "_update_scan_progress") as update, + mock.patch.object(bec_emitter_mock, "_update_device_progress_subscription") as update_sub, + ): bec_emitter_mock.on_scan_status_update(msg) - if msg.status == "open" or msg.scan_id != ref_scan_id: + if msg.status == "open": + update.assert_not_called() + update_sub.assert_called_once_with(msg.scan_id) + elif msg.scan_id != ref_scan_id: update.assert_not_called() + update_sub.assert_not_called() else: update.assert_called_once_with(msg.scan_id, progress, done=True) + update_sub.assert_not_called() + + +def test_bec_emitter_scan_status_update_missing_scan_id_does_not_update(bec_emitter_mock): + msg = messages.ScanStatusMessage( + scan_id="wrong_scan_id", status="aborted", info={"num_points": 10} + ) + with mock.patch.object(bec_emitter_mock, "_update_scan_progress") as update: + bec_emitter_mock.on_scan_status_update(msg) + update.assert_not_called() + + +@pytest.mark.parametrize("status", ["closed", "aborted"]) +def test_bec_emitter_scan_status_update_wrong_scan_id_does_not_emit_progress( + bec_emitter_mock, status +): + msg = messages.ScanStatusMessage( + scan_id="wrong_scan_id", status=status, info={"num_points": 10} + ) + with ( + mock.patch.object(bec_emitter_mock, "_update_scan_progress") as update, + mock.patch.object(bec_emitter_mock, "send_scan_progress") as send_scan_progress, + ): + bec_emitter_mock.on_scan_status_update(msg) + update.assert_not_called() + send_scan_progress.assert_not_called() + + +def test_update_device_progress_subscription_registers_device_progress(bec_emitter_mock): + sb = bec_emitter_mock.scan_bundler + scan_id = "scan_id" + sb.sync_storage[scan_id] = {"info": {}, "status": "open", "sent": set()} + sb.scan_report_instructions[scan_id] = [{"device_progress": ["samx"]}] + + with mock.patch.object(bec_emitter_mock.connector, "register") as register: + bec_emitter_mock._update_device_progress_subscription(scan_id) + + expected_sub = { + "topics": MessageEndpoints.device_progress(device="samx"), + "cb": bec_emitter_mock._on_device_progress, + "scan_id": scan_id, + } + register.assert_called_once_with(**expected_sub) + assert sb.sync_storage[scan_id]["device_progress_sub"] == expected_sub + + +def test_on_device_progress_done_unregisters_and_emits_progress(bec_emitter_mock): + sb = bec_emitter_mock.scan_bundler + scan_id = "scan_id" + sub = { + "topics": MessageEndpoints.device_progress(device="samx"), + "cb": bec_emitter_mock._on_device_progress, + "scan_id": scan_id, + } + sb.sync_storage[scan_id] = { + "info": {}, + "status": "open", + "sent": set(), + "device_progress_sub": sub, + } + progress_msg = messages.ProgressMessage(value=3, max_value=7, done=True) + msg_obj = MessageObject(MessageEndpoints.device_progress("samx").endpoint, progress_msg) + + with ( + mock.patch.object(bec_emitter_mock.connector, "unregister") as unregister, + mock.patch.object(bec_emitter_mock, "send_scan_progress") as send_scan_progress, + ): + bec_emitter_mock._on_device_progress(msg_obj, scan_id) + + unregister.assert_called_once_with(**sub) + send_scan_progress.assert_called_once_with(scan_id, value=3, max_value=7, done=True) + + +def test_scan_status_update_closed_with_device_progress_unsubscribes_and_emits_last_progress( + bec_emitter_mock, +): + sb = bec_emitter_mock.scan_bundler + scan_id = "scan_id" + sub = { + "topics": MessageEndpoints.device_progress(device="samx"), + "cb": bec_emitter_mock._on_device_progress, + "scan_id": scan_id, + } + sb.sync_storage[scan_id] = { + "info": {}, + "status": "closed", + "sent": {0, 1}, + "baseline": {}, + "device_progress_sub": sub, + "last_progress_sent": messages.ProgressMessage(value=4, max_value=9, done=False), + } + msg = messages.ScanStatusMessage(scan_id=scan_id, status="closed", info={"num_points": 10}) + + with ( + mock.patch.object(bec_emitter_mock.connector, "unregister") as unregister, + mock.patch.object(bec_emitter_mock, "send_scan_progress") as send_scan_progress, + ): + bec_emitter_mock.on_scan_status_update(msg) + + unregister.assert_called_once_with(**sub) + send_scan_progress.assert_called_once_with(scan_id, value=4, max_value=9, done=True) + + +@pytest.mark.parametrize("status", ["closed", "aborted"]) +def test_scan_status_update_device_progress_without_last_progress_emits_done_message( + bec_emitter_mock, status +): + sb = bec_emitter_mock.scan_bundler + scan_id = "scan_id" + sub = { + "topics": MessageEndpoints.device_progress(device="samx"), + "cb": bec_emitter_mock._on_device_progress, + "scan_id": scan_id, + } + sb.sync_storage[scan_id] = { + "info": {}, + "status": status, + "sent": {0, 1}, + "baseline": {}, + "device_progress_sub": sub, + } + msg = messages.ScanStatusMessage(scan_id=scan_id, status=status, info={"num_points": 10}) + + with ( + mock.patch.object(bec_emitter_mock.connector, "unregister") as unregister, + mock.patch.object(bec_emitter_mock, "send_scan_progress") as send_scan_progress, + ): + bec_emitter_mock.on_scan_status_update(msg) + + unregister.assert_called_once_with(**sub) + send_scan_progress.assert_called_once_with(scan_id, value=0, max_value=0, done=True) + + +def test_on_cleanup_unregisters_device_progress_subscription(bec_emitter_mock): + sb = bec_emitter_mock.scan_bundler + scan_id = "scan_id" + sub = { + "topics": MessageEndpoints.device_progress(device="samx"), + "cb": bec_emitter_mock._on_device_progress, + "scan_id": scan_id, + } + sb.sync_storage[scan_id] = {"device_progress_sub": sub} + + with mock.patch.object(bec_emitter_mock.connector, "unregister") as unregister: + bec_emitter_mock.on_cleanup(scan_id) + + unregister.assert_called_once_with(**sub)