-
Notifications
You must be signed in to change notification settings - Fork 89
Add automated notebook testing with Papermill #602
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| name: Test Notebooks | ||
|
|
||
| on: | ||
| pull_request: | ||
| branches: [main] | ||
| paths: | ||
| - "pyproject.toml" | ||
| - "causalpy/**" | ||
| - ".github/workflows/test_notebook.yml" | ||
| - "scripts/run_notebooks/**" | ||
| - "docs/source/notebooks/**" | ||
| push: | ||
| branches: [main] | ||
| paths: | ||
| - "pyproject.toml" | ||
| - "causalpy/**" | ||
| - ".github/workflows/test_notebook.yml" | ||
| - "scripts/run_notebooks/**" | ||
| - "docs/source/notebooks/**" | ||
|
|
||
| concurrency: | ||
| group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} | ||
| cancel-in-progress: true | ||
|
|
||
| jobs: | ||
| notebooks: | ||
| runs-on: ubuntu-latest | ||
| timeout-minutes: 60 | ||
| strategy: | ||
| matrix: | ||
| split: | ||
| - "--pattern *_pymc*.ipynb" | ||
| - "--pattern *_skl*.ipynb" | ||
| - "--exclude-pattern _pymc --exclude-pattern _skl" | ||
| steps: | ||
| - uses: actions/checkout@v4 | ||
|
|
||
| - uses: actions/setup-python@v5 | ||
| with: | ||
| python-version: "3.12" | ||
|
|
||
| - name: Install dependencies | ||
| run: | | ||
| pip install --upgrade pip | ||
| pip install -e ".[test,docs]" | ||
|
|
||
| - name: Run notebooks | ||
| run: python scripts/run_notebooks/runner.py ${{ matrix.split }} |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| # Notebook Runner | ||
|
|
||
| This script runs Jupyter notebooks from `docs/source/notebooks/` to validate they execute without errors. | ||
|
|
||
| ## How It Works | ||
|
|
||
| 1. **Mocks `pm.sample()`** — Replaces MCMC sampling with prior predictive (10 draws) for speed | ||
| 2. **Uses Papermill** — Executes notebooks programmatically | ||
| 3. **Discards outputs** — Only checks for errors, doesn't save results | ||
|
|
||
| ## Usage | ||
|
|
||
| ```bash | ||
| # Run all notebooks | ||
| python scripts/run_notebooks/runner.py | ||
|
|
||
| # Run only PyMC notebooks | ||
| python scripts/run_notebooks/runner.py --pattern "*_pymc*.ipynb" | ||
|
|
||
| # Run only sklearn notebooks | ||
| python scripts/run_notebooks/runner.py --pattern "*_skl*.ipynb" | ||
|
|
||
| # Exclude PyMC and sklearn notebooks (run others) | ||
| python scripts/run_notebooks/runner.py --exclude-pattern _pymc --exclude-pattern _skl | ||
| ``` | ||
|
|
||
| ## CI Integration | ||
|
|
||
| The GitHub Actions workflow (`.github/workflows/test_notebook.yml`) runs this script in parallel: | ||
| - Job 1: PyMC notebooks | ||
| - Job 2: Sklearn notebooks | ||
| - Job 3: Other notebooks | ||
|
|
||
| ## Files | ||
|
|
||
| - `runner.py` — Main script | ||
| - `injected.py` — Code injected into notebooks to mock `pm.sample()` | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,44 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Injected code to mock pm.sample for faster notebook execution.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| import numpy as np | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| import pymc as pm | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| import xarray as xr | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| def mock_sample(*args, **kwargs): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Mock pm.sample using prior predictive sampling for speed.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| random_seed = kwargs.get("random_seed") | ||||||||||||||||||||||||||||||||||||||||||||||||||||
| model = kwargs.get("model") | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
| model = kwargs.get("model") | |
| model = kwargs.get("model") | |
| # If no model is provided via kwargs, try to infer it from positional args | |
| if model is None and args: | |
| first_arg = args[0] | |
| if isinstance(first_arg, pm.Model): | |
| model = first_arg |
Copilot
AI
Dec 20, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable name 'samples' is misleading. In MCMC terminology, this represents the number of 'draws', not 'samples' (which typically refers to chains × draws). Consider renaming to 'draws' or 'n_draws' for clarity.
| samples = 10 | |
| idata = pm.sample_prior_predictive( | |
| model=model, | |
| random_seed=random_seed, | |
| draws=samples, | |
| ) | |
| idata.add_groups(posterior=idata.prior) | |
| # Create mock sample stats with diverging data | |
| if "sample_stats" not in idata: | |
| n_chains = 1 | |
| n_draws = samples | |
| n_draws = 10 | |
| idata = pm.sample_prior_predictive( | |
| model=model, | |
| random_seed=random_seed, | |
| draws=n_draws, | |
| ) | |
| idata.add_groups(posterior=idata.prior) | |
| # Create mock sample stats with diverging data | |
| if "sample_stats" not in idata: | |
| n_chains = 1 |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,135 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Script to run notebooks in docs/source/notebooks directory. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Examples | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| -------- | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Run all notebooks: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| python scripts/run_notebooks/runner.py | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Run only PyMC notebooks: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| python scripts/run_notebooks/runner.py --pattern "*_pymc*.ipynb" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Run only sklearn notebooks: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| python scripts/run_notebooks/runner.py --pattern "*_skl*.ipynb" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Exclude PyMC and sklearn notebooks (run others): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| python scripts/run_notebooks/runner.py --exclude-pattern _pymc --exclude-pattern _skl | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import argparse | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import logging | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from pathlib import Path | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from tempfile import NamedTemporaryFile | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import papermill | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from nbformat.notebooknode import NotebookNode | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from papermill.iorw import load_notebook_node, write_ipynb | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| HERE = Path(__file__).parent | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| NOTEBOOKS_PATH = Path("docs/source/notebooks") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| KERNEL_NAME = "python3" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| INJECTED_CODE_FILE = HERE / "injected.py" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| INJECTED_CODE = INJECTED_CODE_FILE.read_text() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def setup_logging() -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logging.basicConfig( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| level=logging.INFO, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| format="%(asctime)s - %(levelname)s - %(message)s", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def inject_mock_code(cells: list) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Inject mock pm.sample code at the start of the notebook.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cells.insert( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| 0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| NotebookNode( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| id="mock-injection", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| execution_count=0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cell_type="code", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| metadata={"tags": []}, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| outputs=[], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| source=INJECTED_CODE, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def run_notebook(notebook_path: Path) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Run a notebook with mocked pm.sample.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logging.info(f"Running notebook: {notebook_path.name}") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| nb = load_notebook_node(str(notebook_path)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| inject_mock_code(nb.cells) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with NamedTemporaryFile(suffix=".ipynb", delete=False) as f: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| write_ipynb(nb, f.name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| papermill.execute_notebook( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| input_path=f.name, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output_path=None, # Discard output | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| kernel_name=KERNEL_NAME, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| progress_bar=True, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cwd=notebook_path.parent, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| except Exception as e: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logging.error(f"Error running notebook: {notebook_path.name}") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise e | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+79
to
+81
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+69
to
+83
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with NamedTemporaryFile(suffix=".ipynb", delete=False) as f: | |
| write_ipynb(nb, f.name) | |
| try: | |
| papermill.execute_notebook( | |
| input_path=f.name, | |
| output_path=None, # Discard output | |
| kernel_name=KERNEL_NAME, | |
| progress_bar=True, | |
| cwd=notebook_path.parent, | |
| ) | |
| except Exception as e: | |
| logging.error(f"Error running notebook: {notebook_path.name}") | |
| raise e | |
| temp_path: Path | None = None | |
| try: | |
| with NamedTemporaryFile(suffix=".ipynb", delete=False) as f: | |
| temp_path = Path(f.name) | |
| write_ipynb(nb, f.name) | |
| papermill.execute_notebook( | |
| input_path=str(temp_path), | |
| output_path=None, # Discard output | |
| kernel_name=KERNEL_NAME, | |
| progress_bar=True, | |
| cwd=notebook_path.parent, | |
| ) | |
| except Exception as e: | |
| logging.error(f"Error running notebook: {notebook_path.name}") | |
| raise e | |
| finally: | |
| if temp_path is not None: | |
| try: | |
| temp_path.unlink(missing_ok=True) | |
| except OSError as cleanup_error: | |
| logging.warning( | |
| "Failed to delete temporary notebook file %s: %s", | |
| temp_path, | |
| cleanup_error, | |
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The documentation states the mock uses "10 draws" but doesn't explain that this happens with only 1 chain. For users debugging test failures, it would be helpful to mention both the number of chains and draws (e.g., "1 chain × 10 draws").