Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
319 changes: 319 additions & 0 deletions docs/tutorials/bayesflow_nle_onnx_integration.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,319 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
Comment on lines +1 to +6
"# Integrating bayesflow-trained likelihoods into HSSM (NLE via ONNX)\n",
"\n",
"This notebook covers the bayesflow version of the ONNX-based NLE integration workflow in HSSM. Once you've trained a [bayesflow](https://github.com/bayesflow-org/bayesflow) `ContinuousApproximator` (NLE) or `RatioApproximator` (NRE), you can export it to a single ONNX file and hand the path to HSSM. From HSSM's side, the file is consumed identically to a LAN export or an sbi export — same `loglik=\"file.onnx\"` gesture, no library-specific glue.\n",
"\n",
"Two paths into HSSM exist, side by side:\n",
"\n",
"| Path | Source | Mechanism | When to use |\n",
"|---|---|---|---|\n",
"| `loglik=\"file.onnx\"` | sbi or bayesflow | ONNX file, framework-agnostic | Portability, sharing trained surrogates |\n",
"| `loglik=<jax_callable>` | bayesflow (this tutorial's sibling) | In-memory JAX callable | Fast iteration during model development |\n",
"\n",
"See [`bayesflow_lre_integration.ipynb`](./bayesflow_lre_integration.ipynb) for the JAX-callable path. This notebook covers the ONNX path."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": "## Part 1 — Setup\n\n**Critical**: `KERAS_BACKEND=torch` must be set *before* importing `keras` or `bayesflow`. `torch.onnx.export` cannot trace a JAX-backed Keras model. On Apple silicon also pin `KERAS_TORCH_DEVICE=cpu` to avoid the orthogonal initializer's missing MPS op.\n\n**Note on the two ONNX-related setup lines below** (`jax_enable_x64` and the `jaxonnxruntime` strict-mode relax): these are the same workarounds that HSSM PR #964 plans to auto-handle inside `hssm.distribution_utils.onnx2jax`. Until that PR lands on `main`, this tutorial sets them explicitly so it runs standalone. Once #964 merges, both lines can be deleted — HSSM will take care of them on import."
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "import os\n\nos.environ[\"KERAS_BACKEND\"] = \"torch\"\nos.environ[\"KERAS_TORCH_DEVICE\"] = \"cpu\"\n\nimport jax\n\n# x64 BEFORE other JAX-touching imports; ONNX graphs from torch.onnx.export\n# carry int64 shape/index tensors that get silently truncated under JAX's\n# default int32 mode, producing wrong log-prob values inside HSSM.\njax.config.update(\"jax_enable_x64\", True)\n\n# Relax jaxonnxruntime's default strict mode on Reshape shape arguments —\n# torch.onnx.export emits them as Constant nodes, not initializers, which the\n# strict default rejects. Safe because the shapes are genuinely constant.\n# HSSM PR #964 sets this automatically inside hssm.distribution_utils.onnx2jax;\n# until that PR lands on main, we set it here.\nfrom jaxonnxruntime import config as _jaxonnx_config\n_jaxonnx_config.update(\"jaxort_only_allow_initializers_as_static_args\", False)\n\nimport tempfile # noqa: F401 — available for the Part 4 override example\nfrom pathlib import Path\n\nimport arviz as az\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport pandas as pd\n\nimport bayesflow as bf\nimport keras\nfrom bayesflow.datasets import OfflineDataset\nfrom bayesflow.networks.inference.coupling.transforms import AffineTransform\nfrom ssms.basic_simulators.simulator import simulator\n\nimport hssm\nfrom lanfactory.onnx import transform_bayesflow_to_onnx\n\nprint(\"keras backend:\", keras.backend.backend())\nprint(\"bayesflow: \", bf.__version__)\nprint(\"hssm: \", hssm.__version__)"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Part 2 — Simulate observed DDM data\n",
"\n",
"We use [`ssm-simulators`](https://github.com/AlexanderFengler/ssm-simulators) for ground-truth DDM samples at a known parameter vector. HSSM consumes a `DataFrame` with `rt` and `response` columns."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"DDM_PARAM_NAMES = [\"v\", \"a\", \"z\", \"t\"]\n",
"DDM_PARAM_LOW = np.array([-2.0, 0.6, 0.3, 0.1], dtype=np.float32)\n",
"DDM_PARAM_HIGH = np.array([2.0, 1.8, 0.7, 0.5], dtype=np.float32)\n",
"TRUE_THETA = np.array([0.5, 1.2, 0.5, 0.25], dtype=np.float32)\n",
"N_OBS = 500\n",
"\n",
"out = simulator(theta=TRUE_THETA[None, :], model=\"ddm\", n_samples=N_OBS)\n",
"obs_data = pd.DataFrame({\n",
" \"rt\": out[\"rts\"].squeeze().astype(np.float32),\n",
" \"response\": out[\"choices\"].squeeze().astype(np.float32),\n",
"})\n",
"obs_data.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig, ax = plt.subplots(1, 1, figsize=(8, 4))\n",
"for choice, label in [(1.0, \"choice +1\"), (-1.0, \"choice -1\")]:\n",
" rts = obs_data.loc[obs_data[\"response\"] == choice, \"rt\"]\n",
" ax.hist(rts, bins=40, alpha=0.5, label=label)\n",
"ax.set_xlabel(\"reaction time\")\n",
"ax.set_ylabel(\"count\")\n",
"ax.legend()\n",
"ax.set_title(f\"Observed DDM data at θ={dict(zip(DDM_PARAM_NAMES, TRUE_THETA))}\")\n",
"plt.tight_layout()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Part 3 — Train a bayesflow CouplingFlow NLE on DDM simulations\n",
"\n",
"The CouplingFlow setup below is the v1 ONNX-friendly configuration documented in [LANfactory's `exporting_bayesflow_models.md`](https://github.com/lnccbrown/LANFactory/blob/main/docs/exporting_bayesflow_models.md). Three opinionated choices that aren't bayesflow's defaults:\n",
"\n",
"- `permutation=None` — `FixedPermutation` uses `keras.ops.take`, which exports as `aten::ravel`, unsupported in ONNX opset 17/20.\n",
"- `transform=AffineTransform(clamp=False)` (explicit instance) — default `clamp=True` emits `aten::asinh`. Also, `bf.networks.CouplingFlow(..., transform_kwargs={\"clamp\": False})` silently drops the kwarg (upstream bug); pass an instance.\n",
"- `activation=\"silu\"` — default `\"hard_silu\"` exports as a single fused `HardSwish` op that jaxonnxruntime can't run. Real SiLU decomposes to `Sigmoid` + `Mul`."
]
Comment on lines +83 to +90
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"keras.utils.set_random_seed(0)\n",
"rng = np.random.default_rng(0)\n",
"\n",
"N_TRAIN = 20_000\n",
"theta_train = rng.uniform(\n",
" DDM_PARAM_LOW, DDM_PARAM_HIGH, size=(N_TRAIN, len(DDM_PARAM_NAMES))\n",
").astype(np.float32)\n",
"\n",
"# One trial per (θ_i) — NLE convention: each row is (θ_i, x_i).\n",
"x_train = np.empty((N_TRAIN, 2), dtype=np.float32)\n",
"for i, th in enumerate(theta_train):\n",
" out = simulator(theta=th[None, :], model=\"ddm\", n_samples=1)\n",
" x_train[i, 0] = out[\"rts\"].squeeze()\n",
" x_train[i, 1] = out[\"choices\"].squeeze()\n",
"\n",
"print(\"theta_train:\", theta_train.shape, \" x_train:\", x_train.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"approximator = bf.ContinuousApproximator(\n",
" inference_network=bf.networks.CouplingFlow(\n",
" depth=6,\n",
" subnet_kwargs={\"widths\": (64, 64), \"activation\": \"silu\", \"dropout\": None},\n",
" permutation=None,\n",
" use_actnorm=False,\n",
" transform=AffineTransform(clamp=False),\n",
" ),\n",
" standardize=\"inference_variables\", # standardize x (rt, choice)\n",
")\n",
"approximator.build({\n",
" \"inference_variables\": (None, 2),\n",
" \"inference_conditions\": (None, len(DDM_PARAM_NAMES)),\n",
"})\n",
"approximator.compile(optimizer=keras.optimizers.Adam(learning_rate=5e-4))\n",
"\n",
"dataset = OfflineDataset(\n",
" data={\n",
" \"inference_variables\": x_train,\n",
" \"inference_conditions\": theta_train,\n",
" },\n",
" batch_size=256,\n",
" adapter=None, # MUST be identity for ONNX export; we use only the in-network Standardize layer\n",
")\n",
"history = approximator.fit(dataset=dataset, epochs=50, verbose=1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": "## Part 4 — Export the trained approximator to ONNX\n\nOne call. The exporter raises clearly if any v1 constraint is violated (wrong `KERAS_BACKEND`, non-identity adapter, missing `inference_network`).\n\n**Where to write the file.** The cell below sets `ARTIFACT_DIR` to `~/bayesflow_onnx_tutorial/` — outside the HSSM repo, so re-running the notebook doesn't leave artifacts in your working tree. Change it to any path you want to keep the trained ONNX around for downstream work; set it to `tempfile.mkdtemp()` for an ephemeral demo run."
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "# User-configurable: where the .onnx file lands. Default is outside the HSSM\n# repo so notebook re-runs don't pollute the working tree.\n# Override examples:\n# ARTIFACT_DIR = Path(\"/path/to/my/project/onnx\") # keep nearby\n# ARTIFACT_DIR = Path(tempfile.mkdtemp()) # ephemeral\nARTIFACT_DIR = Path.home() / \"bayesflow_onnx_tutorial\"\nARTIFACT_DIR.mkdir(parents=True, exist_ok=True)\nonnx_path = ARTIFACT_DIR / \"ddm_nle.onnx\"\n\ntransform_bayesflow_to_onnx(\n approximator,\n str(onnx_path),\n mode=\"nle\",\n example_theta_dim=len(DDM_PARAM_NAMES),\n example_x_dim=2,\n)\nprint(f\"wrote {onnx_path} ({onnx_path.stat().st_size:,} bytes)\")"
Comment on lines +152 to +159
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Part 5 — Hand the ONNX to HSSM\n",
"\n",
"The user gesture is identical to the sbi path and the LAN-MLP path. HSSM detects the `.onnx` extension, loads it via `jaxonnxruntime`, vmaps over trials, and wires it into a PyMC `Distribution`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model_nle = hssm.HSSM(\n",
" data=obs_data,\n",
" model=\"ddm\",\n",
" loglik_kind=\"approx_differentiable\",\n",
" loglik=str(onnx_path),\n",
" p_outlier=0,\n",
")\n",
"model_nle"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"idata_nle = model_nle.sample(\n",
" sampler=\"numpyro\",\n",
" draws=1000,\n",
" tune=1000,\n",
" chains=2,\n",
" target_accept=0.9,\n",
Comment thread
AlexanderFengler marked this conversation as resolved.
" mp_ctx=\"spawn\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"summary_nle = az.summary(idata_nle, var_names=DDM_PARAM_NAMES)\n",
"summary_nle"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"az.plot_trace(idata_nle, var_names=DDM_PARAM_NAMES)\n",
"plt.tight_layout()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Part 6 — Ground-truth posterior via HSSM's analytical DDM\n",
"\n",
"DDM has a closed-form likelihood (Navarro & Fuss). Comparing the bayesflow-NLE posterior to this gives a fairness check: any drift comes from the neural approximation, not from the inference machinery."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model_analytical = hssm.HSSM(\n",
" data=obs_data,\n",
" model=\"ddm\",\n",
" p_outlier=0,\n",
" # default loglik_kind here is the analytical Navarro & Fuss path\n",
")\n",
"idata_analytical = model_analytical.sample(\n",
" sampler=\"numpyro\",\n",
" draws=1000,\n",
" tune=1000,\n",
" chains=2,\n",
" target_accept=0.9,\n",
")\n",
"summary_analytical = az.summary(idata_analytical, var_names=DDM_PARAM_NAMES)\n",
"summary_analytical"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Part 7 — Posterior comparison: bayesflow NLE vs analytical\n",
"\n",
"Overlay marginals and mark the true values. If the bayesflow NLE was trained well, the two posterior densities should agree closely."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig, axes = plt.subplots(1, len(DDM_PARAM_NAMES), figsize=(16, 4))\n",
"for ax, name, true_val in zip(axes, DDM_PARAM_NAMES, TRUE_THETA):\n",
" az.plot_kde(\n",
" idata_analytical.posterior[name].values.flatten(),\n",
" plot_kwargs={\"color\": \"black\", \"linestyle\": \"--\", \"label\": \"analytical\"},\n",
" ax=ax,\n",
" )\n",
" az.plot_kde(\n",
" idata_nle.posterior[name].values.flatten(),\n",
" plot_kwargs={\"color\": \"tab:blue\", \"label\": \"bayesflow NLE\"},\n",
" ax=ax,\n",
" )\n",
" ax.axvline(true_val, color=\"red\", lw=1, label=\"truth\")\n",
" ax.set_title(name)\n",
" ax.legend(fontsize=8)\n",
"plt.tight_layout()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Summary\n",
"\n",
"- **User gesture is the same as sbi or LAN-MLP**: `hssm.HSSM(loglik=\"file.onnx\", loglik_kind=\"approx_differentiable\")`. HSSM doesn't need to know which framework trained the surrogate.\n",
"- **The v1 constraints on the bayesflow side** (CouplingFlow with `permutation=None`, explicit `AffineTransform(clamp=False)`, `silu` not `hard_silu` activation, identity adapter, `KERAS_BACKEND=torch` at export time) are all enforced or documented by `lanfactory.onnx.transform_bayesflow_to_onnx`. The exporter raises clearly when a constraint is violated.\n",
"- **For an in-memory JAX-callable alternative** (no ONNX file, faster iteration during model development), see [`bayesflow_lre_integration.ipynb`](./bayesflow_lre_integration.ipynb).\n",
"\n",
"Out of v1 scope (tracked as future work):\n",
"\n",
"- MNLE-style discrete + continuous mixed observations\n",
"- Non-identity bayesflow Adapters (would require either baking the tensor-able subset into the ONNX graph, or shipping the adapter spec alongside the ONNX file)\n",
"- Transformer / attention summary networks (LayerNorm + dynamic-shape ops)\n",
"- FlowMatching / DiffusionModel / ConsistencyModel inference networks (`log_prob` requires ODE integration)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}