From ea71510729304999afc89ced19fc5fd4f3c9d4a9 Mon Sep 17 00:00:00 2001 From: Alexander Date: Sun, 17 May 2026 23:41:05 -0400 Subject: [PATCH 1/4] docs(tutorials): add bayesflow NLE ONNX integration tutorial Sibling of sbi_nle_integration.ipynb. Demonstrates the end-to-end path: train a bayesflow CouplingFlow NLE on DDM data, export to ONNX via lanfactory.onnx.transform_bayesflow_to_onnx (companion LANfactory PR), and hand the file to HSSM via the same loglik="*.onnx" / loglik_kind="approx_differentiable" gesture that sbi exports already use. Final part overlays the bayesflow-NLE posterior against HSSM's analytical Navarro & Fuss DDM posterior as ground truth. Part 1 includes a temporary workaround line setting jaxort_only_allow_initializers_as_static_args=False directly so the notebook runs standalone on main. HSSM PR #964 sets this automatically inside hssm.distribution_utils.onnx2jax; once it merges, the manual line can be deleted. This is a docs-only addition - no HSSM code changes. Companion PRs: - LANfactory bayesflow-connector (stacked on #79 sbi-connector) - HSSMSpine: bayesflow-onnx-integration.md design doc + upstream-bugs-from-bayesflow-onnx-work.md catalog of upstream defects surfaced during this work Co-Authored-By: Claude Opus 4.7 (1M context) --- .../bayesflow_nle_onnx_integration.ipynb | 335 ++++++++++++++++++ 1 file changed, 335 insertions(+) create mode 100644 docs/tutorials/bayesflow_nle_onnx_integration.ipynb diff --git a/docs/tutorials/bayesflow_nle_onnx_integration.ipynb b/docs/tutorials/bayesflow_nle_onnx_integration.ipynb new file mode 100644 index 00000000..7fab956a --- /dev/null +++ b/docs/tutorials/bayesflow_nle_onnx_integration.ipynb @@ -0,0 +1,335 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Integrating bayesflow-trained likelihoods into HSSM (NLE via ONNX)\n", + "\n", + "This is the bayesflow sibling of `sbi_nle_integration.ipynb`. Once you've trained a [bayesflow](https://github.com/bayesflow-org/bayesflow) `ContinuousApproximator` (NLE) or `RatioApproximator` (NRE), you can export it to a single ONNX file and hand the path to HSSM. From HSSM's side, the file is consumed identically to a LAN export or an sbi export — same `loglik=\"file.onnx\"` gesture, no library-specific glue.\n", + "\n", + "Two paths into HSSM exist, side by side:\n", + "\n", + "| Path | Source | Mechanism | When to use |\n", + "|---|---|---|---|\n", + "| `loglik=\"file.onnx\"` | sbi or bayesflow | ONNX file, framework-agnostic | Portability, sharing trained surrogates |\n", + "| `loglik=` | bayesflow (this tutorial's sibling) | In-memory JAX callable | Fast iteration during model development |\n", + "\n", + "See [`bayesflow_lre_integration.ipynb`](./bayesflow_lre_integration.ipynb) for the JAX-callable path. This notebook covers the ONNX path." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "## Part 1 — Setup\n\n**Critical**: `KERAS_BACKEND=torch` must be set *before* importing `keras` or `bayesflow`. `torch.onnx.export` cannot trace a JAX-backed Keras model. On Apple silicon also pin `KERAS_TORCH_DEVICE=cpu` to avoid the orthogonal initializer's missing MPS op.\n\n**Note on the two ONNX-related setup lines below** (`jax_enable_x64` and the `jaxonnxruntime` strict-mode relax): these are the same workarounds that HSSM PR #964 plans to auto-handle inside `hssm.distribution_utils.onnx2jax`. Until that PR lands on `main`, this tutorial sets them explicitly so it runs standalone. Once #964 merges, both lines can be deleted — HSSM will take care of them on import." + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "import os\n\nos.environ[\"KERAS_BACKEND\"] = \"torch\"\nos.environ[\"KERAS_TORCH_DEVICE\"] = \"cpu\"\n\nimport jax\n\n# x64 BEFORE other JAX-touching imports; ONNX graphs from torch.onnx.export\n# carry int64 shape/index tensors that get silently truncated under JAX's\n# default int32 mode, producing wrong log-prob values inside HSSM.\njax.config.update(\"jax_enable_x64\", True)\n\n# Relax jaxonnxruntime's default strict mode on Reshape shape arguments —\n# torch.onnx.export emits them as Constant nodes, not initializers, which the\n# strict default rejects. Safe because the shapes are genuinely constant.\n# HSSM PR #964 sets this automatically inside hssm.distribution_utils.onnx2jax;\n# until that PR lands on main, we set it here.\nfrom jaxonnxruntime import config as _jaxonnx_config\n_jaxonnx_config.update(\"jaxort_only_allow_initializers_as_static_args\", False)\n\nfrom pathlib import Path\n\nimport arviz as az\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport pandas as pd\n\nimport bayesflow as bf\nimport keras\nfrom bayesflow.datasets import OfflineDataset\nfrom bayesflow.networks.inference.coupling.transforms import AffineTransform\nfrom ssms.basic_simulators.simulator import simulator\n\nimport hssm\nfrom lanfactory.onnx import transform_bayesflow_to_onnx\n\nprint(\"keras backend:\", keras.backend.backend())\nprint(\"bayesflow: \", bf.__version__)\nprint(\"hssm: \", hssm.__version__)" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Part 2 — Simulate observed DDM data\n", + "\n", + "We use [`ssm-simulators`](https://github.com/AlexanderFengler/ssm-simulators) for ground-truth DDM samples at a known parameter vector. HSSM consumes a `DataFrame` with `rt` and `response` columns." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "DDM_PARAM_NAMES = [\"v\", \"a\", \"z\", \"t\"]\n", + "DDM_PARAM_LOW = np.array([-2.0, 0.6, 0.3, 0.1], dtype=np.float32)\n", + "DDM_PARAM_HIGH = np.array([2.0, 1.8, 0.7, 0.5], dtype=np.float32)\n", + "TRUE_THETA = np.array([0.5, 1.2, 0.5, 0.25], dtype=np.float32)\n", + "N_OBS = 500\n", + "\n", + "out = simulator(theta=TRUE_THETA[None, :], model=\"ddm\", n_samples=N_OBS)\n", + "obs_data = pd.DataFrame({\n", + " \"rt\": out[\"rts\"].squeeze().astype(np.float32),\n", + " \"response\": out[\"choices\"].squeeze().astype(np.float32),\n", + "})\n", + "obs_data.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(1, 1, figsize=(8, 4))\n", + "for choice, label in [(1.0, \"choice +1\"), (-1.0, \"choice -1\")]:\n", + " rts = obs_data.loc[obs_data[\"response\"] == choice, \"rt\"]\n", + " ax.hist(rts, bins=40, alpha=0.5, label=label)\n", + "ax.set_xlabel(\"reaction time\")\n", + "ax.set_ylabel(\"count\")\n", + "ax.legend()\n", + "ax.set_title(f\"Observed DDM data at θ={dict(zip(DDM_PARAM_NAMES, TRUE_THETA))}\")\n", + "plt.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Part 3 — Train a bayesflow CouplingFlow NLE on DDM simulations\n", + "\n", + "The CouplingFlow setup below is the v1 ONNX-friendly configuration documented in [LANfactory's `exporting_bayesflow_models.md`](https://github.com/lnccbrown/LANFactory/blob/main/docs/exporting_bayesflow_models.md). Three opinionated choices that aren't bayesflow's defaults:\n", + "\n", + "- `permutation=None` — `FixedPermutation` uses `keras.ops.take`, which exports as `aten::ravel`, unsupported in ONNX opset 17/20.\n", + "- `transform=AffineTransform(clamp=False)` (explicit instance) — default `clamp=True` emits `aten::asinh`. Also, `bf.networks.CouplingFlow(..., transform_kwargs={\"clamp\": False})` silently drops the kwarg (upstream bug); pass an instance.\n", + "- `activation=\"silu\"` — default `\"hard_silu\"` exports as a single fused `HardSwish` op that jaxonnxruntime can't run. Real SiLU decomposes to `Sigmoid` + `Mul`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "keras.utils.set_random_seed(0)\n", + "rng = np.random.default_rng(0)\n", + "\n", + "N_TRAIN = 20_000\n", + "theta_train = rng.uniform(\n", + " DDM_PARAM_LOW, DDM_PARAM_HIGH, size=(N_TRAIN, len(DDM_PARAM_NAMES))\n", + ").astype(np.float32)\n", + "\n", + "# One trial per (θ_i) — NLE convention: each row is (θ_i, x_i).\n", + "x_train = np.empty((N_TRAIN, 2), dtype=np.float32)\n", + "for i, th in enumerate(theta_train):\n", + " out = simulator(theta=th[None, :], model=\"ddm\", n_samples=1)\n", + " x_train[i, 0] = out[\"rts\"].squeeze()\n", + " x_train[i, 1] = out[\"choices\"].squeeze()\n", + "\n", + "print(\"theta_train:\", theta_train.shape, \" x_train:\", x_train.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "approximator = bf.ContinuousApproximator(\n", + " inference_network=bf.networks.CouplingFlow(\n", + " depth=6,\n", + " subnet_kwargs={\"widths\": (64, 64), \"activation\": \"silu\", \"dropout\": None},\n", + " permutation=None,\n", + " use_actnorm=False,\n", + " transform=AffineTransform(clamp=False),\n", + " ),\n", + " standardize=\"inference_variables\", # standardize x (rt, choice)\n", + ")\n", + "approximator.build({\n", + " \"inference_variables\": (None, 2),\n", + " \"inference_conditions\": (None, len(DDM_PARAM_NAMES)),\n", + "})\n", + "approximator.compile(optimizer=keras.optimizers.Adam(learning_rate=5e-4))\n", + "\n", + "dataset = OfflineDataset(\n", + " data={\n", + " \"inference_variables\": x_train,\n", + " \"inference_conditions\": theta_train,\n", + " },\n", + " batch_size=256,\n", + " adapter=None, # MUST be identity for ONNX export; we use only the in-network Standardize layer\n", + ")\n", + "history = approximator.fit(dataset=dataset, epochs=50, verbose=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Part 4 — Export the trained approximator to ONNX\n", + "\n", + "One call. The exporter raises clearly if any v1 constraint is violated (wrong `KERAS_BACKEND`, non-identity adapter, missing `inference_network`)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "onnx_dir = Path(\"./bayesflow_onnx_artifacts\")\n", + "onnx_dir.mkdir(exist_ok=True)\n", + "onnx_path = onnx_dir / \"ddm_nle.onnx\"\n", + "\n", + "transform_bayesflow_to_onnx(\n", + " approximator,\n", + " str(onnx_path),\n", + " mode=\"nle\",\n", + " example_theta_dim=len(DDM_PARAM_NAMES),\n", + " example_x_dim=2,\n", + ")\n", + "print(f\"wrote {onnx_path} ({onnx_path.stat().st_size:,} bytes)\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Part 5 — Hand the ONNX to HSSM\n", + "\n", + "The user gesture is identical to the sbi path and the LAN-MLP path. HSSM detects the `.onnx` extension, loads it via `jaxonnxruntime`, vmaps over trials, and wires it into a PyMC `Distribution`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model_nle = hssm.HSSM(\n", + " data=obs_data,\n", + " model=\"ddm\",\n", + " loglik_kind=\"approx_differentiable\",\n", + " loglik=str(onnx_path),\n", + " p_outlier=0,\n", + ")\n", + "model_nle" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "idata_nle = model_nle.sample(\n", + " sampler=\"numpyro\",\n", + " draws=1000,\n", + " tune=1000,\n", + " chains=2,\n", + " target_accept=0.9,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "summary_nle = az.summary(idata_nle, var_names=DDM_PARAM_NAMES)\n", + "summary_nle" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "az.plot_trace(idata_nle, var_names=DDM_PARAM_NAMES)\n", + "plt.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Part 6 — Ground-truth posterior via HSSM's analytical DDM\n", + "\n", + "DDM has a closed-form likelihood (Navarro & Fuss). Comparing the bayesflow-NLE posterior to this gives a fairness check: any drift comes from the neural approximation, not from the inference machinery." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model_analytical = hssm.HSSM(\n", + " data=obs_data,\n", + " model=\"ddm\",\n", + " p_outlier=0,\n", + " # default loglik_kind here is the analytical Navarro & Fuss path\n", + ")\n", + "idata_analytical = model_analytical.sample(\n", + " sampler=\"numpyro\",\n", + " draws=1000,\n", + " tune=1000,\n", + " chains=2,\n", + " target_accept=0.9,\n", + ")\n", + "summary_analytical = az.summary(idata_analytical, var_names=DDM_PARAM_NAMES)\n", + "summary_analytical" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Part 7 — Posterior comparison: bayesflow NLE vs analytical\n", + "\n", + "Overlay marginals and mark the true values. If the bayesflow NLE was trained well, the two posterior densities should agree closely." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(1, len(DDM_PARAM_NAMES), figsize=(16, 4))\n", + "for ax, name, true_val in zip(axes, DDM_PARAM_NAMES, TRUE_THETA):\n", + " az.plot_kde(\n", + " idata_analytical.posterior[name].values.flatten(),\n", + " plot_kwargs={\"color\": \"black\", \"linestyle\": \"--\", \"label\": \"analytical\"},\n", + " ax=ax,\n", + " )\n", + " az.plot_kde(\n", + " idata_nle.posterior[name].values.flatten(),\n", + " plot_kwargs={\"color\": \"tab:blue\", \"label\": \"bayesflow NLE\"},\n", + " ax=ax,\n", + " )\n", + " ax.axvline(true_val, color=\"red\", lw=1, label=\"truth\")\n", + " ax.set_title(name)\n", + " ax.legend(fontsize=8)\n", + "plt.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "- **User gesture is the same as sbi or LAN-MLP**: `hssm.HSSM(loglik=\"file.onnx\", loglik_kind=\"approx_differentiable\")`. HSSM doesn't need to know which framework trained the surrogate.\n", + "- **The v1 constraints on the bayesflow side** (CouplingFlow with `permutation=None`, explicit `AffineTransform(clamp=False)`, `silu` not `hard_silu` activation, identity adapter, `KERAS_BACKEND=torch` at export time) are all enforced or documented by `lanfactory.onnx.transform_bayesflow_to_onnx`. The exporter raises clearly when a constraint is violated.\n", + "- **For an in-memory JAX-callable alternative** (no ONNX file, faster iteration during model development), see [`bayesflow_lre_integration.ipynb`](./bayesflow_lre_integration.ipynb).\n", + "\n", + "Out of v1 scope (tracked as future work):\n", + "\n", + "- MNLE-style discrete + continuous mixed observations\n", + "- Non-identity bayesflow Adapters (would require either baking the tensor-able subset into the ONNX graph, or shipping the adapter spec alongside the ONNX file)\n", + "- Transformer / attention summary networks (LayerNorm + dynamic-shape ops)\n", + "- FlowMatching / DiffusionModel / ConsistencyModel inference networks (`log_prob` requires ODE integration)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file From 657f65551c33b18b39271a170637f12532bfa988 Mon Sep 17 00:00:00 2001 From: Alexander Date: Mon, 18 May 2026 18:26:56 -0400 Subject: [PATCH 2/4] docs(tutorial): user-configurable ARTIFACT_DIR for ONNX output Replaces the hardcoded ./bayesflow_onnx_artifacts/ path in Part 4 with a clearly-named ARTIFACT_DIR constant that defaults to ~/bayesflow_onnx_tutorial/ (outside the HSSM repo). Comment block in the cell shows two override examples: an arbitrary user path, or tempfile.mkdtemp() for an ephemeral demo run. Adds `import tempfile` in Part 1 so the override example works out of the box. Why: the previous default (./bayesflow_onnx_artifacts/) wrote into whatever directory the notebook was running from, which for typical setups meant docs/tutorials/ inside the HSSM repo. Re-running the notebook would accumulate untracked artifacts in the working tree (the same pattern that left sbi-logs/ and sbi_onnx_artifacts/ around from earlier sbi tutorial work). Moving the default outside the repo eliminates that footgun. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../bayesflow_nle_onnx_integration.ipynb | 23 +++---------------- 1 file changed, 3 insertions(+), 20 deletions(-) diff --git a/docs/tutorials/bayesflow_nle_onnx_integration.ipynb b/docs/tutorials/bayesflow_nle_onnx_integration.ipynb index 7fab956a..3074e4b9 100644 --- a/docs/tutorials/bayesflow_nle_onnx_integration.ipynb +++ b/docs/tutorials/bayesflow_nle_onnx_integration.ipynb @@ -28,7 +28,7 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": "import os\n\nos.environ[\"KERAS_BACKEND\"] = \"torch\"\nos.environ[\"KERAS_TORCH_DEVICE\"] = \"cpu\"\n\nimport jax\n\n# x64 BEFORE other JAX-touching imports; ONNX graphs from torch.onnx.export\n# carry int64 shape/index tensors that get silently truncated under JAX's\n# default int32 mode, producing wrong log-prob values inside HSSM.\njax.config.update(\"jax_enable_x64\", True)\n\n# Relax jaxonnxruntime's default strict mode on Reshape shape arguments —\n# torch.onnx.export emits them as Constant nodes, not initializers, which the\n# strict default rejects. Safe because the shapes are genuinely constant.\n# HSSM PR #964 sets this automatically inside hssm.distribution_utils.onnx2jax;\n# until that PR lands on main, we set it here.\nfrom jaxonnxruntime import config as _jaxonnx_config\n_jaxonnx_config.update(\"jaxort_only_allow_initializers_as_static_args\", False)\n\nfrom pathlib import Path\n\nimport arviz as az\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport pandas as pd\n\nimport bayesflow as bf\nimport keras\nfrom bayesflow.datasets import OfflineDataset\nfrom bayesflow.networks.inference.coupling.transforms import AffineTransform\nfrom ssms.basic_simulators.simulator import simulator\n\nimport hssm\nfrom lanfactory.onnx import transform_bayesflow_to_onnx\n\nprint(\"keras backend:\", keras.backend.backend())\nprint(\"bayesflow: \", bf.__version__)\nprint(\"hssm: \", hssm.__version__)" + "source": "import os\n\nos.environ[\"KERAS_BACKEND\"] = \"torch\"\nos.environ[\"KERAS_TORCH_DEVICE\"] = \"cpu\"\n\nimport jax\n\n# x64 BEFORE other JAX-touching imports; ONNX graphs from torch.onnx.export\n# carry int64 shape/index tensors that get silently truncated under JAX's\n# default int32 mode, producing wrong log-prob values inside HSSM.\njax.config.update(\"jax_enable_x64\", True)\n\n# Relax jaxonnxruntime's default strict mode on Reshape shape arguments —\n# torch.onnx.export emits them as Constant nodes, not initializers, which the\n# strict default rejects. Safe because the shapes are genuinely constant.\n# HSSM PR #964 sets this automatically inside hssm.distribution_utils.onnx2jax;\n# until that PR lands on main, we set it here.\nfrom jaxonnxruntime import config as _jaxonnx_config\n_jaxonnx_config.update(\"jaxort_only_allow_initializers_as_static_args\", False)\n\nimport tempfile # noqa: F401 — available for the Part 4 override example\nfrom pathlib import Path\n\nimport arviz as az\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport pandas as pd\n\nimport bayesflow as bf\nimport keras\nfrom bayesflow.datasets import OfflineDataset\nfrom bayesflow.networks.inference.coupling.transforms import AffineTransform\nfrom ssms.basic_simulators.simulator import simulator\n\nimport hssm\nfrom lanfactory.onnx import transform_bayesflow_to_onnx\n\nprint(\"keras backend:\", keras.backend.backend())\nprint(\"bayesflow: \", bf.__version__)\nprint(\"hssm: \", hssm.__version__)" }, { "cell_type": "markdown", @@ -149,31 +149,14 @@ { "cell_type": "markdown", "metadata": {}, - "source": [ - "## Part 4 — Export the trained approximator to ONNX\n", - "\n", - "One call. The exporter raises clearly if any v1 constraint is violated (wrong `KERAS_BACKEND`, non-identity adapter, missing `inference_network`)." - ] + "source": "## Part 4 — Export the trained approximator to ONNX\n\nOne call. The exporter raises clearly if any v1 constraint is violated (wrong `KERAS_BACKEND`, non-identity adapter, missing `inference_network`).\n\n**Where to write the file.** The cell below sets `ARTIFACT_DIR` to `~/bayesflow_onnx_tutorial/` — outside the HSSM repo, so re-running the notebook doesn't leave artifacts in your working tree. Change it to any path you want to keep the trained ONNX around for downstream work; set it to `tempfile.mkdtemp()` for an ephemeral demo run." }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "onnx_dir = Path(\"./bayesflow_onnx_artifacts\")\n", - "onnx_dir.mkdir(exist_ok=True)\n", - "onnx_path = onnx_dir / \"ddm_nle.onnx\"\n", - "\n", - "transform_bayesflow_to_onnx(\n", - " approximator,\n", - " str(onnx_path),\n", - " mode=\"nle\",\n", - " example_theta_dim=len(DDM_PARAM_NAMES),\n", - " example_x_dim=2,\n", - ")\n", - "print(f\"wrote {onnx_path} ({onnx_path.stat().st_size:,} bytes)\")" - ] + "source": "# User-configurable: where the .onnx file lands. Default is outside the HSSM\n# repo so notebook re-runs don't pollute the working tree.\n# Override examples:\n# ARTIFACT_DIR = Path(\"/path/to/my/project/onnx\") # keep nearby\n# ARTIFACT_DIR = Path(tempfile.mkdtemp()) # ephemeral\nARTIFACT_DIR = Path.home() / \"bayesflow_onnx_tutorial\"\nARTIFACT_DIR.mkdir(parents=True, exist_ok=True)\nonnx_path = ARTIFACT_DIR / \"ddm_nle.onnx\"\n\ntransform_bayesflow_to_onnx(\n approximator,\n str(onnx_path),\n mode=\"nle\",\n example_theta_dim=len(DDM_PARAM_NAMES),\n example_x_dim=2,\n)\nprint(f\"wrote {onnx_path} ({onnx_path.stat().st_size:,} bytes)\")" }, { "cell_type": "markdown", From 2aa198b5a0e7945e79b29cbacd506363d86a0e89 Mon Sep 17 00:00:00 2001 From: Alexander Fengler Date: Fri, 22 May 2026 18:51:25 -0400 Subject: [PATCH 3/4] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- docs/tutorials/bayesflow_nle_onnx_integration.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/bayesflow_nle_onnx_integration.ipynb b/docs/tutorials/bayesflow_nle_onnx_integration.ipynb index 3074e4b9..bc2cd1aa 100644 --- a/docs/tutorials/bayesflow_nle_onnx_integration.ipynb +++ b/docs/tutorials/bayesflow_nle_onnx_integration.ipynb @@ -6,7 +6,7 @@ "source": [ "# Integrating bayesflow-trained likelihoods into HSSM (NLE via ONNX)\n", "\n", - "This is the bayesflow sibling of `sbi_nle_integration.ipynb`. Once you've trained a [bayesflow](https://github.com/bayesflow-org/bayesflow) `ContinuousApproximator` (NLE) or `RatioApproximator` (NRE), you can export it to a single ONNX file and hand the path to HSSM. From HSSM's side, the file is consumed identically to a LAN export or an sbi export — same `loglik=\"file.onnx\"` gesture, no library-specific glue.\n", + "This notebook covers the bayesflow version of the ONNX-based NLE integration workflow in HSSM. Once you've trained a [bayesflow](https://github.com/bayesflow-org/bayesflow) `ContinuousApproximator` (NLE) or `RatioApproximator` (NRE), you can export it to a single ONNX file and hand the path to HSSM. From HSSM's side, the file is consumed identically to a LAN export or an sbi export — same `loglik=\"file.onnx\"` gesture, no library-specific glue.\n", "\n", "Two paths into HSSM exist, side by side:\n", "\n", From 3439832aca039ab002e3b6ce9de3e896e76ca97a Mon Sep 17 00:00:00 2001 From: Alexander Fengler Date: Fri, 22 May 2026 18:51:53 -0400 Subject: [PATCH 4/4] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- docs/tutorials/bayesflow_nle_onnx_integration.ipynb | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/tutorials/bayesflow_nle_onnx_integration.ipynb b/docs/tutorials/bayesflow_nle_onnx_integration.ipynb index bc2cd1aa..2e80b485 100644 --- a/docs/tutorials/bayesflow_nle_onnx_integration.ipynb +++ b/docs/tutorials/bayesflow_nle_onnx_integration.ipynb @@ -195,6 +195,7 @@ " tune=1000,\n", " chains=2,\n", " target_accept=0.9,\n", + " mp_ctx=\"spawn\",\n", ")" ] },