Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bec_lib/bec_lib/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def register(self, topics=None, patterns=None, cb=None, start_thread=True, **kwa
"""Register a callback for a topic or pattern"""

@abc.abstractmethod
def unregister(self, topics=None, patterns=None, cb=None):
def unregister(self, topics=None, patterns=None, cb=None, **kwargs):
"""Unregister a callback for a topic or pattern"""

@abc.abstractmethod
Expand Down
23 changes: 16 additions & 7 deletions bec_lib/bec_lib/redis_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,14 +1002,23 @@ def _register_stream(
self._stream_events_listener_thread.name += f" ({self.name})"
self._stream_events_listener_thread.start()

def _filter_topics_cb(self, topics: list, cb: Callable | None):
@staticmethod
def _matches_subscription(
item: tuple[louie.saferef.BoundMethodWeakref, dict], cb: Callable | None, kwargs: dict
) -> bool:
cb_ref, item_kwargs = item
if cb is not None and cb_ref() != cb:
return False
return all(item_kwargs.get(key) == value for key, value in kwargs.items())

def _filter_topics_cb(self, topics: list, cb: Callable | None, kwargs: dict):
unsubscribe_list = []
with self._topics_cb_lock:
for topic in topics:
topics_cb = self._topics_cb[topic]
# remove callback from list
self._topics_cb[topic] = list(
filter(lambda item: cb and item[0]() is not cb, topics_cb)
filter(lambda item: not self._matches_subscription(item, cb, kwargs), topics_cb)
)
if not self._topics_cb[topic]:
# no callbacks left, unsubscribe
Expand All @@ -1019,7 +1028,7 @@ def _filter_topics_cb(self, topics: list, cb: Callable | None):
del self._topics_cb[topic]
return unsubscribe_list

def unregister(self, topics=None, patterns=None, cb=None):
def unregister(self, topics=None, patterns=None, cb=None, **kwargs):
if self._events_listener_thread is None:
return
if topics and patterns:
Expand All @@ -1031,20 +1040,20 @@ def unregister(self, topics=None, patterns=None, cb=None):
# see if registered streams can be unregistered
for pattern in patterns:
self._unregister_stream(fnmatch.filter(self._stream_subs.all_topics, pattern), cb)
pubsub_unsubscribe_list = self._filter_topics_cb(patterns, cb)
pubsub_unsubscribe_list = self._filter_topics_cb(patterns, cb, kwargs)
if pubsub_unsubscribe_list:
self._pubsub_conn.punsubscribe(pubsub_unsubscribe_list)
elif topics is not None:
topics, _ = self._convert_endpointinfo(topics, check_message_op=False)
if not self._unregister_stream(topics, cb):
unsubscribe_list = self._filter_topics_cb(topics, cb)
unsubscribe_list = self._filter_topics_cb(topics, cb, kwargs)
if unsubscribe_list:
self._pubsub_conn.unsubscribe(unsubscribe_list)
else:
with self._topics_cb_lock:
topics = list(self._topics_cb.keys())
self.unregister(topics, cb)
self.unregister(self._stream_subs.all_topics, cb)
self.unregister(topics=topics, cb=cb, **kwargs)
self.unregister(topics=self._stream_subs.all_topics, cb=cb, **kwargs)

def _unregister_stream(self, topics: list[str], cb: Callable | None = None) -> bool:
"""Unregister callbacks from a list of topics. Returns true if any were removed"""
Expand Down
144 changes: 144 additions & 0 deletions bec_lib/tests/test_redis_connector_fakeredis.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,99 @@ def test_redis_connector_unregister_cb_not_topic(connected_connector):
assert received_event2.call_count == 2


def test_redis_connector_unregister_cb_not_topic_with_kwargs(connected_connector):
connector = connected_connector

topic1 = EndpointInfo("topic1", TestMessage, MessageOp.SEND)
topic2 = EndpointInfo("topic2", TestMessage, MessageOp.SEND)

received_event1 = mock.Mock(spec=[])
received_event2 = mock.Mock(spec=[])

connector.register(topics=topic1, cb=received_event1, start_thread=False, a=1)
connector.register(topics=topic1, cb=received_event2, start_thread=False, b=2)
connector.register(topics=topic2, cb=received_event1, start_thread=False, c=3)

connector.send(topic1, TestMessage(msg="topic1"))
connector.poll_messages(timeout=1)
received_event1.assert_called_once_with(MessageObject("topic1", TestMessage(msg="topic1")), a=1)
received_event2.assert_called_once_with(MessageObject("topic1", TestMessage(msg="topic1")), b=2)

connector.unregister(topic1, cb=received_event1)

received_event1.reset_mock()
received_event2.reset_mock()
connector.send(topic1, TestMessage(msg="topic1"))
connector.poll_messages(timeout=1)
received_event1.assert_not_called()
received_event2.assert_called_once_with(MessageObject("topic1", TestMessage(msg="topic1")), b=2)

connector.send(topic2, TestMessage(msg="topic2"))
connector.poll_messages(timeout=1)
received_event1.assert_called_once_with(MessageObject("topic2", TestMessage(msg="topic2")), c=3)
assert list(connector._topics_cb.keys()) == ["topic1", "topic2"]


def test_redis_connector_unregister_only_specified_callback_kwargs(connected_connector):
connector = connected_connector

topic1 = EndpointInfo("topic1", TestMessage, MessageOp.SEND)
received_event = mock.Mock(spec=[])

connector.register(topics=topic1, cb=received_event, start_thread=False, a=1)
connector.register(topics=topic1, cb=received_event, start_thread=False, a=2)

connector.send(topic1, TestMessage(msg="topic1"))
connector.poll_messages(timeout=1)
received_event.assert_has_calls(
[
mock.call(MessageObject("topic1", TestMessage(msg="topic1")), a=1),
mock.call(MessageObject("topic1", TestMessage(msg="topic1")), a=2),
]
)

received_event.reset_mock()
connector.unregister(topic1, cb=received_event, a=1)

connector.send(topic1, TestMessage(msg="topic1"))
connector.poll_messages(timeout=1)
received_event.assert_called_once_with(MessageObject("topic1", TestMessage(msg="topic1")), a=2)
assert len(connector._topics_cb["topic1"]) == 1


def test_redis_connector_unregister_bound_method_with_kwargs(connected_connector):
connector = connected_connector

topic1 = EndpointInfo("topic1", TestMessage, MessageOp.SEND)

class Receiver:
def __init__(self):
self.received = mock.Mock(spec=[])

def on_message(self, msg_obj, scan_id):
self.received(msg_obj, scan_id=scan_id)

receiver = Receiver()

connector.register(topics=topic1, cb=receiver.on_message, start_thread=False, scan_id="scan-1")

connector.send(topic1, TestMessage(msg="topic1"))
connector.poll_messages(timeout=1)
receiver.received.assert_called_once_with(
MessageObject("topic1", TestMessage(msg="topic1")), scan_id="scan-1"
)

receiver.received.reset_mock()
connector.unregister(topic1, cb=receiver.on_message, scan_id="scan-1")

connector.send(topic1, TestMessage(msg="topic1"))
with pytest.raises(TimeoutError):
connector.poll_messages(timeout=1)
receiver.received.assert_not_called()
assert connector._redis_conn.execute_command("PUBSUB CHANNELS") == []
assert len(connector._topics_cb) == 0


def test_redis_connector_unregister_topic_keeps_others_alive(connected_connector):
def send_msgs_and_poll(timeout=None):
connector.send(topic1, TestMessage())
Expand Down Expand Up @@ -223,6 +316,57 @@ def send_msgs_and_poll(timeout=None):
assert received_event2.call_count == 2


def test_redis_connector_unregister_all_callback_subscriptions_with_kwargs(connected_connector):
connector = connected_connector

received_event1 = mock.Mock(spec=[])
received_event2 = mock.Mock(spec=[])

connector.register(topics="topic1", cb=received_event1, start_thread=False, a=1)
connector.register(topics="topic2", cb=received_event1, start_thread=False, b=2)
connector.register(topics="topic2", cb=received_event2, start_thread=False, c=3)

connector.unregister(cb=received_event1)

connector.send("topic1", TestMessage(msg="topic1"))
connector.send("topic2", TestMessage(msg="topic2"))
connector.poll_messages(timeout=1)

received_event1.assert_not_called()
received_event2.assert_called_once()
assert list(connector._topics_cb.keys()) == ["topic2"]


def test_redis_connector_unregister_same_callback_registered_with_multiple_kwargs(
connected_connector,
):
connector = connected_connector

received_event = mock.Mock(spec=[])

connector.register(topics="topic1", cb=received_event, start_thread=False, a=1)
connector.register(topics="topic1", cb=received_event, start_thread=False, a=2)

connector.send("topic1", TestMessage(msg="topic1"))
connector.poll_messages(timeout=1)

received_event.assert_has_calls(
[
mock.call(MessageObject("topic1", TestMessage(msg="topic1")), a=1),
mock.call(MessageObject("topic1", TestMessage(msg="topic1")), a=2),
]
)
assert received_event.call_count == 2

received_event.reset_mock()
connector.unregister("topic1", cb=received_event)
connector.send("topic1", TestMessage(msg="topic1"))

assert received_event.call_count == 0
assert connector._redis_conn.execute_command("PUBSUB CHANNELS") == []
assert len(connector._topics_cb) == 0


def test_redis_register_poll_messages(connected_connector):
connector = connected_connector
cb_fcn_has_been_called = False
Expand Down
108 changes: 94 additions & 14 deletions bec_server/bec_server/scan_bundler/bec_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import threading
import time
from queue import Queue
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast

from bec_lib import messages
from bec_lib.endpoints import MessageEndpoints
Expand All @@ -14,6 +14,8 @@
logger = bec_logger.logger

if TYPE_CHECKING:
from bec_lib.redis_connector import MessageObject

from .scan_bundler import ScanBundler


Expand Down Expand Up @@ -96,7 +98,8 @@ def _send_bec_scan_point(self, scan_id: str, point_id: int) -> None:
MessageEndpoints.scan_segment(),
MessageEndpoints.public_scan_segment(scan_id=scan_id, point_id=point_id),
)
self._update_scan_progress(scan_id, point_id)
if not sb.sync_storage[scan_id].get("device_progress_sub"):
self._update_scan_progress(scan_id, point_id)

def _update_scan_progress(self, scan_id: str, point_id: int, done=False) -> None:
if scan_id not in self.scan_bundler.sync_storage:
Expand All @@ -107,18 +110,36 @@ def _update_scan_progress(self, scan_id: str, point_id: int, done=False) -> None
info = self.scan_bundler.sync_storage[scan_id]["info"]

num_monitored_readouts = info.get("num_monitored_readouts", info.get("num_points", 0))

value = point_id + 1
max_value = num_monitored_readouts or point_id + 1
self.send_scan_progress(scan_id, value=value, max_value=max_value, done=done)

def send_scan_progress(self, scan_id: str, value: float, max_value: float, done=False) -> None:
"""
Send a scan progress update.

Args:
scan_id (str): The ID of the scan.
value (float): The current progress value.
max_value (float): The maximum progress value.
done (bool): Whether the scan is done.
"""
storage = self.scan_bundler.sync_storage.get(scan_id)
if not storage:
return
info = storage["info"]
msg = messages.ProgressMessage(
value=point_id + 1,
max_value=num_monitored_readouts or point_id + 1,
value=value,
max_value=max_value,
done=done,
metadata={
"scan_id": scan_id,
"RID": info.get("RID", ""),
"queue_id": info.get("queue_id", ""),
"status": self.scan_bundler.sync_storage[scan_id]["status"],
"status": storage["status"],
},
)
storage["last_progress_sent"] = msg
self.scan_bundler.connector.set_and_publish(MessageEndpoints.scan_progress(), msg)

def _send_baseline(self, scan_id: str) -> None:
Expand All @@ -141,29 +162,88 @@ def _send_baseline(self, scan_id: str) -> None:
pipe.execute()

def on_scan_status_update(self, status_msg: messages.ScanStatusMessage):
sb = self.scan_bundler
if status_msg.scan_id not in sb.sync_storage:
logger.warning(
f"Cannot update scan progress: Scan {status_msg.scan_id} not found in sync storage."
)
return

if status_msg.status == "open":
# No need to update progress for an open scan. This is handled by the scan point emit.
# Update progress subscription:
# - If the scan report instruction contains "scan_progress", we simply emit
# progress updates as they come in.
# - If the scan report instruction contains "device_progress", we subscribe
# to the progress of the first device and use that as the progress for the whole scan.
self._update_device_progress_subscription(status_msg.scan_id)
return

num_points = max(status_msg.info.get("num_points", 0) - 1, 0)
num_monitored_readouts = status_msg.info.get("num_monitored_readouts", num_points)
if status_msg.status == "closed":
self._update_scan_progress(status_msg.scan_id, num_monitored_readouts, done=True)
storage = sb.sync_storage[status_msg.scan_id]
device_sub = storage.get("device_progress_sub")
if not device_sub:
self._update_scan_progress(status_msg.scan_id, num_monitored_readouts, done=True)
return

self.connector.unregister(**storage["device_progress_sub"])
self._emit_last_progress(status_msg.scan_id)
return

sb = self.scan_bundler
if status_msg.scan_id not in sb.sync_storage:
logger.warning(
f"Cannot update scan progress: Scan {status_msg.scan_id} not found in sync storage."
)
return
# Scan is not open or closed but instead in ["aborted", "halted", "user_completed"]
storage = sb.sync_storage[status_msg.scan_id]
if storage.get("device_progress_sub"):
self.connector.unregister(**storage["device_progress_sub"])
self._emit_last_progress(status_msg.scan_id)
return
sent_vals = storage.get("sent", {0}) or {0}
max_point = max(sent_vals)
self._update_scan_progress(status_msg.scan_id, max_point, done=True)

def on_cleanup(self, scan_id: str):
if scan_id in self.scan_bundler.sync_storage:
device_progress_sub = self.scan_bundler.sync_storage[scan_id].get("device_progress_sub")
if device_progress_sub:
self.connector.unregister(**device_progress_sub)

def shutdown(self):
if self._buffered_connector_thread:
self._buffered_publisher_stop_event.set()
self._buffered_connector_thread.join()
self._buffered_connector_thread = None

#############################################################
################# Device Progress Helpers ###################
#############################################################

def _update_device_progress_subscription(self, scan_id: str):
sb = self.scan_bundler
instructions = sb.scan_report_instructions.get(scan_id, [])
if sb.sync_storage[scan_id].get("device_progress_sub"):
return
for instruction in instructions:
if "device_progress" in instruction:
device = instruction["device_progress"][0]
sub = sb.sync_storage[scan_id]["device_progress_sub"] = {
"topics": MessageEndpoints.device_progress(device=device),
"cb": self._on_device_progress,
"scan_id": scan_id,
}

self.connector.register(**sub)

def _emit_last_progress(self, scan_id: str):
storage = self.scan_bundler.sync_storage.get(scan_id, {})
msg = storage.get("last_progress_sent")
value = msg.value if msg else 0
max_value = msg.max_value if msg else 0
self.send_scan_progress(scan_id, value=value, max_value=max_value, done=True)

def _on_device_progress(self, msg_obj: MessageObject, scan_id: str):
msg = cast(messages.ProgressMessage, msg_obj.value)
if msg.done:
sub_info = self.scan_bundler.sync_storage.get(scan_id, {}).get("device_progress_sub")
if sub_info:
self.connector.unregister(**sub_info)
self.send_scan_progress(scan_id, value=msg.value, max_value=msg.max_value, done=msg.done)
Loading
Loading