Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
2e76516
fix(onnx2jax): relax jaxonnxruntime strict-mode for Constant shape args
AlexanderFengler May 14, 2026
d1d7ffe
fix(onnx2jax): auto-enable jax_enable_x64 for int64-bearing graphs (C7a)
AlexanderFengler May 14, 2026
f90ee60
docs(tutorials): sbi NLE + NRE integration tutorial (C8 keystone)
AlexanderFengler May 14, 2026
a91f003
docs(tutorials): robust lanfactory import fallback in sbi NLE notebook
AlexanderFengler May 14, 2026
f94b496
docs(tutorials): use cores=1 in sbi NLE notebook (avoid cloudpickle o…
AlexanderFengler May 15, 2026
ce19950
fix(base): pass user's sampler choice through to bambi verbatim
AlexanderFengler May 15, 2026
5c628d0
Merge branch 'fix/sampler-routing' into sbi-integration
AlexanderFengler May 15, 2026
5856d62
docs(tutorials): drop cores=1 workaround now that sampler routing is …
AlexanderFengler May 15, 2026
7abb071
docs(tutorials): match training prior to HSSM defaults; bump training…
AlexanderFengler May 15, 2026
4c9a3ad
docs(tutorials): add NLE training-vs-sampling diagnostic cell
AlexanderFengler May 15, 2026
ebe86af
docs(tutorials): scale up sbi training to 1M sims + larger flow
AlexanderFengler May 15, 2026
a4d881e
docs(tutorials): add analytical DDM ground-truth posterior + 3-way co…
AlexanderFengler May 16, 2026
8e91d92
docs(tutorials): restructure C8 notebook around NRE; drop NLE-MAF sec…
AlexanderFengler May 16, 2026
1dde8f9
docs(tutorials): add Part 5c deeper NRE diagnostics (logit sweep + ex…
AlexanderFengler May 17, 2026
572e74b
docs(tutorials): bisect step 1 — remove FCEmbedding + tighten MCMC bu…
AlexanderFengler May 17, 2026
c589b14
docs(tutorials): revert NRE config to last-known-working baseline + m…
AlexanderFengler May 17, 2026
0e14c3e
fix(tutorials): use np.ptp() — ndarray.ptp() removed in NumPy 2.0
AlexanderFengler May 17, 2026
f772325
docs(tutorial): user-configurable ARTIFACT_DIR + TUTORIAL_LOG_DIR
AlexanderFengler May 18, 2026
84ca897
refactor(onnx2jax): use hssm logger instead of warnings.warn
AlexanderFengler May 22, 2026
dd168fc
refactor(onnx2jax): drop _ensure_x64 auto-flip; precast int64; guard …
AlexanderFengler May 23, 2026
5c8275a
docs(claude): document the ONNX single-trial + vmap contract
AlexanderFengler May 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,14 @@ uv run ruff format .

`HSSM.sample()` passes `**kwargs` through to `bambi.Model.fit()`, which in turn passes them to PyMC's `pm.sample()`. So parameters like `cores`, `chains`, `nuts_sampler`, `target_accept`, etc. are valid even though they don't appear in HSSM's own signature. Similarly, the HSSM constructor passes `**kwargs` to `bambi.Model()`, so bambi parameters like `noncentered` are valid.

### ONNX likelihoods are single-trial + `jax.vmap`

Every ONNX graph consumed by HSSM must be exported with a concrete single-trial input shape (no `dynamic_axes`). Per-trial batching happens at the HSSM layer via `jax.vmap` over trials — see [`src/hssm/distribution_utils/onnx.py:115-138`](src/hssm/distribution_utils/onnx.py#L115-L138), where `logp(*inputs)` builds one flat per-trial vector and `make_vmap_func` lifts it.

Enforced at load time by `_check_single_trial_input_shape` in [`src/hssm/distribution_utils/onnx_utils/onnx2jax.py`](src/hssm/distribution_utils/onnx_utils/onnx2jax.py), which raises a `ValueError` on any symbolic input dim. The constraint exists because `jaxonnxruntime` traces against the construction-time dummy and bakes those shapes into the returned closure — calling that closure at a different batch size silently produces wrong outputs for graphs with batch-dependent intermediates (log-det accumulators, `Reshape` with `-1`).

LANfactory's exporters (`transform_sbi_to_onnx`, BayesFlow LRE export) already follow this convention. A new ONNX source must do the same: trace with a rank-1 dummy, no `dynamic_axes`.

### Notebook execution in CI

Two separate skip mechanisms for notebooks:
Expand Down
Loading