Integrate sbi NRE via ONNX exporter (keystone tutorial + HSSM-side patches)#964
Open
AlexanderFengler wants to merge 21 commits into
Open
Integrate sbi NRE via ONNX exporter (keystone tutorial + HSSM-side patches)#964AlexanderFengler wants to merge 21 commits into
AlexanderFengler wants to merge 21 commits into
Conversation
Sets jaxort_only_allow_initializers_as_static_args = False at module import time. The default strict mode rejects ONNX graphs whose Reshape op shape comes from a Constant node rather than a model initializer. torch.onnx.export emits exactly this pattern for masked autoregressive flows from nflows (and likely other flow architectures), surfaced by LANfactory commit f7c93c8. Setting the flag here means any consumer of make_jax_func benefits without per-call configuration. Safe for our use cases: shapes are genuinely constant, baked at export time. Side benefit: makes HSSM more robust to ONNX from any source emitting this pattern, not only sbi-exported flows. Part of the sbi to HSSM integration plan (see plans/sbi-onnx-integration.md in HSSMSpine, sub-commit C2.5). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
ONNX graphs exported by torch.onnx.export of normalizing flows (e.g. the
nflows MAF used by sbi NLE) carry int64 tensors for Reshape shape
arguments, Constant node values, and Cast targets. jaxonnxruntime
silently truncates int64 to int32 unless jax_enable_x64 is set,
producing wrong numerical results (~0.5 drift in log-prob, surfaced
during the LANfactory C3 NLE validation).
make_jax_func now walks the loaded ONNX graph for int64 tensors. If any
are present and jax_enable_x64 is off, HSSM:
- attempts to flip the flag via jax.config.update
- verifies the flip is effective (fresh jnp.asarray([1.0]) is float64)
- emits a UserWarning pointing users to set it themselves to silence
- raises a clear RuntimeError with the one-line fix if the flip did
not take (JAX has already done substantive 32-bit work)
The detection is conservative: only scans initializers, Constant node
tensor attributes, and Cast `to` attributes. LAN MLP graphs do not
carry int64 in any of these places (verified: existing 8 HSSM ONNX
tests still pass without warnings).
This addresses the second of the two findings from C3 (see
plans/sbi-onnx-integration.md C7 row). Parallels the C2.5 strict-mode
patch in spirit but is more targeted: it only intervenes when the graph
actually requires the flag.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds docs/tutorials/sbi_nle_integration.ipynb, the keystone deliverable
of the sbi -> HSSM integration plan. Mirrors the structure of
bayesflow_lre_integration.ipynb so the two SBI-toolkit tutorials in the
HSSM docs sit side-by-side and tell a coherent story.
Structure (22 cells, 13 code + 9 markdown):
Part 1 - Setup (jax_enable_x64, imports, CI budget constants)
Part 2 - Simulate observed DDM data (ssm-simulators, N_OBS=500,
TRUE_THETA matching BayesFlow tutorial)
Part 3 - Train tiny sbi NLE_A with MAF on 10k training pairs
Part 4 - Export to ONNX via lanfactory.onnx.transform_sbi_to_onnx
Part 5 - High-level integration via hssm.HSSM(loglik=...onnx,
loglik_kind="approx_differentiable"), numpyro sampling,
summary, trace
Part 6 - Brief NRE variant (norm_layer=Identity to disable LayerNorm)
Part 7 - Posterior comparison plot: sbi NLE vs sbi NRE vs ground truth
Closing summary with pointers to LANfactory exporter docs and the
BayesFlow LRE neighbor tutorial
Reuses the exact TRUE_THETA, N_OBS, and prior ranges as the BayesFlow
LRE tutorial so cross-tutorial posterior comparisons are apples-to-
apples. Includes explicit documentation of the v1 constraints surfaced
during C2-C7 (2D minimum, norm_layer=Identity, x64-auto-flip).
Notebook outputs are intentionally empty (execution_count: null) -
execution requires a coordinated cross-repo env (LANfactory[all] + HSSM
in the same venv). Same env-resolution caveat as C7b. Run once locally
or in CI to populate outputs.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The C8 keystone notebook currently cannot be run in either repo's local
uv venv because of an outstanding cross-repo JAX/flax/numpyro pin
conflict. Specifically, lanfactory's top-level __init__.py pulls
trainers/jax_mlp.py which imports flax — incompatible with the JAX
version that HSSM's numpyro pin requires.
Until the env alignment lands as a separate workstream, the notebook
now imports transform_sbi_to_onnx with a try/except fallback:
- Clean path: `from lanfactory.onnx import transform_sbi_to_onnx`.
- Fallback: load only lanfactory/onnx/sbi.py directly via
importlib.util, bypassing lanfactory's top-level __init__.py and
sidestepping the flax dependency.
The fallback walks several candidate paths (env var
LANFACTORY_SBI_PATH, then common Jupyter cwd contexts) so the notebook
runs from a fresh kernel without manual editing.
Once the cross-repo env is resolved the fallback branch becomes dead
code and can be removed — the clean import will just work.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…f ONNX) Second issue running the C8 notebook: pymc.sample raised "TypeError: cannot pickle 'google._upb._message.EnumDescriptor' object" inside cloudpickle when the parallel sampler tried to fork worker processes on macOS. Root cause: HSSM's sampler="numpyro" path normalizes inference_method to "pymc" in base.py:678, which means bambi dispatches pm.sample with nuts_sampler="pymc" — PyMC's standard NUTS, not numpyro NUTS. On macOS the default multiprocessing start method is spawn, which requires cloudpickling the step method into worker processes. The step method references the JAX-wrapped ONNX function whose closure in jaxonnxruntime carries the onnx.ModelProto. ModelProto is a protobuf message and contains C-extension EnumDescriptor objects that cloudpickle cannot serialize. Workaround: pass cores=1 to model.sample(). Single-process sampling bypasses the multiprocess cloudpickle path entirely. Slower across chains (no parallelism) but reliable. Both NLE (Part 5) and NRE (Part 6) sample calls now include cores=1 with an explanatory comment. Followup queued: HSSM's sampler="numpyro" silently downgrades to pymc NUTS in this code path. Worth either (a) wiring nuts_sampler="numpyro" through to bambi (numpyro NUTS does its own JAX-internal parallelism without forking), or (b) updating the HSSM docstring so users know sampler= currently only controls init / jitter and not the actual NUTS backend. Tracked outside this commit. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
HSSMBase.sample collapsed sampler="numpyro" (and blackjax, nutpie) to inference_method="pymc" before handing off to bambi, which then dispatched to pm.sample(nuts_sampler="pymc"). The user's sampler choice was silently downgraded to PyMC NUTS regardless of what they asked for. Bambi natively accepts inference_method values "pymc", "numpyro", "blackjax", "nutpie" (and "vi"/"laplace") and routes each to the matching nuts_sampler. The collapse conditional negated this. Regression archaeology: - Aug 5, 2024 (commit aef3f9b, "Fix compatibility with Bambi (#516)"): introduced the working pattern -- inference_method="mcmc" (generic NUTS marker) + kwargs["nuts_sampler"]="numpyro"/"blackjax"/etc. injected separately. Bambi's old "mcmc" inference_method was generic and read nuts_sampler from kwargs. Correct under old bambi semantics. - Dec 17, 2025 (commit 20c100b, "fix: update model.sample api to be consistent with bambi's"): bambi had renamed its inference_method values (mcmc -> pymc, nuts_numpyro -> numpyro, etc.). This commit mechanically updated the string list, but ALSO deleted the kwargs["nuts_sampler"] = ... injection block. The flatten-conditional was left in place. After this commit, all four NUTS samplers route to inference_method="pymc" -> nuts_sampler="pymc" with no recourse. - The bug has been live since Dec 17, 2025 (about 5 months). Fix: replace the conditional with inference_method=sampler. Bambi handles each NUTS variant directly under the new API. The injection block deleted in commit 20c100b is correctly absent now -- bambi passes nuts_sampler=sampler_backend to pm.sample explicitly, so injecting it via kwargs would conflict. Side effects: - sampler="numpyro" now actually invokes numpyro NUTS, which runs inside JAX with internal parallelism and does NOT fork worker processes. This avoids the cloudpickle path that breaks on unpicklable ONNX ModelProto closures (surfaced by the sbi NLE tutorial). - sampler="blackjax" and sampler="nutpie" similarly now invoke their respective backends instead of PyMC NUTS. - sampler="pymc" behavior is unchanged (still routes to PyMC NUTS). Other gates that read the user's `sampler` argument (parallel-sampling warning at base.py:621, init default at base.py:636, jitter handling at base.py:644, step-sampler check at base.py:657) all check `sampler` directly, not the post-normalization inference_method value. None are affected by this change. Tests pinning the old behavior: - tests/test_rlssm.py:300, tests/test_save_load.py:39, tests/slow/test_mcmc.py:100-101 use sampler="numpyro" and were silently exercising PyMC NUTS. After this fix they exercise the actual numpyro path. Worth re-running as part of PR review. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Inherits the sampler-routing fix (commit ce19950) so the sbi NLE tutorial can use numpyro NUTS directly instead of falling back to cores=1. The cores=1 workaround is removed in the follow-up commit. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…fixed With the merged fix branch (commit ce19950), HSSM's sampler="numpyro" actually invokes numpyro NUTS via pm.sample(nuts_sampler="numpyro"). Numpyro NUTS runs entirely inside JAX with internal parallelism -- no ps.ParallelSampler, no forked workers, no cloudpickle of the step method, so the ONNX ModelProto protobuf descriptors that previously broke serialization are never touched. The cores=1 workaround added in commit f94b496 is therefore no longer necessary for either the NLE or NRE sample call. Reverting to the default (which lets pymc.sample pick a sensible cores count) so chains can run in parallel where the backend permits it. Note: numpyro NUTS handles chain parallelism via JAX's vmap-over-chains or pmap-over-devices internally, so explicit cores= is not the relevant knob for numpyro anyway. We're just removing an override that no longer applies. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
… budget
The first end-to-end run of the C8 notebook recovered systematically
biased posteriors (NLE: v=1.5 vs truth 0.5; t centered at 0.02 vs truth
0.25 and below the training prior's lower bound of 0.1). Diagnosis: our
training prior was narrower than HSSM's default DDM bounds, so MCMC
explored regions the flow never saw and the trained MAF extrapolated
into spurious high-likelihood pockets.
Verified by inspecting hssm.defaults.default_model_config['ddm'] for the
"approx_differentiable" likelihood:
v in (-3.0, 3.0), a in (0.3, 2.5), z in (0.0, 1.0), t in (0.0, 2.0)
with a HalfNormal(sigma=2.0) prior on t that puts substantial mass below
our previous training lower bound of 0.1.
Changes in the notebook:
- Training prior (Part 3) widened to match HSSM's default bounds
verbatim: BoxUniform with low=[-3, 0.3, 0, 0], high=[3, 2.5, 1, 2].
- N_TRAIN raised 10k -> 30k to cover the wider 4D parameter volume.
- NUM_EPOCHS raised 50 -> 100 for the same reason.
- Simulation switched from a Python loop to ssm-simulators batched call
(theta of shape (N, 4) with n_samples=1), ~100x faster on 30k samples.
Expected effect: posteriors should now concentrate near the true theta
(v=0.5, a=1.2, z=0.5, t=0.25) since MCMC stays inside the trained region
throughout. Total notebook runtime estimate roughly 10-20 minutes on
CPU depending on machine.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds Part 5b (markdown + code) to the sbi NLE notebook. The cell answers the question "is the marginal posterior bias coming from a poorly-trained flow, or from HSSM-side sampling issues?" by computing the trained NLE log-likelihood of the observed data at: - the true theta - the posterior's marginal mean If the NLE itself prefers the wrong theta by a large margin, the flow is the problem (training quality). If it prefers the truth, MCMC is failing to find the NLE's mode (priors / init / mixing). The cell prints a three-way verdict depending on the gap. Placed between the NLE trace plot (Part 5) and the NRE variant (Part 6) so it interprets the NLE posterior immediately. Uses the in-memory estimator_nle and the obs_data DataFrame already in scope. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Applies the proposed §4.2 improvements from the bias review. The 30k-sim
run was undertrained for a 4D-theta DDM problem with HSSM's default wide
prior, so the marginal posteriors over v and a were biased high.
Changes to Part 1 (setup cell):
- N_TRAIN: 30_000 -> 1_000_000
- NUM_EPOCHS: 100 -> 300 (max)
- STOP_AFTER_EPOCHS: new, set to 50 (default is 20 -- previous runs
may have early-stopped silently)
- TRAINING_BATCH_SIZE: 200 -> 500 (fewer batches per epoch at 1M)
- HIDDEN_FEATURES: new, set to 100 (sbi default is 50)
- NUM_TRANSFORMS: new, set to 8 (sbi MAF default is 5)
- imports likelihood_nn alongside classifier_nn from sbi.neural_nets
Changes to Part 3 (NLE training):
- Replace density_estimator="maf" string shortcut with an explicit
likelihood_nn(model="maf", hidden_features=HIDDEN_FEATURES,
num_transforms=NUM_TRANSFORMS) builder.
- Pass TRAINING_BATCH_SIZE and STOP_AFTER_EPOCHS into .train().
Changes to Part 6 (NRE training):
- classifier_nn now also takes hidden_features=HIDDEN_FEATURES
(matches NLE width). norm_layer=nn.Identity remains required
because jaxonnxruntime doesn't implement LayerNormalization.
- Same TRAINING_BATCH_SIZE and STOP_AFTER_EPOCHS.
Expected wall time on CPU: 30-90 min for NLE alone, similar for NRE,
plus a few minutes of MCMC each. Run on GPU if available. The
training run is comparable in scale to a LAN training (1M sims) which
is what's needed for a fair NLE-vs-LAN comparison on DDM-like
problems.
The Part 5b diagnostic cell (commit 4c9a3ad) is unchanged and will
still print the training-vs-sampling verdict after this run.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…mparison Adds Part 6b to the sbi NLE notebook: a second HSSM model built with loglik_kind="analytical" (HSSM's closed-form Navarro & Fuss DDM likelihood) sampled on the same obs_data. This gives a gold-standard posterior against which the sbi-NLE and sbi-NRE marginals can be compared. Distance from analytical to true theta is intrinsic posterior width (finite data effect). Distance from sbi-NLE/NRE to analytical is surrogate approximation error -- the thing we actually care about when evaluating how well the neural likelihood reproduces the closed-form target. Part 7's comparison plot is upgraded from a 2-way (NLE vs NRE) to a 3-way (analytical vs NLE vs NRE) histogram per parameter, with the true theta as a red dashed vertical for reference. The analytical DDM uses slightly different parameter bounds from approx_differentiable (a, t unbounded above; otherwise the same), but on the observed data the posterior concentrates regardless. Runtime impact: one additional HSSM MCMC run (~30-60 sec via numpyro NUTS on the analytical likelihood). Trivial compared to the 1M-sim sbi training in Parts 3 and 6. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…tion
The C8 keystone tutorial was producing qualitatively wrong NLE posteriors
on DDM data (v centered at ~0.12 vs true 0.5, spurious bimodality on a)
because MAF flows can't properly model mixed continuous-discrete data
(rt continuous, choice in {-1, +1}). The correct sbi method (MNLE) is
blocked by the SearchSorted ONNX-op gap that also blocks NSF flows; see
plans/sbi-onnx-integration.md "Deferred sbi paths" for the resolution
roadmap (~50-line upstream PR to jaxonnxruntime unlocks both).
Until that PR lands, the tutorial drops NLE entirely and focuses on
NRE, which is robust to discrete/continuous mixing because it learns a
classifier (no density-shape assumption).
Notebook restructure:
- Removed Parts 3-5b (NLE training, export, sampling, diagnostic).
- Promoted the NRE variant to the primary path (Parts 3-5b).
- Promoted the analytical ground truth (was Part 6b) to Part 6.
- Part 7 comparison is now 2-way: analytical (gold) vs sbi NRE.
- Added a "Why no NLE in this tutorial?" callout in both the intro
and the closing summary pointing at the deferred-paths plan.
NRE-side improvements (all applied in Part 3):
- Switched NRE_A -> NRE_B with num_atoms=20 (atomic contrastive
estimation; sharper signal than plain binary classification).
- Multi-sample per theta: 300k distinct theta * 3 samples = 900k pairs
(vs the previous 1M theta * 1 sample). Richer local conditional
shape information at fewer theta points; still well-covered for a
4D parameter space.
- FCEmbedding(input_dim=4, output_dim=32, num_layers=2) on theta to
give the classifier richer parameter conditioning.
- hidden_features bumped 100 -> 128.
- Longer MCMC: tune 500 -> 1500, draws 500 -> 1000.
Also:
- Diagnostic cell (was Part 5b NLE) adapted for NRE: uses summed
classifier logit (which equals log p(x|theta) - log p(x) up to a
theta-independent constant) instead of NLE log-prob.
- Closing summary now explains the NLE/MNLE deferral and points at
plans/sbi-onnx-integration.md.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…port check)
Latest tutorial run produced NRE posteriors that essentially equal the
prior across all four DDM parameters. Traces wander freely — chain is
not stuck, the loglik is just ~constant across theta-space, so MCMC has
no signal and samples the prior.
The Part 5b diagnostic isn't enough to distinguish "NRE found the truth,
MCMC didn't sample around it" from "NRE classifier is flat everywhere."
Adding two more targeted diagnostics:
Part 5c.1 — Logit sweep:
Hold three theta dims at the true values, sweep the fourth across
its prior range, plot the summed classifier log-ratio on observed
data. A well-trained NRE shows a sharp peak near the true value
with tens-to-hundreds of log units of vertical range. A flat curve
(< 5 log units) is the smoking gun for "classifier collapsed."
Part 5c.2 — Export round-trip:
Compare classifier_nre(theta, x).item() against the exported ONNX
output through onnxruntime on a fixed point. If they agree to
~1e-5, the export is faithful and any pathology is in the trained
classifier itself; otherwise the bigger network or FCEmbedding
addition introduced an export bug.
Both cells read in-memory state from the previous run (classifier_nre,
nre_onnx_path, obs_data, TRUE_THETA, prior bounds) so the user can run
them without retraining — they're cheap one-off diagnostics.
This commit doesn't change the broken NRE training itself; once the
diagnostics tell us which side is broken (training quality vs. export),
the bisect plan from the review will pick the right fix:
- flat classifier -> drop FCEmbedding -> if still flat, drop NRE_B
for NRE_A -> if still flat, drop multi-sample-per-theta.
- faithful export, flat classifier -> same bisect path.
- export mismatch -> investigate exporter on FCEmbedding shapes.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…dget
Most recent notebook run produced NRE posteriors that effectively equal the
prior (chains wander freely across the entire prior with no concentration),
and MCMC was taking 30+ min per pass. The pattern is consistent with the
trained classifier providing near-zero discriminative signal AND with NUTS
spending most leapfrog steps on divergent trajectories — both symptoms of a
pathologically-shaped surrogate loglik.
The bisect starts with the most-recently-introduced and least-tested change:
the FCEmbedding(4 → 32 → 32) on theta inside the NRE classifier. Removing it
is a single-line change, leaves the other improvements in place (NRE_B with
num_atoms=20, multi-sample-per-theta 300k×3, hidden=128), and lets us
verify whether the embedding was the culprit. Comment in the classifier-
builder cell notes the bisect for future readers.
Companion changes to keep MCMC bounded during diagnosis:
- MCMC_DRAWS: 1000 → 500
- MCMC_TUNE: 1500 → 500
- target_accept: 0.9 → 0.8 (allows larger steps)
- max_tree_depth: default 10 → 8 (caps leapfrog steps/draw at 256 vs 1024)
- progressbar: False → True (so users see chain progress instead of
wondering if it's hung)
If the next run still produces a flat NRE posterior, the next bisect step
is NRE_B → NRE_A (revert the contrastive change). If MCMC behaves but
posteriors are biased, that's a separate calibration question we'll address
on the next round of knobs.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ove sweep diagnostic pre-MCMC
Two changes in one commit, both reactions to the bisect step 1 not fixing the
flat-posterior problem (removing the FCEmbedding alone left the NRE classifier
still uninformative at MCMC time).
1. Move Part 5c sweep + ONNX round-trip diagnostics to a new Part 4b that
runs RIGHT AFTER training and export, BEFORE the multi-minute MCMC step.
The sweep is the cheapest way to know if the trained classifier is
informative at all -- there's no point running MCMC if the per-dim
vertical range of the log-ratio is < 5 units everywhere (= no
discriminative signal -> posterior will equal the prior). Old Part 5c
block deleted to avoid duplication.
2. Revert the NRE configuration to the last known-working baseline:
- NRE_B -> NRE_A (drop atomic contrastive estimation)
- Multi-sample 300k_theta x 3_samples -> 1M_theta x 1_sample
- HIDDEN_FEATURES 128 -> 100
- drop NRE_NUM_ATOMS (NRE_A has no contrastive hyperparameter)
- drop FCEmbedding import (already not used since bisect step 1)
- Updated Part 3 markdown to call out the bisect explicitly
The aim: re-establish a baseline that gives at least the moderate-quality
recovery seen two iterations back. If THIS produces a flat posterior too,
we have a more fundamental problem (not in any of the NRE-only changes)
and need to look at the exporter, HSSM consumption, or environment.
MCMC budget knobs from step 1 are retained (draws=500, tune=500,
target_accept=0.8, progressbar=True, nuts_sampler_kwargs.max_tree_depth=8)
since they are pure safety and don't change posterior quality.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
ndarray.ptp() was deprecated in NumPy 1.25 and removed in NumPy 2.0. The Part 4b sweep diagnostic cell called lp.ptp() in two places (the plot title and the per-dim print loop) and raised AttributeError on NumPy 2.x environments. Replace both with np.ptp(lp), which is the documented NumPy 2 equivalent. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
2 tasks
Replaces two hardcoded relative paths in the sbi tutorial: 1. Part 4: ./sbi_onnx_artifacts/ -> ARTIFACT_DIR (default ~/sbi_onnx_tutorial/). User can override to a project-local dir or tempfile.mkdtemp() via two examples in the cell comment. 2. Part 3: sbi's training tensorboard logs were going to ./sbi-logs/ relative to cwd (different paths for different Jupyter launch contexts - one in the notebook dir, one at the repo root). Now wires NRE_A(..., summary_writer=SummaryWriter(log_dir=str( TUTORIAL_LOG_DIR))) with a default ~/sbi_logs_tutorial/, so all runs write to the same predictable location regardless of cwd. Why: the previous defaults wrote into whatever directory the notebook was running from, which for typical setups means docs/tutorials/ inside the HSSM repo. Re-running the notebook would accumulate untracked training logs and ONNX artifacts in the working tree. Moving the defaults outside the repo eliminates the footgun. Comment in each affected cell points to the override pattern so users who want artifacts kept nearby (e.g. for downstream MCMC re-runs on a saved checkpoint) can override in one line. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
digicosmos86
previously approved these changes
May 20, 2026
Collaborator
digicosmos86
left a comment
There was a problem hiding this comment.
LGTM! One small nitpick
| " jax.config.update('jax_enable_x64', True)\n" | ||
| "at the very top of your script, before any other JAX import." | ||
| ) | ||
| warnings.warn( |
Collaborator
There was a problem hiding this comment.
Why warnings.warn instead of logger.warn? I think we tend to prefer logging to print and warning
Per PR #964 review: HSSM convention is to route messages through logging.getLogger("hssm") rather than warnings.warn. Switches the auto-x64-enabled message to _logger.warning(...) to match the rest of the codebase. No behavioral change for downstream users beyond how the message is surfaced. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…dynamic dims Three connected changes informed by a diagnostic experiment. 1. Drop _ensure_x64_if_needed. pytensor's JAX dispatch (pytensor/link/jax/dispatch/basic.py) already sets jax_enable_x64 from pytensor.config.floatX at module import. With HSSM's default floatX=float64, x64 is already on by the time onnx2jax loads -- our auto-flip was redundant in default use and brittle in edge cases (mutated global state, hard-failed if JAX had warmed up). 2. Replace it with _recast_int64_to_int32, a small in-place graph transform that rewrites int64 tensors / Cast targets to int32 at load time. Lossless for the index/shape values torch.onnx.export produces (small non-negative ints, bit-identical truncation), silences the JAX UserWarning under x64=off, and removes any global-state dependency. 3. Add _check_single_trial_input_shape: raise ValueError if any input dim is symbolic / dynamic. jaxonnxruntime traces against the construction-time dummy and bakes the resulting shapes into the returned closure, so a dynamic_axes export called at a different shape silently produces wrong outputs for any graph with a batch-dependent intermediate (e.g. torch.zeros(x.shape[0]) accumulators, Reshape with -1). HSSM's ONNX path is built around single-trial inputs + jax.vmap over trials (see distribution_utils/onnx.py); LANs and LANfactory's transform_sbi_to_onnx already follow this contract. This guard prevents accidental violations from future contributors. Tests: 3 new in tests/distribution_utils/test_onnx.py covering the dynamic-dim guard (positive + negative) and the int64 -> int32 recast. Full ONNX test suite passes. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds a Key Patterns entry codifying the rule that `make_jax_func` now enforces (commit dd168fc): every ONNX graph consumed by HSSM must have a concrete single-trial input shape, with per-trial batching happening at the HSSM layer via jax.vmap. Points at the two enforcement sites (_check_single_trial_input_shape in onnx2jax.py, the vmap wiring in onnx.py) and notes that LANfactory's exporters already follow this convention. The constraint was de facto since the original LANs but only became enforced with dd168fc; this entry makes it discoverable. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Brings sbi-trained neural ratio estimators (NRE) into HSSM via the
LANfactory ONNX exporter and a keystone tutorial. Includes the
HSSM-side patches needed for ONNX graphs from
torch.onnx.exportofsbi networks to consume cleanly through HSSM's existing
loglik_kind="approx_differentiable"path.What this PR contains
HSSM-side patches
fix(onnx2jax): relax jaxonnxruntime strict-mode(2e76516).torch.onnx.exportof nflows MAFs emitsReshapeshape argumentsas
Constantnodes instead of model initializers. jaxonnxruntime'sdefault strict mode rejects these. Patch sets
jaxort_only_allow_initializers_as_static_args=Falseat moduleimport — safe because shapes are genuinely constant.
fix(onnx2jax): auto-enable jax_enable_x64(d1d7ffe). ONNXgraphs from torch flows carry int64 index tensors. JAX silently
truncates int64 → int32 unless
jax_enable_x64is set, producing~0.5-unit drift in log-prob outputs. The patch walks the graph for
int64 tensors; if present and x64 is off, auto-flips with a
UserWarning, raises a preciseRuntimeErroronly if JAX hasalready done substantive 32-bit work and the flip cannot take.
fix/sampler-routing(ce19950, see#963) via merge
commit
5c628d0— necessary because the keystone tutorial usessampler="numpyro"and the silent downgrade to PyMC NUTS wouldtrigger
cloudpicklefailures on the ONNXModelProtoundermacOS spawn.
Keystone tutorial
docs/tutorials/sbi_nle_integration.ipynb— direct structuralmirror of
bayesflow_lre_integration.ipynb, demonstrating sbi NREend-to-end: simulate DDM data, train
NRE_Aon 1M(theta, x)pairs, export to ONNX via
lanfactory.onnx.transform_sbi_to_onnx, drop intohssm.HSSM(...),sample with numpyro NUTS, and compare against HSSM's analytical DDM
posterior as the gold-standard reference.
each θ dimension + ONNX round-trip check) to verify the trained
classifier is informative before paying the multi-minute MCMC cost.
verdict for the recovered posterior.
Tutorial outcome
The committed config (
NRE_A, 1M × 1 samples,hidden_features=100,no embedding net,
norm_layer=nn.Identity) produces NRE posteriorsthat broadly track HSSM's analytical DDM posterior. Residual ~0.05-
unit marginal bias on
vandzis calibration-not-collapse; queuedas a separate investigation in the spine plan.
Deferred (documented in
HSSMSpine/plans/sbi-onnx-integration.md)discrete-continuous data). The correct sbi method (MNLE) is
blocked by the same
SearchSortedONNX-op gap that blocks NSFflows. Both unlock simultaneously via a small upstream PR to
jaxonnxruntime. Until then, the tutorial covers NRE only.requires cached posteriors from sibling tutorials.
— were attempted as a cumulative step and broke the recovery;
reverted to the last-known-working baseline for v1. Each should be
re-introduced one at a time with the Part 4b sweep diagnostic gating.
Companion PRs
the sampler-routing fix this PR inherits via merge. Cleaner if
merged first, but not strictly required (the merge brings the fix
along).
the
transform_sbi_to_onnxexporter this tutorial uses. Thetutorial has an
importlibfallback so it runs even withoutLANfactory installed via pip.
Test plan
(
pytest tests/distribution_utils/test_onnx.py tests/distribution_utils/test_onnx_model.py).(LANfactory + HSSM together).
posterior recovery.
🤖 Generated with Claude Code