From ccad9963188ab6dc2b6524a154590fd5af9dc997 Mon Sep 17 00:00:00 2001 From: ritik4ever Date: Mon, 1 Jun 2026 09:46:24 +0000 Subject: [PATCH] Fix deterministic seed handling, add graph utils tests, improve snapshot prefetch concurrency, and document notebook dependencies --- astroml/features/graph/snapshot.py | 98 ++++++++++++++++++++ examples/01_getting_started.ipynb | 25 +++++ examples/02_fraud_detection.ipynb | 25 +++++ examples/03_transaction_graph_analysis.ipynb | 25 +++++ examples/README.md | 20 ++++ tests/test_graph_to_pyg.py | 32 ++++++- tests/test_snapshot_memory.py | 55 +++++++++++ tests/test_train_seed.py | 17 ++++ train.py | 77 ++++++++------- 9 files changed, 342 insertions(+), 32 deletions(-) create mode 100644 examples/README.md create mode 100644 tests/test_train_seed.py diff --git a/astroml/features/graph/snapshot.py b/astroml/features/graph/snapshot.py index 5fc6c1d..a4635b2 100644 --- a/astroml/features/graph/snapshot.py +++ b/astroml/features/graph/snapshot.py @@ -4,6 +4,7 @@ from datetime import datetime, timedelta, timezone from typing import Generator, Iterable, Iterator, List, Optional, Sequence, Set, Tuple import bisect +from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait # Issue #199 — default chunk size for the streaming graph builder. SQLAlchemy @@ -227,6 +228,58 @@ def _edges_iter(_result=result) -> Iterator[Edge]: index += 1 +def _build_snapshot_window( + index: int, + window_start: datetime, + window_end: datetime, + chunk_size: int, +) -> SnapshotWindow: + """Build a single snapshot window from the database.""" + from astroml.db.schema import NormalizedTransaction + from astroml.db.session import get_session + from sqlalchemy import select + + session = get_session() + try: + result = session.execute( + select( + NormalizedTransaction.sender, + NormalizedTransaction.receiver, + NormalizedTransaction.timestamp, + ) + .where( + NormalizedTransaction.timestamp >= window_start, + NormalizedTransaction.timestamp <= window_end, + NormalizedTransaction.receiver.isnot(None), + NormalizedTransaction.sender != NormalizedTransaction.receiver, + ) + .order_by(NormalizedTransaction.timestamp) + ) + + edges: List[Edge] = [] + nodes: Set[str] = set() + + for row in result.yield_per(chunk_size): + edge = Edge( + src=row.sender, + dst=row.receiver, + timestamp=int(row.timestamp.timestamp()), + ) + edges.append(edge) + nodes.add(edge.src) + nodes.add(edge.dst) + + return SnapshotWindow( + index=index, + start=window_start, + end=window_end, + edges=edges, + nodes=nodes, + ) + finally: + session.close() + + def iter_db_snapshots( window: str = "7d", t0: Optional[datetime] = None, @@ -234,6 +287,7 @@ def iter_db_snapshots( step: Optional[str] = None, session=None, chunk_size: int = 100_000, + workers: int = 1, ) -> Generator[SnapshotWindow, None, None]: """Yield discrete time-windowed graph snapshots from the database. @@ -250,6 +304,8 @@ def iter_db_snapshots( chunk_size: Number of rows to stream per fetch from the DB. Larger values reduce round-trips but increase peak memory; smaller values keep the working set bounded for long-window snapshots. + workers: Number of concurrent window fetch workers. Set to >1 to prefetch + windows in parallel when using the default session factory. Yields: :class:`SnapshotWindow` instances in chronological order. @@ -257,6 +313,7 @@ def iter_db_snapshots( from astroml.db.schema import NormalizedTransaction from sqlalchemy import select, func as sqlfunc + session_provided = session is not None if session is None: from astroml.db.session import get_session session = get_session() @@ -275,6 +332,7 @@ def iter_db_snapshots( select(sqlfunc.min(NormalizedTransaction.timestamp)) ).scalar() if result is None: + session.close() return # empty DB t0 = result if result.tzinfo else result.replace(tzinfo=timezone.utc) @@ -286,6 +344,46 @@ def iter_db_snapshots( window_start = t0 index = 0 + if workers > 1 and not session_provided: + session.close() + + pending_windows: Dict[int, SnapshotWindow] = {} + futures: Dict[int, "concurrent.futures.Future[SnapshotWindow]"] = {} + next_index_to_yield = 0 + + with ThreadPoolExecutor(max_workers=workers) as executor: + while window_start < t_now or futures: + while window_start < t_now and len(futures) < workers: + window_end = min(window_start + win_delta, t_now) + future = executor.submit( + _build_snapshot_window, + index, + window_start, + window_end, + chunk_size, + ) + futures[index] = future + window_start += step_delta + index += 1 + + if not futures: + break + + done, _ = wait(set(futures.values()), return_when=FIRST_COMPLETED) + for future in done: + result_window = future.result() + pending_windows[result_window.index] = result_window + future_index = next( + idx for idx, fut in futures.items() if fut is future + ) + del futures[future_index] + + while next_index_to_yield in pending_windows: + yield pending_windows.pop(next_index_to_yield) + next_index_to_yield += 1 + + return + while window_start < t_now: window_end = min(window_start + win_delta, t_now) diff --git a/examples/01_getting_started.ipynb b/examples/01_getting_started.ipynb index 7814713..6b8a9b7 100644 --- a/examples/01_getting_started.ipynb +++ b/examples/01_getting_started.ipynb @@ -16,6 +16,31 @@ "5. Running a baseline GCN model" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "951c94dc", + "metadata": {}, + "outputs": [], + "source": [ + "# Example notebook dependency setup\n", + "# Run these commands from the repository root before executing this notebook:\n", + "# pip install -r requirements.txt\n", + "# pip install -e .\n", + "#\n", + "# If this notebook is opened from a different working directory, add the repository root to sys.path.\n", + "import os\n", + "import sys\n", + "\n", + "repo_root = os.path.abspath(os.path.join(os.getcwd(), \"..\"))\n", + "if repo_root not in sys.path:\n", + " sys.path.insert(0, repo_root)\n", + "\n", + "print(\"Repository root candidate:\", repo_root)\n", + "print(\"Current Python executable:\", sys.executable)\n", + "print(\"First sys.path entry:\", sys.path[0])\n" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/examples/02_fraud_detection.ipynb b/examples/02_fraud_detection.ipynb index 5529d91..162c874 100644 --- a/examples/02_fraud_detection.ipynb +++ b/examples/02_fraud_detection.ipynb @@ -17,6 +17,31 @@ "- Wash trading loops (circular value transfer)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "f6a98ff8", + "metadata": {}, + "outputs": [], + "source": [ + "# Example notebook dependency setup\n", + "# Run these commands from the repository root before executing this notebook:\n", + "# pip install -r requirements.txt\n", + "# pip install -e .\n", + "#\n", + "# If this notebook is opened from a different working directory, add the repository root to sys.path.\n", + "import os\n", + "import sys\n", + "\n", + "repo_root = os.path.abspath(os.path.join(os.getcwd(), \"..\"))\n", + "if repo_root not in sys.path:\n", + " sys.path.insert(0, repo_root)\n", + "\n", + "print(\"Repository root candidate:\", repo_root)\n", + "print(\"Current Python executable:\", sys.executable)\n", + "print(\"First sys.path entry:\", sys.path[0])\n" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/examples/03_transaction_graph_analysis.ipynb b/examples/03_transaction_graph_analysis.ipynb index 4b58c72..2fb1d63 100644 --- a/examples/03_transaction_graph_analysis.ipynb +++ b/examples/03_transaction_graph_analysis.ipynb @@ -15,6 +15,31 @@ "5. **Graph validation** — data quality checks before training" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "68fb2cab", + "metadata": {}, + "outputs": [], + "source": [ + "# Example notebook dependency setup\n", + "# Run these commands from the repository root before executing this notebook:\n", + "# pip install -r requirements.txt\n", + "# pip install -e .\n", + "#\n", + "# If this notebook is opened from a different working directory, add the repository root to sys.path.\n", + "import os\n", + "import sys\n", + "\n", + "repo_root = os.path.abspath(os.path.join(os.getcwd(), \"..\"))\n", + "if repo_root not in sys.path:\n", + " sys.path.insert(0, repo_root)\n", + "\n", + "print(\"Repository root candidate:\", repo_root)\n", + "print(\"Current Python executable:\", sys.executable)\n", + "print(\"First sys.path entry:\", sys.path[0])\n" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000..e5845db --- /dev/null +++ b/examples/README.md @@ -0,0 +1,20 @@ +# Example Notebooks Setup + +Before running the example notebooks, install project dependencies from the repository root. + +```bash +pip install -r requirements.txt +pip install -e . +``` + +If your notebook kernel is started from a different directory, make sure the repository root is on `sys.path` or change the working directory to the project root before importing `astroml`. + +Example: + +```python +import os +import sys +repo_root = os.path.abspath(os.path.join(os.getcwd(), "..")) +if repo_root not in sys.path: + sys.path.insert(0, repo_root) +``` diff --git a/tests/test_graph_to_pyg.py b/tests/test_graph_to_pyg.py index 0760543..b97479c 100644 --- a/tests/test_graph_to_pyg.py +++ b/tests/test_graph_to_pyg.py @@ -59,7 +59,37 @@ def test_conversion_with_node_labels(self): # Check labels assert data.y is not None assert data.y.shape[0] == 3 # num_nodes - + + def test_conversion_with_numpy_edge_features_and_node_labels(self): + """Test conversion with numpy arrays for edge features and labels.""" + node_features = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float64) + edge_index = np.array([[0, 1], [1, 0]], dtype=np.int32) + edge_features = np.array([[0.5], [0.6]], dtype=np.float64) + node_labels = np.array([0, 1], dtype=np.int64) + + data = graph_to_pyg_data(node_features, edge_index, edge_features, node_labels) + + assert data.edge_attr.dtype == torch.float32 + assert data.y.dtype == torch.int64 + assert data.y.shape == (2,) + + def test_invalid_edge_index_negative_id(self): + """Test error handling for negative edge index values.""" + node_features = [[1.0, 2.0], [3.0, 4.0]] + edge_index = [[0, -1], [1, 0]] + + with pytest.raises(ValueError, match="Edge index contains negative node IDs"): + graph_to_pyg_data(node_features, edge_index) + + def test_invalid_node_labels_2d_shape(self): + """Test error handling for node labels with incorrect dimensionality.""" + node_features = [[1.0, 2.0], [3.0, 4.0]] + edge_index = [[0, 1], [1, 0]] + node_labels = [[0], [1]] + + with pytest.raises(ValueError, match="node_labels must be 1D array"): + graph_to_pyg_data(node_features, edge_index, node_labels=node_labels) + def test_edge_index_format_conversion(self): """Test edge index format conversion from [num_edges, 2] to [2, num_edges].""" node_features = [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]] diff --git a/tests/test_snapshot_memory.py b/tests/test_snapshot_memory.py index 2b849b0..6d47e53 100644 --- a/tests/test_snapshot_memory.py +++ b/tests/test_snapshot_memory.py @@ -47,3 +47,58 @@ def test_iter_db_snapshots_streams_in_chunks(): Edge(src="c", dst="d", timestamp=int(t0.replace(minute=1).timestamp())), ] assert session.execute_calls == 1 + + +def test_iter_db_snapshots_parallel_prefetches_windows(monkeypatch): + from datetime import timedelta + from astroml.features.graph.snapshot import iter_db_snapshots + + t0 = datetime(2024, 1, 1, tzinfo=timezone.utc) + t_now = t0 + timedelta(hours=2) + + class FakeResult: + def __init__(self, rows, scalar_value=None): + self._rows = rows + self._scalar = scalar_value + + def yield_per(self, size): + assert size == 2 + return iter(self._rows) + + def scalar(self): + return self._scalar + + class FakeSession: + def __init__(self, result): + self._result = result + self.closed = False + + def execute(self, _query): + return self._result + + def close(self): + self.closed = True + + windows_rows = [ + [type("Row", (), {"sender": "a", "receiver": "b", "timestamp": t0})()], + [type("Row", (), {"sender": "c", "receiver": "d", "timestamp": t0 + timedelta(hours=1)})()], + ] + call_count = {"calls": 0} + + def fake_get_session(): + if call_count["calls"] == 0: + result = FakeResult([], scalar_value=t0) + else: + window_index = call_count["calls"] - 1 + result = FakeResult(windows_rows[window_index]) + call_count["calls"] += 1 + return FakeSession(result) + + monkeypatch.setattr("astroml.db.session.get_session", fake_get_session) + + windows = list(iter_db_snapshots("1h", t0=t0, t_now=t_now, chunk_size=2, workers=2)) + + assert len(windows) == 2 + assert windows[0].edges[0].src == "a" + assert windows[1].edges[0].src == "c" + assert call_count["calls"] == 3 diff --git a/tests/test_train_seed.py b/tests/test_train_seed.py new file mode 100644 index 0000000..3cdfd2e --- /dev/null +++ b/tests/test_train_seed.py @@ -0,0 +1,17 @@ +import os +import sys +from importlib import reload + + +def test_parse_command_line_seed_sets_astroml_seed(monkeypatch): + """Ensure the top-level --seed CLI flag is parsed and preserved for Hydra.""" + monkeypatch.delenv("ASTROML_SEED", raising=False) + monkeypatch.setattr(sys, "argv", ["train.py", "--seed", "123", "experiment=debug"]) + + import train + reload(train) + + train._parse_command_line_seed() + + assert os.environ["ASTROML_SEED"] == "123" + assert sys.argv == ["train.py", "experiment=debug"] diff --git a/train.py b/train.py index 6e713ed..fbf32f9 100644 --- a/train.py +++ b/train.py @@ -9,8 +9,10 @@ python train.py --multirun model.lr=0.001,0.01,0.1 # Hyperparameter sweep """ +import argparse import os import logging +import sys from pathlib import Path from typing import Dict, Any @@ -35,11 +37,27 @@ def set_device(device_config: str) -> torch.device: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: device = torch.device(device_config) - + logger.info(f"Using device: {device}") return device +def set_random_seed(seed: int) -> None: + """Set deterministic random seeds for Python, NumPy, and PyTorch.""" + import random as _random + import numpy as _np + + _random.seed(seed) + _np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + os.environ["PYTHONHASHSEED"] = str(seed) + + def apply_temporal_masks(data: Any, cfg: DictConfig) -> Any: """Replace dataset masks with strict temporal train/val/test splits. @@ -310,23 +328,34 @@ def train(cfg: DictConfig) -> Dict[str, Any]: } +def _parse_command_line_seed() -> None: + """Parse an optional top-level --seed flag and set ASTROML_SEED.""" + parser = argparse.ArgumentParser(add_help=False) + parser.add_argument( + "--seed", + type=int, + help="Deterministic seed for Python, NumPy, and PyTorch", + ) + + args, remaining = parser.parse_known_args() + if args.seed is not None: + os.environ["ASTROML_SEED"] = str(args.seed) + + # Preserve all other arguments for Hydra + sys.argv = [sys.argv[0]] + remaining + + @hydra.main(version_base=None, config_path="configs", config_name="config") -def main(cfg: DictConfig) -> None: - """Main entry point.""" +def _hydra_main(cfg: DictConfig) -> None: + """Hydra entry point after CLI preprocessing.""" # Create save directory save_dir = Path(cfg.experiment.save_dir) save_dir.mkdir(parents=True, exist_ok=True) - + # Log configuration logger.info("Configuration:") logger.info(OmegaConf.to_yaml(cfg)) - - # Set random seeds for reproducibility (#189). Hydra's `cfg.experiment.seed` - # remains the canonical source; a top-level `seed=` override on the CLI - # (e.g. `python train.py seed=42`) is automatically merged into the - # experiment group, so no explicit `--seed` flag is needed beyond Hydra's - # standard overrides. We additionally honour `ASTROML_SEED` as an env - # fallback for non-Hydra entrypoints. + env_seed = os.environ.get("ASTROML_SEED") seed = cfg.experiment.seed if seed is None and env_seed is not None: @@ -341,35 +370,21 @@ def main(cfg: DictConfig) -> None: if seed is not None: seed = int(seed) logger.info("Setting deterministic seeds: %d", seed) - import random as _random - - import numpy as _np - - _random.seed(seed) - _np.random.seed(seed) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - # Trade some throughput for reproducibility on GPU runs. - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - # Ensure DataLoader workers inherit the seed. - os.environ["PYTHONHASHSEED"] = str(seed) + set_random_seed(seed) # Run training results = train(cfg) - - # Log results + logger.info("Training completed!") logger.info(f"Results: {results}") - + # Save results results_path = save_dir / "results.yaml" OmegaConf.save(OmegaConf.create(results), results_path) - + logger.info(f"Results saved to {results_path}") if __name__ == "__main__": - main() + _parse_command_line_seed() + _hydra_main()