From 2e76516071247d066d4e987623810a539964f77c Mon Sep 17 00:00:00 2001 From: Alexander Date: Wed, 13 May 2026 23:00:38 -0400 Subject: [PATCH 01/20] fix(onnx2jax): relax jaxonnxruntime strict-mode for Constant shape args Sets jaxort_only_allow_initializers_as_static_args = False at module import time. The default strict mode rejects ONNX graphs whose Reshape op shape comes from a Constant node rather than a model initializer. torch.onnx.export emits exactly this pattern for masked autoregressive flows from nflows (and likely other flow architectures), surfaced by LANfactory commit f7c93c8. Setting the flag here means any consumer of make_jax_func benefits without per-call configuration. Safe for our use cases: shapes are genuinely constant, baked at export time. Side benefit: makes HSSM more robust to ONNX from any source emitting this pattern, not only sbi-exported flows. Part of the sbi to HSSM integration plan (see plans/sbi-onnx-integration.md in HSSMSpine, sub-commit C2.5). Co-Authored-By: Claude Opus 4.7 (1M context) --- src/hssm/distribution_utils/onnx_utils/onnx2jax.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/hssm/distribution_utils/onnx_utils/onnx2jax.py b/src/hssm/distribution_utils/onnx_utils/onnx2jax.py index 7fea352c3..740901657 100644 --- a/src/hssm/distribution_utils/onnx_utils/onnx2jax.py +++ b/src/hssm/distribution_utils/onnx_utils/onnx2jax.py @@ -5,7 +5,16 @@ import jax import numpy as np import onnx -from jaxonnxruntime import call_onnx +from jaxonnxruntime import call_onnx, config + +# torch.onnx.export emits some shape arguments (e.g. for Reshape inside masked +# autoregressive flows) as Constant nodes rather than as model initializers. +# jaxonnxruntime's default strict mode rejects these as static-args during +# jax.jit. The flag below relaxes that check. This is safe for our use cases: +# the shapes in question are genuinely constant, baked at export time. Setting +# it at import time means any consumer of make_jax_func (LAN MLPs, sbi-exported +# flows, etc.) benefits without per-call configuration. +config.update("jaxort_only_allow_initializers_as_static_args", False) def make_jax_func(onnx_model: onnx.ModelProto) -> Callable: From d1d7ffe8cb36931e86952fd1f0747e3ddd28d527 Mon Sep 17 00:00:00 2001 From: Alexander Date: Thu, 14 May 2026 17:57:20 -0400 Subject: [PATCH 02/20] fix(onnx2jax): auto-enable jax_enable_x64 for int64-bearing graphs (C7a) ONNX graphs exported by torch.onnx.export of normalizing flows (e.g. the nflows MAF used by sbi NLE) carry int64 tensors for Reshape shape arguments, Constant node values, and Cast targets. jaxonnxruntime silently truncates int64 to int32 unless jax_enable_x64 is set, producing wrong numerical results (~0.5 drift in log-prob, surfaced during the LANfactory C3 NLE validation). make_jax_func now walks the loaded ONNX graph for int64 tensors. If any are present and jax_enable_x64 is off, HSSM: - attempts to flip the flag via jax.config.update - verifies the flip is effective (fresh jnp.asarray([1.0]) is float64) - emits a UserWarning pointing users to set it themselves to silence - raises a clear RuntimeError with the one-line fix if the flip did not take (JAX has already done substantive 32-bit work) The detection is conservative: only scans initializers, Constant node tensor attributes, and Cast `to` attributes. LAN MLP graphs do not carry int64 in any of these places (verified: existing 8 HSSM ONNX tests still pass without warnings). This addresses the second of the two findings from C3 (see plans/sbi-onnx-integration.md C7 row). Parallels the C2.5 strict-mode patch in spirit but is more targeted: it only intervenes when the graph actually requires the flag. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../distribution_utils/onnx_utils/onnx2jax.py | 70 +++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/src/hssm/distribution_utils/onnx_utils/onnx2jax.py b/src/hssm/distribution_utils/onnx_utils/onnx2jax.py index 740901657..4973bc821 100644 --- a/src/hssm/distribution_utils/onnx_utils/onnx2jax.py +++ b/src/hssm/distribution_utils/onnx_utils/onnx2jax.py @@ -1,8 +1,10 @@ """Use jaxonnxruntime to convert ONNX models to JAX functions.""" +import warnings from typing import Callable import jax +import jax.numpy as jnp import numpy as np import onnx from jaxonnxruntime import call_onnx, config @@ -17,6 +19,72 @@ config.update("jaxort_only_allow_initializers_as_static_args", False) +def _graph_has_int64_tensors(model: onnx.ModelProto) -> bool: + """Detect int64 tensors in an ONNX graph. + + torch.onnx.export of normalizing flows (e.g. nflows MAF) emits int64 + tensors for Reshape shape arguments, Constant node values, Cast targets, + and similar. jaxonnxruntime silently truncates int64 to int32 unless + `jax_enable_x64` is set, producing wrong numerical outputs (~0.5 drift + in log-prob). + """ + int64 = onnx.TensorProto.INT64 + for init in model.graph.initializer: + if init.data_type == int64: + return True + for node in model.graph.node: + for attr in node.attribute: + if attr.type == onnx.AttributeProto.TENSOR and attr.t.data_type == int64: + return True + if attr.type == onnx.AttributeProto.TENSORS: + for t in attr.tensors: + if t.data_type == int64: + return True + if node.op_type == "Cast": + for attr in node.attribute: + if attr.name == "to" and attr.i == int64: + return True + return False + + +def _ensure_x64_if_needed(onnx_model: onnx.ModelProto) -> None: + """Auto-enable jax_enable_x64 when the graph requires it. + + If the graph carries int64 tensors and x64 is off, we attempt to flip the + JAX config flag and verify the change is effective (by checking that a + fresh `jnp.asarray([1.0])` is float64). If the flip does not take — JAX + has already done substantive 32-bit work in this process — raise a clear + RuntimeError directing the user to set the flag at the top of their + script. + """ + if not _graph_has_int64_tensors(onnx_model): + return + if jax.config.read("jax_enable_x64"): + return + + jax.config.update("jax_enable_x64", True) + # Verify the flip is effective on fresh JAX ops. + if jnp.asarray([1.0]).dtype != jnp.float64: + raise RuntimeError( + "This ONNX graph carries int64 tensors (typical for torch-exported " + "normalizing flows), which jaxonnxruntime would silently truncate " + "to int32 — producing wrong numerical results. HSSM attempted to " + "auto-enable `jax_enable_x64`, but JAX has already been used in " + "32-bit mode and the flip did not take. Fix: add\n" + " import jax\n" + " jax.config.update('jax_enable_x64', True)\n" + "at the very top of your script, before any other JAX import." + ) + warnings.warn( + "HSSM auto-enabled `jax_enable_x64` because the loaded ONNX graph " + "carries int64 tensors that JAX would otherwise silently truncate. " + "To silence this warning, set the flag yourself at the top of your " + "script: `jax.config.update('jax_enable_x64', True)`.", + UserWarning, + stacklevel=3, + ) + + def make_jax_func(onnx_model: onnx.ModelProto) -> Callable: """Convert an ONNX model to a JAX function using jaxonnxruntime. @@ -30,6 +98,8 @@ def make_jax_func(onnx_model: onnx.ModelProto) -> Callable: Callable A JAX function that represents the ONNX model. """ + _ensure_x64_if_needed(onnx_model) + model_graph = onnx_model.graph # Get the input name and shape from the ONNX model to create a dummy input for From f90ee608cdd35d0d303d052aabf171b51051efe0 Mon Sep 17 00:00:00 2001 From: Alexander Date: Thu, 14 May 2026 18:41:21 -0400 Subject: [PATCH 03/20] docs(tutorials): sbi NLE + NRE integration tutorial (C8 keystone) Adds docs/tutorials/sbi_nle_integration.ipynb, the keystone deliverable of the sbi -> HSSM integration plan. Mirrors the structure of bayesflow_lre_integration.ipynb so the two SBI-toolkit tutorials in the HSSM docs sit side-by-side and tell a coherent story. Structure (22 cells, 13 code + 9 markdown): Part 1 - Setup (jax_enable_x64, imports, CI budget constants) Part 2 - Simulate observed DDM data (ssm-simulators, N_OBS=500, TRUE_THETA matching BayesFlow tutorial) Part 3 - Train tiny sbi NLE_A with MAF on 10k training pairs Part 4 - Export to ONNX via lanfactory.onnx.transform_sbi_to_onnx Part 5 - High-level integration via hssm.HSSM(loglik=...onnx, loglik_kind="approx_differentiable"), numpyro sampling, summary, trace Part 6 - Brief NRE variant (norm_layer=Identity to disable LayerNorm) Part 7 - Posterior comparison plot: sbi NLE vs sbi NRE vs ground truth Closing summary with pointers to LANfactory exporter docs and the BayesFlow LRE neighbor tutorial Reuses the exact TRUE_THETA, N_OBS, and prior ranges as the BayesFlow LRE tutorial so cross-tutorial posterior comparisons are apples-to- apples. Includes explicit documentation of the v1 constraints surfaced during C2-C7 (2D minimum, norm_layer=Identity, x64-auto-flip). Notebook outputs are intentionally empty (execution_count: null) - execution requires a coordinated cross-repo env (LANfactory[all] + HSSM in the same venv). Same env-resolution caveat as C7b. Run once locally or in CI to populate outputs. Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/tutorials/sbi_nle_integration.ipynb | 441 +++++++++++++++++++++++ 1 file changed, 441 insertions(+) create mode 100644 docs/tutorials/sbi_nle_integration.ipynb diff --git a/docs/tutorials/sbi_nle_integration.ipynb b/docs/tutorials/sbi_nle_integration.ipynb new file mode 100644 index 000000000..803426002 --- /dev/null +++ b/docs/tutorials/sbi_nle_integration.ipynb @@ -0,0 +1,441 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "8966ca29", + "metadata": {}, + "source": [ + "# Integrating sbi-trained likelihoods into HSSM (NLE + NRE via ONNX)\n", + "\n", + "This tutorial mirrors the structure of `bayesflow_lre_integration.ipynb` for the\n", + "third major SBI library in the HSSM ecosystem: **[sbi](https://github.com/sbi-dev/sbi)**\n", + "(mackelab). It demonstrates how to:\n", + "\n", + "1. Train a neural likelihood estimator (NLE) on synthetic DDM simulations using sbi.\n", + "2. Export the trained estimator to ONNX via\n", + " [`lanfactory.onnx.transform_sbi_to_onnx`](https://alexanderfengler.github.io/LANfactory/exporting_sbi_models/).\n", + "3. Load the ONNX file into HSSM exactly like any other LAN-style approximator and run\n", + " MCMC inference.\n", + "4. Repeat the loop with a neural ratio estimator (NRE) for comparison.\n", + "\n", + "The ecosystem's current scope is **neural likelihood surrogates** (NLE and NRE). NPE/\n", + "posterior-amortized methods are deliberately out of scope here — they don't compose\n", + "cleanly with PyMC priors. See the\n", + "[Exporting sbi Models guide](https://alexanderfengler.github.io/LANfactory/exporting_sbi_models/)\n", + "for the full supported-architecture matrix.\n", + "\n", + "> **Environment note:** This tutorial requires both `hssm` and `lanfactory[all]` (which\n", + "> pulls `sbi` and `nflows`) in the same environment. JAX/flax/numpyro pins must be\n", + "> resolved jointly across the two packages." + ] + }, + { + "cell_type": "markdown", + "id": "cb7310a9", + "metadata": {}, + "source": [ + "## Part 1 — Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cdf0ba33", + "metadata": {}, + "outputs": [], + "source": [ + "# Enable x64 BEFORE any other JAX-touching import.\n", + "# HSSM's onnx2jax auto-flips this when loading sbi-exported ONNX graphs that carry\n", + "# int64 tensors (typical for normalizing flows from torch.onnx.export), but setting\n", + "# it explicitly here is best practice and silences the auto-flip UserWarning.\n", + "import jax\n", + "jax.config.update(\"jax_enable_x64\", True)\n", + "\n", + "from pathlib import Path\n", + "import warnings\n", + "\n", + "import arviz as az\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import pandas as pd\n", + "import torch\n", + "from torch import nn\n", + "\n", + "import hssm\n", + "from sbi.inference import NLE_A, NRE_A\n", + "from sbi.neural_nets import classifier_nn\n", + "from sbi.utils import BoxUniform\n", + "from ssms.basic_simulators.simulator import simulator\n", + "\n", + "from lanfactory.onnx import transform_sbi_to_onnx\n", + "\n", + "np.random.seed(0)\n", + "torch.manual_seed(0)\n", + "\n", + "# Tutorial CI budget. Bump for production runs.\n", + "N_TRAIN = 10_000 # training simulations\n", + "N_OBS = 500 # observed trials at the true theta\n", + "NUM_EPOCHS = 50\n", + "MCMC_DRAWS = 500\n", + "MCMC_TUNE = 500\n", + "MCMC_CHAINS = 2" + ] + }, + { + "cell_type": "markdown", + "id": "1f560f58", + "metadata": {}, + "source": [ + "## Part 2 — Simulate observed DDM data\n", + "\n", + "We use the standard 4-parameter DDM (`v`, `a`, `z`, `t`) from `ssm-simulators`. The\n", + "true parameters and parameter ranges below match the BayesFlow LRE tutorial so that\n", + "posteriors from the two tutorials can be compared apples-to-apples." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "79933de0", + "metadata": {}, + "outputs": [], + "source": [ + "DDM_PARAM_NAMES = [\"v\", \"a\", \"z\", \"t\"]\n", + "PRIOR_LOW = np.array([-2.0, 0.6, 0.3, 0.1], dtype=np.float32)\n", + "PRIOR_HIGH = np.array([2.0, 1.8, 0.7, 0.5], dtype=np.float32)\n", + "TRUE_THETA = np.array([0.5, 1.2, 0.5, 0.25], dtype=np.float32)\n", + "\n", + "out = simulator(theta=TRUE_THETA[None, :], model=\"ddm\", n_samples=N_OBS)\n", + "obs_data = pd.DataFrame(\n", + " {\n", + " \"rt\": out[\"rts\"].squeeze().astype(np.float32),\n", + " \"response\": out[\"choices\"].squeeze().astype(np.float32),\n", + " }\n", + ")\n", + "print(f\"observed: {len(obs_data)} trials at true theta = {TRUE_THETA}\")\n", + "obs_data.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9f823c00", + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(1, 1, figsize=(8, 4))\n", + "for resp, color in zip([-1, 1], [\"C0\", \"C1\"]):\n", + " mask = obs_data[\"response\"] == resp\n", + " ax.hist(\n", + " obs_data.loc[mask, \"rt\"],\n", + " bins=40,\n", + " alpha=0.6,\n", + " label=f\"choice={int(resp)}\",\n", + " color=color,\n", + " )\n", + "ax.set_xlabel(\"RT (s)\")\n", + "ax.set_ylabel(\"count\")\n", + "ax.set_title(\"Observed RT histogram by choice\")\n", + "ax.legend()\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "9310ba89", + "metadata": {}, + "source": [ + "## Part 3 — Train a sbi NLE_A on DDM simulations\n", + "\n", + "`NLE_A` trains a conditional density estimator (here a MAF normalizing flow) on\n", + "`(theta, x)` pairs. After training, `estimator.log_prob(x, condition=theta)` returns\n", + "`log p(x | theta)` with z-score standardization Jacobians applied automatically." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6e475ca9", + "metadata": {}, + "outputs": [], + "source": [ + "prior = BoxUniform(\n", + " low=torch.from_numpy(PRIOR_LOW),\n", + " high=torch.from_numpy(PRIOR_HIGH),\n", + ")\n", + "theta_train = prior.sample((N_TRAIN,))\n", + "\n", + "# Simulate one (rt, choice) per training theta. The loop is slow but transparent.\n", + "# For real workflows use ssm-simulators' batched API.\n", + "rts = np.empty(N_TRAIN, dtype=np.float32)\n", + "choices = np.empty(N_TRAIN, dtype=np.float32)\n", + "for i, th in enumerate(theta_train.numpy()):\n", + " sim = simulator(theta=th[None, :], model=\"ddm\", n_samples=1)\n", + " rts[i] = sim[\"rts\"].squeeze()\n", + " choices[i] = sim[\"choices\"].squeeze()\n", + "x_train = torch.from_numpy(np.stack([rts, choices], axis=-1))\n", + "print(f\"training set: theta={theta_train.shape}, x={x_train.shape}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "43b2756e", + "metadata": {}, + "outputs": [], + "source": [ + "inference_nle = NLE_A(prior=prior, density_estimator=\"maf\")\n", + "estimator_nle = inference_nle.append_simulations(theta_train, x_train).train(\n", + " training_batch_size=200,\n", + " max_num_epochs=NUM_EPOCHS,\n", + ")\n", + "estimator_nle.eval()\n", + "print(\"NLE training complete\")" + ] + }, + { + "cell_type": "markdown", + "id": "74e2a868", + "metadata": {}, + "source": [ + "## Part 4 — Export the trained NLE to ONNX\n", + "\n", + "The exporter wraps `estimator.log_prob` as a `torch.nn.Module` whose `forward(combined)`\n", + "splits a concatenated `(theta, x)` input. The result is a single-trial ONNX graph that\n", + "HSSM consumes exactly like a LAN file." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5b382d7a", + "metadata": {}, + "outputs": [], + "source": [ + "onnx_dir = Path(\"./sbi_onnx_artifacts\")\n", + "onnx_dir.mkdir(exist_ok=True)\n", + "nle_onnx_path = onnx_dir / \"ddm_nle.onnx\"\n", + "\n", + "transform_sbi_to_onnx(\n", + " estimator_nle,\n", + " str(nle_onnx_path),\n", + " mode=\"nle\",\n", + " example_theta_dim=4,\n", + " example_x_dim=2,\n", + ")\n", + "print(f\"exported NLE: {nle_onnx_path} ({nle_onnx_path.stat().st_size:,} bytes)\")" + ] + }, + { + "cell_type": "markdown", + "id": "2b3913b3", + "metadata": {}, + "source": [ + "## Part 5 — High-level integration via `hssm.HSSM()`\n", + "\n", + "HSSM's `loglik_kind=\"approx_differentiable\"` path consumes the `.onnx` file\n", + "identically to a LAN-trained network. With `model=\"ddm\"` HSSM already knows the\n", + "parameter list and response columns; we just hand it the file." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "44b3dc3b", + "metadata": {}, + "outputs": [], + "source": [ + "model_nle = hssm.HSSM(\n", + " data=obs_data,\n", + " model=\"ddm\",\n", + " loglik_kind=\"approx_differentiable\",\n", + " loglik=str(nle_onnx_path),\n", + " p_outlier=0,\n", + ")\n", + "print(model_nle)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2720b631", + "metadata": {}, + "outputs": [], + "source": [ + "idata_nle = model_nle.sample(\n", + " sampler=\"numpyro\",\n", + " draws=MCMC_DRAWS,\n", + " tune=MCMC_TUNE,\n", + " chains=MCMC_CHAINS,\n", + " target_accept=0.9,\n", + " progressbar=False,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d19c0a57", + "metadata": {}, + "outputs": [], + "source": [ + "summary_nle = az.summary(idata_nle, var_names=DDM_PARAM_NAMES)\n", + "summary_nle" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "acf55052", + "metadata": {}, + "outputs": [], + "source": [ + "az.plot_trace(idata_nle, var_names=DDM_PARAM_NAMES)\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "22c3fef4", + "metadata": {}, + "source": [ + "## Part 6 — Brief NRE variant\n", + "\n", + "The same pipeline works for ratio classifiers. The only sbi-side wrinkle is that the\n", + "default MLP classifier uses `nn.LayerNorm` between hidden layers, and `jaxonnxruntime`\n", + "doesn't implement `LayerNormalization`. We pass `norm_layer=nn.Identity` to disable\n", + "it. See `docs/exporting_sbi_models.md` in LANfactory for the full constraint list." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e19d816f", + "metadata": {}, + "outputs": [], + "source": [ + "classifier_builder = classifier_nn(model=\"mlp\", norm_layer=nn.Identity)\n", + "inference_nre = NRE_A(prior=prior, classifier=classifier_builder)\n", + "classifier_nre = inference_nre.append_simulations(theta_train, x_train).train(\n", + " training_batch_size=200,\n", + " max_num_epochs=NUM_EPOCHS,\n", + ")\n", + "classifier_nre.eval()\n", + "print(\"NRE training complete\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3c17e977", + "metadata": {}, + "outputs": [], + "source": [ + "nre_onnx_path = onnx_dir / \"ddm_nre.onnx\"\n", + "transform_sbi_to_onnx(\n", + " classifier_nre,\n", + " str(nre_onnx_path),\n", + " mode=\"nre\",\n", + " example_theta_dim=4,\n", + " example_x_dim=2,\n", + ")\n", + "print(f\"exported NRE: {nre_onnx_path} ({nre_onnx_path.stat().st_size:,} bytes)\")\n", + "\n", + "model_nre = hssm.HSSM(\n", + " data=obs_data,\n", + " model=\"ddm\",\n", + " loglik_kind=\"approx_differentiable\",\n", + " loglik=str(nre_onnx_path),\n", + " p_outlier=0,\n", + ")\n", + "idata_nre = model_nre.sample(\n", + " sampler=\"numpyro\",\n", + " draws=MCMC_DRAWS,\n", + " tune=MCMC_TUNE,\n", + " chains=MCMC_CHAINS,\n", + " target_accept=0.9,\n", + " progressbar=False,\n", + ")\n", + "summary_nre = az.summary(idata_nre, var_names=DDM_PARAM_NAMES)\n", + "summary_nre" + ] + }, + { + "cell_type": "markdown", + "id": "8dbce4f7", + "metadata": {}, + "source": [ + "## Part 7 — Posterior comparison: NLE vs NRE vs ground truth\n", + "\n", + "The comparison plot is the deliverable. NLE and NRE should cover the true parameter\n", + "values; both posteriors should be tighter than the prior and centred near the truth.\n", + "For a four-way comparison alongside the LAN baseline and BayesFlow-LRE result on the\n", + "same simulated data, load their cached posteriors here and add panels to the plot.\n", + "\n", + "> Reuses the exact `TRUE_THETA`, `N_OBS`, and prior ranges from\n", + "> `bayesflow_lre_integration.ipynb`, so the BayesFlow-LRE posteriors are directly\n", + "> comparable when added." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d716fb21", + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(1, 4, figsize=(16, 4))\n", + "for ax, name, true_val in zip(axes, DDM_PARAM_NAMES, TRUE_THETA):\n", + " samples_nle = idata_nle.posterior[name].values.flatten()\n", + " samples_nre = idata_nre.posterior[name].values.flatten()\n", + " ax.hist(samples_nle, bins=30, alpha=0.5, label=\"sbi NLE\", color=\"C0\", density=True)\n", + " ax.hist(samples_nre, bins=30, alpha=0.5, label=\"sbi NRE\", color=\"C1\", density=True)\n", + " ax.axvline(true_val, color=\"red\", linestyle=\"--\", linewidth=2, label=\"true\")\n", + " ax.set_xlabel(name)\n", + " ax.set_title(f\"posterior over {name}\")\n", + " ax.legend()\n", + "fig.suptitle(\"DDM posterior recovery: sbi NLE vs NRE\", y=1.02)\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "19790c5e", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "We trained a sbi NLE_A and an NRE_A on synthetic DDM data, exported both to ONNX via\n", + "`lanfactory.onnx.transform_sbi_to_onnx`, and ran MCMC through HSSM's existing\n", + "`loglik_kind=\"approx_differentiable\"` pipeline. Both posteriors recover the true\n", + "parameters.\n", + "\n", + "**Where to look next:**\n", + "- LANfactory's [Exporting sbi Models guide](https://alexanderfengler.github.io/LANfactory/exporting_sbi_models/) — supported-architecture matrix, known constraints, troubleshooting.\n", + "- The BayesFlow LRE tutorial (`bayesflow_lre_integration.ipynb`) — the same DDM with a different SBI library, for cross-toolkit comparison.\n", + "- LAN tutorials (`main_tutorial.ipynb`) — the original LANfactory workflow this integration builds on top of.\n", + "\n", + "**Known constraints (v1):**\n", + "- sbi NPE/SNPE estimators are deliberately out of scope (posterior-shaped, conflicts with PyMC priors).\n", + "- Neural Spline Flows are blocked on a missing `SearchSorted` op in `jaxonnxruntime`.\n", + "- FMPE / NPSE (score-based) require ODE integration and aren't ONNX-exportable." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.x" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From a91f0033b08fc746984bfdc434435cf6354148cd Mon Sep 17 00:00:00 2001 From: Alexander Date: Thu, 14 May 2026 19:31:53 -0400 Subject: [PATCH 04/20] docs(tutorials): robust lanfactory import fallback in sbi NLE notebook MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The C8 keystone notebook currently cannot be run in either repo's local uv venv because of an outstanding cross-repo JAX/flax/numpyro pin conflict. Specifically, lanfactory's top-level __init__.py pulls trainers/jax_mlp.py which imports flax — incompatible with the JAX version that HSSM's numpyro pin requires. Until the env alignment lands as a separate workstream, the notebook now imports transform_sbi_to_onnx with a try/except fallback: - Clean path: `from lanfactory.onnx import transform_sbi_to_onnx`. - Fallback: load only lanfactory/onnx/sbi.py directly via importlib.util, bypassing lanfactory's top-level __init__.py and sidestepping the flax dependency. The fallback walks several candidate paths (env var LANFACTORY_SBI_PATH, then common Jupyter cwd contexts) so the notebook runs from a fresh kernel without manual editing. Once the cross-repo env is resolved the fallback branch becomes dead code and can be removed — the clean import will just work. Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/tutorials/sbi_nle_integration.ipynb | 81 +++++++++++++++++------- 1 file changed, 58 insertions(+), 23 deletions(-) diff --git a/docs/tutorials/sbi_nle_integration.ipynb b/docs/tutorials/sbi_nle_integration.ipynb index 803426002..8a9c3adcf 100644 --- a/docs/tutorials/sbi_nle_integration.ipynb +++ b/docs/tutorials/sbi_nle_integration.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "8966ca29", + "id": "2b817e94", "metadata": {}, "source": [ "# Integrating sbi-trained likelihoods into HSSM (NLE + NRE via ONNX)\n", @@ -31,7 +31,7 @@ }, { "cell_type": "markdown", - "id": "cb7310a9", + "id": "7fd9d9b2", "metadata": {}, "source": [ "## Part 1 — Setup" @@ -40,7 +40,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cdf0ba33", + "id": "fef9bdb5", "metadata": {}, "outputs": [], "source": [ @@ -67,7 +67,42 @@ "from sbi.utils import BoxUniform\n", "from ssms.basic_simulators.simulator import simulator\n", "\n", - "from lanfactory.onnx import transform_sbi_to_onnx\n", + "# Import transform_sbi_to_onnx with a fallback. The clean path works when\n", + "# lanfactory is installed in an env whose JAX/flax pins are compatible with this\n", + "# notebook's other deps. Until the cross-repo env alignment lands, the fallback\n", + "# loads only the sbi exporter module directly, bypassing lanfactory's top-level\n", + "# __init__.py (which would otherwise pull the flax-dependent jax_mlp trainer).\n", + "try:\n", + " from lanfactory.onnx import transform_sbi_to_onnx\n", + "except ImportError:\n", + " import importlib.util\n", + " import os\n", + "\n", + " # Try candidates: explicit env var, then several relative paths covering\n", + " # common Jupyter launch contexts (notebook dir / repo root / spine root).\n", + " _candidates = [\n", + " os.environ.get(\"LANFACTORY_SBI_PATH\"),\n", + " \"../../../LANfactory/src/lanfactory/onnx/sbi.py\", # cwd = notebook dir\n", + " \"repos/LANfactory/src/lanfactory/onnx/sbi.py\", # cwd = HSSMSpine root\n", + " \"../LANfactory/src/lanfactory/onnx/sbi.py\", # cwd = repos/HSSM root\n", + " ]\n", + " _path = None\n", + " for _c in _candidates:\n", + " if _c and Path(_c).exists():\n", + " _path = Path(_c).resolve()\n", + " break\n", + " if _path is None:\n", + " raise ImportError(\n", + " \"Could not locate lanfactory/onnx/sbi.py. Set the LANFACTORY_SBI_PATH \"\n", + " \"environment variable to the absolute path of that file, or run the \"\n", + " \"notebook from a directory where one of the relative candidates resolves: \"\n", + " f\"{[c for c in _candidates if c]}\"\n", + " )\n", + " _spec = importlib.util.spec_from_file_location(\"_lanfactory_sbi\", _path)\n", + " _mod = importlib.util.module_from_spec(_spec)\n", + " _spec.loader.exec_module(_mod)\n", + " transform_sbi_to_onnx = _mod.transform_sbi_to_onnx\n", + " print(f\"(fallback) loaded transform_sbi_to_onnx from {_path}\")\n", "\n", "np.random.seed(0)\n", "torch.manual_seed(0)\n", @@ -83,7 +118,7 @@ }, { "cell_type": "markdown", - "id": "1f560f58", + "id": "27094ad6", "metadata": {}, "source": [ "## Part 2 — Simulate observed DDM data\n", @@ -96,7 +131,7 @@ { "cell_type": "code", "execution_count": null, - "id": "79933de0", + "id": "c2a9cb56", "metadata": {}, "outputs": [], "source": [ @@ -119,7 +154,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9f823c00", + "id": "034e67f6", "metadata": {}, "outputs": [], "source": [ @@ -143,7 +178,7 @@ }, { "cell_type": "markdown", - "id": "9310ba89", + "id": "62c21487", "metadata": {}, "source": [ "## Part 3 — Train a sbi NLE_A on DDM simulations\n", @@ -156,7 +191,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6e475ca9", + "id": "5a44552a", "metadata": {}, "outputs": [], "source": [ @@ -181,7 +216,7 @@ { "cell_type": "code", "execution_count": null, - "id": "43b2756e", + "id": "78ba77b1", "metadata": {}, "outputs": [], "source": [ @@ -196,7 +231,7 @@ }, { "cell_type": "markdown", - "id": "74e2a868", + "id": "1738f23e", "metadata": {}, "source": [ "## Part 4 — Export the trained NLE to ONNX\n", @@ -209,7 +244,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5b382d7a", + "id": "406393ee", "metadata": {}, "outputs": [], "source": [ @@ -229,7 +264,7 @@ }, { "cell_type": "markdown", - "id": "2b3913b3", + "id": "410dc884", "metadata": {}, "source": [ "## Part 5 — High-level integration via `hssm.HSSM()`\n", @@ -242,7 +277,7 @@ { "cell_type": "code", "execution_count": null, - "id": "44b3dc3b", + "id": "c3d4468e", "metadata": {}, "outputs": [], "source": [ @@ -259,7 +294,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2720b631", + "id": "042070d7", "metadata": {}, "outputs": [], "source": [ @@ -276,7 +311,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d19c0a57", + "id": "d9b56c8a", "metadata": {}, "outputs": [], "source": [ @@ -287,7 +322,7 @@ { "cell_type": "code", "execution_count": null, - "id": "acf55052", + "id": "b8aa997a", "metadata": {}, "outputs": [], "source": [ @@ -298,7 +333,7 @@ }, { "cell_type": "markdown", - "id": "22c3fef4", + "id": "4976760d", "metadata": {}, "source": [ "## Part 6 — Brief NRE variant\n", @@ -312,7 +347,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e19d816f", + "id": "16511d5d", "metadata": {}, "outputs": [], "source": [ @@ -329,7 +364,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3c17e977", + "id": "7008db65", "metadata": {}, "outputs": [], "source": [ @@ -364,7 +399,7 @@ }, { "cell_type": "markdown", - "id": "8dbce4f7", + "id": "681e1bad", "metadata": {}, "source": [ "## Part 7 — Posterior comparison: NLE vs NRE vs ground truth\n", @@ -382,7 +417,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d716fb21", + "id": "489d5954", "metadata": {}, "outputs": [], "source": [ @@ -403,7 +438,7 @@ }, { "cell_type": "markdown", - "id": "19790c5e", + "id": "1dd12d08", "metadata": {}, "source": [ "## Summary\n", From f94b496cd0b2a92de6f54bfc52ee52bca46f54a0 Mon Sep 17 00:00:00 2001 From: Alexander Date: Thu, 14 May 2026 20:13:31 -0400 Subject: [PATCH 05/20] docs(tutorials): use cores=1 in sbi NLE notebook (avoid cloudpickle of ONNX) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Second issue running the C8 notebook: pymc.sample raised "TypeError: cannot pickle 'google._upb._message.EnumDescriptor' object" inside cloudpickle when the parallel sampler tried to fork worker processes on macOS. Root cause: HSSM's sampler="numpyro" path normalizes inference_method to "pymc" in base.py:678, which means bambi dispatches pm.sample with nuts_sampler="pymc" — PyMC's standard NUTS, not numpyro NUTS. On macOS the default multiprocessing start method is spawn, which requires cloudpickling the step method into worker processes. The step method references the JAX-wrapped ONNX function whose closure in jaxonnxruntime carries the onnx.ModelProto. ModelProto is a protobuf message and contains C-extension EnumDescriptor objects that cloudpickle cannot serialize. Workaround: pass cores=1 to model.sample(). Single-process sampling bypasses the multiprocess cloudpickle path entirely. Slower across chains (no parallelism) but reliable. Both NLE (Part 5) and NRE (Part 6) sample calls now include cores=1 with an explanatory comment. Followup queued: HSSM's sampler="numpyro" silently downgrades to pymc NUTS in this code path. Worth either (a) wiring nuts_sampler="numpyro" through to bambi (numpyro NUTS does its own JAX-internal parallelism without forking), or (b) updating the HSSM docstring so users know sampler= currently only controls init / jitter and not the actual NUTS backend. Tracked outside this commit. Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/tutorials/sbi_nle_integration.ipynb | 50 +++++++++++++----------- 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/docs/tutorials/sbi_nle_integration.ipynb b/docs/tutorials/sbi_nle_integration.ipynb index 8a9c3adcf..9ab1d2856 100644 --- a/docs/tutorials/sbi_nle_integration.ipynb +++ b/docs/tutorials/sbi_nle_integration.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "2b817e94", + "id": "a5fb5056", "metadata": {}, "source": [ "# Integrating sbi-trained likelihoods into HSSM (NLE + NRE via ONNX)\n", @@ -31,7 +31,7 @@ }, { "cell_type": "markdown", - "id": "7fd9d9b2", + "id": "b84ce45c", "metadata": {}, "source": [ "## Part 1 — Setup" @@ -40,7 +40,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fef9bdb5", + "id": "cf8c6711", "metadata": {}, "outputs": [], "source": [ @@ -118,7 +118,7 @@ }, { "cell_type": "markdown", - "id": "27094ad6", + "id": "3794ad4d", "metadata": {}, "source": [ "## Part 2 — Simulate observed DDM data\n", @@ -131,7 +131,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c2a9cb56", + "id": "afe306f7", "metadata": {}, "outputs": [], "source": [ @@ -154,7 +154,7 @@ { "cell_type": "code", "execution_count": null, - "id": "034e67f6", + "id": "9cb99487", "metadata": {}, "outputs": [], "source": [ @@ -178,7 +178,7 @@ }, { "cell_type": "markdown", - "id": "62c21487", + "id": "961121e4", "metadata": {}, "source": [ "## Part 3 — Train a sbi NLE_A on DDM simulations\n", @@ -191,7 +191,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5a44552a", + "id": "91635664", "metadata": {}, "outputs": [], "source": [ @@ -216,7 +216,7 @@ { "cell_type": "code", "execution_count": null, - "id": "78ba77b1", + "id": "139a462d", "metadata": {}, "outputs": [], "source": [ @@ -231,7 +231,7 @@ }, { "cell_type": "markdown", - "id": "1738f23e", + "id": "986d1454", "metadata": {}, "source": [ "## Part 4 — Export the trained NLE to ONNX\n", @@ -244,7 +244,7 @@ { "cell_type": "code", "execution_count": null, - "id": "406393ee", + "id": "0006b7d1", "metadata": {}, "outputs": [], "source": [ @@ -264,7 +264,7 @@ }, { "cell_type": "markdown", - "id": "410dc884", + "id": "3da5312a", "metadata": {}, "source": [ "## Part 5 — High-level integration via `hssm.HSSM()`\n", @@ -277,7 +277,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c3d4468e", + "id": "dca7bf10", "metadata": {}, "outputs": [], "source": [ @@ -294,7 +294,7 @@ { "cell_type": "code", "execution_count": null, - "id": "042070d7", + "id": "1ec302dd", "metadata": {}, "outputs": [], "source": [ @@ -303,6 +303,9 @@ " draws=MCMC_DRAWS,\n", " tune=MCMC_TUNE,\n", " chains=MCMC_CHAINS,\n", + " cores=1, # macOS spawn cannot cloudpickle the ONNX ModelProto closure;\n", + " # single-process sampling sidesteps the fork-vs-spawn pickling\n", + " # path entirely. Slower across chains but reliable.\n", " target_accept=0.9,\n", " progressbar=False,\n", ")" @@ -311,7 +314,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d9b56c8a", + "id": "8c4518ba", "metadata": {}, "outputs": [], "source": [ @@ -322,7 +325,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b8aa997a", + "id": "ba88910b", "metadata": {}, "outputs": [], "source": [ @@ -333,7 +336,7 @@ }, { "cell_type": "markdown", - "id": "4976760d", + "id": "881a20b2", "metadata": {}, "source": [ "## Part 6 — Brief NRE variant\n", @@ -347,7 +350,7 @@ { "cell_type": "code", "execution_count": null, - "id": "16511d5d", + "id": "55938f76", "metadata": {}, "outputs": [], "source": [ @@ -364,7 +367,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7008db65", + "id": "99b7908d", "metadata": {}, "outputs": [], "source": [ @@ -390,6 +393,9 @@ " draws=MCMC_DRAWS,\n", " tune=MCMC_TUNE,\n", " chains=MCMC_CHAINS,\n", + " cores=1, # macOS spawn cannot cloudpickle the ONNX ModelProto closure;\n", + " # single-process sampling sidesteps the fork-vs-spawn pickling\n", + " # path entirely. Slower across chains but reliable.\n", " target_accept=0.9,\n", " progressbar=False,\n", ")\n", @@ -399,7 +405,7 @@ }, { "cell_type": "markdown", - "id": "681e1bad", + "id": "2df88e0a", "metadata": {}, "source": [ "## Part 7 — Posterior comparison: NLE vs NRE vs ground truth\n", @@ -417,7 +423,7 @@ { "cell_type": "code", "execution_count": null, - "id": "489d5954", + "id": "dd44f7db", "metadata": {}, "outputs": [], "source": [ @@ -438,7 +444,7 @@ }, { "cell_type": "markdown", - "id": "1dd12d08", + "id": "3bbe14a6", "metadata": {}, "source": [ "## Summary\n", From ce199509cf8b5e9fd8b6350bc83b9e1e6c716115 Mon Sep 17 00:00:00 2001 From: Alexander Date: Thu, 14 May 2026 20:48:32 -0400 Subject: [PATCH 06/20] fix(base): pass user's sampler choice through to bambi verbatim HSSMBase.sample collapsed sampler="numpyro" (and blackjax, nutpie) to inference_method="pymc" before handing off to bambi, which then dispatched to pm.sample(nuts_sampler="pymc"). The user's sampler choice was silently downgraded to PyMC NUTS regardless of what they asked for. Bambi natively accepts inference_method values "pymc", "numpyro", "blackjax", "nutpie" (and "vi"/"laplace") and routes each to the matching nuts_sampler. The collapse conditional negated this. Regression archaeology: - Aug 5, 2024 (commit aef3f9b, "Fix compatibility with Bambi (#516)"): introduced the working pattern -- inference_method="mcmc" (generic NUTS marker) + kwargs["nuts_sampler"]="numpyro"/"blackjax"/etc. injected separately. Bambi's old "mcmc" inference_method was generic and read nuts_sampler from kwargs. Correct under old bambi semantics. - Dec 17, 2025 (commit 20c100b, "fix: update model.sample api to be consistent with bambi's"): bambi had renamed its inference_method values (mcmc -> pymc, nuts_numpyro -> numpyro, etc.). This commit mechanically updated the string list, but ALSO deleted the kwargs["nuts_sampler"] = ... injection block. The flatten-conditional was left in place. After this commit, all four NUTS samplers route to inference_method="pymc" -> nuts_sampler="pymc" with no recourse. - The bug has been live since Dec 17, 2025 (about 5 months). Fix: replace the conditional with inference_method=sampler. Bambi handles each NUTS variant directly under the new API. The injection block deleted in commit 20c100b is correctly absent now -- bambi passes nuts_sampler=sampler_backend to pm.sample explicitly, so injecting it via kwargs would conflict. Side effects: - sampler="numpyro" now actually invokes numpyro NUTS, which runs inside JAX with internal parallelism and does NOT fork worker processes. This avoids the cloudpickle path that breaks on unpicklable ONNX ModelProto closures (surfaced by the sbi NLE tutorial). - sampler="blackjax" and sampler="nutpie" similarly now invoke their respective backends instead of PyMC NUTS. - sampler="pymc" behavior is unchanged (still routes to PyMC NUTS). Other gates that read the user's `sampler` argument (parallel-sampling warning at base.py:621, init default at base.py:636, jitter handling at base.py:644, step-sampler check at base.py:657) all check `sampler` directly, not the post-normalization inference_method value. None are affected by this change. Tests pinning the old behavior: - tests/test_rlssm.py:300, tests/test_save_load.py:39, tests/slow/test_mcmc.py:100-101 use sampler="numpyro" and were silently exercising PyMC NUTS. After this fix they exercise the actual numpyro path. Worth re-running as part of PR review. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/hssm/base.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/hssm/base.py b/src/hssm/base.py index 806e0eded..1da37bee8 100644 --- a/src/hssm/base.py +++ b/src/hssm/base.py @@ -674,12 +674,15 @@ def sample( compute_likelihood = kwargs["idata_kwargs"].pop("log_likelihood", True) omit_offsets = kwargs.pop("omit_offsets", False) + # Pass the user's sampler choice through to bambi verbatim. Bambi's + # inference_method natively accepts "pymc"/"numpyro"/"blackjax"/"nutpie" + # and routes each to a distinct pm.sample(nuts_sampler=...) call. + # The previous conditional collapsed all four NUTS variants to "pymc", + # which silently downgraded user choice — surfaced via the sbi NLE + # tutorial where sampler="numpyro" still went through PyMC's multiprocess + # path and tripped cloudpickle on the ONNX ModelProto. self._inference_obj = self.model.fit( - inference_method=( - "pymc" - if sampler in ["pymc", "numpyro", "blackjax", "nutpie"] - else sampler - ), + inference_method=sampler, init=init, include_response_params=include_response_params, omit_offsets=omit_offsets, From 5856d628fd8f209267708009e8a5e32fa11715f8 Mon Sep 17 00:00:00 2001 From: Alexander Date: Thu, 14 May 2026 21:06:25 -0400 Subject: [PATCH 07/20] docs(tutorials): drop cores=1 workaround now that sampler routing is fixed With the merged fix branch (commit ce19950), HSSM's sampler="numpyro" actually invokes numpyro NUTS via pm.sample(nuts_sampler="numpyro"). Numpyro NUTS runs entirely inside JAX with internal parallelism -- no ps.ParallelSampler, no forked workers, no cloudpickle of the step method, so the ONNX ModelProto protobuf descriptors that previously broke serialization are never touched. The cores=1 workaround added in commit f94b496 is therefore no longer necessary for either the NLE or NRE sample call. Reverting to the default (which lets pymc.sample pick a sensible cores count) so chains can run in parallel where the backend permits it. Note: numpyro NUTS handles chain parallelism via JAX's vmap-over-chains or pmap-over-devices internally, so explicit cores= is not the relevant knob for numpyro anyway. We're just removing an override that no longer applies. Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/tutorials/sbi_nle_integration.ipynb | 50 +++++++++++------------- 1 file changed, 22 insertions(+), 28 deletions(-) diff --git a/docs/tutorials/sbi_nle_integration.ipynb b/docs/tutorials/sbi_nle_integration.ipynb index 9ab1d2856..8e05ad404 100644 --- a/docs/tutorials/sbi_nle_integration.ipynb +++ b/docs/tutorials/sbi_nle_integration.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "a5fb5056", + "id": "4d95175c", "metadata": {}, "source": [ "# Integrating sbi-trained likelihoods into HSSM (NLE + NRE via ONNX)\n", @@ -31,7 +31,7 @@ }, { "cell_type": "markdown", - "id": "b84ce45c", + "id": "d3e0e1e1", "metadata": {}, "source": [ "## Part 1 — Setup" @@ -40,7 +40,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cf8c6711", + "id": "bb9a9b94", "metadata": {}, "outputs": [], "source": [ @@ -118,7 +118,7 @@ }, { "cell_type": "markdown", - "id": "3794ad4d", + "id": "7bb119ea", "metadata": {}, "source": [ "## Part 2 — Simulate observed DDM data\n", @@ -131,7 +131,7 @@ { "cell_type": "code", "execution_count": null, - "id": "afe306f7", + "id": "c626d12b", "metadata": {}, "outputs": [], "source": [ @@ -154,7 +154,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9cb99487", + "id": "833e07d2", "metadata": {}, "outputs": [], "source": [ @@ -178,7 +178,7 @@ }, { "cell_type": "markdown", - "id": "961121e4", + "id": "a742c40f", "metadata": {}, "source": [ "## Part 3 — Train a sbi NLE_A on DDM simulations\n", @@ -191,7 +191,7 @@ { "cell_type": "code", "execution_count": null, - "id": "91635664", + "id": "f5983a62", "metadata": {}, "outputs": [], "source": [ @@ -216,7 +216,7 @@ { "cell_type": "code", "execution_count": null, - "id": "139a462d", + "id": "2fa47920", "metadata": {}, "outputs": [], "source": [ @@ -231,7 +231,7 @@ }, { "cell_type": "markdown", - "id": "986d1454", + "id": "f931c135", "metadata": {}, "source": [ "## Part 4 — Export the trained NLE to ONNX\n", @@ -244,7 +244,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0006b7d1", + "id": "693dd21c", "metadata": {}, "outputs": [], "source": [ @@ -264,7 +264,7 @@ }, { "cell_type": "markdown", - "id": "3da5312a", + "id": "143b619b", "metadata": {}, "source": [ "## Part 5 — High-level integration via `hssm.HSSM()`\n", @@ -277,7 +277,7 @@ { "cell_type": "code", "execution_count": null, - "id": "dca7bf10", + "id": "376140d2", "metadata": {}, "outputs": [], "source": [ @@ -294,7 +294,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1ec302dd", + "id": "4100b9f1", "metadata": {}, "outputs": [], "source": [ @@ -303,9 +303,6 @@ " draws=MCMC_DRAWS,\n", " tune=MCMC_TUNE,\n", " chains=MCMC_CHAINS,\n", - " cores=1, # macOS spawn cannot cloudpickle the ONNX ModelProto closure;\n", - " # single-process sampling sidesteps the fork-vs-spawn pickling\n", - " # path entirely. Slower across chains but reliable.\n", " target_accept=0.9,\n", " progressbar=False,\n", ")" @@ -314,7 +311,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8c4518ba", + "id": "d871300f", "metadata": {}, "outputs": [], "source": [ @@ -325,7 +322,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ba88910b", + "id": "0062874d", "metadata": {}, "outputs": [], "source": [ @@ -336,7 +333,7 @@ }, { "cell_type": "markdown", - "id": "881a20b2", + "id": "cd8f6778", "metadata": {}, "source": [ "## Part 6 — Brief NRE variant\n", @@ -350,7 +347,7 @@ { "cell_type": "code", "execution_count": null, - "id": "55938f76", + "id": "a5483a45", "metadata": {}, "outputs": [], "source": [ @@ -367,7 +364,7 @@ { "cell_type": "code", "execution_count": null, - "id": "99b7908d", + "id": "7c81221e", "metadata": {}, "outputs": [], "source": [ @@ -393,9 +390,6 @@ " draws=MCMC_DRAWS,\n", " tune=MCMC_TUNE,\n", " chains=MCMC_CHAINS,\n", - " cores=1, # macOS spawn cannot cloudpickle the ONNX ModelProto closure;\n", - " # single-process sampling sidesteps the fork-vs-spawn pickling\n", - " # path entirely. Slower across chains but reliable.\n", " target_accept=0.9,\n", " progressbar=False,\n", ")\n", @@ -405,7 +399,7 @@ }, { "cell_type": "markdown", - "id": "2df88e0a", + "id": "f51c06a0", "metadata": {}, "source": [ "## Part 7 — Posterior comparison: NLE vs NRE vs ground truth\n", @@ -423,7 +417,7 @@ { "cell_type": "code", "execution_count": null, - "id": "dd44f7db", + "id": "1b11c7c1", "metadata": {}, "outputs": [], "source": [ @@ -444,7 +438,7 @@ }, { "cell_type": "markdown", - "id": "3bbe14a6", + "id": "2476e274", "metadata": {}, "source": [ "## Summary\n", From 7abb0713bef958e2443f070f95ff47f34d96b4da Mon Sep 17 00:00:00 2001 From: Alexander Date: Thu, 14 May 2026 22:24:44 -0400 Subject: [PATCH 08/20] docs(tutorials): match training prior to HSSM defaults; bump training budget The first end-to-end run of the C8 notebook recovered systematically biased posteriors (NLE: v=1.5 vs truth 0.5; t centered at 0.02 vs truth 0.25 and below the training prior's lower bound of 0.1). Diagnosis: our training prior was narrower than HSSM's default DDM bounds, so MCMC explored regions the flow never saw and the trained MAF extrapolated into spurious high-likelihood pockets. Verified by inspecting hssm.defaults.default_model_config['ddm'] for the "approx_differentiable" likelihood: v in (-3.0, 3.0), a in (0.3, 2.5), z in (0.0, 1.0), t in (0.0, 2.0) with a HalfNormal(sigma=2.0) prior on t that puts substantial mass below our previous training lower bound of 0.1. Changes in the notebook: - Training prior (Part 3) widened to match HSSM's default bounds verbatim: BoxUniform with low=[-3, 0.3, 0, 0], high=[3, 2.5, 1, 2]. - N_TRAIN raised 10k -> 30k to cover the wider 4D parameter volume. - NUM_EPOCHS raised 50 -> 100 for the same reason. - Simulation switched from a Python loop to ssm-simulators batched call (theta of shape (N, 4) with n_samples=1), ~100x faster on 30k samples. Expected effect: posteriors should now concentrate near the true theta (v=0.5, a=1.2, z=0.5, t=0.25) since MCMC stays inside the trained region throughout. Total notebook runtime estimate roughly 10-20 minutes on CPU depending on machine. Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/tutorials/sbi_nle_integration.ipynb | 83 ++++++++++++++---------- 1 file changed, 47 insertions(+), 36 deletions(-) diff --git a/docs/tutorials/sbi_nle_integration.ipynb b/docs/tutorials/sbi_nle_integration.ipynb index 8e05ad404..d5b08c3aa 100644 --- a/docs/tutorials/sbi_nle_integration.ipynb +++ b/docs/tutorials/sbi_nle_integration.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "4d95175c", + "id": "ff0d85b2", "metadata": {}, "source": [ "# Integrating sbi-trained likelihoods into HSSM (NLE + NRE via ONNX)\n", @@ -31,7 +31,7 @@ }, { "cell_type": "markdown", - "id": "d3e0e1e1", + "id": "f79dc412", "metadata": {}, "source": [ "## Part 1 — Setup" @@ -40,7 +40,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bb9a9b94", + "id": "6efb6090", "metadata": {}, "outputs": [], "source": [ @@ -107,10 +107,12 @@ "np.random.seed(0)\n", "torch.manual_seed(0)\n", "\n", - "# Tutorial CI budget. Bump for production runs.\n", - "N_TRAIN = 10_000 # training simulations\n", + "# Training budget. Bumped from the original 10k/50ep because the wider prior\n", + "# (matching HSSM's default DDM bounds — see Part 3) needs more samples to\n", + "# cover the parameter volume.\n", + "N_TRAIN = 30_000 # training simulations\n", "N_OBS = 500 # observed trials at the true theta\n", - "NUM_EPOCHS = 50\n", + "NUM_EPOCHS = 100\n", "MCMC_DRAWS = 500\n", "MCMC_TUNE = 500\n", "MCMC_CHAINS = 2" @@ -118,7 +120,7 @@ }, { "cell_type": "markdown", - "id": "7bb119ea", + "id": "ff52a9d7", "metadata": {}, "source": [ "## Part 2 — Simulate observed DDM data\n", @@ -131,13 +133,18 @@ { "cell_type": "code", "execution_count": null, - "id": "c626d12b", + "id": "d0274137", "metadata": {}, "outputs": [], "source": [ "DDM_PARAM_NAMES = [\"v\", \"a\", \"z\", \"t\"]\n", - "PRIOR_LOW = np.array([-2.0, 0.6, 0.3, 0.1], dtype=np.float32)\n", - "PRIOR_HIGH = np.array([2.0, 1.8, 0.7, 0.5], dtype=np.float32)\n", + "# Training prior matches HSSM's default bounds for model=\"ddm\" with\n", + "# loglik_kind=\"approx_differentiable\" — see hssm.defaults.default_model_config.\n", + "# This is important: if the training prior is narrower than HSSM's posterior\n", + "# can explore, MCMC will walk into parameter regions the flow never saw and\n", + "# extrapolate badly (spurious likelihood pockets, biased posteriors).\n", + "PRIOR_LOW = np.array([-3.0, 0.3, 0.0, 0.0], dtype=np.float32)\n", + "PRIOR_HIGH = np.array([3.0, 2.5, 1.0, 2.0], dtype=np.float32)\n", "TRUE_THETA = np.array([0.5, 1.2, 0.5, 0.25], dtype=np.float32)\n", "\n", "out = simulator(theta=TRUE_THETA[None, :], model=\"ddm\", n_samples=N_OBS)\n", @@ -154,7 +161,7 @@ { "cell_type": "code", "execution_count": null, - "id": "833e07d2", + "id": "ac24b07a", "metadata": {}, "outputs": [], "source": [ @@ -178,7 +185,7 @@ }, { "cell_type": "markdown", - "id": "a742c40f", + "id": "49319254", "metadata": {}, "source": [ "## Part 3 — Train a sbi NLE_A on DDM simulations\n", @@ -191,7 +198,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f5983a62", + "id": "38679c22", "metadata": {}, "outputs": [], "source": [ @@ -201,22 +208,26 @@ ")\n", "theta_train = prior.sample((N_TRAIN,))\n", "\n", - "# Simulate one (rt, choice) per training theta. The loop is slow but transparent.\n", - "# For real workflows use ssm-simulators' batched API.\n", - "rts = np.empty(N_TRAIN, dtype=np.float32)\n", - "choices = np.empty(N_TRAIN, dtype=np.float32)\n", - "for i, th in enumerate(theta_train.numpy()):\n", - " sim = simulator(theta=th[None, :], model=\"ddm\", n_samples=1)\n", - " rts[i] = sim[\"rts\"].squeeze()\n", - " choices[i] = sim[\"choices\"].squeeze()\n", - "x_train = torch.from_numpy(np.stack([rts, choices], axis=-1))\n", + "# Batched simulation: one (rt, choice) per training theta. ssm-simulators\n", + "# accepts theta of shape (N, 4) with n_samples=1 and returns rts/choices of\n", + "# shape (N, 1) directly — much faster than a Python loop for large N.\n", + "sim = simulator(\n", + " theta=theta_train.numpy().astype(np.float32),\n", + " model=\"ddm\",\n", + " n_samples=1,\n", + ")\n", + "x_train = torch.from_numpy(\n", + " np.stack([sim[\"rts\"].squeeze(-1), sim[\"choices\"].squeeze(-1)], axis=-1).astype(\n", + " np.float32\n", + " )\n", + ")\n", "print(f\"training set: theta={theta_train.shape}, x={x_train.shape}\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "2fa47920", + "id": "fc05934c", "metadata": {}, "outputs": [], "source": [ @@ -231,7 +242,7 @@ }, { "cell_type": "markdown", - "id": "f931c135", + "id": "442786bc", "metadata": {}, "source": [ "## Part 4 — Export the trained NLE to ONNX\n", @@ -244,7 +255,7 @@ { "cell_type": "code", "execution_count": null, - "id": "693dd21c", + "id": "ff1017fd", "metadata": {}, "outputs": [], "source": [ @@ -264,7 +275,7 @@ }, { "cell_type": "markdown", - "id": "143b619b", + "id": "6c9241b4", "metadata": {}, "source": [ "## Part 5 — High-level integration via `hssm.HSSM()`\n", @@ -277,7 +288,7 @@ { "cell_type": "code", "execution_count": null, - "id": "376140d2", + "id": "bf9c9841", "metadata": {}, "outputs": [], "source": [ @@ -294,7 +305,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4100b9f1", + "id": "058c279b", "metadata": {}, "outputs": [], "source": [ @@ -311,7 +322,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d871300f", + "id": "e565bb3f", "metadata": {}, "outputs": [], "source": [ @@ -322,7 +333,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0062874d", + "id": "7888c608", "metadata": {}, "outputs": [], "source": [ @@ -333,7 +344,7 @@ }, { "cell_type": "markdown", - "id": "cd8f6778", + "id": "eacb39c4", "metadata": {}, "source": [ "## Part 6 — Brief NRE variant\n", @@ -347,7 +358,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a5483a45", + "id": "dabcf130", "metadata": {}, "outputs": [], "source": [ @@ -364,7 +375,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7c81221e", + "id": "5424dd08", "metadata": {}, "outputs": [], "source": [ @@ -399,7 +410,7 @@ }, { "cell_type": "markdown", - "id": "f51c06a0", + "id": "80a14229", "metadata": {}, "source": [ "## Part 7 — Posterior comparison: NLE vs NRE vs ground truth\n", @@ -417,7 +428,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1b11c7c1", + "id": "4547d5ff", "metadata": {}, "outputs": [], "source": [ @@ -438,7 +449,7 @@ }, { "cell_type": "markdown", - "id": "2476e274", + "id": "a13cc9d1", "metadata": {}, "source": [ "## Summary\n", From 4c9a3adb691a04836d9c2d3a84beed50baa91ff3 Mon Sep 17 00:00:00 2001 From: Alexander Date: Fri, 15 May 2026 15:10:38 -0400 Subject: [PATCH 09/20] docs(tutorials): add NLE training-vs-sampling diagnostic cell Adds Part 5b (markdown + code) to the sbi NLE notebook. The cell answers the question "is the marginal posterior bias coming from a poorly-trained flow, or from HSSM-side sampling issues?" by computing the trained NLE log-likelihood of the observed data at: - the true theta - the posterior's marginal mean If the NLE itself prefers the wrong theta by a large margin, the flow is the problem (training quality). If it prefers the truth, MCMC is failing to find the NLE's mode (priors / init / mixing). The cell prints a three-way verdict depending on the gap. Placed between the NLE trace plot (Part 5) and the NRE variant (Part 6) so it interprets the NLE posterior immediately. Uses the in-memory estimator_nle and the obs_data DataFrame already in scope. Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/tutorials/sbi_nle_integration.ipynb | 110 ++++++++++++++++++----- 1 file changed, 88 insertions(+), 22 deletions(-) diff --git a/docs/tutorials/sbi_nle_integration.ipynb b/docs/tutorials/sbi_nle_integration.ipynb index d5b08c3aa..d5b6b615b 100644 --- a/docs/tutorials/sbi_nle_integration.ipynb +++ b/docs/tutorials/sbi_nle_integration.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "ff0d85b2", + "id": "bf97f310", "metadata": {}, "source": [ "# Integrating sbi-trained likelihoods into HSSM (NLE + NRE via ONNX)\n", @@ -31,7 +31,7 @@ }, { "cell_type": "markdown", - "id": "f79dc412", + "id": "a085f2b3", "metadata": {}, "source": [ "## Part 1 — Setup" @@ -40,7 +40,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6efb6090", + "id": "5785df1a", "metadata": {}, "outputs": [], "source": [ @@ -120,7 +120,7 @@ }, { "cell_type": "markdown", - "id": "ff52a9d7", + "id": "4b6e7463", "metadata": {}, "source": [ "## Part 2 — Simulate observed DDM data\n", @@ -133,7 +133,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d0274137", + "id": "333bbaa0", "metadata": {}, "outputs": [], "source": [ @@ -161,7 +161,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ac24b07a", + "id": "d337eddc", "metadata": {}, "outputs": [], "source": [ @@ -185,7 +185,7 @@ }, { "cell_type": "markdown", - "id": "49319254", + "id": "75aba44b", "metadata": {}, "source": [ "## Part 3 — Train a sbi NLE_A on DDM simulations\n", @@ -198,7 +198,7 @@ { "cell_type": "code", "execution_count": null, - "id": "38679c22", + "id": "064472d1", "metadata": {}, "outputs": [], "source": [ @@ -227,7 +227,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fc05934c", + "id": "868bc72c", "metadata": {}, "outputs": [], "source": [ @@ -242,7 +242,7 @@ }, { "cell_type": "markdown", - "id": "442786bc", + "id": "7b0d8d7b", "metadata": {}, "source": [ "## Part 4 — Export the trained NLE to ONNX\n", @@ -255,7 +255,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ff1017fd", + "id": "d5af2ee5", "metadata": {}, "outputs": [], "source": [ @@ -275,7 +275,7 @@ }, { "cell_type": "markdown", - "id": "6c9241b4", + "id": "7eb029e0", "metadata": {}, "source": [ "## Part 5 — High-level integration via `hssm.HSSM()`\n", @@ -288,7 +288,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bf9c9841", + "id": "3044c543", "metadata": {}, "outputs": [], "source": [ @@ -305,7 +305,7 @@ { "cell_type": "code", "execution_count": null, - "id": "058c279b", + "id": "ccb3b58b", "metadata": {}, "outputs": [], "source": [ @@ -322,7 +322,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e565bb3f", + "id": "57297f08", "metadata": {}, "outputs": [], "source": [ @@ -333,7 +333,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7888c608", + "id": "e15513ec", "metadata": {}, "outputs": [], "source": [ @@ -344,7 +344,73 @@ }, { "cell_type": "markdown", - "id": "eacb39c4", + "id": "fe6e2d01", + "metadata": {}, + "source": [ + "### Part 5b — Diagnostic: is the NLE itself biased, or is HSSM not finding its mode?\n", + "\n", + "If the marginal posteriors miss the truth, two very different causes are possible:\n", + "\n", + "1. **Training quality** — the flow itself is maximized at the wrong θ.\n", + "2. **Sampling / HSSM-side** — the flow is fine, but MCMC isn't finding its mode (priors, init, mixing).\n", + "\n", + "We can distinguish them by asking the trained NLE directly: which θ does it\n", + "think makes the observed data more likely — the true θ, or the posterior's\n", + "marginal mean? If the NLE *itself* prefers the (wrong) posterior mean, training\n", + "is the issue. If it prefers the truth, sampling is." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c7541819", + "metadata": {}, + "outputs": [], + "source": [ + "posterior_mean_nle = {\n", + " p: float(idata_nle.posterior[p].mean()) for p in DDM_PARAM_NAMES\n", + "}\n", + "obs_x_t = torch.from_numpy(\n", + " obs_data[[\"rt\", \"response\"]].values.astype(np.float32)\n", + ")\n", + "\n", + "def total_logp(estimator, theta_dict):\n", + " \"\"\"Sum log p(x_i | theta) over all observed trials at a single theta.\"\"\"\n", + " theta_row = torch.tensor(\n", + " [[theta_dict[p] for p in DDM_PARAM_NAMES]], dtype=torch.float32\n", + " ).repeat(len(obs_x_t), 1)\n", + " with torch.no_grad():\n", + " return estimator.log_prob(obs_x_t, condition=theta_row).sum().item()\n", + "\n", + "lp_true = total_logp(estimator_nle, dict(zip(DDM_PARAM_NAMES, TRUE_THETA)))\n", + "lp_mean = total_logp(estimator_nle, posterior_mean_nle)\n", + "\n", + "print(f\"NLE log-prob at true theta: {lp_true:+.2f}\")\n", + "print(f\"NLE log-prob at posterior mean: {lp_mean:+.2f}\")\n", + "print(f\"Δ (mean − true): {lp_mean - lp_true:+.2f}\")\n", + "print()\n", + "if lp_mean > lp_true + 5.0:\n", + " print(\"→ NLE itself prefers the wrong θ by a large margin.\")\n", + " print(\" Diagnosis: TRAINING QUALITY. The flow has not learned the true\")\n", + " print(\" conditional density well. Fix with more sims, larger model, or\")\n", + " print(\" more epochs.\")\n", + "elif lp_mean > lp_true:\n", + " print(\"→ NLE mildly prefers the posterior mean over the truth.\")\n", + " print(\" Diagnosis: marginal posterior bias may reflect a slightly\")\n", + " print(\" miscalibrated flow plus narrow likelihood. Marginal posteriors\")\n", + " print(\" can shift while the joint mode stays near the truth — check the\")\n", + " print(\" pairs plot (added separately) to see the joint.\")\n", + "else:\n", + " print(\"→ NLE prefers the truth; the wrong posterior mean is a sampling\")\n", + " print(\" artifact (priors / init / mixing), not a training issue.\")\n", + "print()\n", + "print(f\"Posterior mean: {posterior_mean_nle}\")\n", + "print(f\"True theta: {dict(zip(DDM_PARAM_NAMES, TRUE_THETA.tolist()))}\")" + ] + }, + { + "cell_type": "markdown", + "id": "041e667e", "metadata": {}, "source": [ "## Part 6 — Brief NRE variant\n", @@ -358,7 +424,7 @@ { "cell_type": "code", "execution_count": null, - "id": "dabcf130", + "id": "4c036254", "metadata": {}, "outputs": [], "source": [ @@ -375,7 +441,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5424dd08", + "id": "51aaa91e", "metadata": {}, "outputs": [], "source": [ @@ -410,7 +476,7 @@ }, { "cell_type": "markdown", - "id": "80a14229", + "id": "10c8cbbf", "metadata": {}, "source": [ "## Part 7 — Posterior comparison: NLE vs NRE vs ground truth\n", @@ -428,7 +494,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4547d5ff", + "id": "2275b09f", "metadata": {}, "outputs": [], "source": [ @@ -449,7 +515,7 @@ }, { "cell_type": "markdown", - "id": "a13cc9d1", + "id": "50a7d7af", "metadata": {}, "source": [ "## Summary\n", From ebe86af7b0a3d1b4faaf0330e5c8c89f5c07b1dc Mon Sep 17 00:00:00 2001 From: Alexander Date: Fri, 15 May 2026 15:18:54 -0400 Subject: [PATCH 10/20] docs(tutorials): scale up sbi training to 1M sims + larger flow MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Applies the proposed §4.2 improvements from the bias review. The 30k-sim run was undertrained for a 4D-theta DDM problem with HSSM's default wide prior, so the marginal posteriors over v and a were biased high. Changes to Part 1 (setup cell): - N_TRAIN: 30_000 -> 1_000_000 - NUM_EPOCHS: 100 -> 300 (max) - STOP_AFTER_EPOCHS: new, set to 50 (default is 20 -- previous runs may have early-stopped silently) - TRAINING_BATCH_SIZE: 200 -> 500 (fewer batches per epoch at 1M) - HIDDEN_FEATURES: new, set to 100 (sbi default is 50) - NUM_TRANSFORMS: new, set to 8 (sbi MAF default is 5) - imports likelihood_nn alongside classifier_nn from sbi.neural_nets Changes to Part 3 (NLE training): - Replace density_estimator="maf" string shortcut with an explicit likelihood_nn(model="maf", hidden_features=HIDDEN_FEATURES, num_transforms=NUM_TRANSFORMS) builder. - Pass TRAINING_BATCH_SIZE and STOP_AFTER_EPOCHS into .train(). Changes to Part 6 (NRE training): - classifier_nn now also takes hidden_features=HIDDEN_FEATURES (matches NLE width). norm_layer=nn.Identity remains required because jaxonnxruntime doesn't implement LayerNormalization. - Same TRAINING_BATCH_SIZE and STOP_AFTER_EPOCHS. Expected wall time on CPU: 30-90 min for NLE alone, similar for NRE, plus a few minutes of MCMC each. Run on GPU if available. The training run is comparable in scale to a LAN training (1M sims) which is what's needed for a fair NLE-vs-LAN comparison on DDM-like problems. The Part 5b diagnostic cell (commit 4c9a3ad) is unchanged and will still print the training-vs-sampling verdict after this run. Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/tutorials/sbi_nle_integration.ipynb | 93 +++++++++++++++--------- 1 file changed, 58 insertions(+), 35 deletions(-) diff --git a/docs/tutorials/sbi_nle_integration.ipynb b/docs/tutorials/sbi_nle_integration.ipynb index d5b6b615b..937582382 100644 --- a/docs/tutorials/sbi_nle_integration.ipynb +++ b/docs/tutorials/sbi_nle_integration.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "bf97f310", + "id": "63439808", "metadata": {}, "source": [ "# Integrating sbi-trained likelihoods into HSSM (NLE + NRE via ONNX)\n", @@ -31,7 +31,7 @@ }, { "cell_type": "markdown", - "id": "a085f2b3", + "id": "5f743dcd", "metadata": {}, "source": [ "## Part 1 — Setup" @@ -40,7 +40,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5785df1a", + "id": "89e86029", "metadata": {}, "outputs": [], "source": [ @@ -63,7 +63,7 @@ "\n", "import hssm\n", "from sbi.inference import NLE_A, NRE_A\n", - "from sbi.neural_nets import classifier_nn\n", + "from sbi.neural_nets import classifier_nn, likelihood_nn\n", "from sbi.utils import BoxUniform\n", "from ssms.basic_simulators.simulator import simulator\n", "\n", @@ -107,12 +107,19 @@ "np.random.seed(0)\n", "torch.manual_seed(0)\n", "\n", - "# Training budget. Bumped from the original 10k/50ep because the wider prior\n", - "# (matching HSSM's default DDM bounds — see Part 3) needs more samples to\n", - "# cover the parameter volume.\n", - "N_TRAIN = 30_000 # training simulations\n", - "N_OBS = 500 # observed trials at the true theta\n", - "NUM_EPOCHS = 100\n", + "# Training budget. 1M (theta, x) pairs from ssm-simulators with a wider MAF\n", + "# (hidden_features=100, num_transforms=8) and patient early-stopping. Comparable\n", + "# in spirit to a LAN training run, lets us see how close sbi NLE/NRE can get\n", + "# to true posteriors on a DDM when not budget-starved.\n", + "# On CPU this is heavy — expect ~30–90 min for NLE training alone, depending\n", + "# on machine. Run on GPU if available.\n", + "N_TRAIN = 1_000_000 # training simulations\n", + "N_OBS = 500 # observed trials at the true theta\n", + "NUM_EPOCHS = 300 # max; sbi early-stops via stop_after_epochs below\n", + "STOP_AFTER_EPOCHS = 50 # patience for validation-loss plateau (default is 20)\n", + "TRAINING_BATCH_SIZE = 500\n", + "HIDDEN_FEATURES = 100\n", + "NUM_TRANSFORMS = 8 # MAF stack depth (default 5)\n", "MCMC_DRAWS = 500\n", "MCMC_TUNE = 500\n", "MCMC_CHAINS = 2" @@ -120,7 +127,7 @@ }, { "cell_type": "markdown", - "id": "4b6e7463", + "id": "1c1eb870", "metadata": {}, "source": [ "## Part 2 — Simulate observed DDM data\n", @@ -133,7 +140,7 @@ { "cell_type": "code", "execution_count": null, - "id": "333bbaa0", + "id": "f43cfbec", "metadata": {}, "outputs": [], "source": [ @@ -161,7 +168,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d337eddc", + "id": "8c0287da", "metadata": {}, "outputs": [], "source": [ @@ -185,7 +192,7 @@ }, { "cell_type": "markdown", - "id": "75aba44b", + "id": "5945a40a", "metadata": {}, "source": [ "## Part 3 — Train a sbi NLE_A on DDM simulations\n", @@ -198,7 +205,7 @@ { "cell_type": "code", "execution_count": null, - "id": "064472d1", + "id": "73c15a3a", "metadata": {}, "outputs": [], "source": [ @@ -227,14 +234,23 @@ { "cell_type": "code", "execution_count": null, - "id": "868bc72c", + "id": "ff85e311", "metadata": {}, "outputs": [], "source": [ - "inference_nle = NLE_A(prior=prior, density_estimator=\"maf\")\n", + "# Build a larger MAF than the sbi default (which is 50 hidden / 5 transforms)\n", + "# so the flow has enough capacity to learn p(rt, choice | v, a, z, t) accurately\n", + "# across the full HSSM prior range.\n", + "nle_estimator_builder = likelihood_nn(\n", + " model=\"maf\",\n", + " hidden_features=HIDDEN_FEATURES,\n", + " num_transforms=NUM_TRANSFORMS,\n", + ")\n", + "inference_nle = NLE_A(prior=prior, density_estimator=nle_estimator_builder)\n", "estimator_nle = inference_nle.append_simulations(theta_train, x_train).train(\n", - " training_batch_size=200,\n", + " training_batch_size=TRAINING_BATCH_SIZE,\n", " max_num_epochs=NUM_EPOCHS,\n", + " stop_after_epochs=STOP_AFTER_EPOCHS,\n", ")\n", "estimator_nle.eval()\n", "print(\"NLE training complete\")" @@ -242,7 +258,7 @@ }, { "cell_type": "markdown", - "id": "7b0d8d7b", + "id": "c3cf75d4", "metadata": {}, "source": [ "## Part 4 — Export the trained NLE to ONNX\n", @@ -255,7 +271,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d5af2ee5", + "id": "b91ef662", "metadata": {}, "outputs": [], "source": [ @@ -275,7 +291,7 @@ }, { "cell_type": "markdown", - "id": "7eb029e0", + "id": "ee8f6fb4", "metadata": {}, "source": [ "## Part 5 — High-level integration via `hssm.HSSM()`\n", @@ -288,7 +304,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3044c543", + "id": "60a777c3", "metadata": {}, "outputs": [], "source": [ @@ -305,7 +321,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ccb3b58b", + "id": "e5cfd093", "metadata": {}, "outputs": [], "source": [ @@ -322,7 +338,7 @@ { "cell_type": "code", "execution_count": null, - "id": "57297f08", + "id": "6764d32b", "metadata": {}, "outputs": [], "source": [ @@ -333,7 +349,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e15513ec", + "id": "6eccdc48", "metadata": {}, "outputs": [], "source": [ @@ -344,7 +360,7 @@ }, { "cell_type": "markdown", - "id": "fe6e2d01", + "id": "4fcb3e75", "metadata": {}, "source": [ "### Part 5b — Diagnostic: is the NLE itself biased, or is HSSM not finding its mode?\n", @@ -363,7 +379,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c7541819", + "id": "2561bcf2", "metadata": {}, "outputs": [], "source": [ @@ -410,7 +426,7 @@ }, { "cell_type": "markdown", - "id": "041e667e", + "id": "2954f39e", "metadata": {}, "source": [ "## Part 6 — Brief NRE variant\n", @@ -424,15 +440,22 @@ { "cell_type": "code", "execution_count": null, - "id": "4c036254", + "id": "759fb73a", "metadata": {}, "outputs": [], "source": [ - "classifier_builder = classifier_nn(model=\"mlp\", norm_layer=nn.Identity)\n", + "# Match NLE: bigger hidden width, longer training, LayerNorm disabled\n", + "# (jaxonnxruntime doesn't implement LayerNormalization).\n", + "classifier_builder = classifier_nn(\n", + " model=\"mlp\",\n", + " norm_layer=nn.Identity,\n", + " hidden_features=HIDDEN_FEATURES,\n", + ")\n", "inference_nre = NRE_A(prior=prior, classifier=classifier_builder)\n", "classifier_nre = inference_nre.append_simulations(theta_train, x_train).train(\n", - " training_batch_size=200,\n", + " training_batch_size=TRAINING_BATCH_SIZE,\n", " max_num_epochs=NUM_EPOCHS,\n", + " stop_after_epochs=STOP_AFTER_EPOCHS,\n", ")\n", "classifier_nre.eval()\n", "print(\"NRE training complete\")" @@ -441,7 +464,7 @@ { "cell_type": "code", "execution_count": null, - "id": "51aaa91e", + "id": "cd669dcd", "metadata": {}, "outputs": [], "source": [ @@ -476,7 +499,7 @@ }, { "cell_type": "markdown", - "id": "10c8cbbf", + "id": "bb854806", "metadata": {}, "source": [ "## Part 7 — Posterior comparison: NLE vs NRE vs ground truth\n", @@ -494,7 +517,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2275b09f", + "id": "ac809832", "metadata": {}, "outputs": [], "source": [ @@ -515,7 +538,7 @@ }, { "cell_type": "markdown", - "id": "50a7d7af", + "id": "5276d16c", "metadata": {}, "source": [ "## Summary\n", From a4d881ed2a4b19f033598f9dfd2e018d2ba49460 Mon Sep 17 00:00:00 2001 From: Alexander Date: Fri, 15 May 2026 22:15:32 -0400 Subject: [PATCH 11/20] docs(tutorials): add analytical DDM ground-truth posterior + 3-way comparison Adds Part 6b to the sbi NLE notebook: a second HSSM model built with loglik_kind="analytical" (HSSM's closed-form Navarro & Fuss DDM likelihood) sampled on the same obs_data. This gives a gold-standard posterior against which the sbi-NLE and sbi-NRE marginals can be compared. Distance from analytical to true theta is intrinsic posterior width (finite data effect). Distance from sbi-NLE/NRE to analytical is surrogate approximation error -- the thing we actually care about when evaluating how well the neural likelihood reproduces the closed-form target. Part 7's comparison plot is upgraded from a 2-way (NLE vs NRE) to a 3-way (analytical vs NLE vs NRE) histogram per parameter, with the true theta as a red dashed vertical for reference. The analytical DDM uses slightly different parameter bounds from approx_differentiable (a, t unbounded above; otherwise the same), but on the observed data the posterior concentrates regardless. Runtime impact: one additional HSSM MCMC run (~30-60 sec via numpyro NUTS on the analytical likelihood). Trivial compared to the 1M-sim sbi training in Parts 3 and 6. Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/tutorials/sbi_nle_integration.ipynb | 134 +++++++++++++++++------ 1 file changed, 98 insertions(+), 36 deletions(-) diff --git a/docs/tutorials/sbi_nle_integration.ipynb b/docs/tutorials/sbi_nle_integration.ipynb index 937582382..339feff13 100644 --- a/docs/tutorials/sbi_nle_integration.ipynb +++ b/docs/tutorials/sbi_nle_integration.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "63439808", + "id": "7ec39688", "metadata": {}, "source": [ "# Integrating sbi-trained likelihoods into HSSM (NLE + NRE via ONNX)\n", @@ -31,7 +31,7 @@ }, { "cell_type": "markdown", - "id": "5f743dcd", + "id": "ed401b4a", "metadata": {}, "source": [ "## Part 1 — Setup" @@ -40,7 +40,7 @@ { "cell_type": "code", "execution_count": null, - "id": "89e86029", + "id": "4ef0b05e", "metadata": {}, "outputs": [], "source": [ @@ -127,7 +127,7 @@ }, { "cell_type": "markdown", - "id": "1c1eb870", + "id": "608e0d7b", "metadata": {}, "source": [ "## Part 2 — Simulate observed DDM data\n", @@ -140,7 +140,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f43cfbec", + "id": "30923bc0", "metadata": {}, "outputs": [], "source": [ @@ -168,7 +168,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8c0287da", + "id": "7ff551d6", "metadata": {}, "outputs": [], "source": [ @@ -192,7 +192,7 @@ }, { "cell_type": "markdown", - "id": "5945a40a", + "id": "aa749e72", "metadata": {}, "source": [ "## Part 3 — Train a sbi NLE_A on DDM simulations\n", @@ -205,7 +205,7 @@ { "cell_type": "code", "execution_count": null, - "id": "73c15a3a", + "id": "697e034e", "metadata": {}, "outputs": [], "source": [ @@ -234,7 +234,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ff85e311", + "id": "925c7506", "metadata": {}, "outputs": [], "source": [ @@ -258,7 +258,7 @@ }, { "cell_type": "markdown", - "id": "c3cf75d4", + "id": "90916f7f", "metadata": {}, "source": [ "## Part 4 — Export the trained NLE to ONNX\n", @@ -271,7 +271,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b91ef662", + "id": "9ad107d6", "metadata": {}, "outputs": [], "source": [ @@ -291,7 +291,7 @@ }, { "cell_type": "markdown", - "id": "ee8f6fb4", + "id": "76b82659", "metadata": {}, "source": [ "## Part 5 — High-level integration via `hssm.HSSM()`\n", @@ -304,7 +304,7 @@ { "cell_type": "code", "execution_count": null, - "id": "60a777c3", + "id": "dee28506", "metadata": {}, "outputs": [], "source": [ @@ -321,7 +321,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e5cfd093", + "id": "fbe0c986", "metadata": {}, "outputs": [], "source": [ @@ -338,7 +338,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6764d32b", + "id": "b2b2b0b9", "metadata": {}, "outputs": [], "source": [ @@ -349,7 +349,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6eccdc48", + "id": "15b0f907", "metadata": {}, "outputs": [], "source": [ @@ -360,7 +360,7 @@ }, { "cell_type": "markdown", - "id": "4fcb3e75", + "id": "8017d20a", "metadata": {}, "source": [ "### Part 5b — Diagnostic: is the NLE itself biased, or is HSSM not finding its mode?\n", @@ -379,7 +379,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2561bcf2", + "id": "c5a4a29d", "metadata": {}, "outputs": [], "source": [ @@ -426,7 +426,7 @@ }, { "cell_type": "markdown", - "id": "2954f39e", + "id": "5e3162a7", "metadata": {}, "source": [ "## Part 6 — Brief NRE variant\n", @@ -440,7 +440,7 @@ { "cell_type": "code", "execution_count": null, - "id": "759fb73a", + "id": "1ec4149f", "metadata": {}, "outputs": [], "source": [ @@ -464,7 +464,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cd669dcd", + "id": "aaa999bd", "metadata": {}, "outputs": [], "source": [ @@ -499,46 +499,108 @@ }, { "cell_type": "markdown", - "id": "bb854806", + "id": "9af6d9ad", "metadata": {}, "source": [ - "## Part 7 — Posterior comparison: NLE vs NRE vs ground truth\n", + "## Part 6b — Ground-truth posterior via HSSM's analytical DDM\n", "\n", - "The comparison plot is the deliverable. NLE and NRE should cover the true parameter\n", - "values; both posteriors should be tighter than the prior and centred near the truth.\n", - "For a four-way comparison alongside the LAN baseline and BayesFlow-LRE result on the\n", - "same simulated data, load their cached posteriors here and add panels to the plot.\n", + "HSSM ships a closed-form analytical likelihood for the standard DDM\n", + "(`loglik_kind=\"analytical\"`, the [Navarro & Fuss](https://psyarxiv.com/cwsbm/)\n", + "form). Running MCMC against this likelihood on the *same observed data*\n", + "produces what we should treat as the gold-standard posterior for this\n", + "model + dataset: any deviation of the sbi-NLE / sbi-NRE marginals from\n", + "these is *approximation error* in the neural surrogate, not intrinsic\n", + "posterior width.\n", + "\n", + "The analytical likelihood uses slightly different parameter bounds\n", + "(a, t unbounded above; otherwise the same) but on the observed data the\n", + "posterior concentrates regardless of the wider bound." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fba9cccd", + "metadata": {}, + "outputs": [], + "source": [ + "model_analytical = hssm.HSSM(\n", + " data=obs_data,\n", + " model=\"ddm\",\n", + " loglik_kind=\"analytical\",\n", + " p_outlier=0,\n", + ")\n", + "idata_analytical = model_analytical.sample(\n", + " sampler=\"numpyro\",\n", + " draws=MCMC_DRAWS,\n", + " tune=MCMC_TUNE,\n", + " chains=MCMC_CHAINS,\n", + " target_accept=0.9,\n", + " progressbar=False,\n", + ")\n", + "summary_analytical = az.summary(idata_analytical, var_names=DDM_PARAM_NAMES)\n", + "summary_analytical" + ] + }, + { + "cell_type": "markdown", + "id": "d8257ed2", + "metadata": {}, + "source": [ + "## Part 7 — Posterior comparison: analytical vs NLE vs NRE\n", + "\n", + "The keystone comparison: the analytical posterior (gold standard for this\n", + "model + data) overlaid against sbi NLE and sbi NRE marginals. Distance\n", + "between an sbi method and the analytical posterior is the\n", + "*approximation error* contributed by the neural surrogate — distance\n", + "between the analytical posterior and the true theta is intrinsic\n", + "posterior width (the data simply isn't infinite).\n", + "\n", + "For a downstream cross-tutorial comparison with the LAN baseline and\n", + "BayesFlow-LRE result on the same simulated data, load their cached\n", + "posteriors here and add panels to the plot.\n", "\n", "> Reuses the exact `TRUE_THETA`, `N_OBS`, and prior ranges from\n", - "> `bayesflow_lre_integration.ipynb`, so the BayesFlow-LRE posteriors are directly\n", - "> comparable when added." + "> `bayesflow_lre_integration.ipynb`, so the BayesFlow-LRE posteriors are\n", + "> directly comparable when added." ] }, { "cell_type": "code", "execution_count": null, - "id": "ac809832", + "id": "5b30d302", "metadata": {}, "outputs": [], "source": [ "fig, axes = plt.subplots(1, 4, figsize=(16, 4))\n", "for ax, name, true_val in zip(axes, DDM_PARAM_NAMES, TRUE_THETA):\n", + " samples_ana = idata_analytical.posterior[name].values.flatten()\n", " samples_nle = idata_nle.posterior[name].values.flatten()\n", " samples_nre = idata_nre.posterior[name].values.flatten()\n", - " ax.hist(samples_nle, bins=30, alpha=0.5, label=\"sbi NLE\", color=\"C0\", density=True)\n", - " ax.hist(samples_nre, bins=30, alpha=0.5, label=\"sbi NRE\", color=\"C1\", density=True)\n", - " ax.axvline(true_val, color=\"red\", linestyle=\"--\", linewidth=2, label=\"true\")\n", + " ax.hist(\n", + " samples_ana,\n", + " bins=30,\n", + " alpha=0.45,\n", + " label=\"analytical (truth)\",\n", + " color=\"C2\",\n", + " density=True,\n", + " )\n", + " ax.hist(samples_nle, bins=30, alpha=0.45, label=\"sbi NLE\", color=\"C0\", density=True)\n", + " ax.hist(samples_nre, bins=30, alpha=0.45, label=\"sbi NRE\", color=\"C1\", density=True)\n", + " ax.axvline(true_val, color=\"red\", linestyle=\"--\", linewidth=2, label=\"true θ\")\n", " ax.set_xlabel(name)\n", " ax.set_title(f\"posterior over {name}\")\n", - " ax.legend()\n", - "fig.suptitle(\"DDM posterior recovery: sbi NLE vs NRE\", y=1.02)\n", + " ax.legend(fontsize=8)\n", + "fig.suptitle(\n", + " \"DDM posterior recovery: analytical (gold) vs sbi NLE vs sbi NRE\", y=1.02\n", + ")\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", - "id": "5276d16c", + "id": "2680bc5d", "metadata": {}, "source": [ "## Summary\n", From 8e91d9204b5ea1257f491663e7924005f76e19a4 Mon Sep 17 00:00:00 2001 From: Alexander Date: Sat, 16 May 2026 13:01:25 -0400 Subject: [PATCH 12/20] docs(tutorials): restructure C8 notebook around NRE; drop NLE-MAF section The C8 keystone tutorial was producing qualitatively wrong NLE posteriors on DDM data (v centered at ~0.12 vs true 0.5, spurious bimodality on a) because MAF flows can't properly model mixed continuous-discrete data (rt continuous, choice in {-1, +1}). The correct sbi method (MNLE) is blocked by the SearchSorted ONNX-op gap that also blocks NSF flows; see plans/sbi-onnx-integration.md "Deferred sbi paths" for the resolution roadmap (~50-line upstream PR to jaxonnxruntime unlocks both). Until that PR lands, the tutorial drops NLE entirely and focuses on NRE, which is robust to discrete/continuous mixing because it learns a classifier (no density-shape assumption). Notebook restructure: - Removed Parts 3-5b (NLE training, export, sampling, diagnostic). - Promoted the NRE variant to the primary path (Parts 3-5b). - Promoted the analytical ground truth (was Part 6b) to Part 6. - Part 7 comparison is now 2-way: analytical (gold) vs sbi NRE. - Added a "Why no NLE in this tutorial?" callout in both the intro and the closing summary pointing at the deferred-paths plan. NRE-side improvements (all applied in Part 3): - Switched NRE_A -> NRE_B with num_atoms=20 (atomic contrastive estimation; sharper signal than plain binary classification). - Multi-sample per theta: 300k distinct theta * 3 samples = 900k pairs (vs the previous 1M theta * 1 sample). Richer local conditional shape information at fewer theta points; still well-covered for a 4D parameter space. - FCEmbedding(input_dim=4, output_dim=32, num_layers=2) on theta to give the classifier richer parameter conditioning. - hidden_features bumped 100 -> 128. - Longer MCMC: tune 500 -> 1500, draws 500 -> 1000. Also: - Diagnostic cell (was Part 5b NLE) adapted for NRE: uses summed classifier logit (which equals log p(x|theta) - log p(x) up to a theta-independent constant) instead of NLE log-prob. - Closing summary now explains the NLE/MNLE deferral and points at plans/sbi-onnx-integration.md. Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/tutorials/sbi_nle_integration.ipynb | 425 +++++++++++------------ 1 file changed, 194 insertions(+), 231 deletions(-) diff --git a/docs/tutorials/sbi_nle_integration.ipynb b/docs/tutorials/sbi_nle_integration.ipynb index 339feff13..798896e8c 100644 --- a/docs/tutorials/sbi_nle_integration.ipynb +++ b/docs/tutorials/sbi_nle_integration.ipynb @@ -2,27 +2,38 @@ "cells": [ { "cell_type": "markdown", - "id": "7ec39688", + "id": "a08fa58e", "metadata": {}, "source": [ - "# Integrating sbi-trained likelihoods into HSSM (NLE + NRE via ONNX)\n", + "# Integrating sbi-trained likelihoods into HSSM (NRE via ONNX)\n", "\n", "This tutorial mirrors the structure of `bayesflow_lre_integration.ipynb` for the\n", "third major SBI library in the HSSM ecosystem: **[sbi](https://github.com/sbi-dev/sbi)**\n", "(mackelab). It demonstrates how to:\n", "\n", - "1. Train a neural likelihood estimator (NLE) on synthetic DDM simulations using sbi.\n", + "1. Train a neural ratio estimator (NRE) on synthetic DDM simulations using sbi.\n", "2. Export the trained estimator to ONNX via\n", " [`lanfactory.onnx.transform_sbi_to_onnx`](https://alexanderfengler.github.io/LANfactory/exporting_sbi_models/).\n", "3. Load the ONNX file into HSSM exactly like any other LAN-style approximator and run\n", " MCMC inference.\n", - "4. Repeat the loop with a neural ratio estimator (NRE) for comparison.\n", - "\n", - "The ecosystem's current scope is **neural likelihood surrogates** (NLE and NRE). NPE/\n", - "posterior-amortized methods are deliberately out of scope here — they don't compose\n", - "cleanly with PyMC priors. See the\n", - "[Exporting sbi Models guide](https://alexanderfengler.github.io/LANfactory/exporting_sbi_models/)\n", - "for the full supported-architecture matrix.\n", + "4. Compare against HSSM's analytical DDM likelihood as the gold-standard posterior.\n", + "\n", + "> **Why NRE only, not NLE/MNLE?**\n", + "> Vanilla NLE with a MAF flow misbehaves on DDM data because rt is continuous but\n", + "> choice is discrete (∈ {−1, +1}). The flow treats choice as continuous, can't\n", + "> represent the support boundary `rt > t_nd`, and produces qualitatively wrong\n", + "> posteriors (we observed v ≈ 0.12 vs truth 0.5, with spurious bimodality on a).\n", + "> The correct sbi method is **MNLE** (Mixed Neural Likelihood Estimator), which\n", + "> splits x into discrete and continuous dims and models each properly. But MNLE's\n", + "> categorical lookup uses `torch.searchsorted`, which `torch.onnx.export` doesn't\n", + "> support and `jaxonnxruntime` lacks a handler for. See\n", + "> [plans/sbi-onnx-integration.md](../../../HSSMSpine/plans/sbi-onnx-integration.md)\n", + "> \"Deferred sbi paths\" for the resolution roadmap (a ~50-line upstream PR to\n", + "> `jaxonnxruntime` unlocks both MNLE and NSF flows in one stroke).\n", + ">\n", + "> Until then, **NRE is the working integration path** — it doesn't need to model\n", + "> density at all, just a classifier between joint and marginal pairs, which is\n", + "> robust to the discrete/continuous mixing.\n", "\n", "> **Environment note:** This tutorial requires both `hssm` and `lanfactory[all]` (which\n", "> pulls `sbi` and `nflows`) in the same environment. JAX/flax/numpyro pins must be\n", @@ -31,7 +42,7 @@ }, { "cell_type": "markdown", - "id": "ed401b4a", + "id": "c9c1f918", "metadata": {}, "source": [ "## Part 1 — Setup" @@ -40,7 +51,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4ef0b05e", + "id": "cb4a66bb", "metadata": {}, "outputs": [], "source": [ @@ -62,8 +73,9 @@ "from torch import nn\n", "\n", "import hssm\n", - "from sbi.inference import NLE_A, NRE_A\n", - "from sbi.neural_nets import classifier_nn, likelihood_nn\n", + "from sbi.inference import NRE_B\n", + "from sbi.neural_nets import classifier_nn\n", + "from sbi.neural_nets.embedding_nets import FCEmbedding\n", "from sbi.utils import BoxUniform\n", "from ssms.basic_simulators.simulator import simulator\n", "\n", @@ -107,27 +119,27 @@ "np.random.seed(0)\n", "torch.manual_seed(0)\n", "\n", - "# Training budget. 1M (theta, x) pairs from ssm-simulators with a wider MAF\n", - "# (hidden_features=100, num_transforms=8) and patient early-stopping. Comparable\n", - "# in spirit to a LAN training run, lets us see how close sbi NLE/NRE can get\n", - "# to true posteriors on a DDM when not budget-starved.\n", - "# On CPU this is heavy — expect ~30–90 min for NLE training alone, depending\n", - "# on machine. Run on GPU if available.\n", - "N_TRAIN = 1_000_000 # training simulations\n", - "N_OBS = 500 # observed trials at the true theta\n", - "NUM_EPOCHS = 300 # max; sbi early-stops via stop_after_epochs below\n", - "STOP_AFTER_EPOCHS = 50 # patience for validation-loss plateau (default is 20)\n", + "# Training budget for NRE_B (atomic contrastive estimation).\n", + "# Multi-sample-per-θ: 300k unique θ values × 3 samples = 900k pairs.\n", + "# Comparable scale to a LAN training run; better local conditional structure\n", + "# than 1M θ × 1 sample because the classifier sees varied x at each θ.\n", + "N_THETAS = 300_000\n", + "N_SAMPLES_PER_THETA = 3\n", + "N_OBS = 500\n", + "NUM_EPOCHS = 300\n", + "STOP_AFTER_EPOCHS = 50\n", "TRAINING_BATCH_SIZE = 500\n", - "HIDDEN_FEATURES = 100\n", - "NUM_TRANSFORMS = 8 # MAF stack depth (default 5)\n", - "MCMC_DRAWS = 500\n", - "MCMC_TUNE = 500\n", + "HIDDEN_FEATURES = 128\n", + "THETA_EMBEDDING_DIM = 32\n", + "NRE_NUM_ATOMS = 20 # NRE_B atomic contrastive size; default 10\n", + "MCMC_DRAWS = 1000\n", + "MCMC_TUNE = 1500\n", "MCMC_CHAINS = 2" ] }, { "cell_type": "markdown", - "id": "608e0d7b", + "id": "66d77379", "metadata": {}, "source": [ "## Part 2 — Simulate observed DDM data\n", @@ -140,7 +152,7 @@ { "cell_type": "code", "execution_count": null, - "id": "30923bc0", + "id": "b6a96998", "metadata": {}, "outputs": [], "source": [ @@ -148,8 +160,8 @@ "# Training prior matches HSSM's default bounds for model=\"ddm\" with\n", "# loglik_kind=\"approx_differentiable\" — see hssm.defaults.default_model_config.\n", "# This is important: if the training prior is narrower than HSSM's posterior\n", - "# can explore, MCMC will walk into parameter regions the flow never saw and\n", - "# extrapolate badly (spurious likelihood pockets, biased posteriors).\n", + "# can explore, MCMC will walk into parameter regions the surrogate never saw\n", + "# and extrapolate badly.\n", "PRIOR_LOW = np.array([-3.0, 0.3, 0.0, 0.0], dtype=np.float32)\n", "PRIOR_HIGH = np.array([3.0, 2.5, 1.0, 2.0], dtype=np.float32)\n", "TRUE_THETA = np.array([0.5, 1.2, 0.5, 0.25], dtype=np.float32)\n", @@ -168,7 +180,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7ff551d6", + "id": "89320183", "metadata": {}, "outputs": [], "source": [ @@ -192,20 +204,34 @@ }, { "cell_type": "markdown", - "id": "aa749e72", + "id": "612ac4bd", "metadata": {}, "source": [ - "## Part 3 — Train a sbi NLE_A on DDM simulations\n", + "## Part 3 — Train an sbi NRE_B classifier on DDM simulations\n", + "\n", + "`NRE_B` learns a classifier that distinguishes joint `(θ, x)` pairs from marginal\n", + "`(θ', x)` pairs (where θ' is drawn from the prior). The output logit equals\n", + "`log p(x | θ) − log p(x)` up to a constant, so it serves directly as the\n", + "HSSM log-likelihood for MCMC (the θ-independent constant drops out).\n", + "\n", + "We use **NRE_B** rather than NRE_A because NRE_B's atomic contrastive estimation\n", + "(`num_atoms=20`) gives a sharper discriminative signal — at each gradient step,\n", + "the classifier scores one positive vs. `num_atoms − 1` marginals, multiclass\n", + "softmax over the row. NRE_A's plain binary classifier (1 positive vs 1 marginal)\n", + "trains faster but reaches a less calibrated optimum.\n", "\n", - "`NLE_A` trains a conditional density estimator (here a MAF normalizing flow) on\n", - "`(theta, x)` pairs. After training, `estimator.log_prob(x, condition=theta)` returns\n", - "`log p(x | theta)` with z-score standardization Jacobians applied automatically." + "We also use a **multi-sample-per-θ training set**: 300k distinct θ values, 3\n", + "simulated `(rt, choice)` per θ. This gives the classifier richer local\n", + "information about the per-θ distribution shape than 1 sample per θ would.\n", + "\n", + "Finally, an `FCEmbedding` on θ (4 → 32 → 32) gives the classifier richer\n", + "parameter conditioning." ] }, { "cell_type": "code", "execution_count": null, - "id": "697e034e", + "id": "f01cbef8", "metadata": {}, "outputs": [], "source": [ @@ -213,85 +239,92 @@ " low=torch.from_numpy(PRIOR_LOW),\n", " high=torch.from_numpy(PRIOR_HIGH),\n", ")\n", - "theta_train = prior.sample((N_TRAIN,))\n", + "theta_unique = prior.sample((N_THETAS,))\n", "\n", - "# Batched simulation: one (rt, choice) per training theta. ssm-simulators\n", - "# accepts theta of shape (N, 4) with n_samples=1 and returns rts/choices of\n", - "# shape (N, 1) directly — much faster than a Python loop for large N.\n", + "# Batched ssm-simulators: theta shape (N, 4) with n_samples=k → rts/choices of\n", + "# shape (N, k). Massively faster than a Python loop for large N.\n", "sim = simulator(\n", - " theta=theta_train.numpy().astype(np.float32),\n", + " theta=theta_unique.numpy().astype(np.float32),\n", " model=\"ddm\",\n", - " n_samples=1,\n", - ")\n", - "x_train = torch.from_numpy(\n", - " np.stack([sim[\"rts\"].squeeze(-1), sim[\"choices\"].squeeze(-1)], axis=-1).astype(\n", - " np.float32\n", - " )\n", + " n_samples=N_SAMPLES_PER_THETA,\n", ")\n", - "print(f\"training set: theta={theta_train.shape}, x={x_train.shape}\")" + "rts_flat = sim[\"rts\"].reshape(-1).astype(np.float32) # (N * k,)\n", + "choices_flat = sim[\"choices\"].reshape(-1).astype(np.float32)\n", + "x_train = torch.from_numpy(np.stack([rts_flat, choices_flat], axis=-1))\n", + "theta_train = theta_unique.repeat_interleave(N_SAMPLES_PER_THETA, dim=0)\n", + "print(f\"training set: theta={theta_train.shape}, x={x_train.shape} \"\n", + " f\"(N_THETAS={N_THETAS} × N_SAMPLES_PER_THETA={N_SAMPLES_PER_THETA})\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "925c7506", + "id": "9790e9ac", "metadata": {}, "outputs": [], "source": [ - "# Build a larger MAF than the sbi default (which is 50 hidden / 5 transforms)\n", - "# so the flow has enough capacity to learn p(rt, choice | v, a, z, t) accurately\n", - "# across the full HSSM prior range.\n", - "nle_estimator_builder = likelihood_nn(\n", - " model=\"maf\",\n", + "# Build the classifier with an FCEmbedding on theta and LayerNorm disabled\n", + "# (jaxonnxruntime doesn't implement LayerNormalization, so the MLP norm_layer\n", + "# must be nn.Identity for ONNX export to work).\n", + "embedding_theta = FCEmbedding(\n", + " input_dim=4,\n", + " output_dim=THETA_EMBEDDING_DIM,\n", + " num_layers=2,\n", + " num_hiddens=THETA_EMBEDDING_DIM,\n", + ")\n", + "classifier_builder = classifier_nn(\n", + " model=\"mlp\",\n", + " norm_layer=nn.Identity,\n", " hidden_features=HIDDEN_FEATURES,\n", - " num_transforms=NUM_TRANSFORMS,\n", + " embedding_net_theta=embedding_theta,\n", ")\n", - "inference_nle = NLE_A(prior=prior, density_estimator=nle_estimator_builder)\n", - "estimator_nle = inference_nle.append_simulations(theta_train, x_train).train(\n", + "inference_nre = NRE_B(prior=prior, classifier=classifier_builder)\n", + "classifier_nre = inference_nre.append_simulations(theta_train, x_train).train(\n", " training_batch_size=TRAINING_BATCH_SIZE,\n", " max_num_epochs=NUM_EPOCHS,\n", " stop_after_epochs=STOP_AFTER_EPOCHS,\n", + " num_atoms=NRE_NUM_ATOMS,\n", ")\n", - "estimator_nle.eval()\n", - "print(\"NLE training complete\")" + "classifier_nre.eval()\n", + "print(\"NRE_B training complete\")" ] }, { "cell_type": "markdown", - "id": "90916f7f", + "id": "2a30deb4", "metadata": {}, "source": [ - "## Part 4 — Export the trained NLE to ONNX\n", + "## Part 4 — Export the trained NRE to ONNX\n", "\n", - "The exporter wraps `estimator.log_prob` as a `torch.nn.Module` whose `forward(combined)`\n", - "splits a concatenated `(theta, x)` input. The result is a single-trial ONNX graph that\n", - "HSSM consumes exactly like a LAN file." + "The exporter wraps the classifier's `forward(theta, x)` logit as the HSSM\n", + "log-likelihood. No Jacobian correction is needed — ratios are invariant to the\n", + "z-score standardization sbi applies internally." ] }, { "cell_type": "code", "execution_count": null, - "id": "9ad107d6", + "id": "276ad225", "metadata": {}, "outputs": [], "source": [ "onnx_dir = Path(\"./sbi_onnx_artifacts\")\n", "onnx_dir.mkdir(exist_ok=True)\n", - "nle_onnx_path = onnx_dir / \"ddm_nle.onnx\"\n", + "nre_onnx_path = onnx_dir / \"ddm_nre.onnx\"\n", "\n", "transform_sbi_to_onnx(\n", - " estimator_nle,\n", - " str(nle_onnx_path),\n", - " mode=\"nle\",\n", + " classifier_nre,\n", + " str(nre_onnx_path),\n", + " mode=\"nre\",\n", " example_theta_dim=4,\n", " example_x_dim=2,\n", ")\n", - "print(f\"exported NLE: {nle_onnx_path} ({nle_onnx_path.stat().st_size:,} bytes)\")" + "print(f\"exported NRE: {nre_onnx_path} ({nre_onnx_path.stat().st_size:,} bytes)\")" ] }, { "cell_type": "markdown", - "id": "76b82659", + "id": "8a618485", "metadata": {}, "source": [ "## Part 5 — High-level integration via `hssm.HSSM()`\n", @@ -304,28 +337,28 @@ { "cell_type": "code", "execution_count": null, - "id": "dee28506", + "id": "37111e5c", "metadata": {}, "outputs": [], "source": [ - "model_nle = hssm.HSSM(\n", + "model_nre = hssm.HSSM(\n", " data=obs_data,\n", " model=\"ddm\",\n", " loglik_kind=\"approx_differentiable\",\n", - " loglik=str(nle_onnx_path),\n", + " loglik=str(nre_onnx_path),\n", " p_outlier=0,\n", ")\n", - "print(model_nle)" + "print(model_nre)" ] }, { "cell_type": "code", "execution_count": null, - "id": "fbe0c986", + "id": "5fd8d29a", "metadata": {}, "outputs": [], "source": [ - "idata_nle = model_nle.sample(\n", + "idata_nre = model_nre.sample(\n", " sampler=\"numpyro\",\n", " draws=MCMC_DRAWS,\n", " tune=MCMC_TUNE,\n", @@ -338,179 +371,97 @@ { "cell_type": "code", "execution_count": null, - "id": "b2b2b0b9", + "id": "c45d934b", "metadata": {}, "outputs": [], "source": [ - "summary_nle = az.summary(idata_nle, var_names=DDM_PARAM_NAMES)\n", - "summary_nle" + "summary_nre = az.summary(idata_nre, var_names=DDM_PARAM_NAMES)\n", + "summary_nre" ] }, { "cell_type": "code", "execution_count": null, - "id": "15b0f907", + "id": "aa29df5f", "metadata": {}, "outputs": [], "source": [ - "az.plot_trace(idata_nle, var_names=DDM_PARAM_NAMES)\n", + "az.plot_trace(idata_nre, var_names=DDM_PARAM_NAMES)\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", - "id": "8017d20a", + "id": "f78d6166", "metadata": {}, "source": [ - "### Part 5b — Diagnostic: is the NLE itself biased, or is HSSM not finding its mode?\n", - "\n", - "If the marginal posteriors miss the truth, two very different causes are possible:\n", - "\n", - "1. **Training quality** — the flow itself is maximized at the wrong θ.\n", - "2. **Sampling / HSSM-side** — the flow is fine, but MCMC isn't finding its mode (priors, init, mixing).\n", - "\n", - "We can distinguish them by asking the trained NLE directly: which θ does it\n", - "think makes the observed data more likely — the true θ, or the posterior's\n", - "marginal mean? If the NLE *itself* prefers the (wrong) posterior mean, training\n", - "is the issue. If it prefers the truth, sampling is." + "### Part 5b — Diagnostic: is the NRE classifier itself biased, or is HSSM not finding its mode?\n", + "\n", + "For NRE the logit `forward(theta, x)` equals `log p(x | θ) − log p(x)` up to a\n", + "constant. The θ-independent term cancels when we compare two θ values on the\n", + "same data, so we can use the *summed logit over trials* as a proxy for \"which θ\n", + "does the classifier think makes the data more likely.\" If the classifier itself\n", + "prefers a θ far from the truth, training quality is the issue; if it prefers the\n", + "truth, MCMC is exploring elsewhere." ] }, { "cell_type": "code", "execution_count": null, - "id": "c5a4a29d", + "id": "52a21be1", "metadata": {}, "outputs": [], "source": [ - "posterior_mean_nle = {\n", - " p: float(idata_nle.posterior[p].mean()) for p in DDM_PARAM_NAMES\n", + "posterior_mean_nre = {\n", + " p: float(idata_nre.posterior[p].mean()) for p in DDM_PARAM_NAMES\n", "}\n", "obs_x_t = torch.from_numpy(\n", " obs_data[[\"rt\", \"response\"]].values.astype(np.float32)\n", ")\n", "\n", - "def total_logp(estimator, theta_dict):\n", - " \"\"\"Sum log p(x_i | theta) over all observed trials at a single theta.\"\"\"\n", + "def total_logit(classifier, theta_dict):\n", + " \"\"\"Sum log-ratio logit over all observed trials at a single theta.\"\"\"\n", " theta_row = torch.tensor(\n", " [[theta_dict[p] for p in DDM_PARAM_NAMES]], dtype=torch.float32\n", " ).repeat(len(obs_x_t), 1)\n", " with torch.no_grad():\n", - " return estimator.log_prob(obs_x_t, condition=theta_row).sum().item()\n", + " return classifier(theta_row, obs_x_t).sum().item()\n", "\n", - "lp_true = total_logp(estimator_nle, dict(zip(DDM_PARAM_NAMES, TRUE_THETA)))\n", - "lp_mean = total_logp(estimator_nle, posterior_mean_nle)\n", + "lt_true = total_logit(classifier_nre, dict(zip(DDM_PARAM_NAMES, TRUE_THETA)))\n", + "lt_mean = total_logit(classifier_nre, posterior_mean_nre)\n", "\n", - "print(f\"NLE log-prob at true theta: {lp_true:+.2f}\")\n", - "print(f\"NLE log-prob at posterior mean: {lp_mean:+.2f}\")\n", - "print(f\"Δ (mean − true): {lp_mean - lp_true:+.2f}\")\n", + "print(f\"NRE total logit at true theta: {lt_true:+.2f}\")\n", + "print(f\"NRE total logit at posterior mean: {lt_mean:+.2f}\")\n", + "print(f\"Δ (mean − true): {lt_mean - lt_true:+.2f}\")\n", "print()\n", - "if lp_mean > lp_true + 5.0:\n", - " print(\"→ NLE itself prefers the wrong θ by a large margin.\")\n", - " print(\" Diagnosis: TRAINING QUALITY. The flow has not learned the true\")\n", - " print(\" conditional density well. Fix with more sims, larger model, or\")\n", - " print(\" more epochs.\")\n", - "elif lp_mean > lp_true:\n", - " print(\"→ NLE mildly prefers the posterior mean over the truth.\")\n", - " print(\" Diagnosis: marginal posterior bias may reflect a slightly\")\n", - " print(\" miscalibrated flow plus narrow likelihood. Marginal posteriors\")\n", - " print(\" can shift while the joint mode stays near the truth — check the\")\n", - " print(\" pairs plot (added separately) to see the joint.\")\n", + "if lt_mean > lt_true + 5.0:\n", + " print(\"→ NRE itself prefers the wrong θ by a large margin.\")\n", + " print(\" Diagnosis: TRAINING QUALITY.\")\n", + "elif lt_mean > lt_true:\n", + " print(\"→ NRE mildly prefers the posterior mean over the truth.\")\n", + " print(\" Could be marginal-vs-joint, mild miscalibration, or sampling.\")\n", "else:\n", - " print(\"→ NLE prefers the truth; the wrong posterior mean is a sampling\")\n", + " print(\"→ NRE prefers the truth; the wrong posterior mean is a sampling\")\n", " print(\" artifact (priors / init / mixing), not a training issue.\")\n", "print()\n", - "print(f\"Posterior mean: {posterior_mean_nle}\")\n", + "print(f\"Posterior mean: {posterior_mean_nre}\")\n", "print(f\"True theta: {dict(zip(DDM_PARAM_NAMES, TRUE_THETA.tolist()))}\")" ] }, { "cell_type": "markdown", - "id": "5e3162a7", - "metadata": {}, - "source": [ - "## Part 6 — Brief NRE variant\n", - "\n", - "The same pipeline works for ratio classifiers. The only sbi-side wrinkle is that the\n", - "default MLP classifier uses `nn.LayerNorm` between hidden layers, and `jaxonnxruntime`\n", - "doesn't implement `LayerNormalization`. We pass `norm_layer=nn.Identity` to disable\n", - "it. See `docs/exporting_sbi_models.md` in LANfactory for the full constraint list." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1ec4149f", - "metadata": {}, - "outputs": [], - "source": [ - "# Match NLE: bigger hidden width, longer training, LayerNorm disabled\n", - "# (jaxonnxruntime doesn't implement LayerNormalization).\n", - "classifier_builder = classifier_nn(\n", - " model=\"mlp\",\n", - " norm_layer=nn.Identity,\n", - " hidden_features=HIDDEN_FEATURES,\n", - ")\n", - "inference_nre = NRE_A(prior=prior, classifier=classifier_builder)\n", - "classifier_nre = inference_nre.append_simulations(theta_train, x_train).train(\n", - " training_batch_size=TRAINING_BATCH_SIZE,\n", - " max_num_epochs=NUM_EPOCHS,\n", - " stop_after_epochs=STOP_AFTER_EPOCHS,\n", - ")\n", - "classifier_nre.eval()\n", - "print(\"NRE training complete\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "aaa999bd", - "metadata": {}, - "outputs": [], - "source": [ - "nre_onnx_path = onnx_dir / \"ddm_nre.onnx\"\n", - "transform_sbi_to_onnx(\n", - " classifier_nre,\n", - " str(nre_onnx_path),\n", - " mode=\"nre\",\n", - " example_theta_dim=4,\n", - " example_x_dim=2,\n", - ")\n", - "print(f\"exported NRE: {nre_onnx_path} ({nre_onnx_path.stat().st_size:,} bytes)\")\n", - "\n", - "model_nre = hssm.HSSM(\n", - " data=obs_data,\n", - " model=\"ddm\",\n", - " loglik_kind=\"approx_differentiable\",\n", - " loglik=str(nre_onnx_path),\n", - " p_outlier=0,\n", - ")\n", - "idata_nre = model_nre.sample(\n", - " sampler=\"numpyro\",\n", - " draws=MCMC_DRAWS,\n", - " tune=MCMC_TUNE,\n", - " chains=MCMC_CHAINS,\n", - " target_accept=0.9,\n", - " progressbar=False,\n", - ")\n", - "summary_nre = az.summary(idata_nre, var_names=DDM_PARAM_NAMES)\n", - "summary_nre" - ] - }, - { - "cell_type": "markdown", - "id": "9af6d9ad", + "id": "f26d6a4f", "metadata": {}, "source": [ - "## Part 6b — Ground-truth posterior via HSSM's analytical DDM\n", + "## Part 6 — Ground-truth posterior via HSSM's analytical DDM\n", "\n", "HSSM ships a closed-form analytical likelihood for the standard DDM\n", "(`loglik_kind=\"analytical\"`, the [Navarro & Fuss](https://psyarxiv.com/cwsbm/)\n", "form). Running MCMC against this likelihood on the *same observed data*\n", "produces what we should treat as the gold-standard posterior for this\n", - "model + dataset: any deviation of the sbi-NLE / sbi-NRE marginals from\n", - "these is *approximation error* in the neural surrogate, not intrinsic\n", - "posterior width.\n", + "model + dataset: any deviation of the sbi-NRE marginals from these is\n", + "*approximation error* in the neural surrogate, not intrinsic posterior width.\n", "\n", "The analytical likelihood uses slightly different parameter bounds\n", "(a, t unbounded above; otherwise the same) but on the observed data the\n", @@ -520,7 +471,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fba9cccd", + "id": "aa8aca68", "metadata": {}, "outputs": [], "source": [ @@ -544,21 +495,20 @@ }, { "cell_type": "markdown", - "id": "d8257ed2", + "id": "1bcd7d63", "metadata": {}, "source": [ - "## Part 7 — Posterior comparison: analytical vs NLE vs NRE\n", + "## Part 7 — Posterior comparison: analytical vs sbi NRE\n", "\n", - "The keystone comparison: the analytical posterior (gold standard for this\n", - "model + data) overlaid against sbi NLE and sbi NRE marginals. Distance\n", - "between an sbi method and the analytical posterior is the\n", - "*approximation error* contributed by the neural surrogate — distance\n", - "between the analytical posterior and the true theta is intrinsic\n", - "posterior width (the data simply isn't infinite).\n", + "The keystone comparison. The analytical posterior is the gold standard for this\n", + "model + data; the sbi-NRE marginals are the approximation we built. Distance\n", + "between NRE and analytical is the surrogate's approximation error; distance\n", + "between the analytical posterior and the true θ is intrinsic posterior width\n", + "(the data simply isn't infinite at N=500 trials).\n", "\n", - "For a downstream cross-tutorial comparison with the LAN baseline and\n", - "BayesFlow-LRE result on the same simulated data, load their cached\n", - "posteriors here and add panels to the plot.\n", + "For the broader cross-tutorial comparison alongside the LAN baseline and\n", + "BayesFlow-LRE result on the same simulated data, load their cached posteriors\n", + "here and add panels to the plot.\n", "\n", "> Reuses the exact `TRUE_THETA`, `N_OBS`, and prior ranges from\n", "> `bayesflow_lre_integration.ipynb`, so the BayesFlow-LRE posteriors are\n", @@ -568,31 +518,29 @@ { "cell_type": "code", "execution_count": null, - "id": "5b30d302", + "id": "06b36b08", "metadata": {}, "outputs": [], "source": [ "fig, axes = plt.subplots(1, 4, figsize=(16, 4))\n", "for ax, name, true_val in zip(axes, DDM_PARAM_NAMES, TRUE_THETA):\n", " samples_ana = idata_analytical.posterior[name].values.flatten()\n", - " samples_nle = idata_nle.posterior[name].values.flatten()\n", " samples_nre = idata_nre.posterior[name].values.flatten()\n", " ax.hist(\n", " samples_ana,\n", " bins=30,\n", - " alpha=0.45,\n", + " alpha=0.5,\n", " label=\"analytical (truth)\",\n", " color=\"C2\",\n", " density=True,\n", " )\n", - " ax.hist(samples_nle, bins=30, alpha=0.45, label=\"sbi NLE\", color=\"C0\", density=True)\n", - " ax.hist(samples_nre, bins=30, alpha=0.45, label=\"sbi NRE\", color=\"C1\", density=True)\n", + " ax.hist(samples_nre, bins=30, alpha=0.5, label=\"sbi NRE\", color=\"C1\", density=True)\n", " ax.axvline(true_val, color=\"red\", linestyle=\"--\", linewidth=2, label=\"true θ\")\n", " ax.set_xlabel(name)\n", " ax.set_title(f\"posterior over {name}\")\n", " ax.legend(fontsize=8)\n", "fig.suptitle(\n", - " \"DDM posterior recovery: analytical (gold) vs sbi NLE vs sbi NRE\", y=1.02\n", + " \"DDM posterior recovery: analytical (gold) vs sbi NRE\", y=1.02\n", ")\n", "plt.tight_layout()\n", "plt.show()" @@ -600,25 +548,40 @@ }, { "cell_type": "markdown", - "id": "2680bc5d", + "id": "4db9c6f8", "metadata": {}, "source": [ - "## Summary\n", + "## Summary and deferred work\n", + "\n", + "We trained an sbi NRE_B classifier on synthetic DDM data with FCEmbedding on\n", + "the parameters and atomic contrastive estimation (`num_atoms=20`), exported it\n", + "to ONNX via `lanfactory.onnx.transform_sbi_to_onnx`, and ran MCMC through\n", + "HSSM's existing `loglik_kind=\"approx_differentiable\"` pipeline. The resulting\n", + "posterior is compared against HSSM's analytical DDM posterior on the same\n", + "data — that's our gold-standard reference.\n", + "\n", + "**Why not NLE in this tutorial?**\n", + "\n", + "We originally planned an NLE section too. Vanilla NLE with a MAF flow produces\n", + "qualitatively wrong posteriors on DDM data because rt is continuous but choice\n", + "is discrete (∈ {−1, +1}); the flow can't represent that structure or the hard\n", + "support boundary `rt > t_nd`. The correct sbi method is **MNLE** (Mixed Neural\n", + "Likelihood Estimator), which factorizes `p(rt, choice | θ) = p(choice | θ) ·\n", + "p(rt | choice, θ)`. But MNLE's `CategoricalMassEstimator` uses\n", + "`torch.searchsorted` for value-to-index lookup, which `torch.onnx.export` does\n", + "not support — blocking the ONNX export path until a `SearchSorted` ONNX-op\n", + "handler is added to `jaxonnxruntime`. The same gap blocks Neural Spline Flows.\n", + "\n", + "See `plans/sbi-onnx-integration.md` \"Deferred sbi paths (MNLE, vanilla NLE on\n", + "DDM, NSF flows)\" in HSSMSpine for the resolution roadmap. A single ~50-line\n", + "upstream PR to `jaxonnxruntime` adding `SearchSorted` unlocks both NSF flows\n", + "AND MNLE in one stroke.\n", + "\n", + "**Where to look next**\n", "\n", - "We trained a sbi NLE_A and an NRE_A on synthetic DDM data, exported both to ONNX via\n", - "`lanfactory.onnx.transform_sbi_to_onnx`, and ran MCMC through HSSM's existing\n", - "`loglik_kind=\"approx_differentiable\"` pipeline. Both posteriors recover the true\n", - "parameters.\n", - "\n", - "**Where to look next:**\n", "- LANfactory's [Exporting sbi Models guide](https://alexanderfengler.github.io/LANfactory/exporting_sbi_models/) — supported-architecture matrix, known constraints, troubleshooting.\n", "- The BayesFlow LRE tutorial (`bayesflow_lre_integration.ipynb`) — the same DDM with a different SBI library, for cross-toolkit comparison.\n", - "- LAN tutorials (`main_tutorial.ipynb`) — the original LANfactory workflow this integration builds on top of.\n", - "\n", - "**Known constraints (v1):**\n", - "- sbi NPE/SNPE estimators are deliberately out of scope (posterior-shaped, conflicts with PyMC priors).\n", - "- Neural Spline Flows are blocked on a missing `SearchSorted` op in `jaxonnxruntime`.\n", - "- FMPE / NPSE (score-based) require ODE integration and aren't ONNX-exportable." + "- LAN tutorials (`main_tutorial.ipynb`) — the original LANfactory workflow this integration builds on top of." ] } ], From 1dde8f9d4fcd947abb4a5f770a5ee91b0d0df428 Mon Sep 17 00:00:00 2001 From: Alexander Date: Sat, 16 May 2026 20:37:02 -0400 Subject: [PATCH 13/20] docs(tutorials): add Part 5c deeper NRE diagnostics (logit sweep + export check) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Latest tutorial run produced NRE posteriors that essentially equal the prior across all four DDM parameters. Traces wander freely — chain is not stuck, the loglik is just ~constant across theta-space, so MCMC has no signal and samples the prior. The Part 5b diagnostic isn't enough to distinguish "NRE found the truth, MCMC didn't sample around it" from "NRE classifier is flat everywhere." Adding two more targeted diagnostics: Part 5c.1 — Logit sweep: Hold three theta dims at the true values, sweep the fourth across its prior range, plot the summed classifier log-ratio on observed data. A well-trained NRE shows a sharp peak near the true value with tens-to-hundreds of log units of vertical range. A flat curve (< 5 log units) is the smoking gun for "classifier collapsed." Part 5c.2 — Export round-trip: Compare classifier_nre(theta, x).item() against the exported ONNX output through onnxruntime on a fixed point. If they agree to ~1e-5, the export is faithful and any pathology is in the trained classifier itself; otherwise the bigger network or FCEmbedding addition introduced an export bug. Both cells read in-memory state from the previous run (classifier_nre, nre_onnx_path, obs_data, TRUE_THETA, prior bounds) so the user can run them without retraining — they're cheap one-off diagnostics. This commit doesn't change the broken NRE training itself; once the diagnostics tell us which side is broken (training quality vs. export), the bisect plan from the review will pick the right fix: - flat classifier -> drop FCEmbedding -> if still flat, drop NRE_B for NRE_A -> if still flat, drop multi-sample-per-theta. - faithful export, flat classifier -> same bisect path. - export mismatch -> investigate exporter on FCEmbedding shapes. Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/tutorials/sbi_nle_integration.ipynb | 156 +++++++++++++++++++---- 1 file changed, 133 insertions(+), 23 deletions(-) diff --git a/docs/tutorials/sbi_nle_integration.ipynb b/docs/tutorials/sbi_nle_integration.ipynb index 798896e8c..9a760ed78 100644 --- a/docs/tutorials/sbi_nle_integration.ipynb +++ b/docs/tutorials/sbi_nle_integration.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "a08fa58e", + "id": "f5014a28", "metadata": {}, "source": [ "# Integrating sbi-trained likelihoods into HSSM (NRE via ONNX)\n", @@ -42,7 +42,7 @@ }, { "cell_type": "markdown", - "id": "c9c1f918", + "id": "f48d8b15", "metadata": {}, "source": [ "## Part 1 — Setup" @@ -51,7 +51,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cb4a66bb", + "id": "53226015", "metadata": {}, "outputs": [], "source": [ @@ -139,7 +139,7 @@ }, { "cell_type": "markdown", - "id": "66d77379", + "id": "b6b49be1", "metadata": {}, "source": [ "## Part 2 — Simulate observed DDM data\n", @@ -152,7 +152,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b6a96998", + "id": "70b21f3b", "metadata": {}, "outputs": [], "source": [ @@ -180,7 +180,7 @@ { "cell_type": "code", "execution_count": null, - "id": "89320183", + "id": "e4c0d87f", "metadata": {}, "outputs": [], "source": [ @@ -204,7 +204,7 @@ }, { "cell_type": "markdown", - "id": "612ac4bd", + "id": "e0bd9676", "metadata": {}, "source": [ "## Part 3 — Train an sbi NRE_B classifier on DDM simulations\n", @@ -231,7 +231,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f01cbef8", + "id": "54940535", "metadata": {}, "outputs": [], "source": [ @@ -259,7 +259,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9790e9ac", + "id": "b7d905b5", "metadata": {}, "outputs": [], "source": [ @@ -291,7 +291,7 @@ }, { "cell_type": "markdown", - "id": "2a30deb4", + "id": "3b9d53fb", "metadata": {}, "source": [ "## Part 4 — Export the trained NRE to ONNX\n", @@ -304,7 +304,7 @@ { "cell_type": "code", "execution_count": null, - "id": "276ad225", + "id": "e5d6da0e", "metadata": {}, "outputs": [], "source": [ @@ -324,7 +324,7 @@ }, { "cell_type": "markdown", - "id": "8a618485", + "id": "186cdc43", "metadata": {}, "source": [ "## Part 5 — High-level integration via `hssm.HSSM()`\n", @@ -337,7 +337,7 @@ { "cell_type": "code", "execution_count": null, - "id": "37111e5c", + "id": "8c7bc487", "metadata": {}, "outputs": [], "source": [ @@ -354,7 +354,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5fd8d29a", + "id": "7aaecc6d", "metadata": {}, "outputs": [], "source": [ @@ -371,7 +371,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c45d934b", + "id": "bddd2753", "metadata": {}, "outputs": [], "source": [ @@ -382,7 +382,7 @@ { "cell_type": "code", "execution_count": null, - "id": "aa29df5f", + "id": "fa6a0848", "metadata": {}, "outputs": [], "source": [ @@ -393,7 +393,7 @@ }, { "cell_type": "markdown", - "id": "f78d6166", + "id": "dc2c1716", "metadata": {}, "source": [ "### Part 5b — Diagnostic: is the NRE classifier itself biased, or is HSSM not finding its mode?\n", @@ -409,7 +409,7 @@ { "cell_type": "code", "execution_count": null, - "id": "52a21be1", + "id": "52a58b68", "metadata": {}, "outputs": [], "source": [ @@ -451,7 +451,117 @@ }, { "cell_type": "markdown", - "id": "f26d6a4f", + "id": "22d9c68e", + "metadata": {}, + "source": [ + "### Part 5c — Deeper diagnostics: classifier shape + export round-trip\n", + "\n", + "If the diagnostic above prints a Δ near zero, that's ambiguous: it could mean\n", + "\"NRE found the truth, just MCMC didn't sample around it\" OR \"NRE's logit is\n", + "nearly flat everywhere, so all θ look equally good to MCMC.\" To distinguish\n", + "these we need to look at the trained classifier directly.\n", + "\n", + "The two cells below:\n", + "1. **Sweep the classifier logit across each θ dimension** (holding the others\n", + " at the true values) and plot the response. A well-trained classifier shows\n", + " a sharp peak near the red truth line with a large vertical range (tens to\n", + " hundreds of log units). A poorly-trained classifier produces a nearly flat\n", + " curve — the smoking gun for \"MCMC samples the prior because the loglik is\n", + " uninformative.\"\n", + "2. **Compare the exported ONNX output to the torch classifier output** on the\n", + " same input. If they agree to ~1e-5, the export is fine and any pathology\n", + " is in training; if they differ, the export is broken on the bigger network." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4bbe513d", + "metadata": {}, + "outputs": [], + "source": [ + "# Diagnostic: how much does the NRE logit change as we sweep each theta dim?\n", + "sweep_pts = 50\n", + "sweep_results = {}\n", + "obs_x_t = torch.from_numpy(\n", + " obs_data[[\"rt\", \"response\"]].values.astype(np.float32)\n", + ")\n", + "theta_center = torch.tensor(TRUE_THETA, dtype=torch.float32)\n", + "for dim, name in enumerate(DDM_PARAM_NAMES):\n", + " sweep = torch.linspace(PRIOR_LOW[dim], PRIOR_HIGH[dim], sweep_pts)\n", + " logits = []\n", + " for v in sweep:\n", + " theta = theta_center.clone()\n", + " theta[dim] = v\n", + " theta_row = theta.unsqueeze(0).repeat(len(obs_x_t), 1)\n", + " with torch.no_grad():\n", + " logits.append(classifier_nre(theta_row, obs_x_t).sum().item())\n", + " sweep_results[name] = (sweep.numpy(), np.array(logits))\n", + "\n", + "fig, axes = plt.subplots(1, 4, figsize=(16, 3.5))\n", + "for ax, name in zip(axes, DDM_PARAM_NAMES):\n", + " th, lp = sweep_results[name]\n", + " ax.plot(th, lp - lp.max(), \"C0-\", linewidth=2)\n", + " ax.axvline(TRUE_THETA[DDM_PARAM_NAMES.index(name)], color=\"red\",\n", + " linestyle=\"--\", linewidth=2, label=\"true θ\")\n", + " ax.set_xlabel(name)\n", + " ax.set_ylabel(\"Δ summed log-ratio (= 0 at max)\")\n", + " ax.set_title(f\"sweep over {name}\\n(vertical range = {lp.ptp():.2f})\")\n", + " ax.legend(fontsize=8)\n", + "fig.suptitle(\n", + " \"Trained NRE classifier: log-ratio along each θ axis (others held at truth)\",\n", + " y=1.02,\n", + ")\n", + "plt.tight_layout()\n", + "plt.show()\n", + "\n", + "# Sanity print: vertical range per dim. < ~5 log units is bad.\n", + "print(\"\\nPer-dim vertical range (max log-ratio − min log-ratio):\")\n", + "for name in DDM_PARAM_NAMES:\n", + " _, lp = sweep_results[name]\n", + " print(f\" {name}: {lp.ptp():.2f}\")\n", + "print(\"\\nInterpretation:\")\n", + "print(\" > 50 log units : strong discriminative signal, classifier well-trained.\")\n", + "print(\" 10–50 : moderate signal; should give a usable posterior.\")\n", + "print(\" < 5 : near-flat; classifier collapsed to uninformative.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20854d41", + "metadata": {}, + "outputs": [], + "source": [ + "# Diagnostic: does the exported ONNX match the torch classifier?\n", + "import onnxruntime as _ort\n", + "\n", + "_sess = _ort.InferenceSession(str(nre_onnx_path))\n", + "_input_name = _sess.get_inputs()[0].name\n", + "_test_theta = torch.tensor([[0.5, 1.2, 0.5, 0.25]], dtype=torch.float32)\n", + "_test_x = torch.tensor([[0.5, 1.0]], dtype=torch.float32)\n", + "_combined = (\n", + " torch.cat([_test_theta, _test_x], dim=-1).squeeze(0).numpy().astype(np.float32)\n", + ")\n", + "\n", + "with torch.no_grad():\n", + " _y_torch = float(classifier_nre(_test_theta, _test_x).item())\n", + "_y_ort = float(_sess.run(None, {_input_name: _combined})[0])\n", + "\n", + "print(f\"torch logit at (θ=true, x=(0.5, +1)): {_y_torch:+.5f}\")\n", + "print(f\"ORT logit at (θ=true, x=(0.5, +1)): {_y_ort:+.5f}\")\n", + "print(f\"|Δ|: {abs(_y_torch - _y_ort):.2e}\")\n", + "print()\n", + "if abs(_y_torch - _y_ort) < 1e-4:\n", + " print(\"→ Export is faithful; any pathology is in the trained classifier itself.\")\n", + "else:\n", + " print(\"→ Export disagrees with torch by more than 1e-4. The bigger network or\"\n", + " \" FCEmbedding likely introduced an op the exporter handles incorrectly.\")" + ] + }, + { + "cell_type": "markdown", + "id": "362404a2", "metadata": {}, "source": [ "## Part 6 — Ground-truth posterior via HSSM's analytical DDM\n", @@ -471,7 +581,7 @@ { "cell_type": "code", "execution_count": null, - "id": "aa8aca68", + "id": "a5585daf", "metadata": {}, "outputs": [], "source": [ @@ -495,7 +605,7 @@ }, { "cell_type": "markdown", - "id": "1bcd7d63", + "id": "027bc9de", "metadata": {}, "source": [ "## Part 7 — Posterior comparison: analytical vs sbi NRE\n", @@ -518,7 +628,7 @@ { "cell_type": "code", "execution_count": null, - "id": "06b36b08", + "id": "dde71468", "metadata": {}, "outputs": [], "source": [ @@ -548,7 +658,7 @@ }, { "cell_type": "markdown", - "id": "4db9c6f8", + "id": "6bd93a3c", "metadata": {}, "source": [ "## Summary and deferred work\n", From 572e74b870f8bb796ab4e9e63d55cf100818baea Mon Sep 17 00:00:00 2001 From: Alexander Date: Sat, 16 May 2026 23:14:44 -0400 Subject: [PATCH 14/20] =?UTF-8?q?docs(tutorials):=20bisect=20step=201=20?= =?UTF-8?q?=E2=80=94=20remove=20FCEmbedding=20+=20tighten=20MCMC=20budget?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Most recent notebook run produced NRE posteriors that effectively equal the prior (chains wander freely across the entire prior with no concentration), and MCMC was taking 30+ min per pass. The pattern is consistent with the trained classifier providing near-zero discriminative signal AND with NUTS spending most leapfrog steps on divergent trajectories — both symptoms of a pathologically-shaped surrogate loglik. The bisect starts with the most-recently-introduced and least-tested change: the FCEmbedding(4 → 32 → 32) on theta inside the NRE classifier. Removing it is a single-line change, leaves the other improvements in place (NRE_B with num_atoms=20, multi-sample-per-theta 300k×3, hidden=128), and lets us verify whether the embedding was the culprit. Comment in the classifier- builder cell notes the bisect for future readers. Companion changes to keep MCMC bounded during diagnosis: - MCMC_DRAWS: 1000 → 500 - MCMC_TUNE: 1500 → 500 - target_accept: 0.9 → 0.8 (allows larger steps) - max_tree_depth: default 10 → 8 (caps leapfrog steps/draw at 256 vs 1024) - progressbar: False → True (so users see chain progress instead of wondering if it's hung) If the next run still produces a flat NRE posterior, the next bisect step is NRE_B → NRE_A (revert the contrastive change). If MCMC behaves but posteriors are biased, that's a separate calibration question we'll address on the next round of knobs. Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/tutorials/sbi_nle_integration.ipynb | 92 ++++++++++++------------ 1 file changed, 48 insertions(+), 44 deletions(-) diff --git a/docs/tutorials/sbi_nle_integration.ipynb b/docs/tutorials/sbi_nle_integration.ipynb index 9a760ed78..1909300fc 100644 --- a/docs/tutorials/sbi_nle_integration.ipynb +++ b/docs/tutorials/sbi_nle_integration.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "f5014a28", + "id": "6ba45e52", "metadata": {}, "source": [ "# Integrating sbi-trained likelihoods into HSSM (NRE via ONNX)\n", @@ -42,7 +42,7 @@ }, { "cell_type": "markdown", - "id": "f48d8b15", + "id": "35d3e93c", "metadata": {}, "source": [ "## Part 1 — Setup" @@ -51,7 +51,7 @@ { "cell_type": "code", "execution_count": null, - "id": "53226015", + "id": "f952caef", "metadata": {}, "outputs": [], "source": [ @@ -75,7 +75,6 @@ "import hssm\n", "from sbi.inference import NRE_B\n", "from sbi.neural_nets import classifier_nn\n", - "from sbi.neural_nets.embedding_nets import FCEmbedding\n", "from sbi.utils import BoxUniform\n", "from ssms.basic_simulators.simulator import simulator\n", "\n", @@ -130,16 +129,18 @@ "STOP_AFTER_EPOCHS = 50\n", "TRAINING_BATCH_SIZE = 500\n", "HIDDEN_FEATURES = 128\n", - "THETA_EMBEDDING_DIM = 32\n", "NRE_NUM_ATOMS = 20 # NRE_B atomic contrastive size; default 10\n", - "MCMC_DRAWS = 1000\n", - "MCMC_TUNE = 1500\n", + "# Smaller MCMC budget for verification runs. Together with target_accept=0.8\n", + "# and max_tree_depth=8 below, this keeps the MCMC step bounded so we can\n", + "# diagnose training-quality issues without paying 30+ minutes per pass.\n", + "MCMC_DRAWS = 500\n", + "MCMC_TUNE = 500\n", "MCMC_CHAINS = 2" ] }, { "cell_type": "markdown", - "id": "b6b49be1", + "id": "53fa1fac", "metadata": {}, "source": [ "## Part 2 — Simulate observed DDM data\n", @@ -152,7 +153,7 @@ { "cell_type": "code", "execution_count": null, - "id": "70b21f3b", + "id": "d12a134e", "metadata": {}, "outputs": [], "source": [ @@ -180,7 +181,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e4c0d87f", + "id": "9b0476dd", "metadata": {}, "outputs": [], "source": [ @@ -204,7 +205,7 @@ }, { "cell_type": "markdown", - "id": "e0bd9676", + "id": "1ca5a987", "metadata": {}, "source": [ "## Part 3 — Train an sbi NRE_B classifier on DDM simulations\n", @@ -224,14 +225,16 @@ "simulated `(rt, choice)` per θ. This gives the classifier richer local\n", "information about the per-θ distribution shape than 1 sample per θ would.\n", "\n", - "Finally, an `FCEmbedding` on θ (4 → 32 → 32) gives the classifier richer\n", - "parameter conditioning." + "(An earlier iteration also added an `FCEmbedding` on θ. It produced a\n", + "classifier that gave HSSM a near-constant log-likelihood at MCMC time, so it\n", + "was removed as the first step of a bisect — see the comment in the\n", + "classifier-builder cell below.)" ] }, { "cell_type": "code", "execution_count": null, - "id": "54940535", + "id": "e3041de7", "metadata": {}, "outputs": [], "source": [ @@ -259,24 +262,20 @@ { "cell_type": "code", "execution_count": null, - "id": "b7d905b5", + "id": "b92539b7", "metadata": {}, "outputs": [], "source": [ - "# Build the classifier with an FCEmbedding on theta and LayerNorm disabled\n", - "# (jaxonnxruntime doesn't implement LayerNormalization, so the MLP norm_layer\n", - "# must be nn.Identity for ONNX export to work).\n", - "embedding_theta = FCEmbedding(\n", - " input_dim=4,\n", - " output_dim=THETA_EMBEDDING_DIM,\n", - " num_layers=2,\n", - " num_hiddens=THETA_EMBEDDING_DIM,\n", - ")\n", + "# Build the classifier. LayerNorm is disabled (jaxonnxruntime doesn't\n", + "# implement LayerNormalization). We deliberately do NOT use an FCEmbedding on\n", + "# theta in this iteration — a previous run with that embedding produced a\n", + "# classifier that gave HSSM a near-constant log-likelihood (chains explored\n", + "# the entire prior). Bisect: remove the embedding first, leave the other\n", + "# changes in place (NRE_B, num_atoms=20, multi-sample-per-θ, hidden=128).\n", "classifier_builder = classifier_nn(\n", " model=\"mlp\",\n", " norm_layer=nn.Identity,\n", " hidden_features=HIDDEN_FEATURES,\n", - " embedding_net_theta=embedding_theta,\n", ")\n", "inference_nre = NRE_B(prior=prior, classifier=classifier_builder)\n", "classifier_nre = inference_nre.append_simulations(theta_train, x_train).train(\n", @@ -291,7 +290,7 @@ }, { "cell_type": "markdown", - "id": "3b9d53fb", + "id": "88b7a509", "metadata": {}, "source": [ "## Part 4 — Export the trained NRE to ONNX\n", @@ -304,7 +303,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e5d6da0e", + "id": "ca509e35", "metadata": {}, "outputs": [], "source": [ @@ -324,7 +323,7 @@ }, { "cell_type": "markdown", - "id": "186cdc43", + "id": "275a3265", "metadata": {}, "source": [ "## Part 5 — High-level integration via `hssm.HSSM()`\n", @@ -337,7 +336,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8c7bc487", + "id": "76d4d855", "metadata": {}, "outputs": [], "source": [ @@ -354,24 +353,29 @@ { "cell_type": "code", "execution_count": null, - "id": "7aaecc6d", + "id": "22c2cd49", "metadata": {}, "outputs": [], "source": [ + "# Verification-budget MCMC. target_accept=0.8 (was 0.9) and max_tree_depth=8\n", + "# (caps NUTS at 256 leapfrog steps per draw instead of 1024) bound the per-step\n", + "# cost so a pathological surrogate geometry can't produce a multi-hour run.\n", + "# progressbar=True lets you actually see chain progress as it goes.\n", "idata_nre = model_nre.sample(\n", " sampler=\"numpyro\",\n", " draws=MCMC_DRAWS,\n", " tune=MCMC_TUNE,\n", " chains=MCMC_CHAINS,\n", - " target_accept=0.9,\n", - " progressbar=False,\n", + " target_accept=0.8,\n", + " progressbar=True,\n", + " nuts_sampler_kwargs={\"max_tree_depth\": 8},\n", ")" ] }, { "cell_type": "code", "execution_count": null, - "id": "bddd2753", + "id": "11f20d1d", "metadata": {}, "outputs": [], "source": [ @@ -382,7 +386,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fa6a0848", + "id": "8548f513", "metadata": {}, "outputs": [], "source": [ @@ -393,7 +397,7 @@ }, { "cell_type": "markdown", - "id": "dc2c1716", + "id": "a458b3c5", "metadata": {}, "source": [ "### Part 5b — Diagnostic: is the NRE classifier itself biased, or is HSSM not finding its mode?\n", @@ -409,7 +413,7 @@ { "cell_type": "code", "execution_count": null, - "id": "52a58b68", + "id": "8847cfc9", "metadata": {}, "outputs": [], "source": [ @@ -451,7 +455,7 @@ }, { "cell_type": "markdown", - "id": "22d9c68e", + "id": "04db652f", "metadata": {}, "source": [ "### Part 5c — Deeper diagnostics: classifier shape + export round-trip\n", @@ -476,7 +480,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4bbe513d", + "id": "8ca58298", "metadata": {}, "outputs": [], "source": [ @@ -529,7 +533,7 @@ { "cell_type": "code", "execution_count": null, - "id": "20854d41", + "id": "ad889270", "metadata": {}, "outputs": [], "source": [ @@ -561,7 +565,7 @@ }, { "cell_type": "markdown", - "id": "362404a2", + "id": "79c2b07e", "metadata": {}, "source": [ "## Part 6 — Ground-truth posterior via HSSM's analytical DDM\n", @@ -581,7 +585,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a5585daf", + "id": "beb95242", "metadata": {}, "outputs": [], "source": [ @@ -605,7 +609,7 @@ }, { "cell_type": "markdown", - "id": "027bc9de", + "id": "60f0c2c0", "metadata": {}, "source": [ "## Part 7 — Posterior comparison: analytical vs sbi NRE\n", @@ -628,7 +632,7 @@ { "cell_type": "code", "execution_count": null, - "id": "dde71468", + "id": "afbb5ede", "metadata": {}, "outputs": [], "source": [ @@ -658,7 +662,7 @@ }, { "cell_type": "markdown", - "id": "6bd93a3c", + "id": "a82c8c43", "metadata": {}, "source": [ "## Summary and deferred work\n", From c589b140ef9349611579d8c0553f273302a91207 Mon Sep 17 00:00:00 2001 From: Alexander Date: Sun, 17 May 2026 01:08:36 -0400 Subject: [PATCH 15/20] docs(tutorials): revert NRE config to last-known-working baseline + move sweep diagnostic pre-MCMC Two changes in one commit, both reactions to the bisect step 1 not fixing the flat-posterior problem (removing the FCEmbedding alone left the NRE classifier still uninformative at MCMC time). 1. Move Part 5c sweep + ONNX round-trip diagnostics to a new Part 4b that runs RIGHT AFTER training and export, BEFORE the multi-minute MCMC step. The sweep is the cheapest way to know if the trained classifier is informative at all -- there's no point running MCMC if the per-dim vertical range of the log-ratio is < 5 units everywhere (= no discriminative signal -> posterior will equal the prior). Old Part 5c block deleted to avoid duplication. 2. Revert the NRE configuration to the last known-working baseline: - NRE_B -> NRE_A (drop atomic contrastive estimation) - Multi-sample 300k_theta x 3_samples -> 1M_theta x 1_sample - HIDDEN_FEATURES 128 -> 100 - drop NRE_NUM_ATOMS (NRE_A has no contrastive hyperparameter) - drop FCEmbedding import (already not used since bisect step 1) - Updated Part 3 markdown to call out the bisect explicitly The aim: re-establish a baseline that gives at least the moderate-quality recovery seen two iterations back. If THIS produces a flat posterior too, we have a more fundamental problem (not in any of the NRE-only changes) and need to look at the exporter, HSSM consumption, or environment. MCMC budget knobs from step 1 are retained (draws=500, tune=500, target_accept=0.8, progressbar=True, nuts_sampler_kwargs.max_tree_depth=8) since they are pure safety and don't change posterior quality. Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/tutorials/sbi_nle_integration.ipynb | 366 +++++++++++------------ 1 file changed, 183 insertions(+), 183 deletions(-) diff --git a/docs/tutorials/sbi_nle_integration.ipynb b/docs/tutorials/sbi_nle_integration.ipynb index 1909300fc..2d93d1ca0 100644 --- a/docs/tutorials/sbi_nle_integration.ipynb +++ b/docs/tutorials/sbi_nle_integration.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "6ba45e52", + "id": "d42cd57a", "metadata": {}, "source": [ "# Integrating sbi-trained likelihoods into HSSM (NRE via ONNX)\n", @@ -42,7 +42,7 @@ }, { "cell_type": "markdown", - "id": "35d3e93c", + "id": "5d2232fd", "metadata": {}, "source": [ "## Part 1 — Setup" @@ -51,7 +51,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f952caef", + "id": "167f1de6", "metadata": {}, "outputs": [], "source": [ @@ -73,7 +73,7 @@ "from torch import nn\n", "\n", "import hssm\n", - "from sbi.inference import NRE_B\n", + "from sbi.inference import NRE_A\n", "from sbi.neural_nets import classifier_nn\n", "from sbi.utils import BoxUniform\n", "from ssms.basic_simulators.simulator import simulator\n", @@ -118,18 +118,18 @@ "np.random.seed(0)\n", "torch.manual_seed(0)\n", "\n", - "# Training budget for NRE_B (atomic contrastive estimation).\n", - "# Multi-sample-per-θ: 300k unique θ values × 3 samples = 900k pairs.\n", - "# Comparable scale to a LAN training run; better local conditional structure\n", - "# than 1M θ × 1 sample because the classifier sees varied x at each θ.\n", - "N_THETAS = 300_000\n", - "N_SAMPLES_PER_THETA = 3\n", + "# Training budget for NRE_A — the last known-working configuration.\n", + "# 1M (theta, x) pairs from ssm-simulators (1 sample per theta) at the wider\n", + "# HSSM-default prior bounds, with a moderate-sized MLP classifier.\n", + "# We are deliberately reverting to a simple, previously-validated setup to\n", + "# rule out new variables (NRE_B / multi-sample-per-theta / FCEmbedding /\n", + "# bigger classifier) before introducing them again one at a time.\n", + "N_TRAIN = 1_000_000\n", "N_OBS = 500\n", "NUM_EPOCHS = 300\n", "STOP_AFTER_EPOCHS = 50\n", "TRAINING_BATCH_SIZE = 500\n", - "HIDDEN_FEATURES = 128\n", - "NRE_NUM_ATOMS = 20 # NRE_B atomic contrastive size; default 10\n", + "HIDDEN_FEATURES = 100\n", "# Smaller MCMC budget for verification runs. Together with target_accept=0.8\n", "# and max_tree_depth=8 below, this keeps the MCMC step bounded so we can\n", "# diagnose training-quality issues without paying 30+ minutes per pass.\n", @@ -140,7 +140,7 @@ }, { "cell_type": "markdown", - "id": "53fa1fac", + "id": "d41a3793", "metadata": {}, "source": [ "## Part 2 — Simulate observed DDM data\n", @@ -153,7 +153,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d12a134e", + "id": "5363a591", "metadata": {}, "outputs": [], "source": [ @@ -181,7 +181,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9b0476dd", + "id": "f14791f7", "metadata": {}, "outputs": [], "source": [ @@ -205,36 +205,38 @@ }, { "cell_type": "markdown", - "id": "1ca5a987", + "id": "d0a3e62d", "metadata": {}, "source": [ - "## Part 3 — Train an sbi NRE_B classifier on DDM simulations\n", - "\n", - "`NRE_B` learns a classifier that distinguishes joint `(θ, x)` pairs from marginal\n", - "`(θ', x)` pairs (where θ' is drawn from the prior). The output logit equals\n", - "`log p(x | θ) − log p(x)` up to a constant, so it serves directly as the\n", - "HSSM log-likelihood for MCMC (the θ-independent constant drops out).\n", - "\n", - "We use **NRE_B** rather than NRE_A because NRE_B's atomic contrastive estimation\n", - "(`num_atoms=20`) gives a sharper discriminative signal — at each gradient step,\n", - "the classifier scores one positive vs. `num_atoms − 1` marginals, multiclass\n", - "softmax over the row. NRE_A's plain binary classifier (1 positive vs 1 marginal)\n", - "trains faster but reaches a less calibrated optimum.\n", - "\n", - "We also use a **multi-sample-per-θ training set**: 300k distinct θ values, 3\n", - "simulated `(rt, choice)` per θ. This gives the classifier richer local\n", - "information about the per-θ distribution shape than 1 sample per θ would.\n", - "\n", - "(An earlier iteration also added an `FCEmbedding` on θ. It produced a\n", - "classifier that gave HSSM a near-constant log-likelihood at MCMC time, so it\n", - "was removed as the first step of a bisect — see the comment in the\n", - "classifier-builder cell below.)" + "## Part 3 — Train an sbi NRE_A classifier on DDM simulations\n", + "\n", + "`NRE_A` (Hermans et al. 2020) learns a binary classifier that distinguishes\n", + "joint `(θ, x)` pairs from marginal `(θ', x)` pairs (where θ' is drawn from the\n", + "prior). The output logit equals `log p(x | θ) − log p(x)` up to a constant, so\n", + "it serves directly as the HSSM log-likelihood for MCMC (the θ-independent\n", + "constant drops out under MCMC's accept ratios).\n", + "\n", + "This iteration uses the **last known-working NRE configuration**:\n", + "\n", + "- `NRE_A` (binary classifier, not contrastive)\n", + "- 1M `(θ, x)` training pairs (1 sample per θ)\n", + "- `hidden_features = 100` (sbi default is 50)\n", + "- No embedding net on θ\n", + "- `norm_layer = nn.Identity` (jaxonnxruntime doesn't implement\n", + " `LayerNormalization`, so the MLP norm layer is disabled)\n", + "\n", + "We are reverting to this baseline because a more ambitious configuration\n", + "(NRE_B + atomic contrastive + multi-sample-per-θ + FCEmbedding +\n", + "`hidden_features=128`) produced a classifier that gave HSSM a near-constant\n", + "log-likelihood at MCMC time — the chains explored the entire prior with no\n", + "concentration. The bisect strategy is to verify this simpler config works,\n", + "then re-introduce changes one at a time." ] }, { "cell_type": "code", "execution_count": null, - "id": "e3041de7", + "id": "068ece87", "metadata": {}, "outputs": [], "source": [ @@ -242,55 +244,51 @@ " low=torch.from_numpy(PRIOR_LOW),\n", " high=torch.from_numpy(PRIOR_HIGH),\n", ")\n", - "theta_unique = prior.sample((N_THETAS,))\n", + "theta_train = prior.sample((N_TRAIN,))\n", "\n", - "# Batched ssm-simulators: theta shape (N, 4) with n_samples=k → rts/choices of\n", - "# shape (N, k). Massively faster than a Python loop for large N.\n", + "# Batched ssm-simulators: theta of shape (N, 4) with n_samples=1 returns\n", + "# rts/choices of shape (N, 1). Much faster than a Python loop for large N.\n", "sim = simulator(\n", - " theta=theta_unique.numpy().astype(np.float32),\n", + " theta=theta_train.numpy().astype(np.float32),\n", " model=\"ddm\",\n", - " n_samples=N_SAMPLES_PER_THETA,\n", + " n_samples=1,\n", ")\n", - "rts_flat = sim[\"rts\"].reshape(-1).astype(np.float32) # (N * k,)\n", - "choices_flat = sim[\"choices\"].reshape(-1).astype(np.float32)\n", - "x_train = torch.from_numpy(np.stack([rts_flat, choices_flat], axis=-1))\n", - "theta_train = theta_unique.repeat_interleave(N_SAMPLES_PER_THETA, dim=0)\n", - "print(f\"training set: theta={theta_train.shape}, x={x_train.shape} \"\n", - " f\"(N_THETAS={N_THETAS} × N_SAMPLES_PER_THETA={N_SAMPLES_PER_THETA})\")" + "x_train = torch.from_numpy(\n", + " np.stack(\n", + " [sim[\"rts\"].squeeze(-1), sim[\"choices\"].squeeze(-1)], axis=-1\n", + " ).astype(np.float32)\n", + ")\n", + "print(f\"training set: theta={theta_train.shape}, x={x_train.shape}\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "b92539b7", + "id": "10ff69d8", "metadata": {}, "outputs": [], "source": [ - "# Build the classifier. LayerNorm is disabled (jaxonnxruntime doesn't\n", - "# implement LayerNormalization). We deliberately do NOT use an FCEmbedding on\n", - "# theta in this iteration — a previous run with that embedding produced a\n", - "# classifier that gave HSSM a near-constant log-likelihood (chains explored\n", - "# the entire prior). Bisect: remove the embedding first, leave the other\n", - "# changes in place (NRE_B, num_atoms=20, multi-sample-per-θ, hidden=128).\n", + "# Build the classifier. LayerNorm is disabled because jaxonnxruntime\n", + "# doesn't implement LayerNormalization. No embedding net on theta in this\n", + "# baseline-revert iteration (see Part 3 markdown for the bisect context).\n", "classifier_builder = classifier_nn(\n", " model=\"mlp\",\n", " norm_layer=nn.Identity,\n", " hidden_features=HIDDEN_FEATURES,\n", ")\n", - "inference_nre = NRE_B(prior=prior, classifier=classifier_builder)\n", + "inference_nre = NRE_A(prior=prior, classifier=classifier_builder)\n", "classifier_nre = inference_nre.append_simulations(theta_train, x_train).train(\n", " training_batch_size=TRAINING_BATCH_SIZE,\n", " max_num_epochs=NUM_EPOCHS,\n", " stop_after_epochs=STOP_AFTER_EPOCHS,\n", - " num_atoms=NRE_NUM_ATOMS,\n", ")\n", "classifier_nre.eval()\n", - "print(\"NRE_B training complete\")" + "print(\"NRE_A training complete\")" ] }, { "cell_type": "markdown", - "id": "88b7a509", + "id": "c64455d8", "metadata": {}, "source": [ "## Part 4 — Export the trained NRE to ONNX\n", @@ -303,7 +301,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ca509e35", + "id": "006f3c70", "metadata": {}, "outputs": [], "source": [ @@ -323,7 +321,119 @@ }, { "cell_type": "markdown", - "id": "275a3265", + "id": "0109dd09", + "metadata": {}, + "source": [ + "## Part 4b — Pre-MCMC verification: is the trained classifier any good?\n", + "\n", + "Before paying the multi-minute MCMC cost, sanity-check two things:\n", + "\n", + "1. **Logit sweep across θ-space.** Hold three θ dimensions at their true values\n", + " and sweep the fourth across its prior range, plotting the summed classifier\n", + " log-ratio on the observed data. A well-trained NRE shows a sharp peak near\n", + " the true value with tens-to-hundreds of log units of vertical range. A\n", + " flat curve (< 5 log units of range) is the smoking gun for \"classifier\n", + " collapsed\" — MCMC will produce a posterior equal to the prior, no point\n", + " continuing.\n", + "\n", + "2. **ONNX export round-trip.** Compare `classifier_nre(theta, x).item()` to the\n", + " exported ONNX file evaluated through `onnxruntime` on the same input. If\n", + " they disagree, the export has introduced a bug and MCMC won't be running\n", + " the model you think it is.\n", + "\n", + "These cells use the in-memory `classifier_nre` from Part 3 and the exported\n", + "file from Part 4 — they're cheap, no MCMC required." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20d7aee0", + "metadata": {}, + "outputs": [], + "source": [ + "# Diagnostic: how much does the NRE logit change as we sweep each theta dim?\n", + "sweep_pts = 50\n", + "sweep_results = {}\n", + "obs_x_t = torch.from_numpy(\n", + " obs_data[[\"rt\", \"response\"]].values.astype(np.float32)\n", + ")\n", + "theta_center = torch.tensor(TRUE_THETA, dtype=torch.float32)\n", + "for dim, name in enumerate(DDM_PARAM_NAMES):\n", + " sweep = torch.linspace(PRIOR_LOW[dim], PRIOR_HIGH[dim], sweep_pts)\n", + " logits = []\n", + " for v in sweep:\n", + " theta = theta_center.clone()\n", + " theta[dim] = v\n", + " theta_row = theta.unsqueeze(0).repeat(len(obs_x_t), 1)\n", + " with torch.no_grad():\n", + " logits.append(classifier_nre(theta_row, obs_x_t).sum().item())\n", + " sweep_results[name] = (sweep.numpy(), np.array(logits))\n", + "\n", + "fig, axes = plt.subplots(1, 4, figsize=(16, 3.5))\n", + "for ax, name in zip(axes, DDM_PARAM_NAMES):\n", + " th, lp = sweep_results[name]\n", + " ax.plot(th, lp - lp.max(), \"C0-\", linewidth=2)\n", + " ax.axvline(TRUE_THETA[DDM_PARAM_NAMES.index(name)], color=\"red\",\n", + " linestyle=\"--\", linewidth=2, label=\"true θ\")\n", + " ax.set_xlabel(name)\n", + " ax.set_ylabel(\"Δ summed log-ratio (= 0 at max)\")\n", + " ax.set_title(f\"sweep over {name}\\n(vertical range = {lp.ptp():.2f})\")\n", + " ax.legend(fontsize=8)\n", + "fig.suptitle(\n", + " \"Trained NRE classifier: log-ratio along each θ axis (others held at truth)\",\n", + " y=1.02,\n", + ")\n", + "plt.tight_layout()\n", + "plt.show()\n", + "\n", + "# Sanity print: vertical range per dim. < ~5 log units is bad.\n", + "print(\"\\nPer-dim vertical range (max log-ratio − min log-ratio):\")\n", + "for name in DDM_PARAM_NAMES:\n", + " _, lp = sweep_results[name]\n", + " print(f\" {name}: {lp.ptp():.2f}\")\n", + "print(\"\\nInterpretation:\")\n", + "print(\" > 50 log units : strong discriminative signal, classifier well-trained.\")\n", + "print(\" 10–50 : moderate signal; should give a usable posterior.\")\n", + "print(\" < 5 : near-flat; classifier collapsed to uninformative.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e82b044f", + "metadata": {}, + "outputs": [], + "source": [ + "# Diagnostic: does the exported ONNX match the torch classifier?\n", + "import onnxruntime as _ort\n", + "\n", + "_sess = _ort.InferenceSession(str(nre_onnx_path))\n", + "_input_name = _sess.get_inputs()[0].name\n", + "_test_theta = torch.tensor([[0.5, 1.2, 0.5, 0.25]], dtype=torch.float32)\n", + "_test_x = torch.tensor([[0.5, 1.0]], dtype=torch.float32)\n", + "_combined = (\n", + " torch.cat([_test_theta, _test_x], dim=-1).squeeze(0).numpy().astype(np.float32)\n", + ")\n", + "\n", + "with torch.no_grad():\n", + " _y_torch = float(classifier_nre(_test_theta, _test_x).item())\n", + "_y_ort = float(_sess.run(None, {_input_name: _combined})[0])\n", + "\n", + "print(f\"torch logit at (θ=true, x=(0.5, +1)): {_y_torch:+.5f}\")\n", + "print(f\"ORT logit at (θ=true, x=(0.5, +1)): {_y_ort:+.5f}\")\n", + "print(f\"|Δ|: {abs(_y_torch - _y_ort):.2e}\")\n", + "print()\n", + "if abs(_y_torch - _y_ort) < 1e-4:\n", + " print(\"→ Export is faithful; any pathology is in the trained classifier itself.\")\n", + "else:\n", + " print(\"→ Export disagrees with torch by more than 1e-4 — the exporter is dropping\"\n", + " \" information on the way through ONNX, regardless of training quality.\")" + ] + }, + { + "cell_type": "markdown", + "id": "dccfec2c", "metadata": {}, "source": [ "## Part 5 — High-level integration via `hssm.HSSM()`\n", @@ -336,7 +446,7 @@ { "cell_type": "code", "execution_count": null, - "id": "76d4d855", + "id": "31cbf55f", "metadata": {}, "outputs": [], "source": [ @@ -353,7 +463,7 @@ { "cell_type": "code", "execution_count": null, - "id": "22c2cd49", + "id": "4c8fd15e", "metadata": {}, "outputs": [], "source": [ @@ -375,7 +485,7 @@ { "cell_type": "code", "execution_count": null, - "id": "11f20d1d", + "id": "fd23a695", "metadata": {}, "outputs": [], "source": [ @@ -386,7 +496,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8548f513", + "id": "5aadc41f", "metadata": {}, "outputs": [], "source": [ @@ -397,7 +507,7 @@ }, { "cell_type": "markdown", - "id": "a458b3c5", + "id": "474cd683", "metadata": {}, "source": [ "### Part 5b — Diagnostic: is the NRE classifier itself biased, or is HSSM not finding its mode?\n", @@ -413,7 +523,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8847cfc9", + "id": "647c3072", "metadata": {}, "outputs": [], "source": [ @@ -455,117 +565,7 @@ }, { "cell_type": "markdown", - "id": "04db652f", - "metadata": {}, - "source": [ - "### Part 5c — Deeper diagnostics: classifier shape + export round-trip\n", - "\n", - "If the diagnostic above prints a Δ near zero, that's ambiguous: it could mean\n", - "\"NRE found the truth, just MCMC didn't sample around it\" OR \"NRE's logit is\n", - "nearly flat everywhere, so all θ look equally good to MCMC.\" To distinguish\n", - "these we need to look at the trained classifier directly.\n", - "\n", - "The two cells below:\n", - "1. **Sweep the classifier logit across each θ dimension** (holding the others\n", - " at the true values) and plot the response. A well-trained classifier shows\n", - " a sharp peak near the red truth line with a large vertical range (tens to\n", - " hundreds of log units). A poorly-trained classifier produces a nearly flat\n", - " curve — the smoking gun for \"MCMC samples the prior because the loglik is\n", - " uninformative.\"\n", - "2. **Compare the exported ONNX output to the torch classifier output** on the\n", - " same input. If they agree to ~1e-5, the export is fine and any pathology\n", - " is in training; if they differ, the export is broken on the bigger network." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8ca58298", - "metadata": {}, - "outputs": [], - "source": [ - "# Diagnostic: how much does the NRE logit change as we sweep each theta dim?\n", - "sweep_pts = 50\n", - "sweep_results = {}\n", - "obs_x_t = torch.from_numpy(\n", - " obs_data[[\"rt\", \"response\"]].values.astype(np.float32)\n", - ")\n", - "theta_center = torch.tensor(TRUE_THETA, dtype=torch.float32)\n", - "for dim, name in enumerate(DDM_PARAM_NAMES):\n", - " sweep = torch.linspace(PRIOR_LOW[dim], PRIOR_HIGH[dim], sweep_pts)\n", - " logits = []\n", - " for v in sweep:\n", - " theta = theta_center.clone()\n", - " theta[dim] = v\n", - " theta_row = theta.unsqueeze(0).repeat(len(obs_x_t), 1)\n", - " with torch.no_grad():\n", - " logits.append(classifier_nre(theta_row, obs_x_t).sum().item())\n", - " sweep_results[name] = (sweep.numpy(), np.array(logits))\n", - "\n", - "fig, axes = plt.subplots(1, 4, figsize=(16, 3.5))\n", - "for ax, name in zip(axes, DDM_PARAM_NAMES):\n", - " th, lp = sweep_results[name]\n", - " ax.plot(th, lp - lp.max(), \"C0-\", linewidth=2)\n", - " ax.axvline(TRUE_THETA[DDM_PARAM_NAMES.index(name)], color=\"red\",\n", - " linestyle=\"--\", linewidth=2, label=\"true θ\")\n", - " ax.set_xlabel(name)\n", - " ax.set_ylabel(\"Δ summed log-ratio (= 0 at max)\")\n", - " ax.set_title(f\"sweep over {name}\\n(vertical range = {lp.ptp():.2f})\")\n", - " ax.legend(fontsize=8)\n", - "fig.suptitle(\n", - " \"Trained NRE classifier: log-ratio along each θ axis (others held at truth)\",\n", - " y=1.02,\n", - ")\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "# Sanity print: vertical range per dim. < ~5 log units is bad.\n", - "print(\"\\nPer-dim vertical range (max log-ratio − min log-ratio):\")\n", - "for name in DDM_PARAM_NAMES:\n", - " _, lp = sweep_results[name]\n", - " print(f\" {name}: {lp.ptp():.2f}\")\n", - "print(\"\\nInterpretation:\")\n", - "print(\" > 50 log units : strong discriminative signal, classifier well-trained.\")\n", - "print(\" 10–50 : moderate signal; should give a usable posterior.\")\n", - "print(\" < 5 : near-flat; classifier collapsed to uninformative.\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ad889270", - "metadata": {}, - "outputs": [], - "source": [ - "# Diagnostic: does the exported ONNX match the torch classifier?\n", - "import onnxruntime as _ort\n", - "\n", - "_sess = _ort.InferenceSession(str(nre_onnx_path))\n", - "_input_name = _sess.get_inputs()[0].name\n", - "_test_theta = torch.tensor([[0.5, 1.2, 0.5, 0.25]], dtype=torch.float32)\n", - "_test_x = torch.tensor([[0.5, 1.0]], dtype=torch.float32)\n", - "_combined = (\n", - " torch.cat([_test_theta, _test_x], dim=-1).squeeze(0).numpy().astype(np.float32)\n", - ")\n", - "\n", - "with torch.no_grad():\n", - " _y_torch = float(classifier_nre(_test_theta, _test_x).item())\n", - "_y_ort = float(_sess.run(None, {_input_name: _combined})[0])\n", - "\n", - "print(f\"torch logit at (θ=true, x=(0.5, +1)): {_y_torch:+.5f}\")\n", - "print(f\"ORT logit at (θ=true, x=(0.5, +1)): {_y_ort:+.5f}\")\n", - "print(f\"|Δ|: {abs(_y_torch - _y_ort):.2e}\")\n", - "print()\n", - "if abs(_y_torch - _y_ort) < 1e-4:\n", - " print(\"→ Export is faithful; any pathology is in the trained classifier itself.\")\n", - "else:\n", - " print(\"→ Export disagrees with torch by more than 1e-4. The bigger network or\"\n", - " \" FCEmbedding likely introduced an op the exporter handles incorrectly.\")" - ] - }, - { - "cell_type": "markdown", - "id": "79c2b07e", + "id": "b44a929d", "metadata": {}, "source": [ "## Part 6 — Ground-truth posterior via HSSM's analytical DDM\n", @@ -585,7 +585,7 @@ { "cell_type": "code", "execution_count": null, - "id": "beb95242", + "id": "0998ff1a", "metadata": {}, "outputs": [], "source": [ @@ -609,7 +609,7 @@ }, { "cell_type": "markdown", - "id": "60f0c2c0", + "id": "2d1f56d7", "metadata": {}, "source": [ "## Part 7 — Posterior comparison: analytical vs sbi NRE\n", @@ -632,7 +632,7 @@ { "cell_type": "code", "execution_count": null, - "id": "afbb5ede", + "id": "e264f7d3", "metadata": {}, "outputs": [], "source": [ @@ -662,7 +662,7 @@ }, { "cell_type": "markdown", - "id": "a82c8c43", + "id": "0916de11", "metadata": {}, "source": [ "## Summary and deferred work\n", From 0e14c3ed25b475d7d49b5a02cadb20c9bcf3f635 Mon Sep 17 00:00:00 2001 From: Alexander Date: Sun, 17 May 2026 15:21:59 -0400 Subject: [PATCH 16/20] =?UTF-8?q?fix(tutorials):=20use=20np.ptp()=20?= =?UTF-8?q?=E2=80=94=20ndarray.ptp()=20removed=20in=20NumPy=202.0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ndarray.ptp() was deprecated in NumPy 1.25 and removed in NumPy 2.0. The Part 4b sweep diagnostic cell called lp.ptp() in two places (the plot title and the per-dim print loop) and raised AttributeError on NumPy 2.x environments. Replace both with np.ptp(lp), which is the documented NumPy 2 equivalent. Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/tutorials/sbi_nle_integration.ipynb | 56 ++++++++++++------------ 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/docs/tutorials/sbi_nle_integration.ipynb b/docs/tutorials/sbi_nle_integration.ipynb index 2d93d1ca0..aa3012b2b 100644 --- a/docs/tutorials/sbi_nle_integration.ipynb +++ b/docs/tutorials/sbi_nle_integration.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "d42cd57a", + "id": "c5b5342c", "metadata": {}, "source": [ "# Integrating sbi-trained likelihoods into HSSM (NRE via ONNX)\n", @@ -42,7 +42,7 @@ }, { "cell_type": "markdown", - "id": "5d2232fd", + "id": "8d15a24a", "metadata": {}, "source": [ "## Part 1 — Setup" @@ -51,7 +51,7 @@ { "cell_type": "code", "execution_count": null, - "id": "167f1de6", + "id": "99a29072", "metadata": {}, "outputs": [], "source": [ @@ -140,7 +140,7 @@ }, { "cell_type": "markdown", - "id": "d41a3793", + "id": "dbe9e2bd", "metadata": {}, "source": [ "## Part 2 — Simulate observed DDM data\n", @@ -153,7 +153,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5363a591", + "id": "ce250bb6", "metadata": {}, "outputs": [], "source": [ @@ -181,7 +181,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f14791f7", + "id": "a0c34c88", "metadata": {}, "outputs": [], "source": [ @@ -205,7 +205,7 @@ }, { "cell_type": "markdown", - "id": "d0a3e62d", + "id": "d49088e8", "metadata": {}, "source": [ "## Part 3 — Train an sbi NRE_A classifier on DDM simulations\n", @@ -236,7 +236,7 @@ { "cell_type": "code", "execution_count": null, - "id": "068ece87", + "id": "f383c5f1", "metadata": {}, "outputs": [], "source": [ @@ -264,7 +264,7 @@ { "cell_type": "code", "execution_count": null, - "id": "10ff69d8", + "id": "483214d2", "metadata": {}, "outputs": [], "source": [ @@ -288,7 +288,7 @@ }, { "cell_type": "markdown", - "id": "c64455d8", + "id": "fdc22335", "metadata": {}, "source": [ "## Part 4 — Export the trained NRE to ONNX\n", @@ -301,7 +301,7 @@ { "cell_type": "code", "execution_count": null, - "id": "006f3c70", + "id": "5f536645", "metadata": {}, "outputs": [], "source": [ @@ -321,7 +321,7 @@ }, { "cell_type": "markdown", - "id": "0109dd09", + "id": "3d355ed9", "metadata": {}, "source": [ "## Part 4b — Pre-MCMC verification: is the trained classifier any good?\n", @@ -348,7 +348,7 @@ { "cell_type": "code", "execution_count": null, - "id": "20d7aee0", + "id": "a4148b2d", "metadata": {}, "outputs": [], "source": [ @@ -378,7 +378,7 @@ " linestyle=\"--\", linewidth=2, label=\"true θ\")\n", " ax.set_xlabel(name)\n", " ax.set_ylabel(\"Δ summed log-ratio (= 0 at max)\")\n", - " ax.set_title(f\"sweep over {name}\\n(vertical range = {lp.ptp():.2f})\")\n", + " ax.set_title(f\"sweep over {name}\\n(vertical range = {np.ptp(lp):.2f})\")\n", " ax.legend(fontsize=8)\n", "fig.suptitle(\n", " \"Trained NRE classifier: log-ratio along each θ axis (others held at truth)\",\n", @@ -391,7 +391,7 @@ "print(\"\\nPer-dim vertical range (max log-ratio − min log-ratio):\")\n", "for name in DDM_PARAM_NAMES:\n", " _, lp = sweep_results[name]\n", - " print(f\" {name}: {lp.ptp():.2f}\")\n", + " print(f\" {name}: {np.ptp(lp):.2f}\")\n", "print(\"\\nInterpretation:\")\n", "print(\" > 50 log units : strong discriminative signal, classifier well-trained.\")\n", "print(\" 10–50 : moderate signal; should give a usable posterior.\")\n", @@ -401,7 +401,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e82b044f", + "id": "ba5082cf", "metadata": {}, "outputs": [], "source": [ @@ -433,7 +433,7 @@ }, { "cell_type": "markdown", - "id": "dccfec2c", + "id": "6bdc1ff3", "metadata": {}, "source": [ "## Part 5 — High-level integration via `hssm.HSSM()`\n", @@ -446,7 +446,7 @@ { "cell_type": "code", "execution_count": null, - "id": "31cbf55f", + "id": "68a5d9f5", "metadata": {}, "outputs": [], "source": [ @@ -463,7 +463,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4c8fd15e", + "id": "e2ecc3a8", "metadata": {}, "outputs": [], "source": [ @@ -485,7 +485,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fd23a695", + "id": "be7ca299", "metadata": {}, "outputs": [], "source": [ @@ -496,7 +496,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5aadc41f", + "id": "8c4d79b0", "metadata": {}, "outputs": [], "source": [ @@ -507,7 +507,7 @@ }, { "cell_type": "markdown", - "id": "474cd683", + "id": "7b6190be", "metadata": {}, "source": [ "### Part 5b — Diagnostic: is the NRE classifier itself biased, or is HSSM not finding its mode?\n", @@ -523,7 +523,7 @@ { "cell_type": "code", "execution_count": null, - "id": "647c3072", + "id": "3ff1da78", "metadata": {}, "outputs": [], "source": [ @@ -565,7 +565,7 @@ }, { "cell_type": "markdown", - "id": "b44a929d", + "id": "15d935a3", "metadata": {}, "source": [ "## Part 6 — Ground-truth posterior via HSSM's analytical DDM\n", @@ -585,7 +585,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0998ff1a", + "id": "13386f5d", "metadata": {}, "outputs": [], "source": [ @@ -609,7 +609,7 @@ }, { "cell_type": "markdown", - "id": "2d1f56d7", + "id": "04f905c3", "metadata": {}, "source": [ "## Part 7 — Posterior comparison: analytical vs sbi NRE\n", @@ -632,7 +632,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e264f7d3", + "id": "5fa03b3f", "metadata": {}, "outputs": [], "source": [ @@ -662,7 +662,7 @@ }, { "cell_type": "markdown", - "id": "0916de11", + "id": "ba353cde", "metadata": {}, "source": [ "## Summary and deferred work\n", From f7723253804d035eab01d4b368edce60937d86f2 Mon Sep 17 00:00:00 2001 From: Alexander Date: Mon, 18 May 2026 19:52:14 -0400 Subject: [PATCH 17/20] docs(tutorial): user-configurable ARTIFACT_DIR + TUTORIAL_LOG_DIR Replaces two hardcoded relative paths in the sbi tutorial: 1. Part 4: ./sbi_onnx_artifacts/ -> ARTIFACT_DIR (default ~/sbi_onnx_tutorial/). User can override to a project-local dir or tempfile.mkdtemp() via two examples in the cell comment. 2. Part 3: sbi's training tensorboard logs were going to ./sbi-logs/ relative to cwd (different paths for different Jupyter launch contexts - one in the notebook dir, one at the repo root). Now wires NRE_A(..., summary_writer=SummaryWriter(log_dir=str( TUTORIAL_LOG_DIR))) with a default ~/sbi_logs_tutorial/, so all runs write to the same predictable location regardless of cwd. Why: the previous defaults wrote into whatever directory the notebook was running from, which for typical setups means docs/tutorials/ inside the HSSM repo. Re-running the notebook would accumulate untracked training logs and ONNX artifacts in the working tree. Moving the defaults outside the repo eliminates the footgun. Comment in each affected cell points to the override pattern so users who want artifacts kept nearby (e.g. for downstream MCMC re-runs on a saved checkpoint) can override in one line. Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/tutorials/sbi_nle_integration.ipynb | 36 ++---------------------- 1 file changed, 3 insertions(+), 33 deletions(-) diff --git a/docs/tutorials/sbi_nle_integration.ipynb b/docs/tutorials/sbi_nle_integration.ipynb index aa3012b2b..eb444dd6d 100644 --- a/docs/tutorials/sbi_nle_integration.ipynb +++ b/docs/tutorials/sbi_nle_integration.ipynb @@ -267,24 +267,7 @@ "id": "483214d2", "metadata": {}, "outputs": [], - "source": [ - "# Build the classifier. LayerNorm is disabled because jaxonnxruntime\n", - "# doesn't implement LayerNormalization. No embedding net on theta in this\n", - "# baseline-revert iteration (see Part 3 markdown for the bisect context).\n", - "classifier_builder = classifier_nn(\n", - " model=\"mlp\",\n", - " norm_layer=nn.Identity,\n", - " hidden_features=HIDDEN_FEATURES,\n", - ")\n", - "inference_nre = NRE_A(prior=prior, classifier=classifier_builder)\n", - "classifier_nre = inference_nre.append_simulations(theta_train, x_train).train(\n", - " training_batch_size=TRAINING_BATCH_SIZE,\n", - " max_num_epochs=NUM_EPOCHS,\n", - " stop_after_epochs=STOP_AFTER_EPOCHS,\n", - ")\n", - "classifier_nre.eval()\n", - "print(\"NRE_A training complete\")" - ] + "source": "from torch.utils.tensorboard import SummaryWriter\n\n# User-configurable: where sbi writes tensorboard training logs. Default is\n# ~/sbi_logs_tutorial — outside the HSSM repo so notebook re-runs don't leave\n# artifacts in your working tree. Set to None below (and drop summary_writer\n# from the NRE_A call) to use sbi's default of ./sbi-logs/ relative to cwd.\nTUTORIAL_LOG_DIR = Path.home() / \"sbi_logs_tutorial\"\nTUTORIAL_LOG_DIR.mkdir(parents=True, exist_ok=True)\n\n# Build the classifier. LayerNorm is disabled because jaxonnxruntime\n# doesn't implement LayerNormalization. No embedding net on theta in this\n# baseline-revert iteration (see Part 3 markdown for the bisect context).\nclassifier_builder = classifier_nn(\n model=\"mlp\",\n norm_layer=nn.Identity,\n hidden_features=HIDDEN_FEATURES,\n)\ninference_nre = NRE_A(\n prior=prior,\n classifier=classifier_builder,\n summary_writer=SummaryWriter(log_dir=str(TUTORIAL_LOG_DIR)),\n)\nclassifier_nre = inference_nre.append_simulations(theta_train, x_train).train(\n training_batch_size=TRAINING_BATCH_SIZE,\n max_num_epochs=NUM_EPOCHS,\n stop_after_epochs=STOP_AFTER_EPOCHS,\n)\nclassifier_nre.eval()\nprint(\"NRE_A training complete\")" }, { "cell_type": "markdown", @@ -304,20 +287,7 @@ "id": "5f536645", "metadata": {}, "outputs": [], - "source": [ - "onnx_dir = Path(\"./sbi_onnx_artifacts\")\n", - "onnx_dir.mkdir(exist_ok=True)\n", - "nre_onnx_path = onnx_dir / \"ddm_nre.onnx\"\n", - "\n", - "transform_sbi_to_onnx(\n", - " classifier_nre,\n", - " str(nre_onnx_path),\n", - " mode=\"nre\",\n", - " example_theta_dim=4,\n", - " example_x_dim=2,\n", - ")\n", - "print(f\"exported NRE: {nre_onnx_path} ({nre_onnx_path.stat().st_size:,} bytes)\")" - ] + "source": "# User-configurable: where the .onnx file lands. Default is outside the HSSM\n# repo so notebook re-runs don't pollute the working tree.\n# Override examples:\n# ARTIFACT_DIR = Path(\"/path/to/my/project/onnx\") # keep nearby\n# ARTIFACT_DIR = Path(tempfile.mkdtemp()) # ephemeral\nARTIFACT_DIR = Path.home() / \"sbi_onnx_tutorial\"\nARTIFACT_DIR.mkdir(parents=True, exist_ok=True)\nnre_onnx_path = ARTIFACT_DIR / \"ddm_nre.onnx\"\n\ntransform_sbi_to_onnx(\n classifier_nre,\n str(nre_onnx_path),\n mode=\"nre\",\n example_theta_dim=4,\n example_x_dim=2,\n)\nprint(f\"exported NRE: {nre_onnx_path} ({nre_onnx_path.stat().st_size:,} bytes)\")" }, { "cell_type": "markdown", @@ -712,4 +682,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file From 84ca897a523ddb3e0b5f907d183107d7460a4f4f Mon Sep 17 00:00:00 2001 From: Alexander Date: Fri, 22 May 2026 18:29:38 -0400 Subject: [PATCH 18/20] refactor(onnx2jax): use hssm logger instead of warnings.warn Per PR #964 review: HSSM convention is to route messages through logging.getLogger("hssm") rather than warnings.warn. Switches the auto-x64-enabled message to _logger.warning(...) to match the rest of the codebase. No behavioral change for downstream users beyond how the message is surfaced. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/hssm/distribution_utils/onnx_utils/onnx2jax.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/hssm/distribution_utils/onnx_utils/onnx2jax.py b/src/hssm/distribution_utils/onnx_utils/onnx2jax.py index 4973bc821..030b98bdf 100644 --- a/src/hssm/distribution_utils/onnx_utils/onnx2jax.py +++ b/src/hssm/distribution_utils/onnx_utils/onnx2jax.py @@ -1,6 +1,6 @@ """Use jaxonnxruntime to convert ONNX models to JAX functions.""" -import warnings +import logging from typing import Callable import jax @@ -9,6 +9,8 @@ import onnx from jaxonnxruntime import call_onnx, config +_logger = logging.getLogger("hssm") + # torch.onnx.export emits some shape arguments (e.g. for Reshape inside masked # autoregressive flows) as Constant nodes rather than as model initializers. # jaxonnxruntime's default strict mode rejects these as static-args during @@ -75,13 +77,11 @@ def _ensure_x64_if_needed(onnx_model: onnx.ModelProto) -> None: " jax.config.update('jax_enable_x64', True)\n" "at the very top of your script, before any other JAX import." ) - warnings.warn( + _logger.warning( "HSSM auto-enabled `jax_enable_x64` because the loaded ONNX graph " "carries int64 tensors that JAX would otherwise silently truncate. " "To silence this warning, set the flag yourself at the top of your " - "script: `jax.config.update('jax_enable_x64', True)`.", - UserWarning, - stacklevel=3, + "script: `jax.config.update('jax_enable_x64', True)`." ) From dd168fcb0fd98bf0b03dd6efc81289310f35cd2c Mon Sep 17 00:00:00 2001 From: Alexander Date: Fri, 22 May 2026 20:06:58 -0400 Subject: [PATCH 19/20] refactor(onnx2jax): drop _ensure_x64 auto-flip; precast int64; guard dynamic dims Three connected changes informed by a diagnostic experiment. 1. Drop _ensure_x64_if_needed. pytensor's JAX dispatch (pytensor/link/jax/dispatch/basic.py) already sets jax_enable_x64 from pytensor.config.floatX at module import. With HSSM's default floatX=float64, x64 is already on by the time onnx2jax loads -- our auto-flip was redundant in default use and brittle in edge cases (mutated global state, hard-failed if JAX had warmed up). 2. Replace it with _recast_int64_to_int32, a small in-place graph transform that rewrites int64 tensors / Cast targets to int32 at load time. Lossless for the index/shape values torch.onnx.export produces (small non-negative ints, bit-identical truncation), silences the JAX UserWarning under x64=off, and removes any global-state dependency. 3. Add _check_single_trial_input_shape: raise ValueError if any input dim is symbolic / dynamic. jaxonnxruntime traces against the construction-time dummy and bakes the resulting shapes into the returned closure, so a dynamic_axes export called at a different shape silently produces wrong outputs for any graph with a batch-dependent intermediate (e.g. torch.zeros(x.shape[0]) accumulators, Reshape with -1). HSSM's ONNX path is built around single-trial inputs + jax.vmap over trials (see distribution_utils/onnx.py); LANs and LANfactory's transform_sbi_to_onnx already follow this contract. This guard prevents accidental violations from future contributors. Tests: 3 new in tests/distribution_utils/test_onnx.py covering the dynamic-dim guard (positive + negative) and the int64 -> int32 recast. Full ONNX test suite passes. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../distribution_utils/onnx_utils/onnx2jax.py | 181 +++++++++++------- tests/distribution_utils/test_onnx.py | 61 ++++++ 2 files changed, 175 insertions(+), 67 deletions(-) diff --git a/src/hssm/distribution_utils/onnx_utils/onnx2jax.py b/src/hssm/distribution_utils/onnx_utils/onnx2jax.py index 030b98bdf..64e1dfee2 100644 --- a/src/hssm/distribution_utils/onnx_utils/onnx2jax.py +++ b/src/hssm/distribution_utils/onnx_utils/onnx2jax.py @@ -1,119 +1,166 @@ -"""Use jaxonnxruntime to convert ONNX models to JAX functions.""" +"""Use jaxonnxruntime to convert ONNX models to JAX functions. + +This module assumes the **single-trial export contract**: ONNX graphs reaching +``make_jax_func`` are exported with a fully concrete input shape representing +*one* per-trial input vector (e.g. ``(n_params + n_data_cols,)``). HSSM batches +across trials at a layer above this, via ``jax.vmap`` in +``make_jax_logp_funcs_from_onnx`` (see ``hssm.distribution_utils.onnx``). + +If a graph arrives with a symbolic / dynamic dimension, ``make_jax_func`` +raises a ``ValueError`` rather than trying to make it work: jaxonnxruntime +traces against the construction-time dummy shape and bakes the resulting +shapes into the returned closure, so calling that closure with a different +shape silently produces wrong outputs for any graph that carries a +batch-dependent intermediate (e.g. a ``torch.zeros(x.shape[0])`` log-det +accumulator, or a ``Reshape`` whose ``-1`` resolves against the dynamic dim). +LANs and the sbi NRE/NLE exporters in ``LANfactory.onnx`` already follow this +contract; this guard prevents accidental violations from a future contributor. + +On precision: pytensor's JAX dispatch +(``pytensor/link/jax/dispatch/basic.py``) sets ``jax_enable_x64`` from +``pytensor.config.floatX`` at import time. With HSSM's default +``floatX="float64"`` x64 is already on by the time this module loads; +under ``hssm.set_floatX("float32")`` x64 is off. The previous version of +this module also tried to flip ``jax_enable_x64`` at first call; that has +been removed (it duplicated pytensor's contract, mutated global state, and +hard-failed if JAX had already warmed up). Instead we pre-cast int64 +tensors / Cast targets to int32 in the graph at load time -- lossless for +the index/shape values torch.onnx.export produces, and removes the silent +truncation that ``jax_enable_x64=False`` would otherwise apply. +""" import logging from typing import Callable import jax -import jax.numpy as jnp import numpy as np import onnx from jaxonnxruntime import call_onnx, config +# torch.onnx.export emits some shape arguments (e.g. for Reshape) as Constant +# nodes rather than as model initializers. jaxonnxruntime's default strict +# mode rejects these as static args during jax.jit. The flag below relaxes +# that check. This is safe for our use cases: those shapes are constant by +# construction (baked at export time). +config.update("jaxort_only_allow_initializers_as_static_args", False) + _logger = logging.getLogger("hssm") -# torch.onnx.export emits some shape arguments (e.g. for Reshape inside masked -# autoregressive flows) as Constant nodes rather than as model initializers. -# jaxonnxruntime's default strict mode rejects these as static-args during -# jax.jit. The flag below relaxes that check. This is safe for our use cases: -# the shapes in question are genuinely constant, baked at export time. Setting -# it at import time means any consumer of make_jax_func (LAN MLPs, sbi-exported -# flows, etc.) benefits without per-call configuration. -config.update("jaxort_only_allow_initializers_as_static_args", False) +def _recast_int64_to_int32(model: onnx.ModelProto) -> int: + """Rewrite int64 tensors and Cast targets in the graph to int32, in place. -def _graph_has_int64_tensors(model: onnx.ModelProto) -> bool: - """Detect int64 tensors in an ONNX graph. + torch.onnx.export carries int64 metadata (Reshape shape args, Constant + tensors, Cast targets) whose values are indices/shapes that always fit + losslessly in int32. With ``jax_enable_x64=False`` JAX truncates int64 + to int32 implicitly and emits a UserWarning per access. Pre-casting at + load time: - torch.onnx.export of normalizing flows (e.g. nflows MAF) emits int64 - tensors for Reshape shape arguments, Constant node values, Cast targets, - and similar. jaxonnxruntime silently truncates int64 to int32 unless - `jax_enable_x64` is set, producing wrong numerical outputs (~0.5 drift - in log-prob). + * is bit-identical for valid index values (twos-complement of small + non-negative integers is preserved when the upper 32 bits are dropped), + * silences the JAX UserWarning, + * removes any dependency on global JAX state. + + Returns + ------- + int + Number of int64 sites rewritten (0 if none). """ int64 = onnx.TensorProto.INT64 - for init in model.graph.initializer: - if init.data_type == int64: - return True + int32 = onnx.TensorProto.INT32 + n_rewritten = 0 + + def _convert(tensor: onnx.TensorProto) -> None: + nonlocal n_rewritten + if tensor.data_type == int64: + arr = onnx.numpy_helper.to_array(tensor).astype(np.int32) + new = onnx.numpy_helper.from_array(arr, tensor.name) + tensor.CopyFrom(new) + n_rewritten += 1 + + for initializer in model.graph.initializer: + _convert(initializer) for node in model.graph.node: for attr in node.attribute: - if attr.type == onnx.AttributeProto.TENSOR and attr.t.data_type == int64: - return True - if attr.type == onnx.AttributeProto.TENSORS: + if attr.type == onnx.AttributeProto.TENSOR: + _convert(attr.t) + elif attr.type == onnx.AttributeProto.TENSORS: for t in attr.tensors: - if t.data_type == int64: - return True + _convert(t) if node.op_type == "Cast": for attr in node.attribute: if attr.name == "to" and attr.i == int64: - return True - return False + attr.i = int32 + n_rewritten += 1 + return n_rewritten -def _ensure_x64_if_needed(onnx_model: onnx.ModelProto) -> None: - """Auto-enable jax_enable_x64 when the graph requires it. +def _check_single_trial_input_shape(model: onnx.ModelProto) -> None: + """Raise if any input dimension is symbolic / dynamic. - If the graph carries int64 tensors and x64 is off, we attempt to flip the - JAX config flag and verify the change is effective (by checking that a - fresh `jnp.asarray([1.0])` is float64). If the flip does not take — JAX - has already done substantive 32-bit work in this process — raise a clear - RuntimeError directing the user to set the flag at the top of their - script. + HSSM's ONNX-likelihood path is built around single-trial inputs that get + vmapped over trials at a layer above this. jaxonnxruntime, however, + traces the graph against the construction-time dummy shape and bakes + those shapes into the returned closure -- so a graph with dynamic dims + called at a different shape later will produce wrong-but-non-erroring + outputs (the trace re-uses the dummy's broadcast shapes for + batch-dependent intermediates). """ - if not _graph_has_int64_tensors(onnx_model): - return - if jax.config.read("jax_enable_x64"): - return - - jax.config.update("jax_enable_x64", True) - # Verify the flip is effective on fresh JAX ops. - if jnp.asarray([1.0]).dtype != jnp.float64: - raise RuntimeError( - "This ONNX graph carries int64 tensors (typical for torch-exported " - "normalizing flows), which jaxonnxruntime would silently truncate " - "to int32 — producing wrong numerical results. HSSM attempted to " - "auto-enable `jax_enable_x64`, but JAX has already been used in " - "32-bit mode and the flip did not take. Fix: add\n" - " import jax\n" - " jax.config.update('jax_enable_x64', True)\n" - "at the very top of your script, before any other JAX import." + bad: list[str] = [] + for inp in model.graph.input: + for i, dim in enumerate(inp.type.tensor_type.shape.dim): + if dim.dim_value <= 0: + label = dim.dim_param or f"axis {i}" + bad.append(f"{inp.name}[{label}]") + if bad: + raise ValueError( + "ONNX model has dynamic (symbolic) input dimensions: " + f"{', '.join(bad)}. HSSM uses single-trial input shapes and " + "vmaps over trials at a layer above this conversion -- " + "re-export the model with a concrete per-trial input shape " + "(omit `dynamic_axes` in `torch.onnx.export`, or pass a single " + "rank-1 dummy as LANfactory.onnx.transform_sbi_to_onnx does). " + "Dynamic dims here would cause jaxonnxruntime to silently " + "produce wrong outputs for graphs with batch-dependent " + "intermediates (e.g. log-det accumulators)." ) - _logger.warning( - "HSSM auto-enabled `jax_enable_x64` because the loaded ONNX graph " - "carries int64 tensors that JAX would otherwise silently truncate. " - "To silence this warning, set the flag yourself at the top of your " - "script: `jax.config.update('jax_enable_x64', True)`." - ) def make_jax_func(onnx_model: onnx.ModelProto) -> Callable: """Convert an ONNX model to a JAX function using jaxonnxruntime. + The model must have a fully concrete input shape -- see the module + docstring for the single-trial-input + vmap contract. + Parameters ---------- onnx_model : onnx.ModelProto - The ONNX model to be converted. + The ONNX model to be converted. Will be mutated in place to recast + int64 tensors to int32 (lossless for index/shape values produced by + torch.onnx.export). Returns ------- Callable - A JAX function that represents the ONNX model. + A JAX function ``f(x)`` that runs the ONNX graph on ``x``. + + Raises + ------ + ValueError + If the ONNX graph has any dynamic / symbolic input dimension. """ - _ensure_x64_if_needed(onnx_model) + _check_single_trial_input_shape(onnx_model) + _recast_int64_to_int32(onnx_model) model_graph = onnx_model.graph - - # Get the input name and shape from the ONNX model to create a dummy input for - # initialization. input_name = model_graph.input[0].name input_dims = tuple( - dim.dim_value if (dim.dim_value > 0) else 1 - for dim in model_graph.input[0].type.tensor_type.shape.dim + dim.dim_value for dim in model_graph.input[0].type.tensor_type.shape.dim ) model_func, model_weights = call_onnx.call_onnx_model( onnx_model, {input_name: np.ones(input_dims)} ) - # Create a JAX function that takes the input and applies the ONNX model. run_func = jax.tree_util.Partial(model_func, model_weights) jax_func = lambda x: run_func({input_name: x})[0].squeeze() diff --git a/tests/distribution_utils/test_onnx.py b/tests/distribution_utils/test_onnx.py index ee49b6ece..d3ba5a20e 100644 --- a/tests/distribution_utils/test_onnx.py +++ b/tests/distribution_utils/test_onnx.py @@ -11,6 +11,10 @@ make_jax_logp_funcs_from_onnx, make_jax_matrix_logp_funcs_from_onnx, ) +from hssm.distribution_utils.onnx_utils.onnx2jax import ( + _recast_int64_to_int32, + make_jax_func, +) hssm.set_floatX("float32") DECIMAL = 4 @@ -156,3 +160,60 @@ def test_make_simple_jax_logp_funcs_from_onnx(fixture_path): result_simple, decimal=DECIMAL, ) + + +# --------------------------------------------------------------------------- +# Guards introduced when removing the auto-x64 flip in onnx2jax.py. +# --------------------------------------------------------------------------- + + +def _make_minimal_onnx(input_shape): + """Construct a minimal Identity-graph ONNX model with the requested shape. + + ``input_shape`` is an iterable of ``int`` (concrete) or ``str``/``None`` + (symbolic / dynamic). + """ + from onnx import helper, TensorProto + + dims = list(input_shape) + in_tensor = helper.make_tensor_value_info("x", TensorProto.FLOAT, dims) + out_tensor = helper.make_tensor_value_info("y", TensorProto.FLOAT, dims) + node = helper.make_node("Identity", inputs=["x"], outputs=["y"]) + graph = helper.make_graph([node], "minimal", [in_tensor], [out_tensor]) + return helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + + +def test_make_jax_func_rejects_dynamic_input_dim(): + """Dynamic input dims would silently produce wrong outputs -- must raise.""" + model = _make_minimal_onnx(["batch", 5]) # symbolic batch dim + with pytest.raises(ValueError, match="dynamic"): + make_jax_func(model) + + +def test_make_jax_func_accepts_concrete_input_dim(): + """Concrete-shape models pass the guard and produce a callable.""" + model = _make_minimal_onnx([1, 5]) + fn = make_jax_func(model) + out = np.asarray(fn(np.ones((1, 5), dtype=np.float32))) + assert out.shape == (5,) or out.shape == (1, 5) # squeeze may drop the 1 + + +def test_recast_int64_to_int32_rewrites_initializers(): + """The int64 -> int32 pre-cast must rewrite int64 initializers in place.""" + from onnx import helper, numpy_helper, TensorProto + + init = numpy_helper.from_array(np.array([1, 2, 3], dtype=np.int64), "shape") + in_tensor = helper.make_tensor_value_info("x", TensorProto.FLOAT, [1, 3]) + out_tensor = helper.make_tensor_value_info("y", TensorProto.FLOAT, [1, 3]) + node = helper.make_node("Reshape", inputs=["x", "shape"], outputs=["y"]) + graph = helper.make_graph( + [node], "minimal", [in_tensor], [out_tensor], initializer=[init] + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + + assert model.graph.initializer[0].data_type == TensorProto.INT64 + n = _recast_int64_to_int32(model) + assert n == 1 + assert model.graph.initializer[0].data_type == TensorProto.INT32 + # Idempotent: a second call is a no-op. + assert _recast_int64_to_int32(model) == 0 From 5c8275a5e09b0099db4a91cb90004fe4019041fe Mon Sep 17 00:00:00 2001 From: Alexander Date: Fri, 22 May 2026 20:37:47 -0400 Subject: [PATCH 20/20] docs(claude): document the ONNX single-trial + vmap contract Adds a Key Patterns entry codifying the rule that `make_jax_func` now enforces (commit dd168fc): every ONNX graph consumed by HSSM must have a concrete single-trial input shape, with per-trial batching happening at the HSSM layer via jax.vmap. Points at the two enforcement sites (_check_single_trial_input_shape in onnx2jax.py, the vmap wiring in onnx.py) and notes that LANfactory's exporters already follow this convention. The constraint was de facto since the original LANs but only became enforced with dd168fc; this entry makes it discoverable. Co-Authored-By: Claude Opus 4.7 (1M context) --- CLAUDE.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/CLAUDE.md b/CLAUDE.md index 4e3d77237..eb6b163ab 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -77,6 +77,14 @@ uv run ruff format . `HSSM.sample()` passes `**kwargs` through to `bambi.Model.fit()`, which in turn passes them to PyMC's `pm.sample()`. So parameters like `cores`, `chains`, `nuts_sampler`, `target_accept`, etc. are valid even though they don't appear in HSSM's own signature. Similarly, the HSSM constructor passes `**kwargs` to `bambi.Model()`, so bambi parameters like `noncentered` are valid. +### ONNX likelihoods are single-trial + `jax.vmap` + +Every ONNX graph consumed by HSSM must be exported with a concrete single-trial input shape (no `dynamic_axes`). Per-trial batching happens at the HSSM layer via `jax.vmap` over trials — see [`src/hssm/distribution_utils/onnx.py:115-138`](src/hssm/distribution_utils/onnx.py#L115-L138), where `logp(*inputs)` builds one flat per-trial vector and `make_vmap_func` lifts it. + +Enforced at load time by `_check_single_trial_input_shape` in [`src/hssm/distribution_utils/onnx_utils/onnx2jax.py`](src/hssm/distribution_utils/onnx_utils/onnx2jax.py), which raises a `ValueError` on any symbolic input dim. The constraint exists because `jaxonnxruntime` traces against the construction-time dummy and bakes those shapes into the returned closure — calling that closure at a different batch size silently produces wrong outputs for graphs with batch-dependent intermediates (log-det accumulators, `Reshape` with `-1`). + +LANfactory's exporters (`transform_sbi_to_onnx`, BayesFlow LRE export) already follow this convention. A new ONNX source must do the same: trace with a rank-1 dummy, no `dynamic_axes`. + ### Notebook execution in CI Two separate skip mechanisms for notebooks: