Skip to content

Integrate sbi NRE via ONNX exporter (keystone tutorial + HSSM-side patches)#964

Open
AlexanderFengler wants to merge 21 commits into
mainfrom
sbi-integration
Open

Integrate sbi NRE via ONNX exporter (keystone tutorial + HSSM-side patches)#964
AlexanderFengler wants to merge 21 commits into
mainfrom
sbi-integration

Conversation

@AlexanderFengler
Copy link
Copy Markdown
Member

Summary

Brings sbi-trained neural ratio estimators (NRE) into HSSM via the
LANfactory ONNX exporter and a keystone tutorial. Includes the
HSSM-side patches needed for ONNX graphs from torch.onnx.export of
sbi networks to consume cleanly through HSSM's existing
loglik_kind="approx_differentiable" path.

What this PR contains

HSSM-side patches

  • fix(onnx2jax): relax jaxonnxruntime strict-mode (2e76516).
    torch.onnx.export of nflows MAFs emits Reshape shape arguments
    as Constant nodes instead of model initializers. jaxonnxruntime's
    default strict mode rejects these. Patch sets
    jaxort_only_allow_initializers_as_static_args=False at module
    import — safe because shapes are genuinely constant.
  • fix(onnx2jax): auto-enable jax_enable_x64 (d1d7ffe). ONNX
    graphs from torch flows carry int64 index tensors. JAX silently
    truncates int64 → int32 unless jax_enable_x64 is set, producing
    ~0.5-unit drift in log-prob outputs. The patch walks the graph for
    int64 tensors; if present and x64 is off, auto-flips with a
    UserWarning, raises a precise RuntimeError only if JAX has
    already done substantive 32-bit work and the flip cannot take.
  • Inherits fix/sampler-routing (ce19950, see
    #963) via merge
    commit 5c628d0 — necessary because the keystone tutorial uses
    sampler="numpyro" and the silent downgrade to PyMC NUTS would
    trigger cloudpickle failures on the ONNX ModelProto under
    macOS spawn.

Keystone tutorial

  • docs/tutorials/sbi_nle_integration.ipynb — direct structural
    mirror of bayesflow_lre_integration.ipynb, demonstrating sbi NRE
    end-to-end: simulate DDM data, train NRE_A on 1M (theta, x)
    pairs, export to ONNX via
    lanfactory.onnx.transform_sbi_to_onnx, drop into hssm.HSSM(...),
    sample with numpyro NUTS, and compare against HSSM's analytical DDM
    posterior as the gold-standard reference.
  • Includes a Part 4b pre-MCMC sweep diagnostic (logit sweep across
    each θ dimension + ONNX round-trip check) to verify the trained
    classifier is informative before paying the multi-minute MCMC cost.
  • Part 5b post-MCMC diagnostic prints a training-vs-sampling
    verdict for the recovered posterior.

Tutorial outcome

The committed config (NRE_A, 1M × 1 samples, hidden_features=100,
no embedding net, norm_layer=nn.Identity) produces NRE posteriors
that broadly track HSSM's analytical DDM posterior. Residual ~0.05-
unit marginal bias on v and z is calibration-not-collapse; queued
as a separate investigation in the spine plan.

Deferred (documented in HSSMSpine/plans/sbi-onnx-integration.md)

  • NLE-MAF on DDM produces qualitatively wrong posteriors (mixed
    discrete-continuous data). The correct sbi method (MNLE) is
    blocked by the same SearchSorted ONNX-op gap that blocks NSF
    flows. Both unlock simultaneously via a small upstream PR to
    jaxonnxruntime. Until then, the tutorial covers NRE only.
  • Three-way comparison (LRE + LAN + sbi NRE) on the same data —
    requires cached posteriors from sibling tutorials.
  • Residual v/z calibration investigation.
  • NRE_B / BNRE / num_atoms / multi-sample-per-θ / FCEmbedding upgrades
    — were attempted as a cumulative step and broke the recovery;
    reverted to the last-known-working baseline for v1. Each should be
    re-introduced one at a time with the Part 4b sweep diagnostic gating.

Companion PRs

  • lnccbrown/HSSM#963
    the sampler-routing fix this PR inherits via merge. Cleaner if
    merged first, but not strictly required (the merge brings the fix
    along).
  • lnccbrown/LANfactory#79
    the transform_sbi_to_onnx exporter this tutorial uses. The
    tutorial has an importlib fallback so it runs even without
    LANfactory installed via pip.

Test plan

  • Existing HSSM ONNX test suite passes
    (pytest tests/distribution_utils/test_onnx.py tests/distribution_utils/test_onnx_model.py).
  • Notebook executes end-to-end in a coordinated cross-repo env
    (LANfactory + HSSM together).
  • Reviewer should glance at the Part 7 comparison plot for the
    posterior recovery.

🤖 Generated with Claude Code

AlexanderFengler and others added 17 commits May 13, 2026 23:00
Sets jaxort_only_allow_initializers_as_static_args = False at module
import time. The default strict mode rejects ONNX graphs whose Reshape
op shape comes from a Constant node rather than a model initializer.
torch.onnx.export emits exactly this pattern for masked autoregressive
flows from nflows (and likely other flow architectures), surfaced by
LANfactory commit f7c93c8.

Setting the flag here means any consumer of make_jax_func benefits
without per-call configuration. Safe for our use cases: shapes are
genuinely constant, baked at export time.

Side benefit: makes HSSM more robust to ONNX from any source emitting
this pattern, not only sbi-exported flows.

Part of the sbi to HSSM integration plan (see plans/sbi-onnx-integration.md
in HSSMSpine, sub-commit C2.5).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
ONNX graphs exported by torch.onnx.export of normalizing flows (e.g. the
nflows MAF used by sbi NLE) carry int64 tensors for Reshape shape
arguments, Constant node values, and Cast targets. jaxonnxruntime
silently truncates int64 to int32 unless jax_enable_x64 is set,
producing wrong numerical results (~0.5 drift in log-prob, surfaced
during the LANfactory C3 NLE validation).

make_jax_func now walks the loaded ONNX graph for int64 tensors. If any
are present and jax_enable_x64 is off, HSSM:
  - attempts to flip the flag via jax.config.update
  - verifies the flip is effective (fresh jnp.asarray([1.0]) is float64)
  - emits a UserWarning pointing users to set it themselves to silence
  - raises a clear RuntimeError with the one-line fix if the flip did
    not take (JAX has already done substantive 32-bit work)

The detection is conservative: only scans initializers, Constant node
tensor attributes, and Cast `to` attributes. LAN MLP graphs do not
carry int64 in any of these places (verified: existing 8 HSSM ONNX
tests still pass without warnings).

This addresses the second of the two findings from C3 (see
plans/sbi-onnx-integration.md C7 row). Parallels the C2.5 strict-mode
patch in spirit but is more targeted: it only intervenes when the graph
actually requires the flag.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds docs/tutorials/sbi_nle_integration.ipynb, the keystone deliverable
of the sbi -> HSSM integration plan. Mirrors the structure of
bayesflow_lre_integration.ipynb so the two SBI-toolkit tutorials in the
HSSM docs sit side-by-side and tell a coherent story.

Structure (22 cells, 13 code + 9 markdown):
  Part 1 - Setup (jax_enable_x64, imports, CI budget constants)
  Part 2 - Simulate observed DDM data (ssm-simulators, N_OBS=500,
           TRUE_THETA matching BayesFlow tutorial)
  Part 3 - Train tiny sbi NLE_A with MAF on 10k training pairs
  Part 4 - Export to ONNX via lanfactory.onnx.transform_sbi_to_onnx
  Part 5 - High-level integration via hssm.HSSM(loglik=...onnx,
           loglik_kind="approx_differentiable"), numpyro sampling,
           summary, trace
  Part 6 - Brief NRE variant (norm_layer=Identity to disable LayerNorm)
  Part 7 - Posterior comparison plot: sbi NLE vs sbi NRE vs ground truth
  Closing summary with pointers to LANfactory exporter docs and the
  BayesFlow LRE neighbor tutorial

Reuses the exact TRUE_THETA, N_OBS, and prior ranges as the BayesFlow
LRE tutorial so cross-tutorial posterior comparisons are apples-to-
apples. Includes explicit documentation of the v1 constraints surfaced
during C2-C7 (2D minimum, norm_layer=Identity, x64-auto-flip).

Notebook outputs are intentionally empty (execution_count: null) -
execution requires a coordinated cross-repo env (LANfactory[all] + HSSM
in the same venv). Same env-resolution caveat as C7b. Run once locally
or in CI to populate outputs.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The C8 keystone notebook currently cannot be run in either repo's local
uv venv because of an outstanding cross-repo JAX/flax/numpyro pin
conflict. Specifically, lanfactory's top-level __init__.py pulls
trainers/jax_mlp.py which imports flax — incompatible with the JAX
version that HSSM's numpyro pin requires.

Until the env alignment lands as a separate workstream, the notebook
now imports transform_sbi_to_onnx with a try/except fallback:

  - Clean path: `from lanfactory.onnx import transform_sbi_to_onnx`.
  - Fallback: load only lanfactory/onnx/sbi.py directly via
    importlib.util, bypassing lanfactory's top-level __init__.py and
    sidestepping the flax dependency.

The fallback walks several candidate paths (env var
LANFACTORY_SBI_PATH, then common Jupyter cwd contexts) so the notebook
runs from a fresh kernel without manual editing.

Once the cross-repo env is resolved the fallback branch becomes dead
code and can be removed — the clean import will just work.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…f ONNX)

Second issue running the C8 notebook: pymc.sample raised
"TypeError: cannot pickle 'google._upb._message.EnumDescriptor' object"
inside cloudpickle when the parallel sampler tried to fork worker
processes on macOS.

Root cause: HSSM's sampler="numpyro" path normalizes inference_method to
"pymc" in base.py:678, which means bambi dispatches pm.sample with
nuts_sampler="pymc" — PyMC's standard NUTS, not numpyro NUTS. On macOS
the default multiprocessing start method is spawn, which requires
cloudpickling the step method into worker processes. The step method
references the JAX-wrapped ONNX function whose closure in
jaxonnxruntime carries the onnx.ModelProto. ModelProto is a protobuf
message and contains C-extension EnumDescriptor objects that
cloudpickle cannot serialize.

Workaround: pass cores=1 to model.sample(). Single-process sampling
bypasses the multiprocess cloudpickle path entirely. Slower across
chains (no parallelism) but reliable. Both NLE (Part 5) and NRE
(Part 6) sample calls now include cores=1 with an explanatory comment.

Followup queued: HSSM's sampler="numpyro" silently downgrades to pymc
NUTS in this code path. Worth either (a) wiring nuts_sampler="numpyro"
through to bambi (numpyro NUTS does its own JAX-internal parallelism
without forking), or (b) updating the HSSM docstring so users know
sampler= currently only controls init / jitter and not the actual NUTS
backend. Tracked outside this commit.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
HSSMBase.sample collapsed sampler="numpyro" (and blackjax, nutpie) to
inference_method="pymc" before handing off to bambi, which then dispatched
to pm.sample(nuts_sampler="pymc"). The user's sampler choice was
silently downgraded to PyMC NUTS regardless of what they asked for.

Bambi natively accepts inference_method values "pymc", "numpyro",
"blackjax", "nutpie" (and "vi"/"laplace") and routes each to the
matching nuts_sampler. The collapse conditional negated this.

Regression archaeology:
  - Aug 5, 2024 (commit aef3f9b, "Fix compatibility with Bambi (#516)"):
    introduced the working pattern -- inference_method="mcmc" (generic
    NUTS marker) + kwargs["nuts_sampler"]="numpyro"/"blackjax"/etc.
    injected separately. Bambi's old "mcmc" inference_method was generic
    and read nuts_sampler from kwargs. Correct under old bambi semantics.
  - Dec 17, 2025 (commit 20c100b, "fix: update model.sample api to be
    consistent with bambi's"): bambi had renamed its inference_method
    values (mcmc -> pymc, nuts_numpyro -> numpyro, etc.). This commit
    mechanically updated the string list, but ALSO deleted the
    kwargs["nuts_sampler"] = ... injection block. The flatten-conditional
    was left in place. After this commit, all four NUTS samplers route
    to inference_method="pymc" -> nuts_sampler="pymc" with no recourse.
  - The bug has been live since Dec 17, 2025 (about 5 months).

Fix: replace the conditional with inference_method=sampler. Bambi handles
each NUTS variant directly under the new API. The injection block deleted
in commit 20c100b is correctly absent now -- bambi passes
nuts_sampler=sampler_backend to pm.sample explicitly, so injecting it via
kwargs would conflict.

Side effects:
  - sampler="numpyro" now actually invokes numpyro NUTS, which runs
    inside JAX with internal parallelism and does NOT fork worker
    processes. This avoids the cloudpickle path that breaks on
    unpicklable ONNX ModelProto closures (surfaced by the sbi NLE
    tutorial).
  - sampler="blackjax" and sampler="nutpie" similarly now invoke their
    respective backends instead of PyMC NUTS.
  - sampler="pymc" behavior is unchanged (still routes to PyMC NUTS).

Other gates that read the user's `sampler` argument (parallel-sampling
warning at base.py:621, init default at base.py:636, jitter handling at
base.py:644, step-sampler check at base.py:657) all check `sampler`
directly, not the post-normalization inference_method value. None are
affected by this change.

Tests pinning the old behavior:
  - tests/test_rlssm.py:300, tests/test_save_load.py:39,
    tests/slow/test_mcmc.py:100-101 use sampler="numpyro" and were
    silently exercising PyMC NUTS. After this fix they exercise the
    actual numpyro path. Worth re-running as part of PR review.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Inherits the sampler-routing fix (commit ce19950) so the sbi NLE tutorial
can use numpyro NUTS directly instead of falling back to cores=1. The
cores=1 workaround is removed in the follow-up commit.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…fixed

With the merged fix branch (commit ce19950), HSSM's sampler="numpyro"
actually invokes numpyro NUTS via pm.sample(nuts_sampler="numpyro").
Numpyro NUTS runs entirely inside JAX with internal parallelism --
no ps.ParallelSampler, no forked workers, no cloudpickle of the step
method, so the ONNX ModelProto protobuf descriptors that previously
broke serialization are never touched.

The cores=1 workaround added in commit f94b496 is therefore no longer
necessary for either the NLE or NRE sample call. Reverting to the
default (which lets pymc.sample pick a sensible cores count) so chains
can run in parallel where the backend permits it.

Note: numpyro NUTS handles chain parallelism via JAX's vmap-over-chains
or pmap-over-devices internally, so explicit cores= is not the relevant
knob for numpyro anyway. We're just removing an override that no longer
applies.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
… budget

The first end-to-end run of the C8 notebook recovered systematically
biased posteriors (NLE: v=1.5 vs truth 0.5; t centered at 0.02 vs truth
0.25 and below the training prior's lower bound of 0.1). Diagnosis: our
training prior was narrower than HSSM's default DDM bounds, so MCMC
explored regions the flow never saw and the trained MAF extrapolated
into spurious high-likelihood pockets.

Verified by inspecting hssm.defaults.default_model_config['ddm'] for the
"approx_differentiable" likelihood:
  v in (-3.0, 3.0), a in (0.3, 2.5), z in (0.0, 1.0), t in (0.0, 2.0)
with a HalfNormal(sigma=2.0) prior on t that puts substantial mass below
our previous training lower bound of 0.1.

Changes in the notebook:
  - Training prior (Part 3) widened to match HSSM's default bounds
    verbatim: BoxUniform with low=[-3, 0.3, 0, 0], high=[3, 2.5, 1, 2].
  - N_TRAIN raised 10k -> 30k to cover the wider 4D parameter volume.
  - NUM_EPOCHS raised 50 -> 100 for the same reason.
  - Simulation switched from a Python loop to ssm-simulators batched call
    (theta of shape (N, 4) with n_samples=1), ~100x faster on 30k samples.

Expected effect: posteriors should now concentrate near the true theta
(v=0.5, a=1.2, z=0.5, t=0.25) since MCMC stays inside the trained region
throughout. Total notebook runtime estimate roughly 10-20 minutes on
CPU depending on machine.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds Part 5b (markdown + code) to the sbi NLE notebook. The cell answers
the question "is the marginal posterior bias coming from a poorly-trained
flow, or from HSSM-side sampling issues?" by computing the trained NLE
log-likelihood of the observed data at:
  - the true theta
  - the posterior's marginal mean

If the NLE itself prefers the wrong theta by a large margin, the flow is
the problem (training quality). If it prefers the truth, MCMC is failing
to find the NLE's mode (priors / init / mixing). The cell prints a
three-way verdict depending on the gap.

Placed between the NLE trace plot (Part 5) and the NRE variant (Part 6)
so it interprets the NLE posterior immediately. Uses the in-memory
estimator_nle and the obs_data DataFrame already in scope.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Applies the proposed §4.2 improvements from the bias review. The 30k-sim
run was undertrained for a 4D-theta DDM problem with HSSM's default wide
prior, so the marginal posteriors over v and a were biased high.

Changes to Part 1 (setup cell):
  - N_TRAIN: 30_000 -> 1_000_000
  - NUM_EPOCHS: 100 -> 300 (max)
  - STOP_AFTER_EPOCHS: new, set to 50 (default is 20 -- previous runs
    may have early-stopped silently)
  - TRAINING_BATCH_SIZE: 200 -> 500 (fewer batches per epoch at 1M)
  - HIDDEN_FEATURES: new, set to 100 (sbi default is 50)
  - NUM_TRANSFORMS: new, set to 8 (sbi MAF default is 5)
  - imports likelihood_nn alongside classifier_nn from sbi.neural_nets

Changes to Part 3 (NLE training):
  - Replace density_estimator="maf" string shortcut with an explicit
    likelihood_nn(model="maf", hidden_features=HIDDEN_FEATURES,
    num_transforms=NUM_TRANSFORMS) builder.
  - Pass TRAINING_BATCH_SIZE and STOP_AFTER_EPOCHS into .train().

Changes to Part 6 (NRE training):
  - classifier_nn now also takes hidden_features=HIDDEN_FEATURES
    (matches NLE width). norm_layer=nn.Identity remains required
    because jaxonnxruntime doesn't implement LayerNormalization.
  - Same TRAINING_BATCH_SIZE and STOP_AFTER_EPOCHS.

Expected wall time on CPU: 30-90 min for NLE alone, similar for NRE,
plus a few minutes of MCMC each. Run on GPU if available. The
training run is comparable in scale to a LAN training (1M sims) which
is what's needed for a fair NLE-vs-LAN comparison on DDM-like
problems.

The Part 5b diagnostic cell (commit 4c9a3ad) is unchanged and will
still print the training-vs-sampling verdict after this run.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…mparison

Adds Part 6b to the sbi NLE notebook: a second HSSM model built with
loglik_kind="analytical" (HSSM's closed-form Navarro & Fuss DDM
likelihood) sampled on the same obs_data. This gives a gold-standard
posterior against which the sbi-NLE and sbi-NRE marginals can be
compared.

Distance from analytical to true theta is intrinsic posterior width
(finite data effect). Distance from sbi-NLE/NRE to analytical is
surrogate approximation error -- the thing we actually care about when
evaluating how well the neural likelihood reproduces the closed-form
target.

Part 7's comparison plot is upgraded from a 2-way (NLE vs NRE) to a
3-way (analytical vs NLE vs NRE) histogram per parameter, with the
true theta as a red dashed vertical for reference.

The analytical DDM uses slightly different parameter bounds from
approx_differentiable (a, t unbounded above; otherwise the same), but
on the observed data the posterior concentrates regardless.

Runtime impact: one additional HSSM MCMC run (~30-60 sec via numpyro
NUTS on the analytical likelihood). Trivial compared to the 1M-sim
sbi training in Parts 3 and 6.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…tion

The C8 keystone tutorial was producing qualitatively wrong NLE posteriors
on DDM data (v centered at ~0.12 vs true 0.5, spurious bimodality on a)
because MAF flows can't properly model mixed continuous-discrete data
(rt continuous, choice in {-1, +1}). The correct sbi method (MNLE) is
blocked by the SearchSorted ONNX-op gap that also blocks NSF flows; see
plans/sbi-onnx-integration.md "Deferred sbi paths" for the resolution
roadmap (~50-line upstream PR to jaxonnxruntime unlocks both).

Until that PR lands, the tutorial drops NLE entirely and focuses on
NRE, which is robust to discrete/continuous mixing because it learns a
classifier (no density-shape assumption).

Notebook restructure:
  - Removed Parts 3-5b (NLE training, export, sampling, diagnostic).
  - Promoted the NRE variant to the primary path (Parts 3-5b).
  - Promoted the analytical ground truth (was Part 6b) to Part 6.
  - Part 7 comparison is now 2-way: analytical (gold) vs sbi NRE.
  - Added a "Why no NLE in this tutorial?" callout in both the intro
    and the closing summary pointing at the deferred-paths plan.

NRE-side improvements (all applied in Part 3):
  - Switched NRE_A -> NRE_B with num_atoms=20 (atomic contrastive
    estimation; sharper signal than plain binary classification).
  - Multi-sample per theta: 300k distinct theta * 3 samples = 900k pairs
    (vs the previous 1M theta * 1 sample). Richer local conditional
    shape information at fewer theta points; still well-covered for a
    4D parameter space.
  - FCEmbedding(input_dim=4, output_dim=32, num_layers=2) on theta to
    give the classifier richer parameter conditioning.
  - hidden_features bumped 100 -> 128.
  - Longer MCMC: tune 500 -> 1500, draws 500 -> 1000.

Also:
  - Diagnostic cell (was Part 5b NLE) adapted for NRE: uses summed
    classifier logit (which equals log p(x|theta) - log p(x) up to a
    theta-independent constant) instead of NLE log-prob.
  - Closing summary now explains the NLE/MNLE deferral and points at
    plans/sbi-onnx-integration.md.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…port check)

Latest tutorial run produced NRE posteriors that essentially equal the
prior across all four DDM parameters. Traces wander freely — chain is
not stuck, the loglik is just ~constant across theta-space, so MCMC has
no signal and samples the prior.

The Part 5b diagnostic isn't enough to distinguish "NRE found the truth,
MCMC didn't sample around it" from "NRE classifier is flat everywhere."
Adding two more targeted diagnostics:

  Part 5c.1 — Logit sweep:
    Hold three theta dims at the true values, sweep the fourth across
    its prior range, plot the summed classifier log-ratio on observed
    data. A well-trained NRE shows a sharp peak near the true value
    with tens-to-hundreds of log units of vertical range. A flat curve
    (< 5 log units) is the smoking gun for "classifier collapsed."

  Part 5c.2 — Export round-trip:
    Compare classifier_nre(theta, x).item() against the exported ONNX
    output through onnxruntime on a fixed point. If they agree to
    ~1e-5, the export is faithful and any pathology is in the trained
    classifier itself; otherwise the bigger network or FCEmbedding
    addition introduced an export bug.

Both cells read in-memory state from the previous run (classifier_nre,
nre_onnx_path, obs_data, TRUE_THETA, prior bounds) so the user can run
them without retraining — they're cheap one-off diagnostics.

This commit doesn't change the broken NRE training itself; once the
diagnostics tell us which side is broken (training quality vs. export),
the bisect plan from the review will pick the right fix:
  - flat classifier  -> drop FCEmbedding -> if still flat, drop NRE_B
    for NRE_A -> if still flat, drop multi-sample-per-theta.
  - faithful export, flat classifier -> same bisect path.
  - export mismatch -> investigate exporter on FCEmbedding shapes.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…dget

Most recent notebook run produced NRE posteriors that effectively equal the
prior (chains wander freely across the entire prior with no concentration),
and MCMC was taking 30+ min per pass. The pattern is consistent with the
trained classifier providing near-zero discriminative signal AND with NUTS
spending most leapfrog steps on divergent trajectories — both symptoms of a
pathologically-shaped surrogate loglik.

The bisect starts with the most-recently-introduced and least-tested change:
the FCEmbedding(4 → 32 → 32) on theta inside the NRE classifier. Removing it
is a single-line change, leaves the other improvements in place (NRE_B with
num_atoms=20, multi-sample-per-theta 300k×3, hidden=128), and lets us
verify whether the embedding was the culprit. Comment in the classifier-
builder cell notes the bisect for future readers.

Companion changes to keep MCMC bounded during diagnosis:
  - MCMC_DRAWS: 1000 → 500
  - MCMC_TUNE:  1500 → 500
  - target_accept: 0.9 → 0.8 (allows larger steps)
  - max_tree_depth: default 10 → 8 (caps leapfrog steps/draw at 256 vs 1024)
  - progressbar: False → True (so users see chain progress instead of
    wondering if it's hung)

If the next run still produces a flat NRE posterior, the next bisect step
is NRE_B → NRE_A (revert the contrastive change). If MCMC behaves but
posteriors are biased, that's a separate calibration question we'll address
on the next round of knobs.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ove sweep diagnostic pre-MCMC

Two changes in one commit, both reactions to the bisect step 1 not fixing the
flat-posterior problem (removing the FCEmbedding alone left the NRE classifier
still uninformative at MCMC time).

1. Move Part 5c sweep + ONNX round-trip diagnostics to a new Part 4b that
   runs RIGHT AFTER training and export, BEFORE the multi-minute MCMC step.
   The sweep is the cheapest way to know if the trained classifier is
   informative at all -- there's no point running MCMC if the per-dim
   vertical range of the log-ratio is < 5 units everywhere (= no
   discriminative signal -> posterior will equal the prior). Old Part 5c
   block deleted to avoid duplication.

2. Revert the NRE configuration to the last known-working baseline:

     - NRE_B -> NRE_A (drop atomic contrastive estimation)
     - Multi-sample 300k_theta x 3_samples -> 1M_theta x 1_sample
     - HIDDEN_FEATURES 128 -> 100
     - drop NRE_NUM_ATOMS (NRE_A has no contrastive hyperparameter)
     - drop FCEmbedding import (already not used since bisect step 1)
     - Updated Part 3 markdown to call out the bisect explicitly

   The aim: re-establish a baseline that gives at least the moderate-quality
   recovery seen two iterations back. If THIS produces a flat posterior too,
   we have a more fundamental problem (not in any of the NRE-only changes)
   and need to look at the exporter, HSSM consumption, or environment.

MCMC budget knobs from step 1 are retained (draws=500, tune=500,
target_accept=0.8, progressbar=True, nuts_sampler_kwargs.max_tree_depth=8)
since they are pure safety and don't change posterior quality.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
ndarray.ptp() was deprecated in NumPy 1.25 and removed in NumPy 2.0.
The Part 4b sweep diagnostic cell called lp.ptp() in two places (the
plot title and the per-dim print loop) and raised AttributeError on
NumPy 2.x environments. Replace both with np.ptp(lp), which is the
documented NumPy 2 equivalent.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@review-notebook-app
Copy link
Copy Markdown

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Replaces two hardcoded relative paths in the sbi tutorial:

1. Part 4: ./sbi_onnx_artifacts/ -> ARTIFACT_DIR (default
   ~/sbi_onnx_tutorial/). User can override to a project-local dir or
   tempfile.mkdtemp() via two examples in the cell comment.

2. Part 3: sbi's training tensorboard logs were going to ./sbi-logs/
   relative to cwd (different paths for different Jupyter launch
   contexts - one in the notebook dir, one at the repo root).
   Now wires NRE_A(..., summary_writer=SummaryWriter(log_dir=str(
   TUTORIAL_LOG_DIR))) with a default ~/sbi_logs_tutorial/, so all
   runs write to the same predictable location regardless of cwd.

Why: the previous defaults wrote into whatever directory the notebook
was running from, which for typical setups means docs/tutorials/ inside
the HSSM repo. Re-running the notebook would accumulate untracked
training logs and ONNX artifacts in the working tree. Moving the
defaults outside the repo eliminates the footgun.

Comment in each affected cell points to the override pattern so users
who want artifacts kept nearby (e.g. for downstream MCMC re-runs on a
saved checkpoint) can override in one line.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
digicosmos86
digicosmos86 previously approved these changes May 20, 2026
Copy link
Copy Markdown
Collaborator

@digicosmos86 digicosmos86 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! One small nitpick

" jax.config.update('jax_enable_x64', True)\n"
"at the very top of your script, before any other JAX import."
)
warnings.warn(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why warnings.warn instead of logger.warn? I think we tend to prefer logging to print and warning

Per PR #964 review: HSSM convention is to route messages through
logging.getLogger("hssm") rather than warnings.warn. Switches the
auto-x64-enabled message to _logger.warning(...) to match the rest
of the codebase. No behavioral change for downstream users beyond
how the message is surfaced.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
AlexanderFengler and others added 2 commits May 22, 2026 20:06
…dynamic dims

Three connected changes informed by a diagnostic experiment.

1. Drop _ensure_x64_if_needed. pytensor's JAX dispatch
   (pytensor/link/jax/dispatch/basic.py) already sets jax_enable_x64
   from pytensor.config.floatX at module import. With HSSM's default
   floatX=float64, x64 is already on by the time onnx2jax loads --
   our auto-flip was redundant in default use and brittle in edge
   cases (mutated global state, hard-failed if JAX had warmed up).

2. Replace it with _recast_int64_to_int32, a small in-place graph
   transform that rewrites int64 tensors / Cast targets to int32 at
   load time. Lossless for the index/shape values torch.onnx.export
   produces (small non-negative ints, bit-identical truncation),
   silences the JAX UserWarning under x64=off, and removes any
   global-state dependency.

3. Add _check_single_trial_input_shape: raise ValueError if any
   input dim is symbolic / dynamic. jaxonnxruntime traces against
   the construction-time dummy and bakes the resulting shapes into
   the returned closure, so a dynamic_axes export called at a
   different shape silently produces wrong outputs for any graph
   with a batch-dependent intermediate (e.g. torch.zeros(x.shape[0])
   accumulators, Reshape with -1). HSSM's ONNX path is built around
   single-trial inputs + jax.vmap over trials (see
   distribution_utils/onnx.py); LANs and LANfactory's
   transform_sbi_to_onnx already follow this contract. This guard
   prevents accidental violations from future contributors.

Tests: 3 new in tests/distribution_utils/test_onnx.py covering the
dynamic-dim guard (positive + negative) and the int64 -> int32
recast. Full ONNX test suite passes.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds a Key Patterns entry codifying the rule that `make_jax_func`
now enforces (commit dd168fc): every ONNX graph consumed by HSSM
must have a concrete single-trial input shape, with per-trial
batching happening at the HSSM layer via jax.vmap. Points at the
two enforcement sites (_check_single_trial_input_shape in
onnx2jax.py, the vmap wiring in onnx.py) and notes that
LANfactory's exporters already follow this convention.

The constraint was de facto since the original LANs but only
became enforced with dd168fc; this entry makes it discoverable.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants