Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions src/tracksdata/array/_graph_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from tracksdata.graph._base_graph import BaseGraph
from tracksdata.options import get_options
from tracksdata.utils._dtypes import polars_dtype_to_numpy_dtype
from tracksdata.utils._signal import iter_node_added_events, iter_node_updated_events

if TYPE_CHECKING:
from tracksdata.nodes._mask import Mask
Expand Down Expand Up @@ -415,15 +416,24 @@ def _invalidate_from_attrs(self, attrs: dict) -> None:
if slices is not None:
self._cache.invalidate(time=time, volume_slicing=slices)

def _on_node_added(self, node_id: int, new_attrs: dict) -> None:
del node_id
self._invalidate_from_attrs(new_attrs)
def _on_node_added(
self,
node_id: int | Sequence[int],
new_attrs: dict | Sequence[dict],
) -> None:
for _, attrs in iter_node_added_events(node_id, new_attrs):
self._invalidate_from_attrs(attrs)

def _on_node_removed(self, node_id: int, old_attrs: dict) -> None:
del node_id
self._invalidate_from_attrs(old_attrs)

def _on_node_updated(self, node_id: int, old_attrs: dict, new_attrs: dict) -> None:
del node_id
self._invalidate_from_attrs(old_attrs)
self._invalidate_from_attrs(new_attrs)
def _on_node_updated(
self,
node_id: int | Sequence[int],
old_attrs: dict | Sequence[dict],
new_attrs: dict | Sequence[dict],
) -> None:
for _, old_attr, new_attr in iter_node_updated_events(node_id, old_attrs, new_attrs):
self._invalidate_from_attrs(old_attr)
self._invalidate_from_attrs(new_attr)
33 changes: 17 additions & 16 deletions src/tracksdata/graph/_graph_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
from tracksdata.graph._rustworkx_graph import IndexedRXGraph, RustWorkXGraph, RXFilter
from tracksdata.graph.filters._indexed_filter import IndexRXFilter
from tracksdata.utils._dtypes import AttrSchema
from tracksdata.utils._signal import is_signal_on
from tracksdata.utils._signal import (
emit_node_added_events,
emit_node_updated_events,
is_signal_on,
)


class GraphView(MappedGraphMixin, RustWorkXGraph):
Expand Down Expand Up @@ -411,20 +415,19 @@ def bulk_add_nodes(self, nodes: list[dict[str, Any]], indices: list[int] | None
with self._root.node_added.blocked():
parent_node_ids = self._root.bulk_add_nodes(nodes, indices=indices)

emitted_nodes = [
{key: value for key, value in node_attrs.items() if key != DEFAULT_ATTR_KEYS.NODE_ID}
for node_attrs in nodes
]
if self.sync:
with self.node_added.blocked():
node_ids = RustWorkXGraph.bulk_add_nodes(self, nodes)
node_ids = RustWorkXGraph.bulk_add_nodes(self, emitted_nodes)
self._add_id_mappings(list(zip(node_ids, parent_node_ids, strict=True)))
else:
self._out_of_sync = True

if is_signal_on(self._root.node_added):
for node_id, node_attrs in zip(parent_node_ids, nodes, strict=True):
self._root.node_added.emit(node_id, node_attrs)

if is_signal_on(self.node_added):
for node_id, node_attrs in zip(parent_node_ids, nodes, strict=True):
self.node_added.emit(node_id, node_attrs)
emit_node_added_events(self._root.node_added, zip(parent_node_ids, emitted_nodes, strict=True))
emit_node_added_events(self.node_added, zip(parent_node_ids, emitted_nodes, strict=True))

return parent_node_ids

Expand Down Expand Up @@ -684,13 +687,11 @@ def update_node_attrs(
self._out_of_sync = True

if is_signal_on(self.node_updated):
for node_id in node_ids:
old_attrs_by_id = cast(dict[int, dict[str, Any]], old_attrs_by_id) # for mypy
self.node_updated.emit(
node_id,
old_attrs_by_id[node_id],
self._root.nodes[node_id].to_dict(),
)
old_attrs_by_id = cast(dict[int, dict[str, Any]], old_attrs_by_id) # for mypy
emit_node_updated_events(
self.node_updated,
((node_id, old_attrs_by_id[node_id], self._root.nodes[node_id].to_dict()) for node_id in node_ids),
)

def update_edge_attrs(
self,
Expand Down
36 changes: 20 additions & 16 deletions src/tracksdata/graph/_rustworkx_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
from tracksdata.utils._dataframe import unpack_array_attrs
from tracksdata.utils._dtypes import AttrSchema, process_attr_key_args
from tracksdata.utils._logging import LOG
from tracksdata.utils._signal import is_signal_on
from tracksdata.utils._signal import (
emit_node_added_events,
emit_node_updated_events,
is_signal_on,
)

if TYPE_CHECKING:
from tracksdata.graph._graph_view import GraphView
Expand Down Expand Up @@ -522,10 +526,7 @@ def bulk_add_nodes(self, nodes: list[dict[str, Any]], indices: list[int] | None
for node, index in zip(nodes, node_indices, strict=True):
self._time_to_nodes.setdefault(node["t"], []).append(index)

# checking if it has connections to reduce overhead
if is_signal_on(self.node_added):
for node_id, node_attrs in zip(node_indices, nodes, strict=True):
self.node_added.emit(node_id, node_attrs)
emit_node_added_events(self.node_added, zip(node_indices, nodes, strict=True))

return node_indices

Expand Down Expand Up @@ -1213,6 +1214,8 @@ def update_node_attrs(
"""
if node_ids is None:
node_ids = self.node_ids()
else:
node_ids = list(node_ids)

if is_signal_on(self.node_updated):
old_attrs_by_id = {node_id: dict(self._graph[node_id]) for node_id in node_ids}
Expand All @@ -1232,8 +1235,10 @@ def update_node_attrs(
self._graph[node_id][key] = v

if is_signal_on(self.node_updated):
for node_id in node_ids:
self.node_updated.emit(node_id, old_attrs_by_id[node_id], dict(self._graph[node_id]))
emit_node_updated_events(
self.node_updated,
((node_id, old_attrs_by_id[node_id], dict(self._graph[node_id])) for node_id in node_ids),
)

def update_edge_attrs(
self,
Expand Down Expand Up @@ -1666,9 +1671,7 @@ def bulk_add_nodes(

self._add_id_mappings(list(zip(graph_ids, indices, strict=True)))

if is_signal_on(self.node_added):
for index, node_attrs in zip(indices, nodes, strict=True):
self.node_added.emit(index, node_attrs)
emit_node_added_events(self.node_added, zip(indices, nodes, strict=True))

return indices

Expand Down Expand Up @@ -1959,12 +1962,13 @@ def update_node_attrs(
super().update_node_attrs(attrs=attrs, node_ids=local_node_ids)

if is_signal_on(self.node_updated) and old_attrs_by_id is not None:
for external_node_id, local_node_id in zip(external_node_ids, local_node_ids, strict=True):
self.node_updated.emit(
external_node_id,
old_attrs_by_id[external_node_id],
dict(self._graph[local_node_id]),
)
emit_node_updated_events(
self.node_updated,
(
(external_node_id, old_attrs_by_id[external_node_id], dict(self._graph[local_node_id]))
for external_node_id, local_node_id in zip(external_node_ids, local_node_ids, strict=True)
),
)

def remove_node(self, node_id: int) -> None:
"""
Expand Down
31 changes: 16 additions & 15 deletions src/tracksdata/graph/_sql_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@
sqlalchemy_type_to_polars_dtype,
)
from tracksdata.utils._logging import LOG
from tracksdata.utils._signal import is_signal_on
from tracksdata.utils._signal import (
emit_node_added_events,
emit_node_updated_events,
is_signal_on,
)

if TYPE_CHECKING:
from tracksdata.graph._graph_view import GraphView
Expand Down Expand Up @@ -833,6 +837,7 @@ def bulk_add_nodes(
self._validate_indices_length(nodes, indices)

node_ids = []
insert_rows = []
for i, node in enumerate(nodes):
time = node["t"]

Expand All @@ -844,15 +849,15 @@ def bulk_add_nodes(
else:
node_id = indices[i]

node[DEFAULT_ATTR_KEYS.NODE_ID] = node_id
node_ids.append(node_id)
insert_rows.append({**node, DEFAULT_ATTR_KEYS.NODE_ID: node_id})

self._chunked_sa_write(Session.bulk_insert_mappings, nodes, self.Node)
self._chunked_sa_write(Session.bulk_insert_mappings, insert_rows, self.Node)

if is_signal_on(self.node_added):
for node_id, node_attrs in zip(node_ids, nodes, strict=True):
new_attrs = {key: value for key, value in node_attrs.items() if key != DEFAULT_ATTR_KEYS.NODE_ID}
self.node_added.emit(node_id, new_attrs)
emit_node_added_events(
self.node_added,
zip(node_ids, nodes, strict=True),
)

return node_ids

Expand Down Expand Up @@ -1920,12 +1925,6 @@ def update_node_attrs(
)

self._update_table(self.Node, node_ids, DEFAULT_ATTR_KEYS.NODE_ID, attrs)
new_df = self.filter(node_ids=updated_node_ids).node_attrs(attr_keys=[DEFAULT_ATTR_KEYS.NODE_ID, *attr_keys])
new_attrs_by_id = new_df.rows_by_key(key=DEFAULT_ATTR_KEYS.NODE_ID, named=True, unique=True, include_key=True)

if is_signal_on(self.node_updated):
for node_id in updated_node_ids:
self.node_updated.emit(node_id, old_attrs_by_id[node_id], new_attrs_by_id[node_id])

if is_signal_on(self.node_updated):
new_df = self.filter(node_ids=updated_node_ids).node_attrs(
Expand All @@ -1934,8 +1933,10 @@ def update_node_attrs(
new_attrs_by_id = new_df.rows_by_key(
key=DEFAULT_ATTR_KEYS.NODE_ID, named=True, unique=True, include_key=True
)
for node_id in updated_node_ids:
self.node_updated.emit(node_id, old_attrs_by_id[node_id], new_attrs_by_id[node_id])
emit_node_updated_events(
self.node_updated,
((node_id, old_attrs_by_id[node_id], new_attrs_by_id[node_id]) for node_id in updated_node_ids),
)

def update_edge_attrs(
self,
Expand Down
32 changes: 32 additions & 0 deletions src/tracksdata/graph/_test/test_graph_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,38 @@ def test_update_node_attrs(graph_backend: BaseGraph) -> None:
graph_backend.update_node_attrs(node_ids=[node_1, node_2], attrs={"x": [1.0]})


def test_bulk_add_nodes_emits_batched_node_added_callback(graph_backend: BaseGraph) -> None:
graph_backend.add_node_attr_key("x", pl.Float64)

calls: list[tuple[Any, Any]] = []
graph_backend.node_added.connect(lambda node_ids, attrs: calls.append((node_ids, attrs)))

nodes = [{"t": 0, "x": 1.0}, {"t": 1, "x": 2.0}, {"t": 1, "x": 3.0}]
node_ids = graph_backend.bulk_add_nodes(nodes)

assert len(calls) == 1
assert calls[0][0] == node_ids
assert calls[0][1] == nodes


def test_update_node_attrs_emits_batched_node_updated_callback(graph_backend: BaseGraph) -> None:
graph_backend.add_node_attr_key("x", pl.Float64)

node_ids = graph_backend.bulk_add_nodes([{"t": 0, "x": 1.0}, {"t": 1, "x": 2.0}, {"t": 1, "x": 3.0}])

calls: list[tuple[Any, Any, Any]] = []
graph_backend.node_updated.connect(
lambda node_ids, old_attrs, new_attrs: calls.append((node_ids, old_attrs, new_attrs))
)

graph_backend.update_node_attrs(node_ids=node_ids, attrs={"x": [10.0, 20.0, 30.0]})

assert len(calls) == 1
assert calls[0][0] == node_ids
assert [attrs["x"] for attrs in calls[0][1]] == [1.0, 2.0, 3.0]
assert [attrs["x"] for attrs in calls[0][2]] == [10.0, 20.0, 30.0]


def test_update_edge_attrs(graph_backend: BaseGraph) -> None:
"""Test updating edge attributes."""
node1 = graph_backend.add_node({"t": 0})
Expand Down
37 changes: 37 additions & 0 deletions src/tracksdata/graph/_test/test_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,43 @@ def test_subgraph_add_node(graph_backend: BaseGraph) -> None:
assert attributes["label"].to_list()[0] == "NEW"


def test_subgraph_bulk_add_nodes_emits_batched_node_added_callbacks(graph_backend: BaseGraph) -> None:
graph_with_data = create_test_graph(graph_backend, use_subgraph=False)
subgraph = graph_with_data.filter(node_ids=graph_with_data._test_nodes[:2]).subgraph() # type: ignore

root_calls: list[tuple[object, object]] = []
subgraph_calls: list[tuple[object, object]] = []
graph_with_data.node_added.connect(lambda node_ids, attrs: root_calls.append((node_ids, attrs)))
subgraph.node_added.connect(lambda node_ids, attrs: subgraph_calls.append((node_ids, attrs)))

nodes = [
{"t": 10, "x": 10.0, "y": 10.0, "label": "A"},
{"t": 11, "x": 11.0, "y": 11.0, "label": "B"},
]
node_ids = subgraph.bulk_add_nodes(nodes)

assert len(root_calls) == 1
assert len(subgraph_calls) == 1
assert root_calls[0] == (node_ids, nodes)
assert subgraph_calls[0] == (node_ids, nodes)


def test_subgraph_update_node_attrs_emits_batched_node_updated_callback(graph_backend: BaseGraph) -> None:
graph_with_data = create_test_graph(graph_backend, use_subgraph=False)
subgraph = graph_with_data.filter(node_ids=graph_with_data._test_nodes[:3]).subgraph() # type: ignore
node_ids = graph_with_data._test_nodes[:2] # type: ignore

calls: list[tuple[object, object, object]] = []
subgraph.node_updated.connect(lambda node_ids, old_attrs, new_attrs: calls.append((node_ids, old_attrs, new_attrs)))

subgraph.update_node_attrs(node_ids=node_ids, attrs={"x": [10.0, 20.0]})

assert len(calls) == 1
assert calls[0][0] == node_ids
assert [attrs["x"] for attrs in calls[0][1]] == [0.0, 1.0]
assert [attrs["x"] for attrs in calls[0][2]] == [10.0, 20.0]


def test_subgraph_add_edge(graph_backend: BaseGraph) -> None:
"""Test adding edges to a subgraph."""
graph_with_data = create_test_graph(graph_backend, use_subgraph=False)
Expand Down
Loading
Loading