Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions .github/workflows/test_notebook.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
name: Test Notebooks

on:
pull_request:
branches: [main]
paths:
- "pyproject.toml"
- "causalpy/**"
- ".github/workflows/test_notebook.yml"
- "scripts/run_notebooks/**"
- "docs/source/notebooks/**"
push:
branches: [main]
paths:
- "pyproject.toml"
- "causalpy/**"
- ".github/workflows/test_notebook.yml"
- "scripts/run_notebooks/**"
- "docs/source/notebooks/**"

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true

jobs:
notebooks:
runs-on: ubuntu-latest
timeout-minutes: 60
strategy:
matrix:
split:
- "--pattern *_pymc*.ipynb"
- "--pattern *_skl*.ipynb"
- "--exclude-pattern _pymc --exclude-pattern _skl"
steps:
- uses: actions/checkout@v4

- uses: actions/setup-python@v5
with:
python-version: "3.12"

- name: Install dependencies
run: |
pip install --upgrade pip
pip install -e ".[test,docs]"

- name: Run notebooks
run: python scripts/run_notebooks/runner.py ${{ matrix.split }}
6 changes: 3 additions & 3 deletions docs/source/_static/interrogate_badge.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ docs = [
"sphinx-togglebutton",
]
lint = ["interrogate", "pre-commit", "ruff", "mypy"]
test = ["pytest", "pytest-cov", "codespell", "nbformat", "nbconvert"]
test = ["pytest", "pytest-cov", "codespell", "nbformat", "nbconvert", "papermill"]

[project.urls]
Homepage = "https://github.com/pymc-labs/CausalPy"
Expand Down
37 changes: 37 additions & 0 deletions scripts/run_notebooks/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Notebook Runner

This script runs Jupyter notebooks from `docs/source/notebooks/` to validate they execute without errors.

## How It Works

1. **Mocks `pm.sample()`** — Replaces MCMC sampling with prior predictive (10 draws) for speed
Copy link

Copilot AI Dec 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The documentation states the mock uses "10 draws" but doesn't explain that this happens with only 1 chain. For users debugging test failures, it would be helpful to mention both the number of chains and draws (e.g., "1 chain × 10 draws").

Suggested change
1. **Mocks `pm.sample()`** — Replaces MCMC sampling with prior predictive (10 draws) for speed
1. **Mocks `pm.sample()`** — Replaces MCMC sampling with prior predictive (1 chain × 10 draws) for speed

Copilot uses AI. Check for mistakes.
2. **Uses Papermill** — Executes notebooks programmatically
3. **Discards outputs** — Only checks for errors, doesn't save results

## Usage

```bash
# Run all notebooks
python scripts/run_notebooks/runner.py

# Run only PyMC notebooks
python scripts/run_notebooks/runner.py --pattern "*_pymc*.ipynb"

# Run only sklearn notebooks
python scripts/run_notebooks/runner.py --pattern "*_skl*.ipynb"

# Exclude PyMC and sklearn notebooks (run others)
python scripts/run_notebooks/runner.py --exclude-pattern _pymc --exclude-pattern _skl
```

## CI Integration

The GitHub Actions workflow (`.github/workflows/test_notebook.yml`) runs this script in parallel:
- Job 1: PyMC notebooks
- Job 2: Sklearn notebooks
- Job 3: Other notebooks

## Files

- `runner.py` — Main script
- `injected.py` — Code injected into notebooks to mock `pm.sample()`
44 changes: 44 additions & 0 deletions scripts/run_notebooks/injected.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Injected code to mock pm.sample for faster notebook execution."""

import numpy as np
import pymc as pm
import xarray as xr


def mock_sample(*args, **kwargs):
"""Mock pm.sample using prior predictive sampling for speed."""
random_seed = kwargs.get("random_seed")
model = kwargs.get("model")
Copy link

Copilot AI Dec 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mock_sample function doesn't handle the case when 'model' is not provided in kwargs. If pm.sample is called with a positional model argument or without a model in the current context, this will raise a KeyError or TypeError.

Suggested change
model = kwargs.get("model")
model = kwargs.get("model")
# If no model is provided via kwargs, try to infer it from positional args
if model is None and args:
first_arg = args[0]
if isinstance(first_arg, pm.Model):
model = first_arg

Copilot uses AI. Check for mistakes.
samples = 10

idata = pm.sample_prior_predictive(
model=model,
random_seed=random_seed,
draws=samples,
)
idata.add_groups(posterior=idata.prior)

# Create mock sample stats with diverging data
if "sample_stats" not in idata:
n_chains = 1
n_draws = samples
Comment on lines +12 to +24
Copy link

Copilot AI Dec 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable name 'samples' is misleading. In MCMC terminology, this represents the number of 'draws', not 'samples' (which typically refers to chains × draws). Consider renaming to 'draws' or 'n_draws' for clarity.

Suggested change
samples = 10
idata = pm.sample_prior_predictive(
model=model,
random_seed=random_seed,
draws=samples,
)
idata.add_groups(posterior=idata.prior)
# Create mock sample stats with diverging data
if "sample_stats" not in idata:
n_chains = 1
n_draws = samples
n_draws = 10
idata = pm.sample_prior_predictive(
model=model,
random_seed=random_seed,
draws=n_draws,
)
idata.add_groups(posterior=idata.prior)
# Create mock sample stats with diverging data
if "sample_stats" not in idata:
n_chains = 1

Copilot uses AI. Check for mistakes.
sample_stats = xr.Dataset(
{
"diverging": xr.DataArray(
np.zeros((n_chains, n_draws), dtype=int),
dims=("chain", "draw"),
)
}
)
idata.add_groups(sample_stats=sample_stats)

del idata.prior
if "prior_predictive" in idata:
del idata.prior_predictive

return idata


pm.sample = mock_sample
pm.HalfFlat = pm.HalfNormal
pm.Flat = pm.Normal
135 changes: 135 additions & 0 deletions scripts/run_notebooks/runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""Script to run notebooks in docs/source/notebooks directory.

Examples
--------
Run all notebooks:

python scripts/run_notebooks/runner.py

Run only PyMC notebooks:

python scripts/run_notebooks/runner.py --pattern "*_pymc*.ipynb"

Run only sklearn notebooks:

python scripts/run_notebooks/runner.py --pattern "*_skl*.ipynb"

Exclude PyMC and sklearn notebooks (run others):

python scripts/run_notebooks/runner.py --exclude-pattern _pymc --exclude-pattern _skl

"""

import argparse
import logging
from pathlib import Path
from tempfile import NamedTemporaryFile

import papermill
from nbformat.notebooknode import NotebookNode
from papermill.iorw import load_notebook_node, write_ipynb

HERE = Path(__file__).parent
NOTEBOOKS_PATH = Path("docs/source/notebooks")
KERNEL_NAME = "python3"

INJECTED_CODE_FILE = HERE / "injected.py"
INJECTED_CODE = INJECTED_CODE_FILE.read_text()


def setup_logging() -> None:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)


def inject_mock_code(cells: list) -> None:
"""Inject mock pm.sample code at the start of the notebook."""
cells.insert(
0,
NotebookNode(
id="mock-injection",
execution_count=0,
cell_type="code",
metadata={"tags": []},
outputs=[],
source=INJECTED_CODE,
),
)


def run_notebook(notebook_path: Path) -> None:
"""Run a notebook with mocked pm.sample."""
logging.info(f"Running notebook: {notebook_path.name}")

nb = load_notebook_node(str(notebook_path))
inject_mock_code(nb.cells)

with NamedTemporaryFile(suffix=".ipynb", delete=False) as f:
write_ipynb(nb, f.name)
try:
papermill.execute_notebook(
input_path=f.name,
output_path=None, # Discard output
kernel_name=KERNEL_NAME,
progress_bar=True,
cwd=notebook_path.parent,
)
except Exception as e:
logging.error(f"Error running notebook: {notebook_path.name}")
raise e
Comment on lines +79 to +81
Copy link

Copilot AI Dec 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error is caught, logged, and then re-raised, but the temporary file created on line 69 won't be cleaned up when an exception occurs. Consider using a try-finally block or pathlib's unlink() to ensure cleanup.

Copilot uses AI. Check for mistakes.


Comment on lines +69 to +83
Copy link

Copilot AI Dec 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The temporary file created by NamedTemporaryFile is never deleted because delete=False is set, but there's no cleanup code to remove it after use. This will leave temporary files on the filesystem after each notebook run.

Suggested change
with NamedTemporaryFile(suffix=".ipynb", delete=False) as f:
write_ipynb(nb, f.name)
try:
papermill.execute_notebook(
input_path=f.name,
output_path=None, # Discard output
kernel_name=KERNEL_NAME,
progress_bar=True,
cwd=notebook_path.parent,
)
except Exception as e:
logging.error(f"Error running notebook: {notebook_path.name}")
raise e
temp_path: Path | None = None
try:
with NamedTemporaryFile(suffix=".ipynb", delete=False) as f:
temp_path = Path(f.name)
write_ipynb(nb, f.name)
papermill.execute_notebook(
input_path=str(temp_path),
output_path=None, # Discard output
kernel_name=KERNEL_NAME,
progress_bar=True,
cwd=notebook_path.parent,
)
except Exception as e:
logging.error(f"Error running notebook: {notebook_path.name}")
raise e
finally:
if temp_path is not None:
try:
temp_path.unlink(missing_ok=True)
except OSError as cleanup_error:
logging.warning(
"Failed to delete temporary notebook file %s: %s",
temp_path,
cleanup_error,
)

Copilot uses AI. Check for mistakes.
def get_notebooks(
pattern: str | None = None,
exclude_patterns: list[str] | None = None,
) -> list[Path]:
"""Get list of notebooks to run, optionally filtered."""
notebooks = list(NOTEBOOKS_PATH.glob("*.ipynb"))

if pattern:
notebooks = [nb for nb in notebooks if Path(nb).match(pattern)]

if exclude_patterns:
for exc in exclude_patterns:
notebooks = [nb for nb in notebooks if exc not in nb.name]

return sorted(notebooks)


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run CausalPy notebooks.")
parser.add_argument(
"--pattern",
type=str,
default=None,
help="Glob pattern to filter notebooks (e.g., '*_pymc*.ipynb')",
)
parser.add_argument(
"--exclude-pattern",
type=str,
action="append",
dest="exclude_patterns",
help="Pattern to exclude from notebook names (can be used multiple times)",
)
return parser.parse_args()


if __name__ == "__main__":
setup_logging()
args = parse_args()

notebooks = get_notebooks(
pattern=args.pattern,
exclude_patterns=args.exclude_patterns,
)

logging.info(f"Found {len(notebooks)} notebooks to run")
for nb in notebooks:
logging.info(f" - {nb.name}")

for notebook in notebooks:
run_notebook(notebook)

logging.info("All notebooks completed successfully!")
Loading