From ac16eda231c56d47fd8b58d9dd127d0c4aee922a Mon Sep 17 00:00:00 2001 From: Alexander Date: Wed, 13 May 2026 17:31:02 -0400 Subject: [PATCH 1/8] feat(onnx): scaffold sbi-to-ONNX exporter (C1) Adds a stub transform_sbi_to_onnx in lanfactory/onnx/sbi.py as a sibling of the existing LAN exporter. Extends the `all` extra to pull sbi and nflows, and adds jaxonnxruntime to the dev group for round-trip testing. First commit of the sbi -> HSSM integration plan (plans/sbi-onnx-integration.md in HSSMSpine). Implementation lands in C3 (NLE) and C4 (NRE). Co-Authored-By: Claude Opus 4.7 (1M context) --- pyproject.toml | 8 +- src/lanfactory/onnx/__init__.py | 3 +- src/lanfactory/onnx/sbi.py | 66 +++++++++++ uv.lock | 192 ++++++++++++++++++++++++++++++++ 4 files changed, 267 insertions(+), 2 deletions(-) create mode 100644 src/lanfactory/onnx/sbi.py diff --git a/pyproject.toml b/pyproject.toml index a460283..7f67387 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,12 @@ keywords = [ [project.optional-dependencies] mlflow = ["mlflow>=3.6.0"] hf = ["huggingface-hub>=0.20.0"] -all = ["mlflow>=3.6.0", "huggingface-hub>=0.20.0"] +all = [ + "mlflow>=3.6.0", + "huggingface-hub>=0.20.0", + "sbi>=0.26", + "nflows>=0.14", +] [dependency-groups] dev = [ @@ -74,6 +79,7 @@ dev = [ "ruff>=0.14.4", "types-PyYAML", "mlflow>=3.6.0", + "jaxonnxruntime>=0.3", ] [tool.setuptools.packages.find] diff --git a/src/lanfactory/onnx/__init__.py b/src/lanfactory/onnx/__init__.py index 8db9b26..1924f47 100755 --- a/src/lanfactory/onnx/__init__.py +++ b/src/lanfactory/onnx/__init__.py @@ -1,3 +1,4 @@ +from .sbi import transform_sbi_to_onnx from .transform_onnx import transform_to_onnx -__all__ = ["transform_to_onnx"] +__all__ = ["transform_to_onnx", "transform_sbi_to_onnx"] diff --git a/src/lanfactory/onnx/sbi.py b/src/lanfactory/onnx/sbi.py new file mode 100644 index 0000000..bdf1c8c --- /dev/null +++ b/src/lanfactory/onnx/sbi.py @@ -0,0 +1,66 @@ +"""Export trained sbi estimators to ONNX for HSSM consumption. + +The single public entry point is :func:`transform_sbi_to_onnx`, which wraps a +trained sbi density or ratio estimator and writes a single-trial ONNX graph +that HSSM's ``loglik_kind="approx_differentiable"`` path can load via +``jaxonnxruntime``. + +This module is intentionally a sibling of :mod:`lanfactory.onnx.transform_onnx` +— the LAN exporter — so that "train a network and emit an ONNX HSSM can read" +stays a single conceptual home in LANfactory regardless of which library +trained the network. + +Implementation lands in C3 (NLE path) and C4 (NRE path). See +``plans/sbi-onnx-integration.md`` in HSSMSpine for the full plan. +""" + +from __future__ import annotations + +from typing import Any, Literal + +__all__ = ["transform_sbi_to_onnx"] + + +def transform_sbi_to_onnx( + estimator: Any, + path: str, + *, + mode: Literal["nle", "nre"] = "nle", + example_theta_dim: int | None = None, + example_x_dim: int | None = None, + opset: int = 17, +) -> None: + """Export a trained sbi estimator to a single-trial ONNX graph. + + Parameters + ---------- + estimator + A trained sbi estimator. For ``mode="nle"`` this is a + ``ConditionalDensityEstimator`` (as returned by ``NLE_A.train()``); for + ``mode="nre"`` it is a ``RatioEstimator`` (from ``NRE_A``/``B``/``C``, + ``BNRE``). + path + Filesystem path to write the ``.onnx`` artifact to. + mode + ``"nle"`` exports ``estimator.log_prob`` as the log-likelihood with the + standardization Jacobian baked in. ``"nre"`` exports the classifier + logit as the log-likelihood up to a θ-independent constant. + example_theta_dim + Parameter-vector dimensionality used to trace the graph. Required. + example_x_dim + Observation-vector dimensionality used to trace the graph. Required. + opset + ONNX opset version. Pinned to 17 by default for reproducibility against + ``jaxonnxruntime``. + + Notes + ----- + Only likelihood-shaped families are supported. NPE/posterior estimators, + score-based / flow-matching estimators (FMPE, NPSE), TabPFN-based + estimators, and neural spline flows (blocked on a missing ``SearchSorted`` + op in ``jaxonnxruntime``) are rejected with a clear error at export time. + """ + raise NotImplementedError( + "transform_sbi_to_onnx is scaffolded but not yet implemented. " + "The NLE path lands in commit C3; see plans/sbi-onnx-integration.md." + ) diff --git a/uv.lock b/uv.lock index 4b3648a..1f96995 100644 --- a/uv.lock +++ b/uv.lock @@ -1039,6 +1039,57 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0b/55/2321e43595e6801e105fcfdee02b34c0f996eb71e6ddffca6b10b7e1d771/greenlet-3.2.4-cp313-cp313-win_amd64.whl", hash = "sha256:554b03b6e73aaabec3745364d6239e9e012d64c68ccd0b8430c64ccc14939a8b", size = 299685 }, ] +[[package]] +name = "grpcio" +version = "1.80.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b7/48/af6173dbca4454f4637a4678b67f52ca7e0c1ed7d5894d89d434fecede05/grpcio-1.80.0.tar.gz", hash = "sha256:29aca15edd0688c22ba01d7cc01cb000d72b2033f4a3c72a81a19b56fd143257", size = 12978905 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9d/cd/bb7b7e54084a344c03d68144450da7ddd5564e51a298ae1662de65f48e2d/grpcio-1.80.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:886457a7768e408cdce226ad1ca67d2958917d306523a0e21e1a2fdaa75c9c9c", size = 6050363 }, + { url = "https://files.pythonhosted.org/packages/16/02/1417f5c3460dea65f7a2e3c14e8b31e77f7ffb730e9bfadd89eda7a9f477/grpcio-1.80.0-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:7b641fc3f1dc647bfd80bd713addc68f6d145956f64677e56d9ebafc0bd72388", size = 12026037 }, + { url = "https://files.pythonhosted.org/packages/43/98/c910254eedf2cae368d78336a2de0678e66a7317d27c02522392f949b5c6/grpcio-1.80.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:33eb763f18f006dc7fee1e69831d38d23f5eccd15b2e0f92a13ee1d9242e5e02", size = 6602306 }, + { url = "https://files.pythonhosted.org/packages/7c/f8/88ca4e78c077b2b2113d95da1e1ab43efd43d723c9a0397d26529c2c1a56/grpcio-1.80.0-cp310-cp310-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:52d143637e3872633fc7dd7c3c6a1c84e396b359f3a72e215f8bf69fd82084fc", size = 7301535 }, + { url = "https://files.pythonhosted.org/packages/f9/96/f28660fe2fe0f153288bf4a04e4910b7309d442395135c88ed4f5b3b8b40/grpcio-1.80.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c51bf8ac4575af2e0678bccfb07e47321fc7acb5049b4482832c5c195e04e13a", size = 6808669 }, + { url = "https://files.pythonhosted.org/packages/47/eb/3f68a5e955779c00aeef23850e019c1c1d0e032d90633ba49c01ad5a96e0/grpcio-1.80.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:50a9871536d71c4fba24ee856abc03a87764570f0c457dd8db0b4018f379fed9", size = 7409489 }, + { url = "https://files.pythonhosted.org/packages/5b/a7/d2f681a4bfb881be40659a309771f3bdfbfdb1190619442816c3f0ffc079/grpcio-1.80.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:a72d84ad0514db063e21887fbacd1fd7acb4d494a564cae22227cd45c7fbf199", size = 8423167 }, + { url = "https://files.pythonhosted.org/packages/97/8a/29b4589c204959aa35ce5708400a05bba72181807c45c47b3ec000c39333/grpcio-1.80.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:f7691a6788ad9196872f95716df5bc643ebba13c97140b7a5ee5c8e75d1dea81", size = 7846761 }, + { url = "https://files.pythonhosted.org/packages/6b/d2/ed143e097230ee121ac5848f6ff14372dba91289b10b536d54fb1b7cbae7/grpcio-1.80.0-cp310-cp310-win32.whl", hash = "sha256:46c2390b59d67f84e882694d489f5b45707c657832d7934859ceb8c33f467069", size = 4156534 }, + { url = "https://files.pythonhosted.org/packages/d5/c9/df8279bb49b29409995e95efa85b72973d62f8aeff89abee58c91f393710/grpcio-1.80.0-cp310-cp310-win_amd64.whl", hash = "sha256:dc053420fc75749c961e2a4c906398d7c15725d36ccc04ae6d16093167223b58", size = 4889869 }, + { url = "https://files.pythonhosted.org/packages/5d/db/1d56e5f5823257b291962d6c0ce106146c6447f405b60b234c4f222a7cde/grpcio-1.80.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:dfab85db094068ff42e2a3563f60ab3dddcc9d6488a35abf0132daec13209c8a", size = 6055009 }, + { url = "https://files.pythonhosted.org/packages/6e/18/c83f3cad64c5ca63bca7e91e5e46b0d026afc5af9d0a9972472ceba294b3/grpcio-1.80.0-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:5c07e82e822e1161354e32da2662f741a4944ea955f9f580ec8fb409dd6f6060", size = 12035295 }, + { url = "https://files.pythonhosted.org/packages/0f/8e/e14966b435be2dda99fbe89db9525ea436edc79780431a1c2875a3582644/grpcio-1.80.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ba0915d51fd4ced2db5ff719f84e270afe0e2d4c45a7bdb1e8d036e4502928c2", size = 6610297 }, + { url = "https://files.pythonhosted.org/packages/cc/26/d5eb38f42ce0e3fdc8174ea4d52036ef8d58cc4426cb800f2610f625dd75/grpcio-1.80.0-cp311-cp311-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:3cb8130ba457d2aa09fa6b7c3ed6b6e4e6a2685fce63cb803d479576c4d80e21", size = 7300208 }, + { url = "https://files.pythonhosted.org/packages/25/51/bd267c989f85a17a5b3eea65a6feb4ff672af41ca614e5a0279cc0ea381c/grpcio-1.80.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:09e5e478b3d14afd23f12e49e8b44c8684ac3c5f08561c43a5b9691c54d136ab", size = 6813442 }, + { url = "https://files.pythonhosted.org/packages/9e/d9/d80eef735b19e9169e30164bbf889b46f9df9127598a83d174eb13a48b26/grpcio-1.80.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:00168469238b022500e486c1c33916acf2f2a9b2c022202cf8a1885d2e3073c1", size = 7414743 }, + { url = "https://files.pythonhosted.org/packages/de/f2/567f5bd5054398ed6b0509b9a30900376dcf2786bd936812098808b49d8d/grpcio-1.80.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:8502122a3cc1714038e39a0b071acb1207ca7844208d5ea0d091317555ee7106", size = 8426046 }, + { url = "https://files.pythonhosted.org/packages/62/29/73ef0141b4732ff5eacd68430ff2512a65c004696997f70476a83e548e7e/grpcio-1.80.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ce1794f4ea6cc3ca29463f42d665c32ba1b964b48958a66497917fe9069f26e6", size = 7851641 }, + { url = "https://files.pythonhosted.org/packages/46/69/abbfa360eb229a8623bab5f5a4f8105e445bd38ce81a89514ba55d281ad0/grpcio-1.80.0-cp311-cp311-win32.whl", hash = "sha256:51b4a7189b0bef2aa30adce3c78f09c83526cf3dddb24c6a96555e3b97340440", size = 4154368 }, + { url = "https://files.pythonhosted.org/packages/6f/d4/ae92206d01183b08613e846076115f5ac5991bae358d2a749fa864da5699/grpcio-1.80.0-cp311-cp311-win_amd64.whl", hash = "sha256:02e64bb0bb2da14d947a49e6f120a75e947250aebe65f9629b62bb1f5c14e6e9", size = 4894235 }, + { url = "https://files.pythonhosted.org/packages/5c/e8/a2b749265eb3415abc94f2e619bbd9e9707bebdda787e61c593004ec927a/grpcio-1.80.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:c624cc9f1008361014378c9d776de7182b11fe8b2e5a81bc69f23a295f2a1ad0", size = 6015616 }, + { url = "https://files.pythonhosted.org/packages/3e/97/b1282161a15d699d1e90c360df18d19165a045ce1c343c7f313f5e8a0b77/grpcio-1.80.0-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:f49eddcac43c3bf350c0385366a58f36bed8cc2c0ec35ef7b74b49e56552c0c2", size = 12014204 }, + { url = "https://files.pythonhosted.org/packages/6e/5e/d319c6e997b50c155ac5a8cb12f5173d5b42677510e886d250d50264949d/grpcio-1.80.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d334591df610ab94714048e0d5b4f3dd5ad1bee74dfec11eee344220077a79de", size = 6563866 }, + { url = "https://files.pythonhosted.org/packages/ae/f6/fdd975a2cb4d78eb67769a7b3b3830970bfa2e919f1decf724ae4445f42c/grpcio-1.80.0-cp312-cp312-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:0cb517eb1d0d0aaf1d87af7cc5b801d686557c1d88b2619f5e31fab3c2315921", size = 7273060 }, + { url = "https://files.pythonhosted.org/packages/db/f0/a3deb5feba60d9538a962913e37bd2e69a195f1c3376a3dd44fe0427e996/grpcio-1.80.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4e78c4ac0d97dc2e569b2f4bcbbb447491167cb358d1a389fc4af71ab6f70411", size = 6782121 }, + { url = "https://files.pythonhosted.org/packages/ca/84/36c6dcfddc093e108141f757c407902a05085e0c328007cb090d56646cdf/grpcio-1.80.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2ed770b4c06984f3b47eb0517b1c69ad0b84ef3f40128f51448433be904634cd", size = 7383811 }, + { url = "https://files.pythonhosted.org/packages/7c/ef/f3a77e3dc5b471a0ec86c564c98d6adfa3510d38f8ee99010410858d591e/grpcio-1.80.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:256507e2f524092f1473071a05e65a5b10d84b82e3ff24c5b571513cfaa61e2f", size = 8393860 }, + { url = "https://files.pythonhosted.org/packages/9b/8d/9d4d27ed7f33d109c50d6b5ce578a9914aa68edab75d65869a17e630a8d1/grpcio-1.80.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:9a6284a5d907c37db53350645567c522be314bac859a64a7a5ca63b77bb7958f", size = 7830132 }, + { url = "https://files.pythonhosted.org/packages/14/e4/9990b41c6d7a44e1e9dee8ac11d7a9802ba1378b40d77468a7761d1ad288/grpcio-1.80.0-cp312-cp312-win32.whl", hash = "sha256:c71309cfce2f22be26aa4a847357c502db6c621f1a49825ae98aa0907595b193", size = 4140904 }, + { url = "https://files.pythonhosted.org/packages/2f/2c/296f6138caca1f4b92a31ace4ae1b87dab692fc16a7a3417af3bb3c805bf/grpcio-1.80.0-cp312-cp312-win_amd64.whl", hash = "sha256:9fe648599c0e37594c4809d81a9e77bd138cc82eb8baa71b6a86af65426723ff", size = 4880944 }, + { url = "https://files.pythonhosted.org/packages/2f/3a/7c3c25789e3f069e581dc342e03613c5b1cb012c4e8c7d9d5cf960a75856/grpcio-1.80.0-cp313-cp313-linux_armv7l.whl", hash = "sha256:e9e408fc016dffd20661f0126c53d8a31c2821b5c13c5d67a0f5ed5de93319ad", size = 6017243 }, + { url = "https://files.pythonhosted.org/packages/04/19/21a9806eb8240e174fd1ab0cd5b9aa948bb0e05c2f2f55f9d5d7405e6d08/grpcio-1.80.0-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:92d787312e613754d4d8b9ca6d3297e69994a7912a32fa38c4c4e01c272974b0", size = 12010840 }, + { url = "https://files.pythonhosted.org/packages/18/3a/23347d35f76f639e807fb7a36fad3068aed100996849a33809591f26eca6/grpcio-1.80.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8ac393b58aa16991a2f1144ec578084d544038c12242da3a215966b512904d0f", size = 6567644 }, + { url = "https://files.pythonhosted.org/packages/ff/40/96e07ecb604a6a67ae6ab151e3e35b132875d98bc68ec65f3e5ab3e781d7/grpcio-1.80.0-cp313-cp313-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:68e5851ac4b9afe07e7f84483803ad167852570d65326b34d54ca560bfa53fb6", size = 7277830 }, + { url = "https://files.pythonhosted.org/packages/9b/e2/da1506ecea1f34a5e365964644b35edef53803052b763ca214ba3870c856/grpcio-1.80.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:873ff5d17d68992ef6605330127425d2fc4e77e612fa3c3e0ed4e668685e3140", size = 6783216 }, + { url = "https://files.pythonhosted.org/packages/44/83/3b20ff58d0c3b7f6caaa3af9a4174d4023701df40a3f39f7f1c8e7c48f9d/grpcio-1.80.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:2bea16af2750fd0a899bf1abd9022244418b55d1f37da2202249ba4ba673838d", size = 7385866 }, + { url = "https://files.pythonhosted.org/packages/47/45/55c507599c5520416de5eefecc927d6a0d7af55e91cfffb2e410607e5744/grpcio-1.80.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:ba0db34f7e1d803a878284cd70e4c63cb6ae2510ba51937bf8f45ba997cefcf7", size = 8391602 }, + { url = "https://files.pythonhosted.org/packages/10/bb/dd06f4c24c01db9cf11341b547d0a016b2c90ed7dbbb086a5710df7dd1d7/grpcio-1.80.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:8eb613f02d34721f1acf3626dfdb3545bd3c8505b0e52bf8b5710a28d02e8aa7", size = 7826752 }, + { url = "https://files.pythonhosted.org/packages/f9/1e/9d67992ba23371fd63d4527096eb8c6b76d74d52b500df992a3343fd7251/grpcio-1.80.0-cp313-cp313-win32.whl", hash = "sha256:93b6f823810720912fd131f561f91f5fed0fda372b6b7028a2681b8194d5d294", size = 4142310 }, + { url = "https://files.pythonhosted.org/packages/cf/e6/283326a27da9e2c3038bc93eeea36fb118ce0b2d03922a9cda6688f53c5b/grpcio-1.80.0-cp313-cp313-win_amd64.whl", hash = "sha256:e172cf795a3ba5246d3529e4d34c53db70e888fa582a8ffebd2e6e48bc0cba50", size = 4882833 }, +] + [[package]] name = "gunicorn" version = "23.0.0" @@ -1374,6 +1425,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/10/79/f6e80f7f4cacfc9f03e64ac57ecb856b140de7c2f939b25f8dcf1aff63f9/jaxlib-0.6.2-cp313-cp313t-manylinux2014_x86_64.whl", hash = "sha256:3abd536e44b05fb1657507e3ff1fc3691f99613bae3921ecab9e82f27255f784", size = 90066675 }, ] +[[package]] +name = "jaxonnxruntime" +version = "0.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jax" }, + { name = "jaxlib" }, + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "2.3.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/18/12/2e087eb9930d3dcf9f2ca9d745a68d324cf6f7c70896f864832e8b88bebc/jaxonnxruntime-0.3.0.tar.gz", hash = "sha256:64340d83f280f725ef068326aedc87489a39f5da67ceebcdbcb24ce777cf8198", size = 111451 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6b/99/979546f0d1f57bdd4479da74e31de15d006f0035fcf47680723aa0693741/jaxonnxruntime-0.3.0-py3-none-any.whl", hash = "sha256:72814405d611d549c1a172cfff214c1e7cff5d2b3737c3129970630f8ea5e466", size = 177663 }, +] + [[package]] name = "jedi" version = "0.19.2" @@ -1738,6 +1804,8 @@ dependencies = [ all = [ { name = "huggingface-hub" }, { name = "mlflow" }, + { name = "nflows" }, + { name = "sbi" }, ] hf = [ { name = "huggingface-hub" }, @@ -1753,6 +1821,7 @@ dev = [ { name = "ipython", version = "8.37.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "ipython", version = "9.4.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "ipywidgets" }, + { name = "jaxonnxruntime" }, { name = "jupyterlab" }, { name = "mlflow" }, { name = "mypy" }, @@ -1777,8 +1846,10 @@ requires-dist = [ { name = "matplotlib", specifier = ">=3.10.1" }, { name = "mlflow", marker = "extra == 'all'", specifier = ">=3.6.0" }, { name = "mlflow", marker = "extra == 'mlflow'", specifier = ">=3.6.0" }, + { name = "nflows", marker = "extra == 'all'", specifier = ">=0.14" }, { name = "onnx", specifier = ">=1.17.0" }, { name = "pandas", specifier = ">=2.2.3" }, + { name = "sbi", marker = "extra == 'all'", specifier = ">=0.26" }, { name = "scipy", specifier = ">=1.15.2" }, { name = "ssm-simulators", specifier = ">=0.12.2" }, { name = "torch", specifier = ">=2.7.0" }, @@ -1793,6 +1864,7 @@ dev = [ { name = "ipykernel", specifier = ">=6.29.5" }, { name = "ipython", specifier = ">=8.31.0" }, { name = "ipywidgets", specifier = ">=8.1.2" }, + { name = "jaxonnxruntime", specifier = ">=0.3" }, { name = "jupyterlab", specifier = ">=4.2.4" }, { name = "mlflow", specifier = ">=3.6.0" }, { name = "mypy", specifier = ">=1.11.1" }, @@ -1820,6 +1892,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/87/fb/99f81ac72ae23375f22b7afdb7642aba97c00a713c217124420147681a2f/mako-1.3.10-py3-none-any.whl", hash = "sha256:baef24a52fc4fc514a0887ac600f9f1cff3d82c61d4d700a1fa84d597b88db59", size = 78509 }, ] +[[package]] +name = "markdown" +version = "3.10.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2b/f4/69fa6ed85ae003c2378ffa8f6d2e3234662abd02c10d216c0ba96081a238/markdown-3.10.2.tar.gz", hash = "sha256:994d51325d25ad8aa7ce4ebaec003febcce822c3f8c911e3b17c52f7f589f950", size = 368805 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/de/1f/77fa3081e4f66ca3576c896ae5d31c3002ac6607f9747d2e3aa49227e464/markdown-3.10.2-py3-none-any.whl", hash = "sha256:e91464b71ae3ee7afd3017d9f358ef0baf158fd9a298db92f1d4761133824c36", size = 108180 }, +] + [[package]] name = "markdown-it-py" version = "3.0.0" @@ -2305,6 +2386,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/eb/8d/776adee7bbf76365fdd7f2552710282c79a4ead5d2a46408c9043a2b70ba/networkx-3.5-py3-none-any.whl", hash = "sha256:0030d386a9a06dee3565298b4a734b68589749a544acbb6c412dc9e2489ec6ec", size = 2034406 }, ] +[[package]] +name = "nflows" +version = "0.14" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "matplotlib" }, + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "2.3.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "tensorboard" }, + { name = "torch" }, + { name = "tqdm" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/bd/16/a484db41aab28332f42080435c9342fa87cfc9a4fce5495521ea1e80ca27/nflows-0.14.tar.gz", hash = "sha256:6299844a62f9999fcdf2d95cb2d01c091a50136bd17826e303aba646b2d11b55", size = 45784 } + [[package]] name = "nodeenv" version = "1.9.1" @@ -3703,6 +3798,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/30/bd/4168a751ddbbf43e86544b4de8b5c3b7be8d7167a2a5cb977d274e04f0a1/ruff-0.14.4-py3-none-win_arm64.whl", hash = "sha256:dd09c292479596b0e6fec8cd95c65c3a6dc68e9ad17b8f2382130f87ff6a75bb", size = 12663065 }, ] +[[package]] +name = "sbi" +version = "0.26.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "joblib" }, + { name = "matplotlib" }, + { name = "nflows" }, + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "2.3.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "pillow" }, + { name = "scikit-learn" }, + { name = "scipy", version = "1.15.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "scipy", version = "1.16.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "skorch" }, + { name = "tensorboard" }, + { name = "torch" }, + { name = "tqdm" }, + { name = "zuko" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/38/50/5ab2aa359d0089a97ab855a47c6349166d1bb9a654f7a3496e57be895063/sbi-0.26.1.tar.gz", hash = "sha256:bcdcd9f19318815e8e6314523f23b727b509f02c8e657cf3696c30231d23fc07", size = 4133653 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/97/9c/dfec7c59a04e1656c3e71daa0fce2dfc74362fe8819a60ff6e66c3b9c21d/sbi-0.26.1-py3-none-any.whl", hash = "sha256:3d18b78f79bb2005154f02bb5ec7cd281873e6e2def72e989dd486019d85dd40", size = 517907 }, +] + [[package]] name = "scikit-learn" version = "1.7.0" @@ -3951,6 +4071,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050 }, ] +[[package]] +name = "skorch" +version = "1.3.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "2.3.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "scikit-learn" }, + { name = "scipy", version = "1.15.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "scipy", version = "1.16.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "tabulate" }, + { name = "tqdm" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/21/12/90d072b197bef5033c694ceca3fc5714edda122d8a5ef003d8d03febcb1e/skorch-1.3.1.tar.gz", hash = "sha256:7081a0c9ab2361d524826f90c84b04a74cf55338c2b2028fa59a2e39a9019e43", size = 249195 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/56/4d/6fbe78427fa6b5c54cbdbd3b9bdacd214a5bcf9bf4dd247fc9537fba1644/skorch-1.3.1-py3-none-any.whl", hash = "sha256:bb06c65a15d0bfc765928a0b3fadf569222e7ec772f81b21d422603d52b4ad32", size = 268491 }, +] + [[package]] name = "smmap" version = "5.0.2" @@ -4103,6 +4241,46 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a2/09/77d55d46fd61b4a135c444fc97158ef34a095e5681d0a6c10b75bf356191/sympy-1.14.0-py3-none-any.whl", hash = "sha256:e091cc3e99d2141a0ba2847328f5479b05d94a6635cb96148ccb3f34671bd8f5", size = 6299353 }, ] +[[package]] +name = "tabulate" +version = "0.10.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/46/58/8c37dea7bbf769b20d58e7ace7e5edfe65b849442b00ffcdd56be88697c6/tabulate-0.10.0.tar.gz", hash = "sha256:e2cfde8f79420f6deeffdeda9aaec3b6bc5abce947655d17ac662b126e48a60d", size = 91754 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/99/55/db07de81b5c630da5cbf5c7df646580ca26dfaefa593667fc6f2fe016d2e/tabulate-0.10.0-py3-none-any.whl", hash = "sha256:f0b0622e567335c8fabaaa659f1b33bcb6ddfe2e496071b743aa113f8774f2d3", size = 39814 }, +] + +[[package]] +name = "tensorboard" +version = "2.20.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "absl-py" }, + { name = "grpcio" }, + { name = "markdown" }, + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "2.3.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "packaging" }, + { name = "pillow" }, + { name = "protobuf" }, + { name = "setuptools" }, + { name = "tensorboard-data-server" }, + { name = "werkzeug" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/9c/d9/a5db55f88f258ac669a92858b70a714bbbd5acd993820b41ec4a96a4d77f/tensorboard-2.20.0-py3-none-any.whl", hash = "sha256:9dc9f978cb84c0723acf9a345d96c184f0293d18f166bb8d59ee098e6cfaaba6", size = 5525680 }, +] + +[[package]] +name = "tensorboard-data-server" +version = "0.7.2" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7a/13/e503968fefabd4c6b2650af21e110aa8466fe21432cd7c43a84577a89438/tensorboard_data_server-0.7.2-py3-none-any.whl", hash = "sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb", size = 2356 }, + { url = "https://files.pythonhosted.org/packages/b7/85/dabeaf902892922777492e1d253bb7e1264cadce3cea932f7ff599e53fea/tensorboard_data_server-0.7.2-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60", size = 4823598 }, + { url = "https://files.pythonhosted.org/packages/73/c6/825dab04195756cf8ff2e12698f22513b3db2f64925bdd41671bfb33aaa5/tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530", size = 6590363 }, +] + [[package]] name = "tensorstore" version = "0.1.76" @@ -4535,3 +4713,17 @@ sdist = { url = "https://files.pythonhosted.org/packages/e3/02/0f2892c661036d50e wheels = [ { url = "https://files.pythonhosted.org/packages/2e/54/647ade08bf0db230bfea292f893923872fd20be6ac6f53b2b936ba839d75/zipp-3.23.0-py3-none-any.whl", hash = "sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e", size = 10276 }, ] + +[[package]] +name = "zuko" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "2.3.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "torch" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/8c/41/ddbe72cb64996d7826ba427c675252be0c38fbd9fbf8920d0fcdaf5d8e38/zuko-1.6.0.tar.gz", hash = "sha256:edc516e51bbbf9d64e7663b617cf9293c6e1e6bbfcb39559bc383383e6663b04", size = 45245 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/95/10/ff159867f522cd98e039e748c5e9777446e8b797a79be691c6b730676094/zuko-1.6.0-py3-none-any.whl", hash = "sha256:5c073b613a84a7cd65470ddb94855169020ac49432f73b85f25207e377248a4a", size = 48033 }, +] From f7c93c8f8ba7b5adc36d36835a21f6048681b39f Mon Sep 17 00:00:00 2001 From: Alexander Date: Wed, 13 May 2026 22:37:45 -0400 Subject: [PATCH 2/8] test(sbi-onnx): round-trip spike tests for MLP and nflows MAF (C2) Adds two permanent regression-guard tests validating the torch.onnx.export to {onnxruntime, jaxonnxruntime} toolchain that the sbi exporter (C3) will sit on top of. Both tests assert three-way numerical agreement to 1e-5 on fixed inputs. The MAF spike surfaced a real friction: the nflows MAF exports a Reshape whose shape argument is a Constant node rather than a model initializer, which jaxonnxruntime default strict mode rejects. Setting jaxort_only_allow_initializers_as_static_args = False works around it. Architectural implication for C3: HSSM onnx2jax.py does not set this flag today, so sbi-exported flow graphs will fail to load through the HSSM make_jax_logp_funcs_from_onnx path as-is. C3 should either constant-fold the exported graph (preferred, keeps HSSM untouched) or we will need a small HSSM-side patch. Also adds onnxruntime>=1.17 and nflows>=0.14 to the dev dependency group so uv sync --group dev is sufficient to run these tests. Co-Authored-By: Claude Opus 4.7 (1M context) --- pyproject.toml | 2 + tests/test_sbi_spike_maf_roundtrip.py | 109 ++++++++++++++++++++++++++ tests/test_sbi_spike_mlp_roundtrip.py | 71 +++++++++++++++++ uv.lock | 82 +++++++++++++++++++ 4 files changed, 264 insertions(+) create mode 100644 tests/test_sbi_spike_maf_roundtrip.py create mode 100644 tests/test_sbi_spike_mlp_roundtrip.py diff --git a/pyproject.toml b/pyproject.toml index 7f67387..9bdf9b3 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,6 +80,8 @@ dev = [ "types-PyYAML", "mlflow>=3.6.0", "jaxonnxruntime>=0.3", + "onnxruntime>=1.17", + "nflows>=0.14", ] [tool.setuptools.packages.find] diff --git a/tests/test_sbi_spike_maf_roundtrip.py b/tests/test_sbi_spike_maf_roundtrip.py new file mode 100644 index 0000000..f0eb174 --- /dev/null +++ b/tests/test_sbi_spike_maf_roundtrip.py @@ -0,0 +1,109 @@ +"""Round-trip spike: nflows MAF.log_prob -> ONNX -> {onnxruntime, jaxonnxruntime}. + +Validates that a non-trivial flow architecture (masked autoregressive flow) +survives the toolchain. Confirms that masked dense layers, log-det-Jacobian +accumulation, and the affine-autoregressive ops translate cleanly into +jaxonnxruntime. Kept as a permanent regression guard per +plans/sbi-onnx-integration.md. +""" + +from pathlib import Path + +import jax +import numpy as np +import onnx +import onnxruntime as ort +import pytest +import torch +from jaxonnxruntime import call_onnx, config +from nflows.distributions.normal import StandardNormal +from nflows.flows.base import Flow +from nflows.transforms import ( + CompositeTransform, + MaskedAffineAutoregressiveTransform, + ReversePermutation, +) +from torch import nn + +# Friction discovered in C2: nflows' MAF exports a Reshape whose shape argument +# is a Constant node (not a model initializer). jaxonnxruntime's default strict +# mode rejects this. The flag below tells jaxonnxruntime to treat Constant nodes +# as legitimate static-args during jax.jit — which is correct for our exports +# since these are genuinely constant shapes baked at export time. +# +# Architectural implication: HSSM's onnx2jax.py does NOT set this flag today. +# The real sbi exporter (C3) must either (a) post-process the exported ONNX to +# fold Constant shape nodes into initializers, or (b) we'll need a small patch +# to HSSM. (a) is preferred to keep HSSM untouched per the integration plan. +config.update("jaxort_only_allow_initializers_as_static_args", False) + + +class _MAFLogProbModule(nn.Module): + """Wraps a flow so .forward(x) returns log_prob(x) — the thing we export.""" + + def __init__(self, flow: Flow) -> None: + super().__init__() + self.flow = flow + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.flow.log_prob(x) + + +def _build_maf(features: int = 4, num_layers: int = 3, hidden: int = 32) -> Flow: + base_dist = StandardNormal(shape=[features]) + transforms: list = [] + for _ in range(num_layers): + transforms.append(ReversePermutation(features=features)) + transforms.append( + MaskedAffineAutoregressiveTransform( + features=features, hidden_features=hidden + ) + ) + return Flow(CompositeTransform(transforms), base_dist) + + +@pytest.mark.flaky(reruns=2) +def test_maf_log_prob_three_way_agreement(tmp_path: Path) -> None: + torch.manual_seed(0) + + features = 4 + flow = _build_maf(features=features) + flow.eval() + module = _MAFLogProbModule(flow).eval() + + x_input = torch.randn(1, features, dtype=torch.float32) + + with torch.no_grad(): + y_torch = module(x_input).detach().numpy() + + onnx_path = tmp_path / "maf.onnx" + dummy_input = torch.randn(1, features, requires_grad=True) + torch.onnx.export( + module, + dummy_input, + str(onnx_path), + dynamo=False, + opset_version=17, + ) + + sess = ort.InferenceSession(str(onnx_path)) + input_name = sess.get_inputs()[0].name + y_ort = sess.run(None, {input_name: x_input.numpy()})[0] + + onnx_model = onnx.load(str(onnx_path)) + model_func, model_weights = call_onnx.call_onnx_model( + onnx_model, {input_name: np.asarray(x_input.numpy())} + ) + run_func = jax.tree_util.Partial(model_func, model_weights) + y_jax = np.asarray(run_func({input_name: x_input.numpy()})[0]) + + atol = 1e-5 + assert np.allclose(y_torch, y_ort, atol=atol), ( + f"torch vs onnxruntime mismatch: max |Δ| = {np.abs(y_torch - y_ort).max()}" + ) + assert np.allclose(y_torch, y_jax, atol=atol), ( + f"torch vs jaxonnxruntime mismatch: max |Δ| = {np.abs(y_torch - y_jax).max()}" + ) + assert np.allclose(y_ort, y_jax, atol=atol), ( + f"onnxruntime vs jaxonnxruntime mismatch: max |Δ| = {np.abs(y_ort - y_jax).max()}" + ) diff --git a/tests/test_sbi_spike_mlp_roundtrip.py b/tests/test_sbi_spike_mlp_roundtrip.py new file mode 100644 index 0000000..100beb0 --- /dev/null +++ b/tests/test_sbi_spike_mlp_roundtrip.py @@ -0,0 +1,71 @@ +"""Round-trip spike: torch MLP -> ONNX -> {onnxruntime, jaxonnxruntime}. + +Validates the toolchain assumptions sbi's exporter (lands in C3) will rely on, +without sbi in the loop. If torch.onnx.export, onnxruntime, or jaxonnxruntime +regress on a vanilla MLP, this test catches it before debugging the real +exporter. Kept as a permanent regression guard per plans/sbi-onnx-integration.md. +""" + +from pathlib import Path + +import jax +import numpy as np +import onnx +import onnxruntime as ort +import pytest +import torch +from jaxonnxruntime import call_onnx +from torch import nn + + +@pytest.mark.flaky(reruns=2) +def test_mlp_three_way_agreement(tmp_path: Path) -> None: + torch.manual_seed(0) + + input_dim = 6 + hidden_dim = 32 + + model = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.Tanh(), + nn.Linear(hidden_dim, hidden_dim), + nn.Tanh(), + nn.Linear(hidden_dim, 1), + ).eval() + + x_input = torch.randn(1, input_dim, dtype=torch.float32) + + with torch.no_grad(): + y_torch = model(x_input).detach().numpy() + + onnx_path = tmp_path / "mlp.onnx" + dummy_input = torch.randn(1, input_dim, requires_grad=True) + torch.onnx.export( + model, + dummy_input, + str(onnx_path), + dynamo=False, + opset_version=17, + ) + + sess = ort.InferenceSession(str(onnx_path)) + input_name = sess.get_inputs()[0].name + y_ort = sess.run(None, {input_name: x_input.numpy()})[0] + + onnx_model = onnx.load(str(onnx_path)) + model_func, model_weights = call_onnx.call_onnx_model( + onnx_model, {input_name: np.asarray(x_input.numpy())} + ) + run_func = jax.tree_util.Partial(model_func, model_weights) + y_jax = np.asarray(run_func({input_name: x_input.numpy()})[0]) + + atol = 1e-5 + assert np.allclose(y_torch, y_ort, atol=atol), ( + f"torch vs onnxruntime mismatch: max |Δ| = {np.abs(y_torch - y_ort).max()}" + ) + assert np.allclose(y_torch, y_jax, atol=atol), ( + f"torch vs jaxonnxruntime mismatch: max |Δ| = {np.abs(y_torch - y_jax).max()}" + ) + assert np.allclose(y_ort, y_jax, atol=atol), ( + f"onnxruntime vs jaxonnxruntime mismatch: max |Δ| = {np.abs(y_ort - y_jax).max()}" + ) diff --git a/uv.lock b/uv.lock index 1f96995..bd4656f 100644 --- a/uv.lock +++ b/uv.lock @@ -814,6 +814,14 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/17/f8/01bf35a3afd734345528f98d0353f2a978a476528ad4d7e78b70c4d149dd/flask_cors-6.0.1-py3-none-any.whl", hash = "sha256:c7b2cbfb1a31aa0d2e5341eea03a6805349f7a61647daee1a15c46bbe981494c", size = 13244 }, ] +[[package]] +name = "flatbuffers" +version = "25.12.19" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e8/2d/d2a548598be01649e2d46231d151a6c56d10b964d94043a335ae56ea2d92/flatbuffers-25.12.19-py2.py3-none-any.whl", hash = "sha256:7634f50c427838bb021c2d66a3d1168e9d199b0607e6329399f04846d42e20b4", size = 26661 }, +] + [[package]] name = "flax" version = "0.10.7" @@ -1826,6 +1834,9 @@ dev = [ { name = "mlflow" }, { name = "mypy" }, { name = "nbconvert" }, + { name = "nflows" }, + { name = "onnxruntime", version = "1.24.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "onnxruntime", version = "1.26.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "pre-commit" }, { name = "ptpython" }, { name = "pytest" }, @@ -1869,6 +1880,8 @@ dev = [ { name = "mlflow", specifier = ">=3.6.0" }, { name = "mypy", specifier = ">=1.11.1" }, { name = "nbconvert", specifier = ">=7.16.5" }, + { name = "nflows", specifier = ">=0.14" }, + { name = "onnxruntime", specifier = ">=1.17" }, { name = "pre-commit", specifier = ">=2.20.0" }, { name = "ptpython", specifier = ">=3.0.29" }, { name = "pytest", specifier = ">=8.3.1" }, @@ -2721,6 +2734,75 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/84/dd/6abe5d7bd23f5ed3ade8352abf30dff1c7a9e97fc1b0a17b5d7c726e98a9/onnx-1.18.0-cp313-cp313t-win_amd64.whl", hash = "sha256:a69afd0baa372162948b52c13f3aa2730123381edf926d7ef3f68ca7cec6d0d0", size = 15865055 }, ] +[[package]] +name = "onnxruntime" +version = "1.24.3" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.11'", +] +dependencies = [ + { name = "flatbuffers", marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "packaging", marker = "python_full_version < '3.11'" }, + { name = "protobuf", marker = "python_full_version < '3.11'" }, + { name = "sympy", marker = "python_full_version < '3.11'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/15/41/3253db975a90c3ce1d475e2a230773a21cd7998537f0657947df6fb79861/onnxruntime-1.24.3-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:3e6456801c66b095c5cd68e690ca25db970ea5202bd0c5b84a2c3ef7731c5a3c", size = 17332766 }, + { url = "https://files.pythonhosted.org/packages/7e/c5/3af6b325f1492d691b23844d88ed26844c1164620860c5efe95c0e22782d/onnxruntime-1.24.3-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8b2ebc54c6d8281dccff78d4b06e47d4cf07535937584ab759448390a70f4978", size = 15130330 }, + { url = "https://files.pythonhosted.org/packages/03/4b/f96b46c1866a293ed23ca2cf5e5a63d413ad3a951da60dd877e3c56cbbca/onnxruntime-1.24.3-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fb56575d7794bf0781156955610c9e651c9504c64d42ec880784b6106244882d", size = 17213247 }, + { url = "https://files.pythonhosted.org/packages/36/13/27cf4d8df2578747584e8758aeb0b673b60274048510257f1f084b15e80e/onnxruntime-1.24.3-cp311-cp311-win_amd64.whl", hash = "sha256:c958222ef9eff54018332beecd32d5d94a3ab079d8821937b333811bf4da0d39", size = 12595530 }, + { url = "https://files.pythonhosted.org/packages/19/8c/6d9f31e6bae72a8079be12ed8ba36c4126a571fad38ded0a1b96f60f6896/onnxruntime-1.24.3-cp311-cp311-win_arm64.whl", hash = "sha256:a8f761857ebaf58a85b9e42422d03207f1d39e6bb8fecfdbf613bac5b9710723", size = 12261715 }, + { url = "https://files.pythonhosted.org/packages/d0/7f/dfdc4e52600fde4c02d59bfe98c4b057931c1114b701e175aee311a9bc11/onnxruntime-1.24.3-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:0d244227dc5e00a9ae15a7ac1eba4c4460d7876dfecafe73fb00db9f1d914d91", size = 17342578 }, + { url = "https://files.pythonhosted.org/packages/1c/dc/1f5489f7b21817d4ad352bf7a92a252bd5b438bcbaa7ad20ea50814edc79/onnxruntime-1.24.3-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0a9847b870b6cb462652b547bc98c49e0efb67553410a082fde1918a38707452", size = 15150105 }, + { url = "https://files.pythonhosted.org/packages/28/7c/fd253da53594ab8efbefdc85b3638620ab1a6aab6eb7028a513c853559ce/onnxruntime-1.24.3-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b354afce3333f2859c7e8706d84b6c552beac39233bcd3141ce7ab77b4cabb5d", size = 17237101 }, + { url = "https://files.pythonhosted.org/packages/71/5f/eaabc5699eeed6a9188c5c055ac1948ae50138697a0428d562ac970d7db5/onnxruntime-1.24.3-cp312-cp312-win_amd64.whl", hash = "sha256:44ea708c34965439170d811267c51281d3897ecfc4aa0087fa25d4a4c3eb2e4a", size = 12597638 }, + { url = "https://files.pythonhosted.org/packages/cc/5c/d8066c320b90610dbeb489a483b132c3b3879b2f93f949fb5d30cfa9b119/onnxruntime-1.24.3-cp312-cp312-win_arm64.whl", hash = "sha256:48d1092b44ca2ba6f9543892e7c422c15a568481403c10440945685faf27a8d8", size = 12270943 }, + { url = "https://files.pythonhosted.org/packages/51/8d/487ece554119e2991242d4de55de7019ac6e47ee8dfafa69fcf41d37f8ed/onnxruntime-1.24.3-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:34a0ea5ff191d8420d9c1332355644148b1bf1a0d10c411af890a63a9f662aa7", size = 17342706 }, + { url = "https://files.pythonhosted.org/packages/dd/25/8b444f463c1ac6106b889f6235c84f01eec001eaf689c3eff8c69cf48fae/onnxruntime-1.24.3-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1fd2ec7bb0fabe42f55e8337cfc9b1969d0d14622711aac73d69b4bd5abb5ed7", size = 15149956 }, + { url = "https://files.pythonhosted.org/packages/34/fc/c9182a3e1ab46940dd4f30e61071f59eee8804c1f641f37ce6e173633fb6/onnxruntime-1.24.3-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:df8e70e732fe26346faaeec9147fa38bef35d232d2495d27e93dd221a2d473a9", size = 17237370 }, + { url = "https://files.pythonhosted.org/packages/05/7e/3b549e1f4538514118bff98a1bcd6481dd9a17067f8c9af77151621c9a5c/onnxruntime-1.24.3-cp313-cp313-win_amd64.whl", hash = "sha256:2d3706719be6ad41d38a2250998b1d87758a20f6ea4546962e21dc79f1f1fd2b", size = 12597939 }, + { url = "https://files.pythonhosted.org/packages/80/41/9696a5c4631a0caa75cc8bc4efd30938fd483694aa614898d087c3ee6d29/onnxruntime-1.24.3-cp313-cp313-win_arm64.whl", hash = "sha256:b082f3ba9519f0a1a1e754556bc7e635c7526ef81b98b3f78da4455d25f0437b", size = 12270705 }, + { url = "https://files.pythonhosted.org/packages/b7/65/a26c5e59e3b210852ee04248cf8843c81fe7d40d94cf95343b66efe7eec9/onnxruntime-1.24.3-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:72f956634bc2e4bd2e8b006bef111849bd42c42dea37bd0a4c728404fdaf4d34", size = 15161796 }, + { url = "https://files.pythonhosted.org/packages/f3/25/2035b4aa2ccb5be6acf139397731ec507c5f09e199ab39d3262b22ffa1ac/onnxruntime-1.24.3-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:78d1f25eed4ab9959db70a626ed50ee24cf497e60774f59f1207ac8556399c4d", size = 17240936 }, +] + +[[package]] +name = "onnxruntime" +version = "1.26.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.13'", + "python_full_version == '3.12.*'", + "python_full_version == '3.11.*'", +] +dependencies = [ + { name = "flatbuffers", marker = "python_full_version >= '3.11'" }, + { name = "numpy", version = "2.3.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "packaging", marker = "python_full_version >= '3.11'" }, + { name = "protobuf", marker = "python_full_version >= '3.11'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/d4/81/29a9eb470994a75eb7b3ccf32be314d7c66675a00ac7b50294816cc2db27/onnxruntime-1.26.0-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:ee1109ef4ef27cad90e823399e61e03b3c6c7bfe0fb820b4baf3678c15be8b3c", size = 18005108 }, + { url = "https://files.pythonhosted.org/packages/66/c7/73efa6c8a4000c38fcc14947d84f234a17e5d66f203b37b7f1ad4a7b46eb/onnxruntime-1.26.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:35c7c7b0ac2e02001d28fab6c9fc24e9abc5e6faa35e6e19c63cecf1406ba89f", size = 16043752 }, + { url = "https://files.pythonhosted.org/packages/b6/3f/8de630f595daf6ce884d4dd95afd2a60e70ec6572e52bfee3aa2229befab/onnxruntime-1.26.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:11a8df4dcfe9ad5ff0bd71a7571dbed019fabc7594676c89fe8b86ea029c246f", size = 18176043 }, + { url = "https://files.pythonhosted.org/packages/9c/21/9f041de20787cd85498bd48e0ec4d098bf2a6c486e25b24b8dae1bf492b2/onnxruntime-1.26.0-cp311-cp311-win_amd64.whl", hash = "sha256:e6456718125fd777c673f3b78d4a9ab58d6adea641e9afae85ee6444f0e0e9a9", size = 13023165 }, + { url = "https://files.pythonhosted.org/packages/0e/82/3b9fe0ead2557cc3adf74c74c141bd1c7c4c6a9548c610af37df199f4512/onnxruntime-1.26.0-cp311-cp311-win_arm64.whl", hash = "sha256:cd920e45b730e4a87833e2910d8ca375aaca9da6ccc09e24bce463b3356d637f", size = 12789514 }, + { url = "https://files.pythonhosted.org/packages/81/b1/d111b1df656761f980d9e298a60039a9cb66036b1d039e777537743d0ac3/onnxruntime-1.26.0-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:05b028781b322ad74b57ce5b50aa5280bb1fe96ceec334628ade681e0b24c1ac", size = 18016624 }, + { url = "https://files.pythonhosted.org/packages/f6/a0/3f9d896a0385a36bd04345d6d0b802821a5782adde562e7e135f6bb71c73/onnxruntime-1.26.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:91f2bb870a4b9224eba0a6728c1fa7a9e552b8e59e1083c51fbbc3d013f2b5c0", size = 16052692 }, + { url = "https://files.pythonhosted.org/packages/7c/43/2a4e04f8dbeffad19bbcced4bcd4289bf478921518437404d6b92bdf213b/onnxruntime-1.26.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9b6dd70599005bd1bf29779f04a91978b92b5e719c11a20068a8f8e535f725b6", size = 18185439 }, + { url = "https://files.pythonhosted.org/packages/44/fc/026d0a7162b9c2153dac292baea9e027c42304dc1d9dc6f8ff5b4cfbaedd/onnxruntime-1.26.0-cp312-cp312-win_amd64.whl", hash = "sha256:a26374dc7fbcaae593601086b242120e13f2310558df0991da6dd8b8fac00414", size = 13026427 }, + { url = "https://files.pythonhosted.org/packages/3e/27/1dcf88e45e4c69db5f7b106f2dacc3801ba98994e082ca03e1dfdf7bfe57/onnxruntime-1.26.0-cp312-cp312-win_arm64.whl", hash = "sha256:54a8053410fd31fd66469bd754fcfe8a4df9f7eb44756b4b5479bf50c842d948", size = 12796647 }, + { url = "https://files.pythonhosted.org/packages/cf/a2/c801242685e0ce48a4ca51dfafbb588765e0446397e123be53ba5598f3f5/onnxruntime-1.26.0-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:ccce19c5f771b8268902f77d9fed9e88f9499465d6780808faa6611a789d33f0", size = 18016563 }, + { url = "https://files.pythonhosted.org/packages/e2/64/0492c0b1db04e29b2630c87cfa36f9d6872b1ca8614b90c5cad58fac7d76/onnxruntime-1.26.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bdbed8cf3b672b66acb032f33a253bc27f42bce6ece48ae3fab4fa483a5e96e0", size = 16052634 }, + { url = "https://files.pythonhosted.org/packages/3d/26/4d09ddc755a84fc8d5e192991626b0e0680e8f6c5d58f4f1d05c42bc48cf/onnxruntime-1.26.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c07af6fc6d5557835f2b6ee7a96d8b3235d0c57a8e230efdedaee106a8a3cbc6", size = 18185632 }, + { url = "https://files.pythonhosted.org/packages/77/89/3e52249aa08fa301e217ecba07b5246a8338fa2b401e109326e3fc5be0f9/onnxruntime-1.26.0-cp313-cp313-win_amd64.whl", hash = "sha256:61bec80655efa460591c2bc655392d57d2650ce85533a6b9b3b7a790d7ea7916", size = 13026751 }, + { url = "https://files.pythonhosted.org/packages/06/b3/c1c8782b14af6797c303de132d6eef26a9fb80dfacd3750ce57911d11c6b/onnxruntime-1.26.0-cp313-cp313-win_arm64.whl", hash = "sha256:a6677545ff451e3539a02746d2f207d8c5baa4a0a818886bb9d6a6eb9511ee89", size = 12796807 }, + { url = "https://files.pythonhosted.org/packages/c3/f5/47b0676408abec652c14b84d7173e389837832d850c24f87184277313e8d/onnxruntime-1.26.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5e016edc15d3c19f36807e1c6b10be5b27807688c32720f91b5ae480a95215d0", size = 16057265 }, + { url = "https://files.pythonhosted.org/packages/3b/45/33ab6deeef010ca844c877dd618cebc079590bbe52d2a3678e7223b1b908/onnxruntime-1.26.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f5fc48a91a046a6a5c9b147f83fb41d65d24d24923373b222cdd248f0f4f4aac", size = 18197590 }, +] + [[package]] name = "opentelemetry-api" version = "1.38.0" From bdcabd3835544cecdaa0874a4ac1ba1cd358759f Mon Sep 17 00:00:00 2001 From: Alexander Date: Wed, 13 May 2026 23:42:14 -0400 Subject: [PATCH 3/8] feat(onnx): implement sbi NLE -> ONNX exporter and verify round-trip (C3) Replaces the C1 stub with the real transform_sbi_to_onnx implementation for mode=nle. The exporter wraps an sbi ConditionalDensityEstimator (NLE_A trained estimator) as a torch.nn.Module whose forward(combined) splits a concatenated (theta, x) input and returns log p(x | theta) with sbi standardization Jacobian baked into the traced graph. Exports a single-trial graph at opset 17, matching the LAN convention and HSSM vmap-over-trials expectation. Rejection paths: - Score-based, flow-matching, TabPFN estimators raise ValueError. - NLE mode requires .log_prob(input, condition); clear TypeError if absent. - NRE mode currently raises NotImplementedError (lands in C4). Tests in test_sbi_nle_export.py train a tiny 2D Gaussian NLE_A with MAF and verify: 1. Three-way numerical agreement (torch / onnxruntime / jaxonnxruntime) to atol=1e-5 on a fixed test point. 2. Gradient agreement (torch.autograd vs jax.grad of the translated graph) to atol=1e-4. 3. Sanity check that log-prob ordering matches the analytical Gaussian (near-mean point ranks above far point). 4. Three rejection-path tests for the error contracts above. Two findings surfaced during C3 that affect later commits: - 1D MAFs in sbi collapse to a degenerate Gaussian path with zero-width Gemm contractions that jaxonnxruntime cannot handle. The exporter must be exercised with >=2D theta and x. Documented in the simulator docstring. - jaxonnxruntime silently truncates int64 indices in exported flow graphs to int32, causing ~0.5 drift in log-prob outputs. The fix is jax.config.update("jax_enable_x64", True) BEFORE any JAX import. The test file sets this. C7 will decide whether HSSM onnx2jax.py should also set it globally (mirrors the C2.5 flag patch) or whether it stays a user responsibility documented in C6/C8. Also adds sbi-logs/ to .gitignore (sbi auto-writes tensorboard logs during training). Co-Authored-By: Claude Opus 4.7 (1M context) --- .gitignore | 3 + src/lanfactory/onnx/sbi.py | 125 +++++++++++++++---- tests/test_sbi_nle_export.py | 235 +++++++++++++++++++++++++++++++++++ 3 files changed, 341 insertions(+), 22 deletions(-) create mode 100644 tests/test_sbi_nle_export.py diff --git a/.gitignore b/.gitignore index d429c47..a31a6c8 100755 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,9 @@ notebooks/data/ notebooks/jax_nb_data/ notebooks/torch_nb_data/ +# sbi training tensorboard logs (sbi auto-writes here during NLE_A.train() etc.) +sbi-logs/ + notebooks/test_notebooks/data/ hssm_paper/ diff --git a/src/lanfactory/onnx/sbi.py b/src/lanfactory/onnx/sbi.py index bdf1c8c..2f21b56 100644 --- a/src/lanfactory/onnx/sbi.py +++ b/src/lanfactory/onnx/sbi.py @@ -5,29 +5,49 @@ that HSSM's ``loglik_kind="approx_differentiable"`` path can load via ``jaxonnxruntime``. -This module is intentionally a sibling of :mod:`lanfactory.onnx.transform_onnx` -— the LAN exporter — so that "train a network and emit an ONNX HSSM can read" -stays a single conceptual home in LANfactory regardless of which library -trained the network. +This module is a sibling of :mod:`lanfactory.onnx.transform_onnx` (the LAN +exporter): "train a network and emit an ONNX HSSM can read" stays a single +conceptual home in LANfactory regardless of which library trained the network. -Implementation lands in C3 (NLE path) and C4 (NRE path). See -``plans/sbi-onnx-integration.md`` in HSSMSpine for the full plan. +The exported graph follows the LAN convention: a single concatenated input +of shape ``(1, theta_dim + x_dim)``. Inside the graph the input is split into +``theta`` and ``x`` and routed through the trained estimator's ``log_prob`` +(NLE mode) or classifier logit (NRE mode, lands in C4). HSSM vmaps this graph +over trials. """ from __future__ import annotations -from typing import Any, Literal +from typing import Literal + +import torch +from torch import nn __all__ = ["transform_sbi_to_onnx"] +# Estimator class names that we cannot export. Score-based / flow-matching +# require ODE integration which is not ONNX-exportable; TabPFN has awkward +# in-context input shape; neural-spline-flow estimators are blocked on a +# missing SearchSorted op in jaxonnxruntime (tracked for v1.x upstream PR). +_UNSUPPORTED_ESTIMATORS: frozenset[str] = frozenset( + { + "ScoreEstimator", + "ConditionalScoreEstimator", + "FlowMatchingEstimator", + "ConditionalFlowMatchingEstimator", + "TabPFNEstimator", + } +) + + def transform_sbi_to_onnx( - estimator: Any, + estimator: nn.Module, path: str, *, mode: Literal["nle", "nre"] = "nle", - example_theta_dim: int | None = None, - example_x_dim: int | None = None, + example_theta_dim: int, + example_x_dim: int, opset: int = 17, ) -> None: """Export a trained sbi estimator to a single-trial ONNX graph. @@ -37,30 +57,91 @@ def transform_sbi_to_onnx( estimator A trained sbi estimator. For ``mode="nle"`` this is a ``ConditionalDensityEstimator`` (as returned by ``NLE_A.train()``); for - ``mode="nre"`` it is a ``RatioEstimator`` (from ``NRE_A``/``B``/``C``, - ``BNRE``). + ``mode="nre"`` it is a ratio-estimator classifier (from ``NRE_A``/``B``/ + ``C``, ``BNRE``). path Filesystem path to write the ``.onnx`` artifact to. mode ``"nle"`` exports ``estimator.log_prob`` as the log-likelihood with the standardization Jacobian baked in. ``"nre"`` exports the classifier - logit as the log-likelihood up to a θ-independent constant. + logit as the log-likelihood up to a θ-independent constant (lands in + C4). example_theta_dim - Parameter-vector dimensionality used to trace the graph. Required. + Parameter-vector dimensionality used to trace the graph. example_x_dim - Observation-vector dimensionality used to trace the graph. Required. + Observation-vector dimensionality used to trace the graph. opset ONNX opset version. Pinned to 17 by default for reproducibility against ``jaxonnxruntime``. Notes ----- - Only likelihood-shaped families are supported. NPE/posterior estimators, - score-based / flow-matching estimators (FMPE, NPSE), TabPFN-based - estimators, and neural spline flows (blocked on a missing ``SearchSorted`` - op in ``jaxonnxruntime``) are rejected with a clear error at export time. + Only likelihood-shaped families are supported. NPE/posterior estimators are + rejected by convention (the caller asserts ``mode="nle"`` only for true + likelihood estimators). Score-based / flow-matching estimators (FMPE, + NPSE), TabPFN-based estimators, and neural spline flows (blocked on + missing ``SearchSorted`` in ``jaxonnxruntime``) are rejected with a clear + error. """ - raise NotImplementedError( - "transform_sbi_to_onnx is scaffolded but not yet implemented. " - "The NLE path lands in commit C3; see plans/sbi-onnx-integration.md." + estimator_cls = type(estimator).__name__ + if estimator_cls in _UNSUPPORTED_ESTIMATORS: + raise ValueError( + f"transform_sbi_to_onnx does not support {estimator_cls}. " + "Score-based, flow-matching, and TabPFN estimators are out of v1 " + "scope; neural spline flows are blocked on a missing SearchSorted " + "op in jaxonnxruntime (queued as a v1.x upstream PR). See " + "plans/sbi-onnx-integration.md in HSSMSpine for the full matrix." + ) + + if mode == "nle": + if not hasattr(estimator, "log_prob"): + raise TypeError( + f"NLE mode requires an estimator with " + f".log_prob(input, condition); got {estimator_cls} which lacks " + f"it. If this is an NRE ratio classifier, use mode='nre' " + f"instead." + ) + wrapper: nn.Module = _NLELogProbWrapper( + estimator, example_theta_dim, example_x_dim + ) + elif mode == "nre": + raise NotImplementedError( + "NRE export lands in commit C4; see plans/sbi-onnx-integration.md." + ) + else: + raise ValueError(f"mode must be 'nle' or 'nre', got {mode!r}") + + wrapper.eval() + combined_input_dim = example_theta_dim + example_x_dim + dummy_input = torch.randn(1, combined_input_dim, requires_grad=True) + torch.onnx.export( + wrapper, + dummy_input, + path, + dynamo=False, + opset_version=opset, ) + + +class _NLELogProbWrapper(nn.Module): + """Wrap an NLE density estimator so forward(combined) returns log p(x|θ). + + The estimator's standardization stack is baked into the traced graph + automatically — sbi's ``ConditionalDensityEstimator.log_prob`` already + applies the z-score Jacobian correction internally on the torch side, so + tracing the outer ``.log_prob`` call captures the full corrected + likelihood. + """ + + def __init__( + self, estimator: nn.Module, theta_dim: int, x_dim: int + ) -> None: + super().__init__() + self.estimator = estimator + self.theta_dim = theta_dim + self.x_dim = x_dim + + def forward(self, combined: torch.Tensor) -> torch.Tensor: + theta = combined[..., : self.theta_dim] + x = combined[..., self.theta_dim :] + return self.estimator.log_prob(x, condition=theta) diff --git a/tests/test_sbi_nle_export.py b/tests/test_sbi_nle_export.py new file mode 100644 index 0000000..103db6e --- /dev/null +++ b/tests/test_sbi_nle_export.py @@ -0,0 +1,235 @@ +"""C3 verification: train tiny sbi NLE_A with MAF on a Gaussian toy, export to +ONNX, and verify (i) three-way numerical agreement and (ii) gradient agreement +between torch and the jax-translated graph. + +The Gaussian toy gives us a closed-form likelihood (N(theta, 1)) to sanity-check +that training did something reasonable, separate from the toolchain-equivalence +checks. +""" + +from pathlib import Path + +import jax + +# Enable x64 BEFORE importing anything that may touch JAX dtypes. ONNX graphs +# from torch.onnx.export carry int64 shape/index tensors; JAX's default +# int32 silently truncates them inside jaxonnxruntime translation, producing +# wrong numerical values (~0.5 drift from the torch reference on MAF log_prob). +jax.config.update("jax_enable_x64", True) + +import jax.numpy as jnp # noqa: E402 +import numpy as np # noqa: E402 +import onnx # noqa: E402 +import onnxruntime as ort # noqa: E402 +import pytest # noqa: E402 +import torch # noqa: E402 +from jaxonnxruntime import call_onnx, config # noqa: E402 +from sbi.inference import NLE_A # noqa: E402 +from sbi.utils import BoxUniform # noqa: E402 + +from lanfactory.onnx import transform_sbi_to_onnx # noqa: E402 + +# Same friction as C2's MAF spike — torch.onnx.export emits Reshape shapes as +# Constant nodes. HSSM's onnx2jax patch (commit 2e76516) sets this globally for +# HSSM consumers; tests here exercise jaxonnxruntime directly and must set it +# themselves. +config.update("jaxort_only_allow_initializers_as_static_args", False) + +_THETA_DIM = 2 +_X_DIM = 2 + + +def _gaussian_simulator(theta: torch.Tensor) -> torch.Tensor: + """x | theta ~ N(theta, I) — analytical likelihood available. + + The 2D shape is deliberate: a 1D MAF in sbi collapses to a degenerate + Gaussian path that emits a zero-width Gemm contraction. jaxonnxruntime + cannot handle it. 2D keeps the flow non-degenerate. + """ + return theta + torch.randn_like(theta) + + +@pytest.fixture(scope="module") +def trained_nle() -> torch.nn.Module: + """Train a tiny NLE_A on a 2D Gaussian. Small budget keeps CI fast.""" + torch.manual_seed(0) + prior = BoxUniform( + low=torch.tensor([-3.0, -3.0]), + high=torch.tensor([3.0, 3.0]), + ) + inference = NLE_A(prior=prior, density_estimator="maf") + theta = prior.sample((2000,)) + x = _gaussian_simulator(theta) + estimator = inference.append_simulations(theta, x).train( + training_batch_size=200, + max_num_epochs=15, + ) + estimator.eval() + return estimator + + +def _load_jax_runner(onnx_path: Path, combined: np.ndarray): + onnx_model = onnx.load(str(onnx_path)) + input_name = onnx_model.graph.input[0].name + model_func, model_weights = call_onnx.call_onnx_model( + onnx_model, {input_name: combined} + ) + run_func = jax.tree_util.Partial(model_func, model_weights) + return run_func, input_name + + +@pytest.mark.flaky(reruns=2) +def test_nle_export_three_way_numerical_agreement( + trained_nle: torch.nn.Module, tmp_path: Path +) -> None: + onnx_path = tmp_path / "nle.onnx" + transform_sbi_to_onnx( + trained_nle, + str(onnx_path), + mode="nle", + example_theta_dim=_THETA_DIM, + example_x_dim=_X_DIM, + ) + + theta_t = torch.tensor([[0.5, -0.2]], dtype=torch.float32) + x_t = torch.tensor([[0.7, 0.3]], dtype=torch.float32) + combined = torch.cat([theta_t, x_t], dim=-1).numpy() + + with torch.no_grad(): + y_torch = trained_nle.log_prob(x_t, condition=theta_t).detach().numpy() + + sess = ort.InferenceSession(str(onnx_path)) + input_name = sess.get_inputs()[0].name + y_ort = sess.run(None, {input_name: combined})[0] + + run_func, jax_input_name = _load_jax_runner(onnx_path, combined) + y_jax = np.asarray(run_func({jax_input_name: combined})[0]) + + atol = 1e-5 + y_torch_flat = y_torch.flatten() + y_ort_flat = y_ort.flatten() + y_jax_flat = y_jax.flatten() + + assert np.allclose(y_torch_flat, y_ort_flat, atol=atol), ( + f"torch vs onnxruntime: max |Δ| = " + f"{np.abs(y_torch_flat - y_ort_flat).max()}" + ) + assert np.allclose(y_torch_flat, y_jax_flat, atol=atol), ( + f"torch vs jaxonnxruntime: max |Δ| = " + f"{np.abs(y_torch_flat - y_jax_flat).max()}" + ) + assert np.allclose(y_ort_flat, y_jax_flat, atol=atol), ( + f"onnxruntime vs jaxonnxruntime: max |Δ| = " + f"{np.abs(y_ort_flat - y_jax_flat).max()}" + ) + + +@pytest.mark.flaky(reruns=2) +def test_nle_export_gradient_agreement( + trained_nle: torch.nn.Module, tmp_path: Path +) -> None: + """jax.grad of the translated graph should match torch.autograd.grad.""" + onnx_path = tmp_path / "nle_grad.onnx" + transform_sbi_to_onnx( + trained_nle, + str(onnx_path), + mode="nle", + example_theta_dim=_THETA_DIM, + example_x_dim=_X_DIM, + ) + + theta_t = torch.tensor([[0.5, -0.2]], dtype=torch.float32, requires_grad=True) + x_t = torch.tensor([[0.7, 0.3]], dtype=torch.float32) + + logp = trained_nle.log_prob(x_t, condition=theta_t) + (grad_torch,) = torch.autograd.grad(logp.sum(), theta_t) + grad_torch_np = grad_torch.detach().numpy().flatten() + + theta_np = theta_t.detach().numpy() + x_np = x_t.numpy() + combined_init = np.concatenate([theta_np, x_np], axis=-1).astype(np.float32) + run_func, input_name = _load_jax_runner(onnx_path, combined_init) + + def jax_logp_of_theta(theta_arr: jnp.ndarray) -> jnp.ndarray: + combined = jnp.concatenate([theta_arr, jnp.asarray(x_np)], axis=-1) + return run_func({input_name: combined})[0].sum() + + grad_jax = jax.grad(jax_logp_of_theta)(jnp.asarray(theta_np)) + grad_jax_np = np.asarray(grad_jax).flatten() + + atol = 1e-4 + assert np.allclose(grad_torch_np, grad_jax_np, atol=atol), ( + f"torch vs jax gradient mismatch: max |Δ| = " + f"{np.abs(grad_torch_np - grad_jax_np).max()} " + f"(torch={grad_torch_np}, jax={grad_jax_np})" + ) + + +def test_nle_log_prob_ordering_matches_analytical_gaussian( + trained_nle: torch.nn.Module, +) -> None: + """Sanity: trained N(theta, 1) should rank a near-mean point above a far one. + + Not a precision test — just confirms that training produced a reasonable + likelihood, not a random surface. + """ + theta = torch.tensor([[0.0, 0.0]], dtype=torch.float32) + x_near = torch.tensor([[0.0, 0.0]], dtype=torch.float32) + x_far = torch.tensor([[2.5, 2.5]], dtype=torch.float32) + + with torch.no_grad(): + lp_near = trained_nle.log_prob(x_near, condition=theta).item() + lp_far = trained_nle.log_prob(x_far, condition=theta).item() + + assert lp_near > lp_far, ( + f"trained log-prob should be higher near the mean: " + f"lp_near={lp_near}, lp_far={lp_far}" + ) + + +def test_transform_rejects_unsupported_score_estimator(tmp_path: Path) -> None: + """Estimators in the unsupported set should fail loudly.""" + + class ScoreEstimator(torch.nn.Module): # noqa: D401 - name is the signal + pass + + with pytest.raises(ValueError, match="does not support"): + transform_sbi_to_onnx( + ScoreEstimator(), + str(tmp_path / "should_not_exist.onnx"), + mode="nle", + example_theta_dim=1, + example_x_dim=1, + ) + + +def test_transform_rejects_missing_log_prob(tmp_path: Path) -> None: + """NLE mode without .log_prob should raise a clear TypeError.""" + + class NotADensityEstimator(torch.nn.Module): + pass + + with pytest.raises(TypeError, match=r"\.log_prob\(input, condition\)"): + transform_sbi_to_onnx( + NotADensityEstimator(), + str(tmp_path / "should_not_exist.onnx"), + mode="nle", + example_theta_dim=1, + example_x_dim=1, + ) + + +def test_transform_nre_mode_not_yet_implemented(tmp_path: Path) -> None: + """NRE mode lands in C4 — until then, raise NotImplementedError.""" + + class FakeClassifier(torch.nn.Module): + pass + + with pytest.raises(NotImplementedError, match="C4"): + transform_sbi_to_onnx( + FakeClassifier(), + str(tmp_path / "should_not_exist.onnx"), + mode="nre", + example_theta_dim=1, + example_x_dim=1, + ) From f4a54fe101f48d0cccae4802aad9754507d0343a Mon Sep 17 00:00:00 2001 From: Alexander Date: Wed, 13 May 2026 23:50:58 -0400 Subject: [PATCH 4/8] feat(onnx): NRE path for transform_sbi_to_onnx (C4) Extends transform_sbi_to_onnx to support mode="nre". The wrapper splits the concatenated (theta, x) input and routes through the trained RatioEstimator forward, returning the logit log r(x, theta). Up to a theta-independent constant the logit IS log p(x | theta), so MCMC and HSSM posterior path treat it as the likelihood. No Jacobian correction is needed since ratios are invariant to z-score standardization. Rejection: passing an estimator with .log_prob in mode="nre" raises TypeError, since that signals a density estimator (NLE) rather than a ratio classifier (NRE). The NLE path has the symmetric check. New test file tests/test_sbi_nre_export.py trains a tiny 2D Gaussian NRE_A and verifies the same three-way numerical agreement (atol=1e-5) and gradient agreement (atol=1e-4) as the NLE path, plus a sanity ordering check (log-ratio higher at near-theta than far-theta). The C3 NRE-not-implemented test was repurposed into a cross-mode rejection test: passing an NLE density estimator with mode="nre" now raises a clear TypeError instead of NotImplementedError. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lanfactory/onnx/sbi.py | 40 +++++++- tests/test_sbi_nle_export.py | 20 ++-- tests/test_sbi_nre_export.py | 172 +++++++++++++++++++++++++++++++++++ 3 files changed, 221 insertions(+), 11 deletions(-) create mode 100644 tests/test_sbi_nre_export.py diff --git a/src/lanfactory/onnx/sbi.py b/src/lanfactory/onnx/sbi.py index 2f21b56..dee41ce 100644 --- a/src/lanfactory/onnx/sbi.py +++ b/src/lanfactory/onnx/sbi.py @@ -105,8 +105,19 @@ def transform_sbi_to_onnx( estimator, example_theta_dim, example_x_dim ) elif mode == "nre": - raise NotImplementedError( - "NRE export lands in commit C4; see plans/sbi-onnx-integration.md." + # NRE classifiers expose forward(theta, x) returning a logit; they do + # NOT have .log_prob. If the user passes a density estimator with + # mode="nre", surface the mismatch loudly — silently exporting + # estimator.forward of an NLE flow would produce a graph that is not + # a log-ratio. + if hasattr(estimator, "log_prob"): + raise TypeError( + f"NRE mode expects a ratio classifier without .log_prob; " + f"got {estimator_cls} which has .log_prob. If this is an NLE " + f"density estimator, use mode='nle' instead." + ) + wrapper = _NRELogRatioWrapper( + estimator, example_theta_dim, example_x_dim ) else: raise ValueError(f"mode must be 'nle' or 'nre', got {mode!r}") @@ -145,3 +156,28 @@ def forward(self, combined: torch.Tensor) -> torch.Tensor: theta = combined[..., : self.theta_dim] x = combined[..., self.theta_dim :] return self.estimator.log_prob(x, condition=theta) + + +class _NRELogRatioWrapper(nn.Module): + """Wrap an NRE ratio classifier so forward(combined) returns log r(x, θ). + + For NRE, the classifier logit IS the log-ratio log p(x, θ) / p(x) p(θ), + which equals log p(x | θ) − log p(x). The θ-independent term log p(x) + drops out under MCMC's accept ratios and under HSSM's posterior path, so + we treat the raw logit as the exportable log-likelihood (up to a constant). + No Jacobian correction is needed — the ratio is invariant to z-score + standardization of inputs. + """ + + def __init__( + self, estimator: nn.Module, theta_dim: int, x_dim: int + ) -> None: + super().__init__() + self.estimator = estimator + self.theta_dim = theta_dim + self.x_dim = x_dim + + def forward(self, combined: torch.Tensor) -> torch.Tensor: + theta = combined[..., : self.theta_dim] + x = combined[..., self.theta_dim :] + return self.estimator(theta, x) diff --git a/tests/test_sbi_nle_export.py b/tests/test_sbi_nle_export.py index 103db6e..9dca529 100644 --- a/tests/test_sbi_nle_export.py +++ b/tests/test_sbi_nle_export.py @@ -219,17 +219,19 @@ class NotADensityEstimator(torch.nn.Module): ) -def test_transform_nre_mode_not_yet_implemented(tmp_path: Path) -> None: - """NRE mode lands in C4 — until then, raise NotImplementedError.""" - - class FakeClassifier(torch.nn.Module): - pass +def test_nle_estimator_in_nre_mode_rejected( + trained_nle: torch.nn.Module, tmp_path: Path +) -> None: + """Passing an NLE density estimator with mode='nre' should raise. - with pytest.raises(NotImplementedError, match="C4"): + The presence of .log_prob is the signal that this is a density estimator + rather than a ratio classifier. + """ + with pytest.raises(TypeError, match=r"expects a ratio classifier"): transform_sbi_to_onnx( - FakeClassifier(), + trained_nle, str(tmp_path / "should_not_exist.onnx"), mode="nre", - example_theta_dim=1, - example_x_dim=1, + example_theta_dim=_THETA_DIM, + example_x_dim=_X_DIM, ) diff --git a/tests/test_sbi_nre_export.py b/tests/test_sbi_nre_export.py new file mode 100644 index 0000000..71abc00 --- /dev/null +++ b/tests/test_sbi_nre_export.py @@ -0,0 +1,172 @@ +"""C4 verification: train tiny sbi NRE_A on a Gaussian toy, export to ONNX, +and verify (i) three-way numerical agreement and (ii) gradient agreement +between torch and the jax-translated graph. + +NRE classifier output IS the log-ratio log p(x | theta) / p(x), so up to a +theta-independent constant it serves as the log-likelihood. No Jacobian +correction is needed (ratio invariance under z-score standardization). +""" + +from pathlib import Path + +import jax + +# x64 required before any JAX import — see test_sbi_nle_export.py for details. +jax.config.update("jax_enable_x64", True) + +import jax.numpy as jnp # noqa: E402 +import numpy as np # noqa: E402 +import onnx # noqa: E402 +import onnxruntime as ort # noqa: E402 +import pytest # noqa: E402 +import torch # noqa: E402 +from jaxonnxruntime import call_onnx, config # noqa: E402 +from sbi.inference import NRE_A # noqa: E402 +from sbi.utils import BoxUniform # noqa: E402 + +from lanfactory.onnx import transform_sbi_to_onnx # noqa: E402 + +config.update("jaxort_only_allow_initializers_as_static_args", False) + +_THETA_DIM = 2 +_X_DIM = 2 + + +def _gaussian_simulator(theta: torch.Tensor) -> torch.Tensor: + """x | theta ~ N(theta, I).""" + return theta + torch.randn_like(theta) + + +@pytest.fixture(scope="module") +def trained_nre() -> torch.nn.Module: + """Train a tiny NRE_A on a 2D Gaussian. Small budget keeps CI fast.""" + torch.manual_seed(0) + prior = BoxUniform( + low=torch.tensor([-3.0, -3.0]), + high=torch.tensor([3.0, 3.0]), + ) + inference = NRE_A(prior=prior) + theta = prior.sample((2000,)) + x = _gaussian_simulator(theta) + classifier = inference.append_simulations(theta, x).train( + training_batch_size=200, + max_num_epochs=15, + ) + classifier.eval() + return classifier + + +def _load_jax_runner(onnx_path: Path, combined: np.ndarray): + onnx_model = onnx.load(str(onnx_path)) + input_name = onnx_model.graph.input[0].name + model_func, model_weights = call_onnx.call_onnx_model( + onnx_model, {input_name: combined} + ) + run_func = jax.tree_util.Partial(model_func, model_weights) + return run_func, input_name + + +@pytest.mark.flaky(reruns=2) +def test_nre_export_three_way_numerical_agreement( + trained_nre: torch.nn.Module, tmp_path: Path +) -> None: + onnx_path = tmp_path / "nre.onnx" + transform_sbi_to_onnx( + trained_nre, + str(onnx_path), + mode="nre", + example_theta_dim=_THETA_DIM, + example_x_dim=_X_DIM, + ) + + theta_t = torch.tensor([[0.5, -0.2]], dtype=torch.float32) + x_t = torch.tensor([[0.7, 0.3]], dtype=torch.float32) + combined = torch.cat([theta_t, x_t], dim=-1).numpy() + + with torch.no_grad(): + y_torch = trained_nre(theta_t, x_t).detach().numpy() + + sess = ort.InferenceSession(str(onnx_path)) + input_name = sess.get_inputs()[0].name + y_ort = sess.run(None, {input_name: combined})[0] + + run_func, jax_input_name = _load_jax_runner(onnx_path, combined) + y_jax = np.asarray(run_func({jax_input_name: combined})[0]) + + atol = 1e-5 + y_torch_flat = y_torch.flatten() + y_ort_flat = y_ort.flatten() + y_jax_flat = y_jax.flatten() + + assert np.allclose(y_torch_flat, y_ort_flat, atol=atol), ( + f"torch vs onnxruntime: max |Δ| = " + f"{np.abs(y_torch_flat - y_ort_flat).max()}" + ) + assert np.allclose(y_torch_flat, y_jax_flat, atol=atol), ( + f"torch vs jaxonnxruntime: max |Δ| = " + f"{np.abs(y_torch_flat - y_jax_flat).max()}" + ) + assert np.allclose(y_ort_flat, y_jax_flat, atol=atol), ( + f"onnxruntime vs jaxonnxruntime: max |Δ| = " + f"{np.abs(y_ort_flat - y_jax_flat).max()}" + ) + + +@pytest.mark.flaky(reruns=2) +def test_nre_export_gradient_agreement( + trained_nre: torch.nn.Module, tmp_path: Path +) -> None: + """jax.grad of the translated graph should match torch.autograd.grad.""" + onnx_path = tmp_path / "nre_grad.onnx" + transform_sbi_to_onnx( + trained_nre, + str(onnx_path), + mode="nre", + example_theta_dim=_THETA_DIM, + example_x_dim=_X_DIM, + ) + + theta_t = torch.tensor([[0.5, -0.2]], dtype=torch.float32, requires_grad=True) + x_t = torch.tensor([[0.7, 0.3]], dtype=torch.float32) + + logr = trained_nre(theta_t, x_t) + (grad_torch,) = torch.autograd.grad(logr.sum(), theta_t) + grad_torch_np = grad_torch.detach().numpy().flatten() + + theta_np = theta_t.detach().numpy() + x_np = x_t.numpy() + combined_init = np.concatenate([theta_np, x_np], axis=-1).astype(np.float32) + run_func, input_name = _load_jax_runner(onnx_path, combined_init) + + def jax_logr_of_theta(theta_arr: jnp.ndarray) -> jnp.ndarray: + combined = jnp.concatenate([theta_arr, jnp.asarray(x_np)], axis=-1) + return run_func({input_name: combined})[0].sum() + + grad_jax = jax.grad(jax_logr_of_theta)(jnp.asarray(theta_np)) + grad_jax_np = np.asarray(grad_jax).flatten() + + atol = 1e-4 + assert np.allclose(grad_torch_np, grad_jax_np, atol=atol), ( + f"torch vs jax gradient mismatch: max |Δ| = " + f"{np.abs(grad_torch_np - grad_jax_np).max()} " + f"(torch={grad_torch_np}, jax={grad_jax_np})" + ) + + +def test_nre_log_ratio_ordering(trained_nre: torch.nn.Module) -> None: + """Sanity: the log-ratio at theta=mean(x_obs) should exceed that at a + distant theta. Not a precision test — just confirms training produced a + surface where the ratio behaves reasonably. + """ + x_obs = torch.tensor([[0.0, 0.0]], dtype=torch.float32) + theta_near = torch.tensor([[0.0, 0.0]], dtype=torch.float32) + theta_far = torch.tensor([[2.5, 2.5]], dtype=torch.float32) + + with torch.no_grad(): + logr_near = trained_nre(theta_near, x_obs).item() + logr_far = trained_nre(theta_far, x_obs).item() + + assert logr_near > logr_far, ( + f"log-ratio should be higher near the true theta: " + f"logr_near={logr_near}, logr_far={logr_far}" + ) From b1fd1880b5a9516511cece869e698a809c4aa963 Mon Sep 17 00:00:00 2001 From: Alexander Date: Thu, 14 May 2026 00:00:59 -0400 Subject: [PATCH 5/8] test(sbi-onnx): embedding-net regression coverage (C5) Adds tests/test_sbi_embeddings.py exercising NRE_A with two embedding nets on x: - FCEmbedding (representative flat-MLP embedding) - CNNEmbedding (1D conv stack; validates Conv / MaxPool / etc. survive torch.onnx.export and translate cleanly into jaxonnxruntime) Both tests train a tiny 2D-theta / 10-dim-x linear-Gaussian classifier and assert three-way numerical agreement (torch / onnxruntime / jaxonnxruntime, atol=1e-5). Other sbi embeddings (PermutationInvariantEmbedding, ResNetEmbedding1D, ResNetEmbedding2D, LRUEmbedding, TransformerEmbedding, CausalCNNEmbedding, SpectralConvEmbedding) are out of v1 scope; can be added as follow-up regressions if a user needs them. C5 finding: sbi build_mlp_classifier defaults to nn.LayerNorm between hidden layers, and jaxonnxruntime does NOT implement the LayerNormalization op (raises NotImplementedError at translation time). The fix is to pass norm_layer=nn.Identity through classifier_nn kwargs. This constraint will be documented in C6. Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/test_sbi_embeddings.py | 171 +++++++++++++++++++++++++++++++++++ 1 file changed, 171 insertions(+) create mode 100644 tests/test_sbi_embeddings.py diff --git a/tests/test_sbi_embeddings.py b/tests/test_sbi_embeddings.py new file mode 100644 index 0000000..abf57f3 --- /dev/null +++ b/tests/test_sbi_embeddings.py @@ -0,0 +1,171 @@ +"""C5 verification: NRE_A with embedding nets on x exports and round-trips. + +Tests two representative embeddings: + - FCEmbedding (an extra MLP on x): the most common embedding pattern. + - CNNEmbedding (1D conv stack on x): exercises Conv / MaxPool / etc. in + the exported ONNX. The C2 / C3 spikes did not touch Conv ops. + +Other sbi embeddings (PermutationInvariantEmbedding, ResNetEmbedding1D, +TransformerEmbedding, ...) are out of v1 scope; can be added as follow-up +regression tests if a user needs them. +""" + +from pathlib import Path + +import jax + +jax.config.update("jax_enable_x64", True) + +import numpy as np # noqa: E402 +import onnx # noqa: E402 +import onnxruntime as ort # noqa: E402 +import pytest # noqa: E402 +import torch # noqa: E402 +from jaxonnxruntime import call_onnx, config # noqa: E402 +from sbi.inference import NRE_A # noqa: E402 +from sbi.neural_nets import classifier_nn # noqa: E402 +from sbi.neural_nets.embedding_nets import CNNEmbedding, FCEmbedding # noqa: E402 +from sbi.utils import BoxUniform # noqa: E402 +from torch import nn # noqa: E402 + +from lanfactory.onnx import transform_sbi_to_onnx # noqa: E402 + +config.update("jaxort_only_allow_initializers_as_static_args", False) + +# sbi's build_mlp_classifier defaults to nn.LayerNorm between hidden layers, but +# jaxonnxruntime does not implement the LayerNormalization op. Passing +# norm_layer=nn.Identity disables it. Documenting this constraint in C6 docs: +# users training their own NRE classifiers must disable LayerNorm for export. + +_THETA_DIM = 2 +_X_DIM = 10 # 10-dim flat x — enough to make embedding non-trivial + + +def _simulator(theta: torch.Tensor) -> torch.Tensor: + """x | theta: stack of 10 i.i.d. N(theta[:, 0], 1) and N(theta[:, 1], 1). + + Concretely: x is a 10-vector whose first 5 dims are ~ N(theta[0], 1) and + last 5 dims are ~ N(theta[1], 1). Linear-Gaussian, easy enough for a + tiny NRE classifier to pick up. + """ + batch = theta.shape[0] + first_half = theta[:, 0:1] + torch.randn(batch, 5) + second_half = theta[:, 1:2] + torch.randn(batch, 5) + return torch.cat([first_half, second_half], dim=-1) + + +def _three_way_agreement( + trained_classifier: torch.nn.Module, onnx_path: Path +) -> None: + """Shared assertion: torch / onnxruntime / jaxonnxruntime all agree.""" + theta_t = torch.tensor([[0.3, -0.4]], dtype=torch.float32) + x_t = torch.randn(1, _X_DIM, dtype=torch.float32) + combined = torch.cat([theta_t, x_t], dim=-1).numpy() + + with torch.no_grad(): + y_torch = trained_classifier(theta_t, x_t).detach().numpy().flatten() + + sess = ort.InferenceSession(str(onnx_path)) + input_name = sess.get_inputs()[0].name + y_ort = sess.run(None, {input_name: combined})[0].flatten() + + onnx_model = onnx.load(str(onnx_path)) + model_func, model_weights = call_onnx.call_onnx_model( + onnx_model, {input_name: combined} + ) + run_func = jax.tree_util.Partial(model_func, model_weights) + y_jax = np.asarray(run_func({input_name: combined})[0]).flatten() + + atol = 1e-5 + assert np.allclose(y_torch, y_ort, atol=atol), ( + f"torch vs onnxruntime: max |Δ| = {np.abs(y_torch - y_ort).max()}" + ) + assert np.allclose(y_torch, y_jax, atol=atol), ( + f"torch vs jaxonnxruntime: max |Δ| = {np.abs(y_torch - y_jax).max()}" + ) + + +@pytest.mark.flaky(reruns=2) +def test_nre_with_fc_embedding(tmp_path: Path) -> None: + """NRE_A + FCEmbedding(x) → ONNX → round-trip.""" + torch.manual_seed(0) + prior = BoxUniform( + low=torch.tensor([-3.0, -3.0]), + high=torch.tensor([3.0, 3.0]), + ) + + embedding_x = FCEmbedding(input_dim=_X_DIM, output_dim=8, num_layers=2) + classifier_builder = classifier_nn( + model="mlp", + embedding_net_x=embedding_x, + norm_layer=nn.Identity, + ) + inference = NRE_A(prior=prior, classifier=classifier_builder) + + theta = prior.sample((1000,)) + x = _simulator(theta) + classifier = inference.append_simulations(theta, x).train( + training_batch_size=200, + max_num_epochs=10, + ) + classifier.eval() + + onnx_path = tmp_path / "nre_fc.onnx" + transform_sbi_to_onnx( + classifier, + str(onnx_path), + mode="nre", + example_theta_dim=_THETA_DIM, + example_x_dim=_X_DIM, + ) + _three_way_agreement(classifier, onnx_path) + + +@pytest.mark.flaky(reruns=2) +def test_nre_with_cnn_embedding(tmp_path: Path) -> None: + """NRE_A + CNNEmbedding(x) → ONNX → round-trip. + + Confirms that Conv / pooling ops survive torch.onnx.export and translate + cleanly into jaxonnxruntime. x is treated as a length-10 1D signal. + """ + torch.manual_seed(0) + prior = BoxUniform( + low=torch.tensor([-3.0, -3.0]), + high=torch.tensor([3.0, 3.0]), + ) + + embedding_x = CNNEmbedding( + input_shape=(_X_DIM,), + in_channels=1, + out_channels_per_layer=[4, 4], + num_conv_layers=2, + num_linear_layers=1, + num_linear_units=16, + output_dim=8, + kernel_size=3, + pool_kernel_size=2, + ) + classifier_builder = classifier_nn( + model="mlp", + embedding_net_x=embedding_x, + norm_layer=nn.Identity, + ) + inference = NRE_A(prior=prior, classifier=classifier_builder) + + theta = prior.sample((1000,)) + x = _simulator(theta) + classifier = inference.append_simulations(theta, x).train( + training_batch_size=200, + max_num_epochs=10, + ) + classifier.eval() + + onnx_path = tmp_path / "nre_cnn.onnx" + transform_sbi_to_onnx( + classifier, + str(onnx_path), + mode="nre", + example_theta_dim=_THETA_DIM, + example_x_dim=_X_DIM, + ) + _three_way_agreement(classifier, onnx_path) From 87704bbb7c5073cd1fdc338bdfccc92ce84794ab Mon Sep 17 00:00:00 2001 From: Alexander Date: Thu, 14 May 2026 00:03:22 -0400 Subject: [PATCH 6/8] docs: integration guide for sbi -> ONNX exporter (C6) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds docs/exporting_sbi_models.md as a Guides entry alongside the MLflow and HuggingFace integration guides. Wires it into mkdocs.yml nav. Adds a one-line mention in README.md pointing users to the guide. The guide covers: - Installation (pip install lanfactory[all]) - Quick-start examples for NLE and NRE - Supported architecture matrix (NLE+MAF, NRE+MLP/FC/CNN embeddings) - Explicitly-out-of-scope list (NSF, FMPE, NPSE, NPE, TabPFN) with one-sentence rationales each - Known constraints surfaced during C2-C5: * Use 2D+ for theta and x (1D MAFs degenerate in sbi) * Disable LayerNorm in NRE MLP classifiers (norm_layer=Identity) * Enable jax_enable_x64 before importing JAX in the consumer - Numerical guarantees from the regression tests (atol=1e-5 forward, atol=1e-4 gradients) - Float precision interaction with PyMC The new function transform_sbi_to_onnx is auto-documented on docs/api/onnx.md via the existing :::lanfactory.onnx mkdocstrings directive — no manual API page changes needed. Co-Authored-By: Claude Opus 4.7 (1M context) --- README.md | 5 ++ docs/exporting_sbi_models.md | 162 +++++++++++++++++++++++++++++++++++ mkdocs.yml | 1 + 3 files changed, 168 insertions(+) create mode 100644 docs/exporting_sbi_models.md diff --git a/README.md b/README.md index ee30d43..cd70fff 100755 --- a/README.md +++ b/README.md @@ -13,6 +13,11 @@ Lightweight python package to help with training [LANs](https://elifesciences.org/articles/65074) (Likelihood approximation networks). +LANfactory also ships an ONNX exporter for [`sbi`](https://github.com/sbi-dev/sbi)-trained +neural likelihood (NLE) and neural ratio (NRE) estimators, producing files +HSSM can consume via its `loglik_kind="approx_differentiable"` path. See the +[Exporting sbi Models guide](docs/exporting_sbi_models.md). + Please find the original [documentation here](https://alexanderfengler.github.io/LANfactory/). ### Cite LANfactory diff --git a/docs/exporting_sbi_models.md b/docs/exporting_sbi_models.md new file mode 100644 index 0000000..074a802 --- /dev/null +++ b/docs/exporting_sbi_models.md @@ -0,0 +1,162 @@ +# Exporting sbi-trained networks to ONNX + +LANfactory's [`transform_sbi_to_onnx`](api/onnx.md) wraps a trained +[`sbi`](https://github.com/sbi-dev/sbi) estimator and writes a single-trial +ONNX file that HSSM's `loglik_kind="approx_differentiable"` path can consume +exactly like a LAN export. Use it to bring sbi-trained NLE density estimators +or NRE ratio classifiers into a [HSSM](https://github.com/lnccbrown/HSSM) model. + +## Installation + +```bash +pip install lanfactory[all] +``` + +The `all` extra pulls `sbi>=0.26` and `nflows>=0.14` in addition to LANfactory's +other optional integrations. + +## Quick start (NLE) + +```python +import torch +from sbi.inference import NLE_A +from sbi.utils import BoxUniform +from lanfactory.onnx import transform_sbi_to_onnx + +# 1. Train a likelihood estimator (your simulator + prior here). +prior = BoxUniform(low=torch.tensor([-3.0, -3.0]), high=torch.tensor([3.0, 3.0])) +inference = NLE_A(prior=prior, density_estimator="maf") +theta = prior.sample((5_000,)) +x = my_simulator(theta) # shape: (5000, x_dim) +estimator = inference.append_simulations(theta, x).train() + +# 2. Export to a HSSM-compatible ONNX file. +transform_sbi_to_onnx( + estimator, + "ddm_nle.onnx", + mode="nle", + example_theta_dim=theta.shape[-1], + example_x_dim=x.shape[-1], +) + +# 3. Hand it to HSSM exactly like a LAN file. +import hssm +model = hssm.HSSM( + data=obs_data, + model="ddm", + model_config=my_model_config, + loglik_kind="approx_differentiable", + loglik="ddm_nle.onnx", + p_outlier=0, +) +idata = model.sample(sampler="numpyro", draws=500, tune=500, chains=2) +``` + +## Quick start (NRE) + +```python +from sbi.inference import NRE_A +inference = NRE_A(prior=prior) +classifier = inference.append_simulations(theta, x).train() +transform_sbi_to_onnx( + classifier, + "ddm_nre.onnx", + mode="nre", + example_theta_dim=theta.shape[-1], + example_x_dim=x.shape[-1], +) +``` + +The classifier logit is `log p(x, θ) / p(x) p(θ) = log p(x | θ) − log p(x)`. The +θ-independent `log p(x)` term drops out under MCMC and under HSSM's posterior +path, so the raw logit is consumed as the log-likelihood (up to a constant). No +Jacobian correction is needed — ratios are invariant to z-score +standardization. + +## Supported architectures (v1) + +| Method | Density / classifier | Embedding nets | Status | +|--------|---------------------|----------------|--------| +| **NLE_A** | MAF | none, FC on θ | ✅ supported | +| **NLE_A** | MDN, MoG | none, FC on θ | ✅ supported (untested at v1, expected to work) | +| **NRE_A / B / C / BNRE** | MLP classifier (with `norm_layer=nn.Identity`) | none, FCEmbedding, CNNEmbedding | ✅ supported | + +## Explicitly out of scope (v1) + +| Excluded | Reason | +|----------|--------| +| Neural Spline Flows (NSF coupling, NSF autoregressive) | `jaxonnxruntime` is missing the `SearchSorted` op. Targeted for a future upstream PR. | +| FMPE (flow-matching), NPSE (score-based) | `log_prob` requires ODE integration; not ONNX-exportable. | +| NPE / SNPE | Posterior-shaped, not likelihood-shaped. The HSSM ecosystem's current scope is neural likelihood surrogates. | +| TabPFN / NPE-PFN | Transformer with in-context inputs; awkward shape handling. Deferred. | + +The exporter rejects estimators whose class name is in the unsupported set with a +clear `ValueError`. If you encounter an unsupported architecture, please open an issue. + +## Known constraints + +Three constraints arose during validation and apply to anyone training their +own sbi estimators for export: + +1. **Use ≥2D for both θ and x.** sbi's `density_estimator="maf"` collapses to + a degenerate Gaussian path in 1D that emits zero-width Gemm contractions + `jaxonnxruntime` cannot translate. Use 2D or higher (this is the realistic + case anyway). + +2. **Disable LayerNorm in NRE MLP classifiers.** `jaxonnxruntime` does not + implement the `LayerNormalization` op. When using `classifier_nn(model="mlp", ...)`, + pass `norm_layer=nn.Identity` to skip it: + + ```python + from torch import nn + from sbi.neural_nets import classifier_nn + + classifier_builder = classifier_nn( + model="mlp", + embedding_net_x=my_embedding, + norm_layer=nn.Identity, # <-- required for ONNX export + ) + ``` + +3. **Enable JAX x64 before importing JAX in the consuming process.** ONNX + graphs from `torch.onnx.export` carry int64 shape/index tensors. With JAX's + default 32-bit mode, those get silently truncated to int32, producing + ~0.5-unit drift in log-prob outputs. Set: + + ```python + import jax + jax.config.update("jax_enable_x64", True) + # ...subsequent imports of jaxonnxruntime, hssm, etc. + ``` + + HSSM's `onnx2jax` consumer sets the related `jaxort_only_allow_initializers_as_static_args = False` + flag automatically, but the x64 setting is process-wide and must be opted + into by the caller. + +## Numerical guarantees + +The C2–C5 regression tests assert: + +- Forward pass: torch reference, `onnxruntime`, and `jaxonnxruntime` all agree + to `atol=1e-5` on fixed inputs. +- Gradients: `jax.grad` of the translated graph agrees with `torch.autograd.grad` + on the original estimator to `atol=1e-4`. + +If you run into precision issues smaller than these thresholds, please open +an issue with a minimal repro. + +## Float precision + +ONNX exports default to float32. PyMC defaults to float64. When sampling, either: + +- Cast at the JAX boundary, or +- Set `pytensor.config.floatX = "float32"` for the whole model. + +HSSM handles this consistently in its `approx_differentiable` path; if you're +hand-rolling a model with `pm.CustomDist` you'll need to do this yourself. + +## Related API + +- [`lanfactory.onnx.transform_sbi_to_onnx`](api/onnx.md) — the exporter. +- [`lanfactory.onnx.transform_to_onnx`](api/onnx.md) — the LAN-MLP exporter. + Same family, different network source. diff --git a/mkdocs.yml b/mkdocs.yml index 3e77ce1..8e7d01d 100755 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -14,6 +14,7 @@ nav: - Guides: - MLflow Integration: using_mlflow.md - HuggingFace Hub: using_huggingface.md + - Exporting sbi Models: exporting_sbi_models.md - API: - lanfactory: api/lanfactory.md - config: api/config.md From 4990e85d744d3d3a63a1d42ee3de989e03335ea2 Mon Sep 17 00:00:00 2001 From: Alexander Date: Thu, 14 May 2026 18:12:43 -0400 Subject: [PATCH 7/8] test(sbi-onnx): end-to-end HSSM integration test (C7b) Adds tests/test_sbi_hssm_integration.py exercising the full keystone pipeline: 1. Train tiny sbi NLE_A on synthetic DDM data (ssm-simulators). 2. Export via lanfactory.onnx.transform_sbi_to_onnx. 3. Build HSSM model with model="ddm", loglik_kind="approx_differentiable", loglik=. 4. Short MCMC (500 draws + 500 tune, 2 chains) and verify posterior mean recovery within +/- 2 sigma and r_hat < 1.05. Two test functions: - test_hssm_model_builds_from_sbi_onnx: verifies the exported ONNX loads cleanly into hssm.HSSM (no sampling). - test_hssm_mcmc_recovers_ddm_parameters: full MCMC + recovery assertion. Skip guard via pytest.importorskip("hssm") so the test no-ops when HSSM is not in the env. Currently the test is a no-op in LANfactory's local uv venv because LANfactory's flax>=0.10.6 pin pulls a JAX version incompatible with HSSM's numpyro 0.21.0 pin. The test is intended to run only in a coordinated cross-repo CI environment that resolves both packages together. Plan tracks this as future ecosystem cleanup. The C7a HSSM patch (commit d1d7ffe on HSSM sbi-integration branch) makes jax_enable_x64 self-managed inside HSSM's onnx2jax, so this test does not need to set it explicitly. Marked @pytest.mark.flaky(reruns=2, reruns_delay=5) on both test functions to match HSSM's existing ONNX-test convention. Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/test_sbi_hssm_integration.py | 154 +++++++++++++++++++++++++++++ 1 file changed, 154 insertions(+) create mode 100644 tests/test_sbi_hssm_integration.py diff --git a/tests/test_sbi_hssm_integration.py b/tests/test_sbi_hssm_integration.py new file mode 100644 index 0000000..bd5b0d1 --- /dev/null +++ b/tests/test_sbi_hssm_integration.py @@ -0,0 +1,154 @@ +"""C7b verification: end-to-end pipeline from sbi training to HSSM MCMC. + +This test exercises the full keystone integration: + 1. Train a tiny sbi NLE_A on synthetic DDM data (via ssm-simulators). + 2. Export via lanfactory.onnx.transform_sbi_to_onnx. + 3. Build an HSSM model with model="ddm", loglik_kind="approx_differentiable", + loglik=, backend="jax". + 4. Run a short MCMC and verify posterior means recover the ground truth + (within ±2σ) and r_hat < 1.01. + +Skip guard: this test runs only when HSSM is importable in the test +environment. LANfactory's regular CI does not currently install HSSM, so the +test is a no-op locally — it is intended to run in a coordinated cross-repo +CI matrix where both packages are available with compatible JAX pins. See +plans/sbi-onnx-integration.md C7b for the environment-resolution note. + +The C7a HSSM patch (commit d1d7ffe on HSSM sbi-integration branch) makes the +`jax_enable_x64` flag self-managed inside HSSM's onnx2jax — this test does +not need to set it explicitly. +""" + +from pathlib import Path + +import pytest + +# Skip cleanly when HSSM is not in the environment. +hssm = pytest.importorskip("hssm") + +import numpy as np # noqa: E402 +import pandas as pd # noqa: E402 +import torch # noqa: E402 +from sbi.inference import NLE_A # noqa: E402 +from sbi.utils import BoxUniform # noqa: E402 +from ssms.basic_simulators.simulator import simulator # noqa: E402 + +from lanfactory.onnx import transform_sbi_to_onnx # noqa: E402 + +# DDM parameter order matches sbi simulator inputs and HSSM defaults. +_DDM_PARAM_NAMES = ["v", "a", "z", "t"] +_DDM_PARAM_LOW = np.array([-2.0, 0.6, 0.3, 0.1]) +_DDM_PARAM_HIGH = np.array([2.0, 1.8, 0.7, 0.5]) +_TRUE_THETA = np.array([0.5, 1.2, 0.5, 0.25]) +_N_OBS = 300 +_N_TRAIN = 5000 + + +def _simulate_ddm(theta: torch.Tensor) -> torch.Tensor: + """Simulate (rt, choice) per row of theta. Returns x of shape (batch, 2).""" + theta_np = theta.detach().numpy().astype(np.float32) + rts = np.empty(theta_np.shape[0], dtype=np.float32) + choices = np.empty(theta_np.shape[0], dtype=np.float32) + for i, th in enumerate(theta_np): + out = simulator(theta=th[None, :], model="ddm", n_samples=1) + rts[i] = out["rts"].squeeze() + choices[i] = out["choices"].squeeze() + return torch.from_numpy(np.stack([rts, choices], axis=-1)) + + +def _build_observed_dataframe(rng: np.random.Generator) -> pd.DataFrame: + """Generate N_OBS trials at the true theta as an HSSM-shaped DataFrame.""" + out = simulator(theta=_TRUE_THETA[None, :], model="ddm", n_samples=_N_OBS) + rts = out["rts"].squeeze().astype(np.float32) + choices = out["choices"].squeeze().astype(np.float32) + return pd.DataFrame({"rt": rts, "response": choices}) + + +@pytest.fixture(scope="module") +def trained_nle_for_ddm(tmp_path_factory) -> Path: + """Train tiny NLE_A on DDM and return the exported .onnx path.""" + torch.manual_seed(0) + prior = BoxUniform( + low=torch.from_numpy(_DDM_PARAM_LOW.astype(np.float32)), + high=torch.from_numpy(_DDM_PARAM_HIGH.astype(np.float32)), + ) + inference = NLE_A(prior=prior, density_estimator="maf") + + theta = prior.sample((_N_TRAIN,)) + x = _simulate_ddm(theta) + estimator = inference.append_simulations(theta, x).train( + training_batch_size=200, + max_num_epochs=30, + ) + estimator.eval() + + onnx_path = tmp_path_factory.mktemp("c7b") / "ddm_nle.onnx" + transform_sbi_to_onnx( + estimator, + str(onnx_path), + mode="nle", + example_theta_dim=len(_DDM_PARAM_NAMES), + example_x_dim=2, + ) + return onnx_path + + +@pytest.mark.flaky(reruns=2, reruns_delay=5) +def test_hssm_model_builds_from_sbi_onnx(trained_nle_for_ddm: Path) -> None: + """The exported ONNX should load cleanly into an HSSM model.""" + rng = np.random.default_rng(0) + obs_data = _build_observed_dataframe(rng) + + model = hssm.HSSM( + data=obs_data, + model="ddm", + loglik_kind="approx_differentiable", + loglik=str(trained_nle_for_ddm), + p_outlier=0, + ) + assert model is not None + + +@pytest.mark.flaky(reruns=2, reruns_delay=5) +def test_hssm_mcmc_recovers_ddm_parameters(trained_nle_for_ddm: Path) -> None: + """Short MCMC should recover the true DDM params within ±2σ.""" + rng = np.random.default_rng(0) + obs_data = _build_observed_dataframe(rng) + + model = hssm.HSSM( + data=obs_data, + model="ddm", + loglik_kind="approx_differentiable", + loglik=str(trained_nle_for_ddm), + p_outlier=0, + ) + + idata = model.sample( + draws=500, + tune=500, + chains=2, + cores=1, + progressbar=False, + target_accept=0.9, + ) + + summary = hssm.utils.summary(idata) if hasattr(hssm.utils, "summary") else None + # Fall back to arviz if the convenience method is not exposed. + if summary is None: + import arviz as az + summary = az.summary(idata, var_names=_DDM_PARAM_NAMES) + + posterior_means = summary.loc[_DDM_PARAM_NAMES, "mean"].to_numpy() + posterior_sds = summary.loc[_DDM_PARAM_NAMES, "sd"].to_numpy() + r_hats = summary.loc[_DDM_PARAM_NAMES, "r_hat"].to_numpy() + + # Convergence + assert (r_hats < 1.05).all(), f"r_hat above 1.05 for some params: {r_hats}" + + # Recovery within ±2σ + deviations = np.abs(posterior_means - _TRUE_THETA) / posterior_sds + assert (deviations < 2.0).all(), ( + f"Posterior means more than 2σ from truth: " + f"true={_TRUE_THETA}, mean={posterior_means}, sd={posterior_sds}, " + f"deviations={deviations}" + ) From 222adf5929149acf87d20d800a939da4d32cd2ed Mon Sep 17 00:00:00 2001 From: Alexander Date: Thu, 14 May 2026 19:53:21 -0400 Subject: [PATCH 8/8] fix(onnx): export sbi graphs with rank-1 input for HSSM vmap (C9) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Surfaced by running the C8 notebook in HSSM: pymc.sample raised "IndexError: list assignment index out of range" inside jaxonnxruntime/onnx_ops/slice.py:113 (sub_indx[axis] = slices[i]). Root cause: HSSM's make_jax_logp_funcs_from_onnx vmaps the per-trial loglik over a 1D concatenated input vector (param_vector + data) — see HSSM repos/HSSM/src/hssm/distribution_utils/onnx.py around line 115: input_vector = jnp.concatenate((param_vector, data)) return jax_func(input_vector) But the C3/C4 exporter was tracing with a 2D dummy of shape (1, theta_dim + x_dim), which made torch.onnx.export emit Slice ops with axes=[1]. Under HSSM's vmap the per-trial input is rank-1, so axes=[1] is out of bounds for the inner Slice handler. LAN exports don't trip on this because the LAN graph is pure MatMul/Add/activation — broadcast-rank-agnostic. Ours has explicit Slice ops from `combined[..., :theta_dim]` and `combined[theta_dim:]`. Fix: - Trace the wrapper with a rank-1 dummy (shape (theta_dim+x_dim,)) so Slice ops emit axes=[0], which survives HSSM's vmap. - Inside _NLELogProbWrapper.forward and _NRELogRatioWrapper.forward, take a 1D combined input, split on axis 0, then .unsqueeze(0) the two halves to satisfy sbi's batched log_prob / classifier APIs. Reshape the (1, 1) output back to () so HSSM's downstream .squeeze() sees a clean scalar. - Updated module docstring to document the rank-1 contract and why. Tests: - test_sbi_nle_export.py, test_sbi_nre_export.py, test_sbi_embeddings.py: pass rank-1 inputs through onnxruntime and jaxonnxruntime; rank-1 theta_np_1d / x_np_1d for the gradient tests. - All 13 sbi tests still green at the same atol thresholds (1e-5 forward, 1e-4 gradients). User impact: anyone who already exported a .onnx with the old C3/C4 code needs to re-export with this commit. The exported .onnx is the durable artifact — no API change in the call site. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lanfactory/onnx/sbi.py | 35 +++++++++++++++++++++++------------ tests/test_sbi_embeddings.py | 3 ++- tests/test_sbi_nle_export.py | 14 ++++++++------ tests/test_sbi_nre_export.py | 14 ++++++++------ 4 files changed, 41 insertions(+), 25 deletions(-) diff --git a/src/lanfactory/onnx/sbi.py b/src/lanfactory/onnx/sbi.py index dee41ce..63965e6 100644 --- a/src/lanfactory/onnx/sbi.py +++ b/src/lanfactory/onnx/sbi.py @@ -9,11 +9,14 @@ exporter): "train a network and emit an ONNX HSSM can read" stays a single conceptual home in LANfactory regardless of which library trained the network. -The exported graph follows the LAN convention: a single concatenated input -of shape ``(1, theta_dim + x_dim)``. Inside the graph the input is split into -``theta`` and ``x`` and routed through the trained estimator's ``log_prob`` -(NLE mode) or classifier logit (NRE mode, lands in C4). HSSM vmaps this graph -over trials. +The exported graph follows the LAN-and-HSSM convention: a single concatenated +input of **rank 1, shape ``(theta_dim + x_dim,)``**. Inside the graph the +input is split into ``theta`` and ``x``, upranked to ``(1, …)`` to satisfy +sbi's batched ``log_prob`` API, and routed through the trained estimator. +HSSM vmaps this graph over trials, so the per-call input rank from HSSM is 1 +— matching the export. Tracing with a 2D ``(1, D)`` dummy would emit ``Slice`` +ops with ``axes=[1]`` that fail under HSSM's vmap (``IndexError: list +assignment index out of range`` inside ``jaxonnxruntime`` Slice handler). """ from __future__ import annotations @@ -124,7 +127,9 @@ def transform_sbi_to_onnx( wrapper.eval() combined_input_dim = example_theta_dim + example_x_dim - dummy_input = torch.randn(1, combined_input_dim, requires_grad=True) + # Trace with a rank-1 dummy so the resulting Slice ops use axes=[0], + # which survives HSSM's per-trial vmap (where the input arrives as 1D). + dummy_input = torch.randn(combined_input_dim, requires_grad=True) torch.onnx.export( wrapper, dummy_input, @@ -153,9 +158,13 @@ def __init__( self.x_dim = x_dim def forward(self, combined: torch.Tensor) -> torch.Tensor: - theta = combined[..., : self.theta_dim] - x = combined[..., self.theta_dim :] - return self.estimator.log_prob(x, condition=theta) + # combined: 1D, shape (theta_dim + x_dim,) — matches HSSM's per-trial + # vmap input. Split on axis 0 (rank-friendly), then unsqueeze for + # sbi's batched log_prob contract, and reshape the (1, 1) output back + # to a scalar so HSSM's downstream .squeeze() leaves it as (). + theta = combined[: self.theta_dim].unsqueeze(0) + x = combined[self.theta_dim :].unsqueeze(0) + return self.estimator.log_prob(x, condition=theta).reshape(()) class _NRELogRatioWrapper(nn.Module): @@ -178,6 +187,8 @@ def __init__( self.x_dim = x_dim def forward(self, combined: torch.Tensor) -> torch.Tensor: - theta = combined[..., : self.theta_dim] - x = combined[..., self.theta_dim :] - return self.estimator(theta, x) + # combined: 1D, shape (theta_dim + x_dim,) — see _NLELogProbWrapper for + # the rationale around rank-1 tracing and vmap compatibility. + theta = combined[: self.theta_dim].unsqueeze(0) + x = combined[self.theta_dim :].unsqueeze(0) + return self.estimator(theta, x).reshape(()) diff --git a/tests/test_sbi_embeddings.py b/tests/test_sbi_embeddings.py index abf57f3..03d5b4b 100644 --- a/tests/test_sbi_embeddings.py +++ b/tests/test_sbi_embeddings.py @@ -60,7 +60,8 @@ def _three_way_agreement( """Shared assertion: torch / onnxruntime / jaxonnxruntime all agree.""" theta_t = torch.tensor([[0.3, -0.4]], dtype=torch.float32) x_t = torch.randn(1, _X_DIM, dtype=torch.float32) - combined = torch.cat([theta_t, x_t], dim=-1).numpy() + # Exported graph is rank-1; pass a 1D concatenated vector. + combined = torch.cat([theta_t, x_t], dim=-1).squeeze(0).numpy() with torch.no_grad(): y_torch = trained_classifier(theta_t, x_t).detach().numpy().flatten() diff --git a/tests/test_sbi_nle_export.py b/tests/test_sbi_nle_export.py index 9dca529..6562009 100644 --- a/tests/test_sbi_nle_export.py +++ b/tests/test_sbi_nle_export.py @@ -93,7 +93,8 @@ def test_nle_export_three_way_numerical_agreement( theta_t = torch.tensor([[0.5, -0.2]], dtype=torch.float32) x_t = torch.tensor([[0.7, 0.3]], dtype=torch.float32) - combined = torch.cat([theta_t, x_t], dim=-1).numpy() + # Exported graph is rank-1; pass a 1D concatenated vector. + combined = torch.cat([theta_t, x_t], dim=-1).squeeze(0).numpy() with torch.no_grad(): y_torch = trained_nle.log_prob(x_t, condition=theta_t).detach().numpy() @@ -145,16 +146,17 @@ def test_nle_export_gradient_agreement( (grad_torch,) = torch.autograd.grad(logp.sum(), theta_t) grad_torch_np = grad_torch.detach().numpy().flatten() - theta_np = theta_t.detach().numpy() - x_np = x_t.numpy() - combined_init = np.concatenate([theta_np, x_np], axis=-1).astype(np.float32) + # Rank-1 vectors throughout — matches the new exporter contract. + theta_np_1d = theta_t.detach().numpy().squeeze(0) + x_np_1d = x_t.numpy().squeeze(0) + combined_init = np.concatenate([theta_np_1d, x_np_1d], axis=-1).astype(np.float32) run_func, input_name = _load_jax_runner(onnx_path, combined_init) def jax_logp_of_theta(theta_arr: jnp.ndarray) -> jnp.ndarray: - combined = jnp.concatenate([theta_arr, jnp.asarray(x_np)], axis=-1) + combined = jnp.concatenate([theta_arr, jnp.asarray(x_np_1d)], axis=-1) return run_func({input_name: combined})[0].sum() - grad_jax = jax.grad(jax_logp_of_theta)(jnp.asarray(theta_np)) + grad_jax = jax.grad(jax_logp_of_theta)(jnp.asarray(theta_np_1d)) grad_jax_np = np.asarray(grad_jax).flatten() atol = 1e-4 diff --git a/tests/test_sbi_nre_export.py b/tests/test_sbi_nre_export.py index 71abc00..a6db27e 100644 --- a/tests/test_sbi_nre_export.py +++ b/tests/test_sbi_nre_export.py @@ -81,7 +81,8 @@ def test_nre_export_three_way_numerical_agreement( theta_t = torch.tensor([[0.5, -0.2]], dtype=torch.float32) x_t = torch.tensor([[0.7, 0.3]], dtype=torch.float32) - combined = torch.cat([theta_t, x_t], dim=-1).numpy() + # Exported graph is rank-1; pass a 1D concatenated vector. + combined = torch.cat([theta_t, x_t], dim=-1).squeeze(0).numpy() with torch.no_grad(): y_torch = trained_nre(theta_t, x_t).detach().numpy() @@ -133,16 +134,17 @@ def test_nre_export_gradient_agreement( (grad_torch,) = torch.autograd.grad(logr.sum(), theta_t) grad_torch_np = grad_torch.detach().numpy().flatten() - theta_np = theta_t.detach().numpy() - x_np = x_t.numpy() - combined_init = np.concatenate([theta_np, x_np], axis=-1).astype(np.float32) + # Rank-1 vectors throughout — matches the new exporter contract. + theta_np_1d = theta_t.detach().numpy().squeeze(0) + x_np_1d = x_t.numpy().squeeze(0) + combined_init = np.concatenate([theta_np_1d, x_np_1d], axis=-1).astype(np.float32) run_func, input_name = _load_jax_runner(onnx_path, combined_init) def jax_logr_of_theta(theta_arr: jnp.ndarray) -> jnp.ndarray: - combined = jnp.concatenate([theta_arr, jnp.asarray(x_np)], axis=-1) + combined = jnp.concatenate([theta_arr, jnp.asarray(x_np_1d)], axis=-1) return run_func({input_name: combined})[0].sum() - grad_jax = jax.grad(jax_logr_of_theta)(jnp.asarray(theta_np)) + grad_jax = jax.grad(jax_logr_of_theta)(jnp.asarray(theta_np_1d)) grad_jax_np = np.asarray(grad_jax).flatten() atol = 1e-4