Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions astroml/features/graph/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from datetime import datetime, timedelta, timezone
from typing import Generator, Iterable, Iterator, List, Optional, Sequence, Set, Tuple
import bisect
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait


# Issue #199 — default chunk size for the streaming graph builder. SQLAlchemy
Expand Down Expand Up @@ -227,13 +228,66 @@ def _edges_iter(_result=result) -> Iterator[Edge]:
index += 1


def _build_snapshot_window(
index: int,
window_start: datetime,
window_end: datetime,
chunk_size: int,
) -> SnapshotWindow:
"""Build a single snapshot window from the database."""
from astroml.db.schema import NormalizedTransaction
from astroml.db.session import get_session
from sqlalchemy import select

session = get_session()
try:
result = session.execute(
select(
NormalizedTransaction.sender,
NormalizedTransaction.receiver,
NormalizedTransaction.timestamp,
)
.where(
NormalizedTransaction.timestamp >= window_start,
NormalizedTransaction.timestamp <= window_end,
NormalizedTransaction.receiver.isnot(None),
NormalizedTransaction.sender != NormalizedTransaction.receiver,
)
.order_by(NormalizedTransaction.timestamp)
)

edges: List[Edge] = []
nodes: Set[str] = set()

for row in result.yield_per(chunk_size):
edge = Edge(
src=row.sender,
dst=row.receiver,
timestamp=int(row.timestamp.timestamp()),
)
edges.append(edge)
nodes.add(edge.src)
nodes.add(edge.dst)

return SnapshotWindow(
index=index,
start=window_start,
end=window_end,
edges=edges,
nodes=nodes,
)
finally:
session.close()


def iter_db_snapshots(
window: str = "7d",
t0: Optional[datetime] = None,
t_now: Optional[datetime] = None,
step: Optional[str] = None,
session=None,
chunk_size: int = 100_000,
workers: int = 1,
) -> Generator[SnapshotWindow, None, None]:
"""Yield discrete time-windowed graph snapshots from the database.

Expand All @@ -250,13 +304,16 @@ def iter_db_snapshots(
chunk_size: Number of rows to stream per fetch from the DB. Larger values
reduce round-trips but increase peak memory; smaller values keep the
working set bounded for long-window snapshots.
workers: Number of concurrent window fetch workers. Set to >1 to prefetch
windows in parallel when using the default session factory.

Yields:
:class:`SnapshotWindow` instances in chronological order.
"""
from astroml.db.schema import NormalizedTransaction
from sqlalchemy import select, func as sqlfunc

session_provided = session is not None
if session is None:
from astroml.db.session import get_session
session = get_session()
Expand All @@ -275,6 +332,7 @@ def iter_db_snapshots(
select(sqlfunc.min(NormalizedTransaction.timestamp))
).scalar()
if result is None:
session.close()
return # empty DB
t0 = result if result.tzinfo else result.replace(tzinfo=timezone.utc)

Expand All @@ -286,6 +344,46 @@ def iter_db_snapshots(
window_start = t0
index = 0

if workers > 1 and not session_provided:
session.close()

pending_windows: Dict[int, SnapshotWindow] = {}
futures: Dict[int, "concurrent.futures.Future[SnapshotWindow]"] = {}
next_index_to_yield = 0

with ThreadPoolExecutor(max_workers=workers) as executor:
while window_start < t_now or futures:
while window_start < t_now and len(futures) < workers:
window_end = min(window_start + win_delta, t_now)
future = executor.submit(
_build_snapshot_window,
index,
window_start,
window_end,
chunk_size,
)
futures[index] = future
window_start += step_delta
index += 1

if not futures:
break

done, _ = wait(set(futures.values()), return_when=FIRST_COMPLETED)
for future in done:
result_window = future.result()
pending_windows[result_window.index] = result_window
future_index = next(
idx for idx, fut in futures.items() if fut is future
)
del futures[future_index]

while next_index_to_yield in pending_windows:
yield pending_windows.pop(next_index_to_yield)
next_index_to_yield += 1

return

while window_start < t_now:
window_end = min(window_start + win_delta, t_now)

Expand Down
25 changes: 25 additions & 0 deletions examples/01_getting_started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,31 @@
"5. Running a baseline GCN model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "951c94dc",
"metadata": {},
"outputs": [],
"source": [
"# Example notebook dependency setup\n",
"# Run these commands from the repository root before executing this notebook:\n",
"# pip install -r requirements.txt\n",
"# pip install -e .\n",
"#\n",
"# If this notebook is opened from a different working directory, add the repository root to sys.path.\n",
"import os\n",
"import sys\n",
"\n",
"repo_root = os.path.abspath(os.path.join(os.getcwd(), \"..\"))\n",
"if repo_root not in sys.path:\n",
" sys.path.insert(0, repo_root)\n",
"\n",
"print(\"Repository root candidate:\", repo_root)\n",
"print(\"Current Python executable:\", sys.executable)\n",
"print(\"First sys.path entry:\", sys.path[0])\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
25 changes: 25 additions & 0 deletions examples/02_fraud_detection.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,31 @@
"- Wash trading loops (circular value transfer)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f6a98ff8",
"metadata": {},
"outputs": [],
"source": [
"# Example notebook dependency setup\n",
"# Run these commands from the repository root before executing this notebook:\n",
"# pip install -r requirements.txt\n",
"# pip install -e .\n",
"#\n",
"# If this notebook is opened from a different working directory, add the repository root to sys.path.\n",
"import os\n",
"import sys\n",
"\n",
"repo_root = os.path.abspath(os.path.join(os.getcwd(), \"..\"))\n",
"if repo_root not in sys.path:\n",
" sys.path.insert(0, repo_root)\n",
"\n",
"print(\"Repository root candidate:\", repo_root)\n",
"print(\"Current Python executable:\", sys.executable)\n",
"print(\"First sys.path entry:\", sys.path[0])\n"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
25 changes: 25 additions & 0 deletions examples/03_transaction_graph_analysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,31 @@
"5. **Graph validation** — data quality checks before training"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "68fb2cab",
"metadata": {},
"outputs": [],
"source": [
"# Example notebook dependency setup\n",
"# Run these commands from the repository root before executing this notebook:\n",
"# pip install -r requirements.txt\n",
"# pip install -e .\n",
"#\n",
"# If this notebook is opened from a different working directory, add the repository root to sys.path.\n",
"import os\n",
"import sys\n",
"\n",
"repo_root = os.path.abspath(os.path.join(os.getcwd(), \"..\"))\n",
"if repo_root not in sys.path:\n",
" sys.path.insert(0, repo_root)\n",
"\n",
"print(\"Repository root candidate:\", repo_root)\n",
"print(\"Current Python executable:\", sys.executable)\n",
"print(\"First sys.path entry:\", sys.path[0])\n"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
20 changes: 20 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Example Notebooks Setup

Before running the example notebooks, install project dependencies from the repository root.

```bash
pip install -r requirements.txt
pip install -e .
```

If your notebook kernel is started from a different directory, make sure the repository root is on `sys.path` or change the working directory to the project root before importing `astroml`.

Example:

```python
import os
import sys
repo_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if repo_root not in sys.path:
sys.path.insert(0, repo_root)
```
32 changes: 31 additions & 1 deletion tests/test_graph_to_pyg.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,37 @@ def test_conversion_with_node_labels(self):
# Check labels
assert data.y is not None
assert data.y.shape[0] == 3 # num_nodes


def test_conversion_with_numpy_edge_features_and_node_labels(self):
"""Test conversion with numpy arrays for edge features and labels."""
node_features = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float64)
edge_index = np.array([[0, 1], [1, 0]], dtype=np.int32)
edge_features = np.array([[0.5], [0.6]], dtype=np.float64)
node_labels = np.array([0, 1], dtype=np.int64)

data = graph_to_pyg_data(node_features, edge_index, edge_features, node_labels)

assert data.edge_attr.dtype == torch.float32
assert data.y.dtype == torch.int64
assert data.y.shape == (2,)

def test_invalid_edge_index_negative_id(self):
"""Test error handling for negative edge index values."""
node_features = [[1.0, 2.0], [3.0, 4.0]]
edge_index = [[0, -1], [1, 0]]

with pytest.raises(ValueError, match="Edge index contains negative node IDs"):
graph_to_pyg_data(node_features, edge_index)

def test_invalid_node_labels_2d_shape(self):
"""Test error handling for node labels with incorrect dimensionality."""
node_features = [[1.0, 2.0], [3.0, 4.0]]
edge_index = [[0, 1], [1, 0]]
node_labels = [[0], [1]]

with pytest.raises(ValueError, match="node_labels must be 1D array"):
graph_to_pyg_data(node_features, edge_index, node_labels=node_labels)

def test_edge_index_format_conversion(self):
"""Test edge index format conversion from [num_edges, 2] to [2, num_edges]."""
node_features = [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]
Expand Down
55 changes: 55 additions & 0 deletions tests/test_snapshot_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,58 @@ def test_iter_db_snapshots_streams_in_chunks():
Edge(src="c", dst="d", timestamp=int(t0.replace(minute=1).timestamp())),
]
assert session.execute_calls == 1


def test_iter_db_snapshots_parallel_prefetches_windows(monkeypatch):
from datetime import timedelta
from astroml.features.graph.snapshot import iter_db_snapshots

t0 = datetime(2024, 1, 1, tzinfo=timezone.utc)
t_now = t0 + timedelta(hours=2)

class FakeResult:
def __init__(self, rows, scalar_value=None):
self._rows = rows
self._scalar = scalar_value

def yield_per(self, size):
assert size == 2
return iter(self._rows)

def scalar(self):
return self._scalar

class FakeSession:
def __init__(self, result):
self._result = result
self.closed = False

def execute(self, _query):
return self._result

def close(self):
self.closed = True

windows_rows = [
[type("Row", (), {"sender": "a", "receiver": "b", "timestamp": t0})()],
[type("Row", (), {"sender": "c", "receiver": "d", "timestamp": t0 + timedelta(hours=1)})()],
]
call_count = {"calls": 0}

def fake_get_session():
if call_count["calls"] == 0:
result = FakeResult([], scalar_value=t0)
else:
window_index = call_count["calls"] - 1
result = FakeResult(windows_rows[window_index])
call_count["calls"] += 1
return FakeSession(result)

monkeypatch.setattr("astroml.db.session.get_session", fake_get_session)

windows = list(iter_db_snapshots("1h", t0=t0, t_now=t_now, chunk_size=2, workers=2))

assert len(windows) == 2
assert windows[0].edges[0].src == "a"
assert windows[1].edges[0].src == "c"
assert call_count["calls"] == 3
17 changes: 17 additions & 0 deletions tests/test_train_seed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import os
import sys
from importlib import reload


def test_parse_command_line_seed_sets_astroml_seed(monkeypatch):
"""Ensure the top-level --seed CLI flag is parsed and preserved for Hydra."""
monkeypatch.delenv("ASTROML_SEED", raising=False)
monkeypatch.setattr(sys, "argv", ["train.py", "--seed", "123", "experiment=debug"])

import train
reload(train)

train._parse_command_line_seed()

assert os.environ["ASTROML_SEED"] == "123"
assert sys.argv == ["train.py", "experiment=debug"]
Loading
Loading