From 44ef3155c9db8f8301ac8b05a630c865247fb854 Mon Sep 17 00:00:00 2001 From: wakonig_k Date: Sat, 16 May 2026 13:31:49 +0200 Subject: [PATCH 1/2] fix(redis connector): unsubscibe with kwargs --- bec_lib/bec_lib/connector.py | 2 +- bec_lib/bec_lib/redis_connector.py | 23 ++- .../tests/test_redis_connector_fakeredis.py | 144 ++++++++++++++++++ 3 files changed, 161 insertions(+), 8 deletions(-) diff --git a/bec_lib/bec_lib/connector.py b/bec_lib/bec_lib/connector.py index 107960f4e..9329809f8 100644 --- a/bec_lib/bec_lib/connector.py +++ b/bec_lib/bec_lib/connector.py @@ -139,7 +139,7 @@ def register(self, topics=None, patterns=None, cb=None, start_thread=True, **kwa """Register a callback for a topic or pattern""" @abc.abstractmethod - def unregister(self, topics=None, patterns=None, cb=None): + def unregister(self, topics=None, patterns=None, cb=None, **kwargs): """Unregister a callback for a topic or pattern""" @abc.abstractmethod diff --git a/bec_lib/bec_lib/redis_connector.py b/bec_lib/bec_lib/redis_connector.py index ae36db08c..f824505b8 100644 --- a/bec_lib/bec_lib/redis_connector.py +++ b/bec_lib/bec_lib/redis_connector.py @@ -1002,14 +1002,23 @@ def _register_stream( self._stream_events_listener_thread.name += f" ({self.name})" self._stream_events_listener_thread.start() - def _filter_topics_cb(self, topics: list, cb: Callable | None): + @staticmethod + def _matches_subscription( + item: tuple[louie.saferef.BoundMethodWeakref, dict], cb: Callable | None, kwargs: dict + ) -> bool: + cb_ref, item_kwargs = item + if cb is not None and cb_ref() != cb: + return False + return all(item_kwargs.get(key) == value for key, value in kwargs.items()) + + def _filter_topics_cb(self, topics: list, cb: Callable | None, kwargs: dict): unsubscribe_list = [] with self._topics_cb_lock: for topic in topics: topics_cb = self._topics_cb[topic] # remove callback from list self._topics_cb[topic] = list( - filter(lambda item: cb and item[0]() is not cb, topics_cb) + filter(lambda item: not self._matches_subscription(item, cb, kwargs), topics_cb) ) if not self._topics_cb[topic]: # no callbacks left, unsubscribe @@ -1019,7 +1028,7 @@ def _filter_topics_cb(self, topics: list, cb: Callable | None): del self._topics_cb[topic] return unsubscribe_list - def unregister(self, topics=None, patterns=None, cb=None): + def unregister(self, topics=None, patterns=None, cb=None, **kwargs): if self._events_listener_thread is None: return if topics and patterns: @@ -1031,20 +1040,20 @@ def unregister(self, topics=None, patterns=None, cb=None): # see if registered streams can be unregistered for pattern in patterns: self._unregister_stream(fnmatch.filter(self._stream_subs.all_topics, pattern), cb) - pubsub_unsubscribe_list = self._filter_topics_cb(patterns, cb) + pubsub_unsubscribe_list = self._filter_topics_cb(patterns, cb, kwargs) if pubsub_unsubscribe_list: self._pubsub_conn.punsubscribe(pubsub_unsubscribe_list) elif topics is not None: topics, _ = self._convert_endpointinfo(topics, check_message_op=False) if not self._unregister_stream(topics, cb): - unsubscribe_list = self._filter_topics_cb(topics, cb) + unsubscribe_list = self._filter_topics_cb(topics, cb, kwargs) if unsubscribe_list: self._pubsub_conn.unsubscribe(unsubscribe_list) else: with self._topics_cb_lock: topics = list(self._topics_cb.keys()) - self.unregister(topics, cb) - self.unregister(self._stream_subs.all_topics, cb) + self.unregister(topics=topics, cb=cb, **kwargs) + self.unregister(topics=self._stream_subs.all_topics, cb=cb, **kwargs) def _unregister_stream(self, topics: list[str], cb: Callable | None = None) -> bool: """Unregister callbacks from a list of topics. Returns true if any were removed""" diff --git a/bec_lib/tests/test_redis_connector_fakeredis.py b/bec_lib/tests/test_redis_connector_fakeredis.py index abc792ed4..8dc2b63cf 100644 --- a/bec_lib/tests/test_redis_connector_fakeredis.py +++ b/bec_lib/tests/test_redis_connector_fakeredis.py @@ -180,6 +180,99 @@ def test_redis_connector_unregister_cb_not_topic(connected_connector): assert received_event2.call_count == 2 +def test_redis_connector_unregister_cb_not_topic_with_kwargs(connected_connector): + connector = connected_connector + + topic1 = EndpointInfo("topic1", TestMessage, MessageOp.SEND) + topic2 = EndpointInfo("topic2", TestMessage, MessageOp.SEND) + + received_event1 = mock.Mock(spec=[]) + received_event2 = mock.Mock(spec=[]) + + connector.register(topics=topic1, cb=received_event1, start_thread=False, a=1) + connector.register(topics=topic1, cb=received_event2, start_thread=False, b=2) + connector.register(topics=topic2, cb=received_event1, start_thread=False, c=3) + + connector.send(topic1, TestMessage(msg="topic1")) + connector.poll_messages(timeout=1) + received_event1.assert_called_once_with(MessageObject("topic1", TestMessage(msg="topic1")), a=1) + received_event2.assert_called_once_with(MessageObject("topic1", TestMessage(msg="topic1")), b=2) + + connector.unregister(topic1, cb=received_event1) + + received_event1.reset_mock() + received_event2.reset_mock() + connector.send(topic1, TestMessage(msg="topic1")) + connector.poll_messages(timeout=1) + received_event1.assert_not_called() + received_event2.assert_called_once_with(MessageObject("topic1", TestMessage(msg="topic1")), b=2) + + connector.send(topic2, TestMessage(msg="topic2")) + connector.poll_messages(timeout=1) + received_event1.assert_called_once_with(MessageObject("topic2", TestMessage(msg="topic2")), c=3) + assert list(connector._topics_cb.keys()) == ["topic1", "topic2"] + + +def test_redis_connector_unregister_only_specified_callback_kwargs(connected_connector): + connector = connected_connector + + topic1 = EndpointInfo("topic1", TestMessage, MessageOp.SEND) + received_event = mock.Mock(spec=[]) + + connector.register(topics=topic1, cb=received_event, start_thread=False, a=1) + connector.register(topics=topic1, cb=received_event, start_thread=False, a=2) + + connector.send(topic1, TestMessage(msg="topic1")) + connector.poll_messages(timeout=1) + received_event.assert_has_calls( + [ + mock.call(MessageObject("topic1", TestMessage(msg="topic1")), a=1), + mock.call(MessageObject("topic1", TestMessage(msg="topic1")), a=2), + ] + ) + + received_event.reset_mock() + connector.unregister(topic1, cb=received_event, a=1) + + connector.send(topic1, TestMessage(msg="topic1")) + connector.poll_messages(timeout=1) + received_event.assert_called_once_with(MessageObject("topic1", TestMessage(msg="topic1")), a=2) + assert len(connector._topics_cb["topic1"]) == 1 + + +def test_redis_connector_unregister_bound_method_with_kwargs(connected_connector): + connector = connected_connector + + topic1 = EndpointInfo("topic1", TestMessage, MessageOp.SEND) + + class Receiver: + def __init__(self): + self.received = mock.Mock(spec=[]) + + def on_message(self, msg_obj, scan_id): + self.received(msg_obj, scan_id=scan_id) + + receiver = Receiver() + + connector.register(topics=topic1, cb=receiver.on_message, start_thread=False, scan_id="scan-1") + + connector.send(topic1, TestMessage(msg="topic1")) + connector.poll_messages(timeout=1) + receiver.received.assert_called_once_with( + MessageObject("topic1", TestMessage(msg="topic1")), scan_id="scan-1" + ) + + receiver.received.reset_mock() + connector.unregister(topic1, cb=receiver.on_message, scan_id="scan-1") + + connector.send(topic1, TestMessage(msg="topic1")) + with pytest.raises(TimeoutError): + connector.poll_messages(timeout=1) + receiver.received.assert_not_called() + assert connector._redis_conn.execute_command("PUBSUB CHANNELS") == [] + assert len(connector._topics_cb) == 0 + + def test_redis_connector_unregister_topic_keeps_others_alive(connected_connector): def send_msgs_and_poll(timeout=None): connector.send(topic1, TestMessage()) @@ -223,6 +316,57 @@ def send_msgs_and_poll(timeout=None): assert received_event2.call_count == 2 +def test_redis_connector_unregister_all_callback_subscriptions_with_kwargs(connected_connector): + connector = connected_connector + + received_event1 = mock.Mock(spec=[]) + received_event2 = mock.Mock(spec=[]) + + connector.register(topics="topic1", cb=received_event1, start_thread=False, a=1) + connector.register(topics="topic2", cb=received_event1, start_thread=False, b=2) + connector.register(topics="topic2", cb=received_event2, start_thread=False, c=3) + + connector.unregister(cb=received_event1) + + connector.send("topic1", TestMessage(msg="topic1")) + connector.send("topic2", TestMessage(msg="topic2")) + connector.poll_messages(timeout=1) + + received_event1.assert_not_called() + received_event2.assert_called_once() + assert list(connector._topics_cb.keys()) == ["topic2"] + + +def test_redis_connector_unregister_same_callback_registered_with_multiple_kwargs( + connected_connector, +): + connector = connected_connector + + received_event = mock.Mock(spec=[]) + + connector.register(topics="topic1", cb=received_event, start_thread=False, a=1) + connector.register(topics="topic1", cb=received_event, start_thread=False, a=2) + + connector.send("topic1", TestMessage(msg="topic1")) + connector.poll_messages(timeout=1) + + received_event.assert_has_calls( + [ + mock.call(MessageObject("topic1", TestMessage(msg="topic1")), a=1), + mock.call(MessageObject("topic1", TestMessage(msg="topic1")), a=2), + ] + ) + assert received_event.call_count == 2 + + received_event.reset_mock() + connector.unregister("topic1", cb=received_event) + connector.send("topic1", TestMessage(msg="topic1")) + + assert received_event.call_count == 0 + assert connector._redis_conn.execute_command("PUBSUB CHANNELS") == [] + assert len(connector._topics_cb) == 0 + + def test_redis_register_poll_messages(connected_connector): connector = connected_connector cb_fcn_has_been_called = False From 7cc65024cc7136f6f05e02afd7cad224c1900965 Mon Sep 17 00:00:00 2001 From: Jan Wyzula <133381102+wyzula-jan@users.noreply.github.com> Date: Mon, 18 May 2026 15:35:20 +0200 Subject: [PATCH 2/2] fix: more robust kwarg match for unregister Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- bec_lib/bec_lib/redis_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bec_lib/bec_lib/redis_connector.py b/bec_lib/bec_lib/redis_connector.py index f824505b8..9cef328d3 100644 --- a/bec_lib/bec_lib/redis_connector.py +++ b/bec_lib/bec_lib/redis_connector.py @@ -1009,7 +1009,7 @@ def _matches_subscription( cb_ref, item_kwargs = item if cb is not None and cb_ref() != cb: return False - return all(item_kwargs.get(key) == value for key, value in kwargs.items()) + return all(key in item_kwargs and item_kwargs[key] == value for key, value in kwargs.items()) def _filter_topics_cb(self, topics: list, cb: Callable | None, kwargs: dict): unsubscribe_list = []