Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bec_lib/bec_lib/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def register(self, topics=None, patterns=None, cb=None, start_thread=True, **kwa
"""Register a callback for a topic or pattern"""

@abc.abstractmethod
def unregister(self, topics=None, patterns=None, cb=None):
def unregister(self, topics=None, patterns=None, cb=None, **kwargs):
"""Unregister a callback for a topic or pattern"""

@abc.abstractmethod
Expand Down
23 changes: 16 additions & 7 deletions bec_lib/bec_lib/redis_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,14 +1002,23 @@ def _register_stream(
self._stream_events_listener_thread.name += f" ({self.name})"
self._stream_events_listener_thread.start()

def _filter_topics_cb(self, topics: list, cb: Callable | None):
@staticmethod
def _matches_subscription(
item: tuple[louie.saferef.BoundMethodWeakref, dict], cb: Callable | None, kwargs: dict
) -> bool:
cb_ref, item_kwargs = item
if cb is not None and cb_ref() != cb:
Comment on lines +1006 to +1010
return False
return all(key in item_kwargs and item_kwargs[key] == value for key, value in kwargs.items())

def _filter_topics_cb(self, topics: list, cb: Callable | None, kwargs: dict):
unsubscribe_list = []
with self._topics_cb_lock:
for topic in topics:
topics_cb = self._topics_cb[topic]
# remove callback from list
self._topics_cb[topic] = list(
filter(lambda item: cb and item[0]() is not cb, topics_cb)
filter(lambda item: not self._matches_subscription(item, cb, kwargs), topics_cb)
)
if not self._topics_cb[topic]:
# no callbacks left, unsubscribe
Expand All @@ -1019,7 +1028,7 @@ def _filter_topics_cb(self, topics: list, cb: Callable | None):
del self._topics_cb[topic]
return unsubscribe_list

def unregister(self, topics=None, patterns=None, cb=None):
def unregister(self, topics=None, patterns=None, cb=None, **kwargs):
if self._events_listener_thread is None:
return
if topics and patterns:
Expand All @@ -1031,20 +1040,20 @@ def unregister(self, topics=None, patterns=None, cb=None):
# see if registered streams can be unregistered
for pattern in patterns:
self._unregister_stream(fnmatch.filter(self._stream_subs.all_topics, pattern), cb)
pubsub_unsubscribe_list = self._filter_topics_cb(patterns, cb)
pubsub_unsubscribe_list = self._filter_topics_cb(patterns, cb, kwargs)
if pubsub_unsubscribe_list:
self._pubsub_conn.punsubscribe(pubsub_unsubscribe_list)
elif topics is not None:
topics, _ = self._convert_endpointinfo(topics, check_message_op=False)
if not self._unregister_stream(topics, cb):
unsubscribe_list = self._filter_topics_cb(topics, cb)
unsubscribe_list = self._filter_topics_cb(topics, cb, kwargs)
if unsubscribe_list:
self._pubsub_conn.unsubscribe(unsubscribe_list)
else:
with self._topics_cb_lock:
topics = list(self._topics_cb.keys())
self.unregister(topics, cb)
self.unregister(self._stream_subs.all_topics, cb)
self.unregister(topics=topics, cb=cb, **kwargs)
self.unregister(topics=self._stream_subs.all_topics, cb=cb, **kwargs)

def _unregister_stream(self, topics: list[str], cb: Callable | None = None) -> bool:
"""Unregister callbacks from a list of topics. Returns true if any were removed"""
Expand Down
144 changes: 144 additions & 0 deletions bec_lib/tests/test_redis_connector_fakeredis.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,99 @@ def test_redis_connector_unregister_cb_not_topic(connected_connector):
assert received_event2.call_count == 2


def test_redis_connector_unregister_cb_not_topic_with_kwargs(connected_connector):
connector = connected_connector

topic1 = EndpointInfo("topic1", TestMessage, MessageOp.SEND)
topic2 = EndpointInfo("topic2", TestMessage, MessageOp.SEND)

received_event1 = mock.Mock(spec=[])
received_event2 = mock.Mock(spec=[])

connector.register(topics=topic1, cb=received_event1, start_thread=False, a=1)
connector.register(topics=topic1, cb=received_event2, start_thread=False, b=2)
connector.register(topics=topic2, cb=received_event1, start_thread=False, c=3)

connector.send(topic1, TestMessage(msg="topic1"))
connector.poll_messages(timeout=1)
received_event1.assert_called_once_with(MessageObject("topic1", TestMessage(msg="topic1")), a=1)
received_event2.assert_called_once_with(MessageObject("topic1", TestMessage(msg="topic1")), b=2)

connector.unregister(topic1, cb=received_event1)

received_event1.reset_mock()
received_event2.reset_mock()
connector.send(topic1, TestMessage(msg="topic1"))
connector.poll_messages(timeout=1)
received_event1.assert_not_called()
received_event2.assert_called_once_with(MessageObject("topic1", TestMessage(msg="topic1")), b=2)

connector.send(topic2, TestMessage(msg="topic2"))
connector.poll_messages(timeout=1)
received_event1.assert_called_once_with(MessageObject("topic2", TestMessage(msg="topic2")), c=3)
assert list(connector._topics_cb.keys()) == ["topic1", "topic2"]


def test_redis_connector_unregister_only_specified_callback_kwargs(connected_connector):
connector = connected_connector

topic1 = EndpointInfo("topic1", TestMessage, MessageOp.SEND)
received_event = mock.Mock(spec=[])

connector.register(topics=topic1, cb=received_event, start_thread=False, a=1)
connector.register(topics=topic1, cb=received_event, start_thread=False, a=2)

connector.send(topic1, TestMessage(msg="topic1"))
connector.poll_messages(timeout=1)
received_event.assert_has_calls(
[
mock.call(MessageObject("topic1", TestMessage(msg="topic1")), a=1),
mock.call(MessageObject("topic1", TestMessage(msg="topic1")), a=2),
]
)

received_event.reset_mock()
connector.unregister(topic1, cb=received_event, a=1)

connector.send(topic1, TestMessage(msg="topic1"))
connector.poll_messages(timeout=1)
received_event.assert_called_once_with(MessageObject("topic1", TestMessage(msg="topic1")), a=2)
assert len(connector._topics_cb["topic1"]) == 1


def test_redis_connector_unregister_bound_method_with_kwargs(connected_connector):
connector = connected_connector

topic1 = EndpointInfo("topic1", TestMessage, MessageOp.SEND)

class Receiver:
def __init__(self):
self.received = mock.Mock(spec=[])

def on_message(self, msg_obj, scan_id):
self.received(msg_obj, scan_id=scan_id)

receiver = Receiver()

connector.register(topics=topic1, cb=receiver.on_message, start_thread=False, scan_id="scan-1")

connector.send(topic1, TestMessage(msg="topic1"))
connector.poll_messages(timeout=1)
receiver.received.assert_called_once_with(
MessageObject("topic1", TestMessage(msg="topic1")), scan_id="scan-1"
)

receiver.received.reset_mock()
connector.unregister(topic1, cb=receiver.on_message, scan_id="scan-1")

connector.send(topic1, TestMessage(msg="topic1"))
with pytest.raises(TimeoutError):
connector.poll_messages(timeout=1)
receiver.received.assert_not_called()
assert connector._redis_conn.execute_command("PUBSUB CHANNELS") == []
assert len(connector._topics_cb) == 0


def test_redis_connector_unregister_topic_keeps_others_alive(connected_connector):
def send_msgs_and_poll(timeout=None):
connector.send(topic1, TestMessage())
Expand Down Expand Up @@ -223,6 +316,57 @@ def send_msgs_and_poll(timeout=None):
assert received_event2.call_count == 2


def test_redis_connector_unregister_all_callback_subscriptions_with_kwargs(connected_connector):
connector = connected_connector

received_event1 = mock.Mock(spec=[])
received_event2 = mock.Mock(spec=[])

connector.register(topics="topic1", cb=received_event1, start_thread=False, a=1)
connector.register(topics="topic2", cb=received_event1, start_thread=False, b=2)
connector.register(topics="topic2", cb=received_event2, start_thread=False, c=3)

connector.unregister(cb=received_event1)

connector.send("topic1", TestMessage(msg="topic1"))
connector.send("topic2", TestMessage(msg="topic2"))
connector.poll_messages(timeout=1)

received_event1.assert_not_called()
received_event2.assert_called_once()
assert list(connector._topics_cb.keys()) == ["topic2"]


def test_redis_connector_unregister_same_callback_registered_with_multiple_kwargs(
connected_connector,
):
connector = connected_connector

received_event = mock.Mock(spec=[])

connector.register(topics="topic1", cb=received_event, start_thread=False, a=1)
connector.register(topics="topic1", cb=received_event, start_thread=False, a=2)

connector.send("topic1", TestMessage(msg="topic1"))
connector.poll_messages(timeout=1)

received_event.assert_has_calls(
[
mock.call(MessageObject("topic1", TestMessage(msg="topic1")), a=1),
mock.call(MessageObject("topic1", TestMessage(msg="topic1")), a=2),
]
)
assert received_event.call_count == 2

received_event.reset_mock()
connector.unregister("topic1", cb=received_event)
connector.send("topic1", TestMessage(msg="topic1"))

assert received_event.call_count == 0
assert connector._redis_conn.execute_command("PUBSUB CHANNELS") == []
assert len(connector._topics_cb) == 0


def test_redis_register_poll_messages(connected_connector):
connector = connected_connector
cb_fcn_has_been_called = False
Expand Down
Loading