diff --git a/CLAUDE.md b/CLAUDE.md index 4e3d77237..eb6b163ab 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -77,6 +77,14 @@ uv run ruff format . `HSSM.sample()` passes `**kwargs` through to `bambi.Model.fit()`, which in turn passes them to PyMC's `pm.sample()`. So parameters like `cores`, `chains`, `nuts_sampler`, `target_accept`, etc. are valid even though they don't appear in HSSM's own signature. Similarly, the HSSM constructor passes `**kwargs` to `bambi.Model()`, so bambi parameters like `noncentered` are valid. +### ONNX likelihoods are single-trial + `jax.vmap` + +Every ONNX graph consumed by HSSM must be exported with a concrete single-trial input shape (no `dynamic_axes`). Per-trial batching happens at the HSSM layer via `jax.vmap` over trials — see [`src/hssm/distribution_utils/onnx.py:115-138`](src/hssm/distribution_utils/onnx.py#L115-L138), where `logp(*inputs)` builds one flat per-trial vector and `make_vmap_func` lifts it. + +Enforced at load time by `_check_single_trial_input_shape` in [`src/hssm/distribution_utils/onnx_utils/onnx2jax.py`](src/hssm/distribution_utils/onnx_utils/onnx2jax.py), which raises a `ValueError` on any symbolic input dim. The constraint exists because `jaxonnxruntime` traces against the construction-time dummy and bakes those shapes into the returned closure — calling that closure at a different batch size silently produces wrong outputs for graphs with batch-dependent intermediates (log-det accumulators, `Reshape` with `-1`). + +LANfactory's exporters (`transform_sbi_to_onnx`, BayesFlow LRE export) already follow this convention. A new ONNX source must do the same: trace with a rank-1 dummy, no `dynamic_axes`. + ### Notebook execution in CI Two separate skip mechanisms for notebooks: diff --git a/docs/tutorials/sbi_nle_integration.ipynb b/docs/tutorials/sbi_nle_integration.ipynb new file mode 100644 index 000000000..eb444dd6d --- /dev/null +++ b/docs/tutorials/sbi_nle_integration.ipynb @@ -0,0 +1,685 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "c5b5342c", + "metadata": {}, + "source": [ + "# Integrating sbi-trained likelihoods into HSSM (NRE via ONNX)\n", + "\n", + "This tutorial mirrors the structure of `bayesflow_lre_integration.ipynb` for the\n", + "third major SBI library in the HSSM ecosystem: **[sbi](https://github.com/sbi-dev/sbi)**\n", + "(mackelab). It demonstrates how to:\n", + "\n", + "1. Train a neural ratio estimator (NRE) on synthetic DDM simulations using sbi.\n", + "2. Export the trained estimator to ONNX via\n", + " [`lanfactory.onnx.transform_sbi_to_onnx`](https://alexanderfengler.github.io/LANfactory/exporting_sbi_models/).\n", + "3. Load the ONNX file into HSSM exactly like any other LAN-style approximator and run\n", + " MCMC inference.\n", + "4. Compare against HSSM's analytical DDM likelihood as the gold-standard posterior.\n", + "\n", + "> **Why NRE only, not NLE/MNLE?**\n", + "> Vanilla NLE with a MAF flow misbehaves on DDM data because rt is continuous but\n", + "> choice is discrete (∈ {−1, +1}). The flow treats choice as continuous, can't\n", + "> represent the support boundary `rt > t_nd`, and produces qualitatively wrong\n", + "> posteriors (we observed v ≈ 0.12 vs truth 0.5, with spurious bimodality on a).\n", + "> The correct sbi method is **MNLE** (Mixed Neural Likelihood Estimator), which\n", + "> splits x into discrete and continuous dims and models each properly. But MNLE's\n", + "> categorical lookup uses `torch.searchsorted`, which `torch.onnx.export` doesn't\n", + "> support and `jaxonnxruntime` lacks a handler for. See\n", + "> [plans/sbi-onnx-integration.md](../../../HSSMSpine/plans/sbi-onnx-integration.md)\n", + "> \"Deferred sbi paths\" for the resolution roadmap (a ~50-line upstream PR to\n", + "> `jaxonnxruntime` unlocks both MNLE and NSF flows in one stroke).\n", + ">\n", + "> Until then, **NRE is the working integration path** — it doesn't need to model\n", + "> density at all, just a classifier between joint and marginal pairs, which is\n", + "> robust to the discrete/continuous mixing.\n", + "\n", + "> **Environment note:** This tutorial requires both `hssm` and `lanfactory[all]` (which\n", + "> pulls `sbi` and `nflows`) in the same environment. JAX/flax/numpyro pins must be\n", + "> resolved jointly across the two packages." + ] + }, + { + "cell_type": "markdown", + "id": "8d15a24a", + "metadata": {}, + "source": [ + "## Part 1 — Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "99a29072", + "metadata": {}, + "outputs": [], + "source": [ + "# Enable x64 BEFORE any other JAX-touching import.\n", + "# HSSM's onnx2jax auto-flips this when loading sbi-exported ONNX graphs that carry\n", + "# int64 tensors (typical for normalizing flows from torch.onnx.export), but setting\n", + "# it explicitly here is best practice and silences the auto-flip UserWarning.\n", + "import jax\n", + "jax.config.update(\"jax_enable_x64\", True)\n", + "\n", + "from pathlib import Path\n", + "import warnings\n", + "\n", + "import arviz as az\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import pandas as pd\n", + "import torch\n", + "from torch import nn\n", + "\n", + "import hssm\n", + "from sbi.inference import NRE_A\n", + "from sbi.neural_nets import classifier_nn\n", + "from sbi.utils import BoxUniform\n", + "from ssms.basic_simulators.simulator import simulator\n", + "\n", + "# Import transform_sbi_to_onnx with a fallback. The clean path works when\n", + "# lanfactory is installed in an env whose JAX/flax pins are compatible with this\n", + "# notebook's other deps. Until the cross-repo env alignment lands, the fallback\n", + "# loads only the sbi exporter module directly, bypassing lanfactory's top-level\n", + "# __init__.py (which would otherwise pull the flax-dependent jax_mlp trainer).\n", + "try:\n", + " from lanfactory.onnx import transform_sbi_to_onnx\n", + "except ImportError:\n", + " import importlib.util\n", + " import os\n", + "\n", + " # Try candidates: explicit env var, then several relative paths covering\n", + " # common Jupyter launch contexts (notebook dir / repo root / spine root).\n", + " _candidates = [\n", + " os.environ.get(\"LANFACTORY_SBI_PATH\"),\n", + " \"../../../LANfactory/src/lanfactory/onnx/sbi.py\", # cwd = notebook dir\n", + " \"repos/LANfactory/src/lanfactory/onnx/sbi.py\", # cwd = HSSMSpine root\n", + " \"../LANfactory/src/lanfactory/onnx/sbi.py\", # cwd = repos/HSSM root\n", + " ]\n", + " _path = None\n", + " for _c in _candidates:\n", + " if _c and Path(_c).exists():\n", + " _path = Path(_c).resolve()\n", + " break\n", + " if _path is None:\n", + " raise ImportError(\n", + " \"Could not locate lanfactory/onnx/sbi.py. Set the LANFACTORY_SBI_PATH \"\n", + " \"environment variable to the absolute path of that file, or run the \"\n", + " \"notebook from a directory where one of the relative candidates resolves: \"\n", + " f\"{[c for c in _candidates if c]}\"\n", + " )\n", + " _spec = importlib.util.spec_from_file_location(\"_lanfactory_sbi\", _path)\n", + " _mod = importlib.util.module_from_spec(_spec)\n", + " _spec.loader.exec_module(_mod)\n", + " transform_sbi_to_onnx = _mod.transform_sbi_to_onnx\n", + " print(f\"(fallback) loaded transform_sbi_to_onnx from {_path}\")\n", + "\n", + "np.random.seed(0)\n", + "torch.manual_seed(0)\n", + "\n", + "# Training budget for NRE_A — the last known-working configuration.\n", + "# 1M (theta, x) pairs from ssm-simulators (1 sample per theta) at the wider\n", + "# HSSM-default prior bounds, with a moderate-sized MLP classifier.\n", + "# We are deliberately reverting to a simple, previously-validated setup to\n", + "# rule out new variables (NRE_B / multi-sample-per-theta / FCEmbedding /\n", + "# bigger classifier) before introducing them again one at a time.\n", + "N_TRAIN = 1_000_000\n", + "N_OBS = 500\n", + "NUM_EPOCHS = 300\n", + "STOP_AFTER_EPOCHS = 50\n", + "TRAINING_BATCH_SIZE = 500\n", + "HIDDEN_FEATURES = 100\n", + "# Smaller MCMC budget for verification runs. Together with target_accept=0.8\n", + "# and max_tree_depth=8 below, this keeps the MCMC step bounded so we can\n", + "# diagnose training-quality issues without paying 30+ minutes per pass.\n", + "MCMC_DRAWS = 500\n", + "MCMC_TUNE = 500\n", + "MCMC_CHAINS = 2" + ] + }, + { + "cell_type": "markdown", + "id": "dbe9e2bd", + "metadata": {}, + "source": [ + "## Part 2 — Simulate observed DDM data\n", + "\n", + "We use the standard 4-parameter DDM (`v`, `a`, `z`, `t`) from `ssm-simulators`. The\n", + "true parameters and parameter ranges below match the BayesFlow LRE tutorial so that\n", + "posteriors from the two tutorials can be compared apples-to-apples." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ce250bb6", + "metadata": {}, + "outputs": [], + "source": [ + "DDM_PARAM_NAMES = [\"v\", \"a\", \"z\", \"t\"]\n", + "# Training prior matches HSSM's default bounds for model=\"ddm\" with\n", + "# loglik_kind=\"approx_differentiable\" — see hssm.defaults.default_model_config.\n", + "# This is important: if the training prior is narrower than HSSM's posterior\n", + "# can explore, MCMC will walk into parameter regions the surrogate never saw\n", + "# and extrapolate badly.\n", + "PRIOR_LOW = np.array([-3.0, 0.3, 0.0, 0.0], dtype=np.float32)\n", + "PRIOR_HIGH = np.array([3.0, 2.5, 1.0, 2.0], dtype=np.float32)\n", + "TRUE_THETA = np.array([0.5, 1.2, 0.5, 0.25], dtype=np.float32)\n", + "\n", + "out = simulator(theta=TRUE_THETA[None, :], model=\"ddm\", n_samples=N_OBS)\n", + "obs_data = pd.DataFrame(\n", + " {\n", + " \"rt\": out[\"rts\"].squeeze().astype(np.float32),\n", + " \"response\": out[\"choices\"].squeeze().astype(np.float32),\n", + " }\n", + ")\n", + "print(f\"observed: {len(obs_data)} trials at true theta = {TRUE_THETA}\")\n", + "obs_data.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a0c34c88", + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(1, 1, figsize=(8, 4))\n", + "for resp, color in zip([-1, 1], [\"C0\", \"C1\"]):\n", + " mask = obs_data[\"response\"] == resp\n", + " ax.hist(\n", + " obs_data.loc[mask, \"rt\"],\n", + " bins=40,\n", + " alpha=0.6,\n", + " label=f\"choice={int(resp)}\",\n", + " color=color,\n", + " )\n", + "ax.set_xlabel(\"RT (s)\")\n", + "ax.set_ylabel(\"count\")\n", + "ax.set_title(\"Observed RT histogram by choice\")\n", + "ax.legend()\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "d49088e8", + "metadata": {}, + "source": [ + "## Part 3 — Train an sbi NRE_A classifier on DDM simulations\n", + "\n", + "`NRE_A` (Hermans et al. 2020) learns a binary classifier that distinguishes\n", + "joint `(θ, x)` pairs from marginal `(θ', x)` pairs (where θ' is drawn from the\n", + "prior). The output logit equals `log p(x | θ) − log p(x)` up to a constant, so\n", + "it serves directly as the HSSM log-likelihood for MCMC (the θ-independent\n", + "constant drops out under MCMC's accept ratios).\n", + "\n", + "This iteration uses the **last known-working NRE configuration**:\n", + "\n", + "- `NRE_A` (binary classifier, not contrastive)\n", + "- 1M `(θ, x)` training pairs (1 sample per θ)\n", + "- `hidden_features = 100` (sbi default is 50)\n", + "- No embedding net on θ\n", + "- `norm_layer = nn.Identity` (jaxonnxruntime doesn't implement\n", + " `LayerNormalization`, so the MLP norm layer is disabled)\n", + "\n", + "We are reverting to this baseline because a more ambitious configuration\n", + "(NRE_B + atomic contrastive + multi-sample-per-θ + FCEmbedding +\n", + "`hidden_features=128`) produced a classifier that gave HSSM a near-constant\n", + "log-likelihood at MCMC time — the chains explored the entire prior with no\n", + "concentration. The bisect strategy is to verify this simpler config works,\n", + "then re-introduce changes one at a time." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f383c5f1", + "metadata": {}, + "outputs": [], + "source": [ + "prior = BoxUniform(\n", + " low=torch.from_numpy(PRIOR_LOW),\n", + " high=torch.from_numpy(PRIOR_HIGH),\n", + ")\n", + "theta_train = prior.sample((N_TRAIN,))\n", + "\n", + "# Batched ssm-simulators: theta of shape (N, 4) with n_samples=1 returns\n", + "# rts/choices of shape (N, 1). Much faster than a Python loop for large N.\n", + "sim = simulator(\n", + " theta=theta_train.numpy().astype(np.float32),\n", + " model=\"ddm\",\n", + " n_samples=1,\n", + ")\n", + "x_train = torch.from_numpy(\n", + " np.stack(\n", + " [sim[\"rts\"].squeeze(-1), sim[\"choices\"].squeeze(-1)], axis=-1\n", + " ).astype(np.float32)\n", + ")\n", + "print(f\"training set: theta={theta_train.shape}, x={x_train.shape}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "483214d2", + "metadata": {}, + "outputs": [], + "source": "from torch.utils.tensorboard import SummaryWriter\n\n# User-configurable: where sbi writes tensorboard training logs. Default is\n# ~/sbi_logs_tutorial — outside the HSSM repo so notebook re-runs don't leave\n# artifacts in your working tree. Set to None below (and drop summary_writer\n# from the NRE_A call) to use sbi's default of ./sbi-logs/ relative to cwd.\nTUTORIAL_LOG_DIR = Path.home() / \"sbi_logs_tutorial\"\nTUTORIAL_LOG_DIR.mkdir(parents=True, exist_ok=True)\n\n# Build the classifier. LayerNorm is disabled because jaxonnxruntime\n# doesn't implement LayerNormalization. No embedding net on theta in this\n# baseline-revert iteration (see Part 3 markdown for the bisect context).\nclassifier_builder = classifier_nn(\n model=\"mlp\",\n norm_layer=nn.Identity,\n hidden_features=HIDDEN_FEATURES,\n)\ninference_nre = NRE_A(\n prior=prior,\n classifier=classifier_builder,\n summary_writer=SummaryWriter(log_dir=str(TUTORIAL_LOG_DIR)),\n)\nclassifier_nre = inference_nre.append_simulations(theta_train, x_train).train(\n training_batch_size=TRAINING_BATCH_SIZE,\n max_num_epochs=NUM_EPOCHS,\n stop_after_epochs=STOP_AFTER_EPOCHS,\n)\nclassifier_nre.eval()\nprint(\"NRE_A training complete\")" + }, + { + "cell_type": "markdown", + "id": "fdc22335", + "metadata": {}, + "source": [ + "## Part 4 — Export the trained NRE to ONNX\n", + "\n", + "The exporter wraps the classifier's `forward(theta, x)` logit as the HSSM\n", + "log-likelihood. No Jacobian correction is needed — ratios are invariant to the\n", + "z-score standardization sbi applies internally." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5f536645", + "metadata": {}, + "outputs": [], + "source": "# User-configurable: where the .onnx file lands. Default is outside the HSSM\n# repo so notebook re-runs don't pollute the working tree.\n# Override examples:\n# ARTIFACT_DIR = Path(\"/path/to/my/project/onnx\") # keep nearby\n# ARTIFACT_DIR = Path(tempfile.mkdtemp()) # ephemeral\nARTIFACT_DIR = Path.home() / \"sbi_onnx_tutorial\"\nARTIFACT_DIR.mkdir(parents=True, exist_ok=True)\nnre_onnx_path = ARTIFACT_DIR / \"ddm_nre.onnx\"\n\ntransform_sbi_to_onnx(\n classifier_nre,\n str(nre_onnx_path),\n mode=\"nre\",\n example_theta_dim=4,\n example_x_dim=2,\n)\nprint(f\"exported NRE: {nre_onnx_path} ({nre_onnx_path.stat().st_size:,} bytes)\")" + }, + { + "cell_type": "markdown", + "id": "3d355ed9", + "metadata": {}, + "source": [ + "## Part 4b — Pre-MCMC verification: is the trained classifier any good?\n", + "\n", + "Before paying the multi-minute MCMC cost, sanity-check two things:\n", + "\n", + "1. **Logit sweep across θ-space.** Hold three θ dimensions at their true values\n", + " and sweep the fourth across its prior range, plotting the summed classifier\n", + " log-ratio on the observed data. A well-trained NRE shows a sharp peak near\n", + " the true value with tens-to-hundreds of log units of vertical range. A\n", + " flat curve (< 5 log units of range) is the smoking gun for \"classifier\n", + " collapsed\" — MCMC will produce a posterior equal to the prior, no point\n", + " continuing.\n", + "\n", + "2. **ONNX export round-trip.** Compare `classifier_nre(theta, x).item()` to the\n", + " exported ONNX file evaluated through `onnxruntime` on the same input. If\n", + " they disagree, the export has introduced a bug and MCMC won't be running\n", + " the model you think it is.\n", + "\n", + "These cells use the in-memory `classifier_nre` from Part 3 and the exported\n", + "file from Part 4 — they're cheap, no MCMC required." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a4148b2d", + "metadata": {}, + "outputs": [], + "source": [ + "# Diagnostic: how much does the NRE logit change as we sweep each theta dim?\n", + "sweep_pts = 50\n", + "sweep_results = {}\n", + "obs_x_t = torch.from_numpy(\n", + " obs_data[[\"rt\", \"response\"]].values.astype(np.float32)\n", + ")\n", + "theta_center = torch.tensor(TRUE_THETA, dtype=torch.float32)\n", + "for dim, name in enumerate(DDM_PARAM_NAMES):\n", + " sweep = torch.linspace(PRIOR_LOW[dim], PRIOR_HIGH[dim], sweep_pts)\n", + " logits = []\n", + " for v in sweep:\n", + " theta = theta_center.clone()\n", + " theta[dim] = v\n", + " theta_row = theta.unsqueeze(0).repeat(len(obs_x_t), 1)\n", + " with torch.no_grad():\n", + " logits.append(classifier_nre(theta_row, obs_x_t).sum().item())\n", + " sweep_results[name] = (sweep.numpy(), np.array(logits))\n", + "\n", + "fig, axes = plt.subplots(1, 4, figsize=(16, 3.5))\n", + "for ax, name in zip(axes, DDM_PARAM_NAMES):\n", + " th, lp = sweep_results[name]\n", + " ax.plot(th, lp - lp.max(), \"C0-\", linewidth=2)\n", + " ax.axvline(TRUE_THETA[DDM_PARAM_NAMES.index(name)], color=\"red\",\n", + " linestyle=\"--\", linewidth=2, label=\"true θ\")\n", + " ax.set_xlabel(name)\n", + " ax.set_ylabel(\"Δ summed log-ratio (= 0 at max)\")\n", + " ax.set_title(f\"sweep over {name}\\n(vertical range = {np.ptp(lp):.2f})\")\n", + " ax.legend(fontsize=8)\n", + "fig.suptitle(\n", + " \"Trained NRE classifier: log-ratio along each θ axis (others held at truth)\",\n", + " y=1.02,\n", + ")\n", + "plt.tight_layout()\n", + "plt.show()\n", + "\n", + "# Sanity print: vertical range per dim. < ~5 log units is bad.\n", + "print(\"\\nPer-dim vertical range (max log-ratio − min log-ratio):\")\n", + "for name in DDM_PARAM_NAMES:\n", + " _, lp = sweep_results[name]\n", + " print(f\" {name}: {np.ptp(lp):.2f}\")\n", + "print(\"\\nInterpretation:\")\n", + "print(\" > 50 log units : strong discriminative signal, classifier well-trained.\")\n", + "print(\" 10–50 : moderate signal; should give a usable posterior.\")\n", + "print(\" < 5 : near-flat; classifier collapsed to uninformative.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ba5082cf", + "metadata": {}, + "outputs": [], + "source": [ + "# Diagnostic: does the exported ONNX match the torch classifier?\n", + "import onnxruntime as _ort\n", + "\n", + "_sess = _ort.InferenceSession(str(nre_onnx_path))\n", + "_input_name = _sess.get_inputs()[0].name\n", + "_test_theta = torch.tensor([[0.5, 1.2, 0.5, 0.25]], dtype=torch.float32)\n", + "_test_x = torch.tensor([[0.5, 1.0]], dtype=torch.float32)\n", + "_combined = (\n", + " torch.cat([_test_theta, _test_x], dim=-1).squeeze(0).numpy().astype(np.float32)\n", + ")\n", + "\n", + "with torch.no_grad():\n", + " _y_torch = float(classifier_nre(_test_theta, _test_x).item())\n", + "_y_ort = float(_sess.run(None, {_input_name: _combined})[0])\n", + "\n", + "print(f\"torch logit at (θ=true, x=(0.5, +1)): {_y_torch:+.5f}\")\n", + "print(f\"ORT logit at (θ=true, x=(0.5, +1)): {_y_ort:+.5f}\")\n", + "print(f\"|Δ|: {abs(_y_torch - _y_ort):.2e}\")\n", + "print()\n", + "if abs(_y_torch - _y_ort) < 1e-4:\n", + " print(\"→ Export is faithful; any pathology is in the trained classifier itself.\")\n", + "else:\n", + " print(\"→ Export disagrees with torch by more than 1e-4 — the exporter is dropping\"\n", + " \" information on the way through ONNX, regardless of training quality.\")" + ] + }, + { + "cell_type": "markdown", + "id": "6bdc1ff3", + "metadata": {}, + "source": [ + "## Part 5 — High-level integration via `hssm.HSSM()`\n", + "\n", + "HSSM's `loglik_kind=\"approx_differentiable\"` path consumes the `.onnx` file\n", + "identically to a LAN-trained network. With `model=\"ddm\"` HSSM already knows the\n", + "parameter list and response columns; we just hand it the file." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "68a5d9f5", + "metadata": {}, + "outputs": [], + "source": [ + "model_nre = hssm.HSSM(\n", + " data=obs_data,\n", + " model=\"ddm\",\n", + " loglik_kind=\"approx_differentiable\",\n", + " loglik=str(nre_onnx_path),\n", + " p_outlier=0,\n", + ")\n", + "print(model_nre)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e2ecc3a8", + "metadata": {}, + "outputs": [], + "source": [ + "# Verification-budget MCMC. target_accept=0.8 (was 0.9) and max_tree_depth=8\n", + "# (caps NUTS at 256 leapfrog steps per draw instead of 1024) bound the per-step\n", + "# cost so a pathological surrogate geometry can't produce a multi-hour run.\n", + "# progressbar=True lets you actually see chain progress as it goes.\n", + "idata_nre = model_nre.sample(\n", + " sampler=\"numpyro\",\n", + " draws=MCMC_DRAWS,\n", + " tune=MCMC_TUNE,\n", + " chains=MCMC_CHAINS,\n", + " target_accept=0.8,\n", + " progressbar=True,\n", + " nuts_sampler_kwargs={\"max_tree_depth\": 8},\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "be7ca299", + "metadata": {}, + "outputs": [], + "source": [ + "summary_nre = az.summary(idata_nre, var_names=DDM_PARAM_NAMES)\n", + "summary_nre" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8c4d79b0", + "metadata": {}, + "outputs": [], + "source": [ + "az.plot_trace(idata_nre, var_names=DDM_PARAM_NAMES)\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "7b6190be", + "metadata": {}, + "source": [ + "### Part 5b — Diagnostic: is the NRE classifier itself biased, or is HSSM not finding its mode?\n", + "\n", + "For NRE the logit `forward(theta, x)` equals `log p(x | θ) − log p(x)` up to a\n", + "constant. The θ-independent term cancels when we compare two θ values on the\n", + "same data, so we can use the *summed logit over trials* as a proxy for \"which θ\n", + "does the classifier think makes the data more likely.\" If the classifier itself\n", + "prefers a θ far from the truth, training quality is the issue; if it prefers the\n", + "truth, MCMC is exploring elsewhere." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3ff1da78", + "metadata": {}, + "outputs": [], + "source": [ + "posterior_mean_nre = {\n", + " p: float(idata_nre.posterior[p].mean()) for p in DDM_PARAM_NAMES\n", + "}\n", + "obs_x_t = torch.from_numpy(\n", + " obs_data[[\"rt\", \"response\"]].values.astype(np.float32)\n", + ")\n", + "\n", + "def total_logit(classifier, theta_dict):\n", + " \"\"\"Sum log-ratio logit over all observed trials at a single theta.\"\"\"\n", + " theta_row = torch.tensor(\n", + " [[theta_dict[p] for p in DDM_PARAM_NAMES]], dtype=torch.float32\n", + " ).repeat(len(obs_x_t), 1)\n", + " with torch.no_grad():\n", + " return classifier(theta_row, obs_x_t).sum().item()\n", + "\n", + "lt_true = total_logit(classifier_nre, dict(zip(DDM_PARAM_NAMES, TRUE_THETA)))\n", + "lt_mean = total_logit(classifier_nre, posterior_mean_nre)\n", + "\n", + "print(f\"NRE total logit at true theta: {lt_true:+.2f}\")\n", + "print(f\"NRE total logit at posterior mean: {lt_mean:+.2f}\")\n", + "print(f\"Δ (mean − true): {lt_mean - lt_true:+.2f}\")\n", + "print()\n", + "if lt_mean > lt_true + 5.0:\n", + " print(\"→ NRE itself prefers the wrong θ by a large margin.\")\n", + " print(\" Diagnosis: TRAINING QUALITY.\")\n", + "elif lt_mean > lt_true:\n", + " print(\"→ NRE mildly prefers the posterior mean over the truth.\")\n", + " print(\" Could be marginal-vs-joint, mild miscalibration, or sampling.\")\n", + "else:\n", + " print(\"→ NRE prefers the truth; the wrong posterior mean is a sampling\")\n", + " print(\" artifact (priors / init / mixing), not a training issue.\")\n", + "print()\n", + "print(f\"Posterior mean: {posterior_mean_nre}\")\n", + "print(f\"True theta: {dict(zip(DDM_PARAM_NAMES, TRUE_THETA.tolist()))}\")" + ] + }, + { + "cell_type": "markdown", + "id": "15d935a3", + "metadata": {}, + "source": [ + "## Part 6 — Ground-truth posterior via HSSM's analytical DDM\n", + "\n", + "HSSM ships a closed-form analytical likelihood for the standard DDM\n", + "(`loglik_kind=\"analytical\"`, the [Navarro & Fuss](https://psyarxiv.com/cwsbm/)\n", + "form). Running MCMC against this likelihood on the *same observed data*\n", + "produces what we should treat as the gold-standard posterior for this\n", + "model + dataset: any deviation of the sbi-NRE marginals from these is\n", + "*approximation error* in the neural surrogate, not intrinsic posterior width.\n", + "\n", + "The analytical likelihood uses slightly different parameter bounds\n", + "(a, t unbounded above; otherwise the same) but on the observed data the\n", + "posterior concentrates regardless of the wider bound." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13386f5d", + "metadata": {}, + "outputs": [], + "source": [ + "model_analytical = hssm.HSSM(\n", + " data=obs_data,\n", + " model=\"ddm\",\n", + " loglik_kind=\"analytical\",\n", + " p_outlier=0,\n", + ")\n", + "idata_analytical = model_analytical.sample(\n", + " sampler=\"numpyro\",\n", + " draws=MCMC_DRAWS,\n", + " tune=MCMC_TUNE,\n", + " chains=MCMC_CHAINS,\n", + " target_accept=0.9,\n", + " progressbar=False,\n", + ")\n", + "summary_analytical = az.summary(idata_analytical, var_names=DDM_PARAM_NAMES)\n", + "summary_analytical" + ] + }, + { + "cell_type": "markdown", + "id": "04f905c3", + "metadata": {}, + "source": [ + "## Part 7 — Posterior comparison: analytical vs sbi NRE\n", + "\n", + "The keystone comparison. The analytical posterior is the gold standard for this\n", + "model + data; the sbi-NRE marginals are the approximation we built. Distance\n", + "between NRE and analytical is the surrogate's approximation error; distance\n", + "between the analytical posterior and the true θ is intrinsic posterior width\n", + "(the data simply isn't infinite at N=500 trials).\n", + "\n", + "For the broader cross-tutorial comparison alongside the LAN baseline and\n", + "BayesFlow-LRE result on the same simulated data, load their cached posteriors\n", + "here and add panels to the plot.\n", + "\n", + "> Reuses the exact `TRUE_THETA`, `N_OBS`, and prior ranges from\n", + "> `bayesflow_lre_integration.ipynb`, so the BayesFlow-LRE posteriors are\n", + "> directly comparable when added." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5fa03b3f", + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(1, 4, figsize=(16, 4))\n", + "for ax, name, true_val in zip(axes, DDM_PARAM_NAMES, TRUE_THETA):\n", + " samples_ana = idata_analytical.posterior[name].values.flatten()\n", + " samples_nre = idata_nre.posterior[name].values.flatten()\n", + " ax.hist(\n", + " samples_ana,\n", + " bins=30,\n", + " alpha=0.5,\n", + " label=\"analytical (truth)\",\n", + " color=\"C2\",\n", + " density=True,\n", + " )\n", + " ax.hist(samples_nre, bins=30, alpha=0.5, label=\"sbi NRE\", color=\"C1\", density=True)\n", + " ax.axvline(true_val, color=\"red\", linestyle=\"--\", linewidth=2, label=\"true θ\")\n", + " ax.set_xlabel(name)\n", + " ax.set_title(f\"posterior over {name}\")\n", + " ax.legend(fontsize=8)\n", + "fig.suptitle(\n", + " \"DDM posterior recovery: analytical (gold) vs sbi NRE\", y=1.02\n", + ")\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "ba353cde", + "metadata": {}, + "source": [ + "## Summary and deferred work\n", + "\n", + "We trained an sbi NRE_B classifier on synthetic DDM data with FCEmbedding on\n", + "the parameters and atomic contrastive estimation (`num_atoms=20`), exported it\n", + "to ONNX via `lanfactory.onnx.transform_sbi_to_onnx`, and ran MCMC through\n", + "HSSM's existing `loglik_kind=\"approx_differentiable\"` pipeline. The resulting\n", + "posterior is compared against HSSM's analytical DDM posterior on the same\n", + "data — that's our gold-standard reference.\n", + "\n", + "**Why not NLE in this tutorial?**\n", + "\n", + "We originally planned an NLE section too. Vanilla NLE with a MAF flow produces\n", + "qualitatively wrong posteriors on DDM data because rt is continuous but choice\n", + "is discrete (∈ {−1, +1}); the flow can't represent that structure or the hard\n", + "support boundary `rt > t_nd`. The correct sbi method is **MNLE** (Mixed Neural\n", + "Likelihood Estimator), which factorizes `p(rt, choice | θ) = p(choice | θ) ·\n", + "p(rt | choice, θ)`. But MNLE's `CategoricalMassEstimator` uses\n", + "`torch.searchsorted` for value-to-index lookup, which `torch.onnx.export` does\n", + "not support — blocking the ONNX export path until a `SearchSorted` ONNX-op\n", + "handler is added to `jaxonnxruntime`. The same gap blocks Neural Spline Flows.\n", + "\n", + "See `plans/sbi-onnx-integration.md` \"Deferred sbi paths (MNLE, vanilla NLE on\n", + "DDM, NSF flows)\" in HSSMSpine for the resolution roadmap. A single ~50-line\n", + "upstream PR to `jaxonnxruntime` adding `SearchSorted` unlocks both NSF flows\n", + "AND MNLE in one stroke.\n", + "\n", + "**Where to look next**\n", + "\n", + "- LANfactory's [Exporting sbi Models guide](https://alexanderfengler.github.io/LANfactory/exporting_sbi_models/) — supported-architecture matrix, known constraints, troubleshooting.\n", + "- The BayesFlow LRE tutorial (`bayesflow_lre_integration.ipynb`) — the same DDM with a different SBI library, for cross-toolkit comparison.\n", + "- LAN tutorials (`main_tutorial.ipynb`) — the original LANfactory workflow this integration builds on top of." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.x" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/src/hssm/base.py b/src/hssm/base.py index 806e0eded..1da37bee8 100644 --- a/src/hssm/base.py +++ b/src/hssm/base.py @@ -674,12 +674,15 @@ def sample( compute_likelihood = kwargs["idata_kwargs"].pop("log_likelihood", True) omit_offsets = kwargs.pop("omit_offsets", False) + # Pass the user's sampler choice through to bambi verbatim. Bambi's + # inference_method natively accepts "pymc"/"numpyro"/"blackjax"/"nutpie" + # and routes each to a distinct pm.sample(nuts_sampler=...) call. + # The previous conditional collapsed all four NUTS variants to "pymc", + # which silently downgraded user choice — surfaced via the sbi NLE + # tutorial where sampler="numpyro" still went through PyMC's multiprocess + # path and tripped cloudpickle on the ONNX ModelProto. self._inference_obj = self.model.fit( - inference_method=( - "pymc" - if sampler in ["pymc", "numpyro", "blackjax", "nutpie"] - else sampler - ), + inference_method=sampler, init=init, include_response_params=include_response_params, omit_offsets=omit_offsets, diff --git a/src/hssm/distribution_utils/onnx_utils/onnx2jax.py b/src/hssm/distribution_utils/onnx_utils/onnx2jax.py index 7fea352c3..64e1dfee2 100644 --- a/src/hssm/distribution_utils/onnx_utils/onnx2jax.py +++ b/src/hssm/distribution_utils/onnx_utils/onnx2jax.py @@ -1,40 +1,166 @@ -"""Use jaxonnxruntime to convert ONNX models to JAX functions.""" +"""Use jaxonnxruntime to convert ONNX models to JAX functions. +This module assumes the **single-trial export contract**: ONNX graphs reaching +``make_jax_func`` are exported with a fully concrete input shape representing +*one* per-trial input vector (e.g. ``(n_params + n_data_cols,)``). HSSM batches +across trials at a layer above this, via ``jax.vmap`` in +``make_jax_logp_funcs_from_onnx`` (see ``hssm.distribution_utils.onnx``). + +If a graph arrives with a symbolic / dynamic dimension, ``make_jax_func`` +raises a ``ValueError`` rather than trying to make it work: jaxonnxruntime +traces against the construction-time dummy shape and bakes the resulting +shapes into the returned closure, so calling that closure with a different +shape silently produces wrong outputs for any graph that carries a +batch-dependent intermediate (e.g. a ``torch.zeros(x.shape[0])`` log-det +accumulator, or a ``Reshape`` whose ``-1`` resolves against the dynamic dim). +LANs and the sbi NRE/NLE exporters in ``LANfactory.onnx`` already follow this +contract; this guard prevents accidental violations from a future contributor. + +On precision: pytensor's JAX dispatch +(``pytensor/link/jax/dispatch/basic.py``) sets ``jax_enable_x64`` from +``pytensor.config.floatX`` at import time. With HSSM's default +``floatX="float64"`` x64 is already on by the time this module loads; +under ``hssm.set_floatX("float32")`` x64 is off. The previous version of +this module also tried to flip ``jax_enable_x64`` at first call; that has +been removed (it duplicated pytensor's contract, mutated global state, and +hard-failed if JAX had already warmed up). Instead we pre-cast int64 +tensors / Cast targets to int32 in the graph at load time -- lossless for +the index/shape values torch.onnx.export produces, and removes the silent +truncation that ``jax_enable_x64=False`` would otherwise apply. +""" + +import logging from typing import Callable import jax import numpy as np import onnx -from jaxonnxruntime import call_onnx +from jaxonnxruntime import call_onnx, config + +# torch.onnx.export emits some shape arguments (e.g. for Reshape) as Constant +# nodes rather than as model initializers. jaxonnxruntime's default strict +# mode rejects these as static args during jax.jit. The flag below relaxes +# that check. This is safe for our use cases: those shapes are constant by +# construction (baked at export time). +config.update("jaxort_only_allow_initializers_as_static_args", False) + +_logger = logging.getLogger("hssm") + + +def _recast_int64_to_int32(model: onnx.ModelProto) -> int: + """Rewrite int64 tensors and Cast targets in the graph to int32, in place. + + torch.onnx.export carries int64 metadata (Reshape shape args, Constant + tensors, Cast targets) whose values are indices/shapes that always fit + losslessly in int32. With ``jax_enable_x64=False`` JAX truncates int64 + to int32 implicitly and emits a UserWarning per access. Pre-casting at + load time: + + * is bit-identical for valid index values (twos-complement of small + non-negative integers is preserved when the upper 32 bits are dropped), + * silences the JAX UserWarning, + * removes any dependency on global JAX state. + + Returns + ------- + int + Number of int64 sites rewritten (0 if none). + """ + int64 = onnx.TensorProto.INT64 + int32 = onnx.TensorProto.INT32 + n_rewritten = 0 + + def _convert(tensor: onnx.TensorProto) -> None: + nonlocal n_rewritten + if tensor.data_type == int64: + arr = onnx.numpy_helper.to_array(tensor).astype(np.int32) + new = onnx.numpy_helper.from_array(arr, tensor.name) + tensor.CopyFrom(new) + n_rewritten += 1 + + for initializer in model.graph.initializer: + _convert(initializer) + for node in model.graph.node: + for attr in node.attribute: + if attr.type == onnx.AttributeProto.TENSOR: + _convert(attr.t) + elif attr.type == onnx.AttributeProto.TENSORS: + for t in attr.tensors: + _convert(t) + if node.op_type == "Cast": + for attr in node.attribute: + if attr.name == "to" and attr.i == int64: + attr.i = int32 + n_rewritten += 1 + return n_rewritten + + +def _check_single_trial_input_shape(model: onnx.ModelProto) -> None: + """Raise if any input dimension is symbolic / dynamic. + + HSSM's ONNX-likelihood path is built around single-trial inputs that get + vmapped over trials at a layer above this. jaxonnxruntime, however, + traces the graph against the construction-time dummy shape and bakes + those shapes into the returned closure -- so a graph with dynamic dims + called at a different shape later will produce wrong-but-non-erroring + outputs (the trace re-uses the dummy's broadcast shapes for + batch-dependent intermediates). + """ + bad: list[str] = [] + for inp in model.graph.input: + for i, dim in enumerate(inp.type.tensor_type.shape.dim): + if dim.dim_value <= 0: + label = dim.dim_param or f"axis {i}" + bad.append(f"{inp.name}[{label}]") + if bad: + raise ValueError( + "ONNX model has dynamic (symbolic) input dimensions: " + f"{', '.join(bad)}. HSSM uses single-trial input shapes and " + "vmaps over trials at a layer above this conversion -- " + "re-export the model with a concrete per-trial input shape " + "(omit `dynamic_axes` in `torch.onnx.export`, or pass a single " + "rank-1 dummy as LANfactory.onnx.transform_sbi_to_onnx does). " + "Dynamic dims here would cause jaxonnxruntime to silently " + "produce wrong outputs for graphs with batch-dependent " + "intermediates (e.g. log-det accumulators)." + ) def make_jax_func(onnx_model: onnx.ModelProto) -> Callable: """Convert an ONNX model to a JAX function using jaxonnxruntime. + The model must have a fully concrete input shape -- see the module + docstring for the single-trial-input + vmap contract. + Parameters ---------- onnx_model : onnx.ModelProto - The ONNX model to be converted. + The ONNX model to be converted. Will be mutated in place to recast + int64 tensors to int32 (lossless for index/shape values produced by + torch.onnx.export). Returns ------- Callable - A JAX function that represents the ONNX model. + A JAX function ``f(x)`` that runs the ONNX graph on ``x``. + + Raises + ------ + ValueError + If the ONNX graph has any dynamic / symbolic input dimension. """ - model_graph = onnx_model.graph + _check_single_trial_input_shape(onnx_model) + _recast_int64_to_int32(onnx_model) - # Get the input name and shape from the ONNX model to create a dummy input for - # initialization. + model_graph = onnx_model.graph input_name = model_graph.input[0].name input_dims = tuple( - dim.dim_value if (dim.dim_value > 0) else 1 - for dim in model_graph.input[0].type.tensor_type.shape.dim + dim.dim_value for dim in model_graph.input[0].type.tensor_type.shape.dim ) model_func, model_weights = call_onnx.call_onnx_model( onnx_model, {input_name: np.ones(input_dims)} ) - # Create a JAX function that takes the input and applies the ONNX model. run_func = jax.tree_util.Partial(model_func, model_weights) jax_func = lambda x: run_func({input_name: x})[0].squeeze() diff --git a/tests/distribution_utils/test_onnx.py b/tests/distribution_utils/test_onnx.py index ee49b6ece..d3ba5a20e 100644 --- a/tests/distribution_utils/test_onnx.py +++ b/tests/distribution_utils/test_onnx.py @@ -11,6 +11,10 @@ make_jax_logp_funcs_from_onnx, make_jax_matrix_logp_funcs_from_onnx, ) +from hssm.distribution_utils.onnx_utils.onnx2jax import ( + _recast_int64_to_int32, + make_jax_func, +) hssm.set_floatX("float32") DECIMAL = 4 @@ -156,3 +160,60 @@ def test_make_simple_jax_logp_funcs_from_onnx(fixture_path): result_simple, decimal=DECIMAL, ) + + +# --------------------------------------------------------------------------- +# Guards introduced when removing the auto-x64 flip in onnx2jax.py. +# --------------------------------------------------------------------------- + + +def _make_minimal_onnx(input_shape): + """Construct a minimal Identity-graph ONNX model with the requested shape. + + ``input_shape`` is an iterable of ``int`` (concrete) or ``str``/``None`` + (symbolic / dynamic). + """ + from onnx import helper, TensorProto + + dims = list(input_shape) + in_tensor = helper.make_tensor_value_info("x", TensorProto.FLOAT, dims) + out_tensor = helper.make_tensor_value_info("y", TensorProto.FLOAT, dims) + node = helper.make_node("Identity", inputs=["x"], outputs=["y"]) + graph = helper.make_graph([node], "minimal", [in_tensor], [out_tensor]) + return helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + + +def test_make_jax_func_rejects_dynamic_input_dim(): + """Dynamic input dims would silently produce wrong outputs -- must raise.""" + model = _make_minimal_onnx(["batch", 5]) # symbolic batch dim + with pytest.raises(ValueError, match="dynamic"): + make_jax_func(model) + + +def test_make_jax_func_accepts_concrete_input_dim(): + """Concrete-shape models pass the guard and produce a callable.""" + model = _make_minimal_onnx([1, 5]) + fn = make_jax_func(model) + out = np.asarray(fn(np.ones((1, 5), dtype=np.float32))) + assert out.shape == (5,) or out.shape == (1, 5) # squeeze may drop the 1 + + +def test_recast_int64_to_int32_rewrites_initializers(): + """The int64 -> int32 pre-cast must rewrite int64 initializers in place.""" + from onnx import helper, numpy_helper, TensorProto + + init = numpy_helper.from_array(np.array([1, 2, 3], dtype=np.int64), "shape") + in_tensor = helper.make_tensor_value_info("x", TensorProto.FLOAT, [1, 3]) + out_tensor = helper.make_tensor_value_info("y", TensorProto.FLOAT, [1, 3]) + node = helper.make_node("Reshape", inputs=["x", "shape"], outputs=["y"]) + graph = helper.make_graph( + [node], "minimal", [in_tensor], [out_tensor], initializer=[init] + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + + assert model.graph.initializer[0].data_type == TensorProto.INT64 + n = _recast_int64_to_int32(model) + assert n == 1 + assert model.graph.initializer[0].data_type == TensorProto.INT32 + # Idempotent: a second call is a no-op. + assert _recast_int64_to_int32(model) == 0