diff --git a/src/tracksdata/array/_graph_array.py b/src/tracksdata/array/_graph_array.py index 6e88828d..8a148b1f 100644 --- a/src/tracksdata/array/_graph_array.py +++ b/src/tracksdata/array/_graph_array.py @@ -10,6 +10,7 @@ from tracksdata.graph._base_graph import BaseGraph from tracksdata.options import get_options from tracksdata.utils._dtypes import polars_dtype_to_numpy_dtype +from tracksdata.utils._signal import iter_node_added_events, iter_node_updated_events if TYPE_CHECKING: from tracksdata.nodes._mask import Mask @@ -415,15 +416,24 @@ def _invalidate_from_attrs(self, attrs: dict) -> None: if slices is not None: self._cache.invalidate(time=time, volume_slicing=slices) - def _on_node_added(self, node_id: int, new_attrs: dict) -> None: - del node_id - self._invalidate_from_attrs(new_attrs) + def _on_node_added( + self, + node_id: int | Sequence[int], + new_attrs: dict | Sequence[dict], + ) -> None: + for _, attrs in iter_node_added_events(node_id, new_attrs): + self._invalidate_from_attrs(attrs) def _on_node_removed(self, node_id: int, old_attrs: dict) -> None: del node_id self._invalidate_from_attrs(old_attrs) - def _on_node_updated(self, node_id: int, old_attrs: dict, new_attrs: dict) -> None: - del node_id - self._invalidate_from_attrs(old_attrs) - self._invalidate_from_attrs(new_attrs) + def _on_node_updated( + self, + node_id: int | Sequence[int], + old_attrs: dict | Sequence[dict], + new_attrs: dict | Sequence[dict], + ) -> None: + for _, old_attr, new_attr in iter_node_updated_events(node_id, old_attrs, new_attrs): + self._invalidate_from_attrs(old_attr) + self._invalidate_from_attrs(new_attr) diff --git a/src/tracksdata/graph/_graph_view.py b/src/tracksdata/graph/_graph_view.py index 5419f7a7..a0f44f0c 100644 --- a/src/tracksdata/graph/_graph_view.py +++ b/src/tracksdata/graph/_graph_view.py @@ -12,7 +12,11 @@ from tracksdata.graph._rustworkx_graph import IndexedRXGraph, RustWorkXGraph, RXFilter from tracksdata.graph.filters._indexed_filter import IndexRXFilter from tracksdata.utils._dtypes import AttrSchema -from tracksdata.utils._signal import is_signal_on +from tracksdata.utils._signal import ( + emit_node_added_events, + emit_node_updated_events, + is_signal_on, +) class GraphView(MappedGraphMixin, RustWorkXGraph): @@ -411,20 +415,19 @@ def bulk_add_nodes(self, nodes: list[dict[str, Any]], indices: list[int] | None with self._root.node_added.blocked(): parent_node_ids = self._root.bulk_add_nodes(nodes, indices=indices) + emitted_nodes = [ + {key: value for key, value in node_attrs.items() if key != DEFAULT_ATTR_KEYS.NODE_ID} + for node_attrs in nodes + ] if self.sync: with self.node_added.blocked(): - node_ids = RustWorkXGraph.bulk_add_nodes(self, nodes) + node_ids = RustWorkXGraph.bulk_add_nodes(self, emitted_nodes) self._add_id_mappings(list(zip(node_ids, parent_node_ids, strict=True))) else: self._out_of_sync = True - if is_signal_on(self._root.node_added): - for node_id, node_attrs in zip(parent_node_ids, nodes, strict=True): - self._root.node_added.emit(node_id, node_attrs) - - if is_signal_on(self.node_added): - for node_id, node_attrs in zip(parent_node_ids, nodes, strict=True): - self.node_added.emit(node_id, node_attrs) + emit_node_added_events(self._root.node_added, zip(parent_node_ids, emitted_nodes, strict=True)) + emit_node_added_events(self.node_added, zip(parent_node_ids, emitted_nodes, strict=True)) return parent_node_ids @@ -684,13 +687,11 @@ def update_node_attrs( self._out_of_sync = True if is_signal_on(self.node_updated): - for node_id in node_ids: - old_attrs_by_id = cast(dict[int, dict[str, Any]], old_attrs_by_id) # for mypy - self.node_updated.emit( - node_id, - old_attrs_by_id[node_id], - self._root.nodes[node_id].to_dict(), - ) + old_attrs_by_id = cast(dict[int, dict[str, Any]], old_attrs_by_id) # for mypy + emit_node_updated_events( + self.node_updated, + ((node_id, old_attrs_by_id[node_id], self._root.nodes[node_id].to_dict()) for node_id in node_ids), + ) def update_edge_attrs( self, diff --git a/src/tracksdata/graph/_rustworkx_graph.py b/src/tracksdata/graph/_rustworkx_graph.py index 56770004..4dea094b 100644 --- a/src/tracksdata/graph/_rustworkx_graph.py +++ b/src/tracksdata/graph/_rustworkx_graph.py @@ -16,7 +16,11 @@ from tracksdata.utils._dataframe import unpack_array_attrs from tracksdata.utils._dtypes import AttrSchema, process_attr_key_args from tracksdata.utils._logging import LOG -from tracksdata.utils._signal import is_signal_on +from tracksdata.utils._signal import ( + emit_node_added_events, + emit_node_updated_events, + is_signal_on, +) if TYPE_CHECKING: from tracksdata.graph._graph_view import GraphView @@ -522,10 +526,7 @@ def bulk_add_nodes(self, nodes: list[dict[str, Any]], indices: list[int] | None for node, index in zip(nodes, node_indices, strict=True): self._time_to_nodes.setdefault(node["t"], []).append(index) - # checking if it has connections to reduce overhead - if is_signal_on(self.node_added): - for node_id, node_attrs in zip(node_indices, nodes, strict=True): - self.node_added.emit(node_id, node_attrs) + emit_node_added_events(self.node_added, zip(node_indices, nodes, strict=True)) return node_indices @@ -1213,6 +1214,8 @@ def update_node_attrs( """ if node_ids is None: node_ids = self.node_ids() + else: + node_ids = list(node_ids) if is_signal_on(self.node_updated): old_attrs_by_id = {node_id: dict(self._graph[node_id]) for node_id in node_ids} @@ -1232,8 +1235,10 @@ def update_node_attrs( self._graph[node_id][key] = v if is_signal_on(self.node_updated): - for node_id in node_ids: - self.node_updated.emit(node_id, old_attrs_by_id[node_id], dict(self._graph[node_id])) + emit_node_updated_events( + self.node_updated, + ((node_id, old_attrs_by_id[node_id], dict(self._graph[node_id])) for node_id in node_ids), + ) def update_edge_attrs( self, @@ -1666,9 +1671,7 @@ def bulk_add_nodes( self._add_id_mappings(list(zip(graph_ids, indices, strict=True))) - if is_signal_on(self.node_added): - for index, node_attrs in zip(indices, nodes, strict=True): - self.node_added.emit(index, node_attrs) + emit_node_added_events(self.node_added, zip(indices, nodes, strict=True)) return indices @@ -1959,12 +1962,13 @@ def update_node_attrs( super().update_node_attrs(attrs=attrs, node_ids=local_node_ids) if is_signal_on(self.node_updated) and old_attrs_by_id is not None: - for external_node_id, local_node_id in zip(external_node_ids, local_node_ids, strict=True): - self.node_updated.emit( - external_node_id, - old_attrs_by_id[external_node_id], - dict(self._graph[local_node_id]), - ) + emit_node_updated_events( + self.node_updated, + ( + (external_node_id, old_attrs_by_id[external_node_id], dict(self._graph[local_node_id])) + for external_node_id, local_node_id in zip(external_node_ids, local_node_ids, strict=True) + ), + ) def remove_node(self, node_id: int) -> None: """ diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index a0ced553..92591770 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -28,7 +28,11 @@ sqlalchemy_type_to_polars_dtype, ) from tracksdata.utils._logging import LOG -from tracksdata.utils._signal import is_signal_on +from tracksdata.utils._signal import ( + emit_node_added_events, + emit_node_updated_events, + is_signal_on, +) if TYPE_CHECKING: from tracksdata.graph._graph_view import GraphView @@ -833,6 +837,7 @@ def bulk_add_nodes( self._validate_indices_length(nodes, indices) node_ids = [] + insert_rows = [] for i, node in enumerate(nodes): time = node["t"] @@ -844,15 +849,15 @@ def bulk_add_nodes( else: node_id = indices[i] - node[DEFAULT_ATTR_KEYS.NODE_ID] = node_id node_ids.append(node_id) + insert_rows.append({**node, DEFAULT_ATTR_KEYS.NODE_ID: node_id}) - self._chunked_sa_write(Session.bulk_insert_mappings, nodes, self.Node) + self._chunked_sa_write(Session.bulk_insert_mappings, insert_rows, self.Node) - if is_signal_on(self.node_added): - for node_id, node_attrs in zip(node_ids, nodes, strict=True): - new_attrs = {key: value for key, value in node_attrs.items() if key != DEFAULT_ATTR_KEYS.NODE_ID} - self.node_added.emit(node_id, new_attrs) + emit_node_added_events( + self.node_added, + zip(node_ids, nodes, strict=True), + ) return node_ids @@ -1920,12 +1925,6 @@ def update_node_attrs( ) self._update_table(self.Node, node_ids, DEFAULT_ATTR_KEYS.NODE_ID, attrs) - new_df = self.filter(node_ids=updated_node_ids).node_attrs(attr_keys=[DEFAULT_ATTR_KEYS.NODE_ID, *attr_keys]) - new_attrs_by_id = new_df.rows_by_key(key=DEFAULT_ATTR_KEYS.NODE_ID, named=True, unique=True, include_key=True) - - if is_signal_on(self.node_updated): - for node_id in updated_node_ids: - self.node_updated.emit(node_id, old_attrs_by_id[node_id], new_attrs_by_id[node_id]) if is_signal_on(self.node_updated): new_df = self.filter(node_ids=updated_node_ids).node_attrs( @@ -1934,8 +1933,10 @@ def update_node_attrs( new_attrs_by_id = new_df.rows_by_key( key=DEFAULT_ATTR_KEYS.NODE_ID, named=True, unique=True, include_key=True ) - for node_id in updated_node_ids: - self.node_updated.emit(node_id, old_attrs_by_id[node_id], new_attrs_by_id[node_id]) + emit_node_updated_events( + self.node_updated, + ((node_id, old_attrs_by_id[node_id], new_attrs_by_id[node_id]) for node_id in updated_node_ids), + ) def update_edge_attrs( self, diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index c296aee5..305216da 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -518,6 +518,38 @@ def test_update_node_attrs(graph_backend: BaseGraph) -> None: graph_backend.update_node_attrs(node_ids=[node_1, node_2], attrs={"x": [1.0]}) +def test_bulk_add_nodes_emits_batched_node_added_callback(graph_backend: BaseGraph) -> None: + graph_backend.add_node_attr_key("x", pl.Float64) + + calls: list[tuple[Any, Any]] = [] + graph_backend.node_added.connect(lambda node_ids, attrs: calls.append((node_ids, attrs))) + + nodes = [{"t": 0, "x": 1.0}, {"t": 1, "x": 2.0}, {"t": 1, "x": 3.0}] + node_ids = graph_backend.bulk_add_nodes(nodes) + + assert len(calls) == 1 + assert calls[0][0] == node_ids + assert calls[0][1] == nodes + + +def test_update_node_attrs_emits_batched_node_updated_callback(graph_backend: BaseGraph) -> None: + graph_backend.add_node_attr_key("x", pl.Float64) + + node_ids = graph_backend.bulk_add_nodes([{"t": 0, "x": 1.0}, {"t": 1, "x": 2.0}, {"t": 1, "x": 3.0}]) + + calls: list[tuple[Any, Any, Any]] = [] + graph_backend.node_updated.connect( + lambda node_ids, old_attrs, new_attrs: calls.append((node_ids, old_attrs, new_attrs)) + ) + + graph_backend.update_node_attrs(node_ids=node_ids, attrs={"x": [10.0, 20.0, 30.0]}) + + assert len(calls) == 1 + assert calls[0][0] == node_ids + assert [attrs["x"] for attrs in calls[0][1]] == [1.0, 2.0, 3.0] + assert [attrs["x"] for attrs in calls[0][2]] == [10.0, 20.0, 30.0] + + def test_update_edge_attrs(graph_backend: BaseGraph) -> None: """Test updating edge attributes.""" node1 = graph_backend.add_node({"t": 0}) diff --git a/src/tracksdata/graph/_test/test_subgraph.py b/src/tracksdata/graph/_test/test_subgraph.py index 8664620a..d57b17be 100644 --- a/src/tracksdata/graph/_test/test_subgraph.py +++ b/src/tracksdata/graph/_test/test_subgraph.py @@ -473,6 +473,43 @@ def test_subgraph_add_node(graph_backend: BaseGraph) -> None: assert attributes["label"].to_list()[0] == "NEW" +def test_subgraph_bulk_add_nodes_emits_batched_node_added_callbacks(graph_backend: BaseGraph) -> None: + graph_with_data = create_test_graph(graph_backend, use_subgraph=False) + subgraph = graph_with_data.filter(node_ids=graph_with_data._test_nodes[:2]).subgraph() # type: ignore + + root_calls: list[tuple[object, object]] = [] + subgraph_calls: list[tuple[object, object]] = [] + graph_with_data.node_added.connect(lambda node_ids, attrs: root_calls.append((node_ids, attrs))) + subgraph.node_added.connect(lambda node_ids, attrs: subgraph_calls.append((node_ids, attrs))) + + nodes = [ + {"t": 10, "x": 10.0, "y": 10.0, "label": "A"}, + {"t": 11, "x": 11.0, "y": 11.0, "label": "B"}, + ] + node_ids = subgraph.bulk_add_nodes(nodes) + + assert len(root_calls) == 1 + assert len(subgraph_calls) == 1 + assert root_calls[0] == (node_ids, nodes) + assert subgraph_calls[0] == (node_ids, nodes) + + +def test_subgraph_update_node_attrs_emits_batched_node_updated_callback(graph_backend: BaseGraph) -> None: + graph_with_data = create_test_graph(graph_backend, use_subgraph=False) + subgraph = graph_with_data.filter(node_ids=graph_with_data._test_nodes[:3]).subgraph() # type: ignore + node_ids = graph_with_data._test_nodes[:2] # type: ignore + + calls: list[tuple[object, object, object]] = [] + subgraph.node_updated.connect(lambda node_ids, old_attrs, new_attrs: calls.append((node_ids, old_attrs, new_attrs))) + + subgraph.update_node_attrs(node_ids=node_ids, attrs={"x": [10.0, 20.0]}) + + assert len(calls) == 1 + assert calls[0][0] == node_ids + assert [attrs["x"] for attrs in calls[0][1]] == [0.0, 1.0] + assert [attrs["x"] for attrs in calls[0][2]] == [10.0, 20.0] + + def test_subgraph_add_edge(graph_backend: BaseGraph) -> None: """Test adding edges to a subgraph.""" graph_with_data = create_test_graph(graph_backend, use_subgraph=False) diff --git a/src/tracksdata/graph/filters/_spatial_filter.py b/src/tracksdata/graph/filters/_spatial_filter.py index e13e1711..9e85fa66 100644 --- a/src/tracksdata/graph/filters/_spatial_filter.py +++ b/src/tracksdata/graph/filters/_spatial_filter.py @@ -6,6 +6,7 @@ from tracksdata.constants import DEFAULT_ATTR_KEYS from tracksdata.utils._logging import LOG +from tracksdata.utils._signal import iter_node_added_events, iter_node_updated_events if TYPE_CHECKING: from tracksdata.graph._base_graph import BaseGraph @@ -205,24 +206,25 @@ def _attrs_to_point(self, attrs: dict[str, Any]) -> np.ndarray: def _add_node( self, - node_id: int, - new_attrs: dict[str, Any], + node_id: int | list[int], + new_attrs: dict[str, Any] | list[dict[str, Any]], ) -> None: from spatial_graph import PointRTree - if self._df_filter._node_rtree is None: - self._df_filter._node_rtree = PointRTree( - item_dtype="int64", - coord_dtype="float32", - dims=len(self._attr_keys), - ) - self._df_filter._ndims = len(self._attr_keys) + for event_node_id, event_attrs in iter_node_added_events(node_id, new_attrs): + if self._df_filter._node_rtree is None: + self._df_filter._node_rtree = PointRTree( + item_dtype="int64", + coord_dtype="float32", + dims=len(self._attr_keys), + ) + self._df_filter._ndims = len(self._attr_keys) - positions = self._attrs_to_point(new_attrs) - self._df_filter._node_rtree.insert_point_items( - np.atleast_1d(node_id).astype(np.int64), - positions, - ) + positions = self._attrs_to_point(event_attrs) + self._df_filter._node_rtree.insert_point_items( + np.atleast_1d(event_node_id).astype(np.int64), + positions, + ) def _remove_node( self, @@ -241,12 +243,13 @@ def _remove_node( def _update_node( self, - node_id: int, - old_attrs: dict[str, Any], - new_attrs: dict[str, Any], + node_id: int | list[int], + old_attrs: dict[str, Any] | list[dict[str, Any]], + new_attrs: dict[str, Any] | list[dict[str, Any]], ) -> None: - self._remove_node(node_id, old_attrs=old_attrs) - self._add_node(node_id, new_attrs=new_attrs) + for event_node_id, event_old_attrs, event_new_attrs in iter_node_updated_events(node_id, old_attrs, new_attrs): + self._remove_node(event_node_id, old_attrs=event_old_attrs) + self._add_node(event_node_id, new_attrs=event_new_attrs) class BBoxSpatialFilter: @@ -414,8 +417,8 @@ def _attrs_to_bb_window(self, attrs: dict[str, Any]) -> tuple[np.ndarray, np.nda def _add_node( self, - node_id: int, - new_attrs: dict[str, Any], + node_id: int | list[int], + new_attrs: dict[str, Any] | list[dict[str, Any]], ) -> None: """ Add a node to the spatial filter. @@ -429,29 +432,30 @@ def _add_node( """ from spatial_graph import PointRTree - if self._node_rtree is None: - bbox = new_attrs[self._bbox_attr_key] - if len(bbox) % 2 != 0: - raise ValueError(f"Bounding box coordinates must have even number of dimensions, got {len(bbox)}") - num_dims = len(bbox) // 2 - if self._frame_attr_key is None: - self._ndims = num_dims - else: - self._ndims = num_dims + 1 # +1 for the frame dimension - - self._node_rtree = PointRTree( - item_dtype="int64", - coord_dtype="float32", - dims=self._ndims, - ) + for event_node_id, event_attrs in iter_node_added_events(node_id, new_attrs): + if self._node_rtree is None: + bbox = event_attrs[self._bbox_attr_key] + if len(bbox) % 2 != 0: + raise ValueError(f"Bounding box coordinates must have even number of dimensions, got {len(bbox)}") + num_dims = len(bbox) // 2 + if self._frame_attr_key is None: + self._ndims = num_dims + else: + self._ndims = num_dims + 1 # +1 for the frame dimension + + self._node_rtree = PointRTree( + item_dtype="int64", + coord_dtype="float32", + dims=self._ndims, + ) - positions_min, positions_max = self._attrs_to_bb_window(new_attrs) + positions_min, positions_max = self._attrs_to_bb_window(event_attrs) - self._node_rtree.insert_bb_items( - np.atleast_1d(node_id).astype(np.int64), - positions_min, - positions_max, - ) + self._node_rtree.insert_bb_items( + np.atleast_1d(event_node_id).astype(np.int64), + positions_min, + positions_max, + ) def _remove_node( self, @@ -481,12 +485,13 @@ def _remove_node( def _update_node( self, - node_id: int, - old_attrs: dict[str, Any], - new_attrs: dict[str, Any], + node_id: int | list[int], + old_attrs: dict[str, Any] | list[dict[str, Any]], + new_attrs: dict[str, Any] | list[dict[str, Any]], ) -> None: - self._remove_node(node_id, old_attrs=old_attrs) - self._add_node(node_id, new_attrs=new_attrs) + for event_node_id, event_old_attrs, event_new_attrs in iter_node_updated_events(node_id, old_attrs, new_attrs): + self._remove_node(event_node_id, old_attrs=event_old_attrs) + self._add_node(event_node_id, new_attrs=event_new_attrs) @staticmethod def _bboxes_to_array(bbox_series: pl.Series) -> np.ndarray: diff --git a/src/tracksdata/utils/_signal.py b/src/tracksdata/utils/_signal.py index ee4c431e..70af9ecd 100644 --- a/src/tracksdata/utils/_signal.py +++ b/src/tracksdata/utils/_signal.py @@ -1,6 +1,96 @@ +from collections.abc import Iterable, Iterator, Sequence +from typing import Any + from psygnal import Signal, SignalInstance +def _is_batched(value: object) -> bool: + return isinstance(value, Sequence) and not isinstance(value, str | bytes | bytearray | dict) + + +def reduce_node_added_events( + event_args: Iterable[tuple[int, dict[str, Any]]], +) -> tuple[int | list[int], dict[str, Any] | list[dict[str, Any]]]: + events = list(event_args) + if len(events) == 1: + return events[0] + + node_ids, attrs = zip(*events, strict=True) + return list(node_ids), list(attrs) + + +def reduce_node_updated_events( + event_args: Iterable[tuple[int, dict[str, Any], dict[str, Any]]], +) -> tuple[int | list[int], dict[str, Any] | list[dict[str, Any]], dict[str, Any] | list[dict[str, Any]]]: + events = list(event_args) + if len(events) == 1: + return events[0] + + node_ids, old_attrs, new_attrs = zip(*events, strict=True) + return list(node_ids), list(old_attrs), list(new_attrs) + + +def emit_node_added_events( + sig: Signal | SignalInstance, + event_args: Iterable[tuple[int, dict[str, Any]]], +) -> None: + events = list(event_args) + if len(events) == 0 or not is_signal_on(sig): + return + + with sig.paused(reduce_node_added_events): + for node_id, attrs in events: + sig.emit(node_id, attrs) + + +def emit_node_updated_events( + sig: Signal | SignalInstance, + event_args: Iterable[tuple[int, dict[str, Any], dict[str, Any]]], +) -> None: + events = list(event_args) + if len(events) == 0 or not is_signal_on(sig): + return + + with sig.paused(reduce_node_updated_events): + for node_id, old_attrs, new_attrs in events: + sig.emit(node_id, old_attrs, new_attrs) + + +def iter_node_added_events( + node_ids: int | Sequence[int], + attrs: dict[str, Any] | Sequence[dict[str, Any]], +) -> Iterator[tuple[int, dict[str, Any]]]: + if _is_batched(node_ids): + if not _is_batched(attrs): + raise TypeError("Expected a sequence of node attributes for batched node_added events.") + + yield from zip(node_ids, attrs, strict=True) + return + + if _is_batched(attrs): + raise TypeError("Expected a single node attributes dict for node_added events.") + + yield node_ids, attrs + + +def iter_node_updated_events( + node_ids: int | Sequence[int], + old_attrs: dict[str, Any] | Sequence[dict[str, Any]], + new_attrs: dict[str, Any] | Sequence[dict[str, Any]], +) -> Iterator[tuple[int, dict[str, Any], dict[str, Any]]]: + if _is_batched(node_ids): + if not _is_batched(old_attrs) or not _is_batched(new_attrs): + raise TypeError("Expected sequences of node attribute dicts for batched node_updated events.") + + yield from zip(node_ids, old_attrs, new_attrs, strict=True) + return + + if _is_batched(old_attrs) or _is_batched(new_attrs): + raise TypeError("Expected single node attribute dicts for node_updated events.") + + yield node_ids, old_attrs, new_attrs + + def is_signal_on(sig: Signal | SignalInstance) -> bool: """Check if a signal is connected and not blocked.""" return len(sig._slots) > 0 and not sig._is_blocked