v3: migrate main to immutable Tree, unified sweeps, and BFFG refit#73
Merged
gefanyang merged 34 commits intoJun 2, 2026
Merged
Conversation
Delete the legacy core (`hyperiax/{tree,models,execution,mcmc,plotting}`),
all examples / benchmarks / docs notebooks / existing tests, and `setup.py`.
Stop tracking `_version.py` and `.DS_Store`.
Repackage for Python 3.11+ on uv:
- `pyproject.toml` slimmed to `jax / jaxlib / numpy`; ete3 moved under
`[io]`; trimesh under `[prebuilt-shape]`; pytest into the PEP 735
`dev` dependency group.
- `.python-version` pins 3.11.
- `uv.lock` committed for reproducible installs.
- `makefile` switched to `uv run pytest` / `uv sync --all-extras --group dev`.
`hyperiax/__init__.py` is gutted to a docstring placeholder until Stage 1
introduces `hyperiax.core`. README updated with uv-based dev setup
instructions and a banner pointing at the v3 plan.
Verified: `uv sync --group dev` succeeds and `uv run pytest` collects 0
tests cleanly.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Build the L1 core of the v3 refactor. All three types are frozen dataclasses; Topology and Tree are JAX-pytree-registered so they flow through jit/scan without any static-arg hack. - `core/errors.py` — HyperiaxError, SchemaMismatch, MissingField, StructureMismatch. - `core/schema.py` — FieldSpec + Schema. Schema is an ordered tuple sorted by name for deterministic hash; this is what stabilises the pytree structure across topologically-identical trees. - `core/topology.py` — Topology with BFS-ordered parents and precomputed derived arrays (level_starts, masks, child_counts, equal-degree gather_child_idx, segment pbuckets/pbuckets_ref). All derived fields are np.ndarray (not jnp) so the object is Python-hashable and lives entirely in pytree aux_data. Hash is cached from parents.tobytes(). - `core/tree.py` — Tree composes Topology + Schema + dict[str, Array]. `empty / set / set_at / update / drop` all return a new Tree; the data dict is never mutated in place. Tree is intentionally not hashable, so users cannot accidentally pass it as `static_argnames`. Top-level `hyperiax` package re-exports the public surface. Tests (T-1/T-2/T-3 + edge cases): 36 passing. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Bring the first sweep online. The minimal demo from the plan now runs: build a symmetric binary tree, seed its leaves, run an up-sweep that averages children, read the root — and wrap the whole thing in `@jax.jit` without leaking tracers. - `core/builders.py` — `symmetric_topology(height, degree)` for d≥1 and any h≥0; `from_parents` re-export for explicit constructions. - `core/views.py` — `Node` / `Children` transient views. Attribute access (`node.value`, `children.value`) is the primary surface; item access works too. Missing fields produce a hyperiax-flavoured AttributeError pointing at `reads=`/`reads_children=` so the user can fix the declaration rather than chase a JAX KeyError. - `core/sweep.py` — `SweepFn` (frozen, hashable, value-equal on (direction, fn identity, reads/writes tuples)) and the `@up(reads=..., reads_children=..., writes=...)` decorator. `writes` is required; cross-direction violations (`reads_parent` on @up, `reads_children` on @down) are rejected at construction. - `core/dispatch.py` — `up_dispatch` validates the schema eagerly (so the traceback lands at the call site) and bails on unequal-degree (deferred to Stage 4). `_up_dispatch_jit` is `static_argnums=(0,)` on the SweepFn, walks parent levels deepest→root, gathers via `topo.gather_child_idx`, and vmaps the user function over the scope axis. The vmap is the key decision: from inside the user fn, `node.value` is `(*trailing,)` and `children.value` is `(k, *trailing)`, so `children.value.mean(0)` is naturally a reduction over the children axis. This is the same surface that Stage 4's `ChildrenAxis` proxy will present for unequal-degree trees. Tests now cover T-4 (per-level numerics on a 15-node binary tree, leaf mean recovery at the root, node-vs-children data interplay, params threading) and T-5 (the same `SweepFn` on a structurally-identical fresh tree hits the JIT cache — verified by counting `nonlocal` traces, not by timing). 52 tests total, all green. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Walks root → leaves, one level at a time, gathering each node's parent data via topo.parents and vmapping the user fn per-node. Unlike up, this works on any topology — every non-root node has exactly one parent, so no equal_degree fast path is needed. - `core/views.py` — `Parent` subclasses Node with empty __slots__, just a different __repr__ for clarity in error messages and isinstance checks. - `core/sweep.py` — `@down(reads=, reads_parent=, writes=)`. Signature symmetric with @up; same SweepFn primitive, different direction tag. `SweepFn.__call__` now dispatches to `down_dispatch` for `direction='down'`. - `core/dispatch.py` — `down_dispatch` skips the equal_degree check (down works on any topology) and `_down_dispatch_jit` mirrors the up structure: iterate level 1..depth, slice the level contiguously, gather parents via topo.parents[ls:le], vmap the user fn over (node, parent, params). Tests (T-6 + supporting): 12 new in tests/test_sweep_down.py covering numerical propagation, multi-dim values, params threading, unequal-degree topology (this is what makes down universally applicable), root-untouched semantics, outer-jit composition, jit cache hits on identical topology, the canonical up-then-down round trip (leaf mean broadcast to all nodes), and the validation paths. 64 tests total, all green. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Bring ragged trees online without changing the user-facing surface. The same `@hx.up` function that reduces `children.value.sum(0)` over a symmetric binary tree now runs unchanged on a tree where one parent has 3 children and another has 2. - `core/views.py` — `ChildrenAxis` is a slot-only proxy backed by a flat `(M_total, *trailing)` JAX array plus segment ids. It exposes `.sum/.prod/.max/.min/.mean(axis=0)`, each forwarding to the matching `jax.ops.segment_*` with `num_segments` held as a static Python int (R-2 in the plan: traced num_segments would yield poly-shape errors). Non-zero axes and any attempt to coerce the proxy to a dense array via `__array__` raise a hyperiax-flavoured error that tells the user to reduce first. - `core/dispatch.py` — `up_dispatch` now branches on `tree.topology.equal_degree`. The new `_up_dispatch_unequal` walks parent levels deepest→root using `pbuckets`/`pbuckets_ref`, gathers parent rows by id, wraps each children field in a `ChildrenAxis`, and scatters writes into the unique parent ids. Mixed-depth leaves are handled naturally: a level-L node that is itself a leaf never appears in any `pbuckets_ref[L+1]`, so the dispatcher leaves its data alone. Down dispatch was already topology-agnostic and needs no changes; a small `_check_writes` helper unifies the writes-key validation across all three paths. Tests (T-7 + supporting): 15 new in tests/test_sweep_up_unequal.py. Covers sum/prod/max/min/mean on the canonical ragged tree, node+children arithmetic, multi-dim trailing shapes, mixed-depth leaves (the genuinely ragged case where root has degree 2 and one of its children has degree 3), proxy guards (non-zero axis, `jnp.asarray` coercion), outer-jit composition, JIT cache hits, and an equal/unequal API-parity check asserting the same `@up` function runs on both topology shapes. 79 tests total, all green. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The core of why the v3 rewrite exists. The legacy `OrderedExecutor` did
`tree.data = {...}` in place, which leaked tracers between calls of an
outer `@jax.jit` train_step (this is what made the user abandon
hyperiax for Flax+optax pipelines). With Tree now an immutable JAX
pytree and every dispatcher path pure, these scenarios just work.
tests/test_regression.py covers:
T-10 (outer-jit):
- 20 consecutive `@jax.jit`'d step calls — no leaked tracer
- multi-sweep pipeline (up → down → up) inside one jit
- bit-for-bit parity between eager and jit'd outputs
T-11 (lax.scan):
- user-fn trace count under a 100-iter scan equals that of a single
call (body compiled once, not per iteration)
- structural check: the jaxpr of `scan(body, length=N)(tree)` has
exactly one top-level `scan` primitive
- numerical equivalence between an eager Python loop and `lax.scan`
- scan composes with the unequal-degree (segment_sum) path too
jax.grad:
- `jax.grad` through an up-sweep matches the hand-computed derivative
- same for down-sweep
- `jax.value_and_grad` works on the ChildrenAxis (segment) path
Tree.update + jit:
- adding a new field with `.update()` then running a sweep on the
extended schema works eagerly and under jit
Flax-style train_step:
- the exact pattern that broke the legacy code — params + grads +
optimizer-style update wrapping a sweep — runs cleanly through 10
jit'd iterations and produces sensible gradient descent on a toy loss.
92 tests total, all green.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
L2 module wrapping ete3 for Newick read/write. `hyperiax.io` is the
first non-core module, opt-in via `uv sync --extra io`.
- `hyperiax/io/__init__.py` — exposes `newick`; no eager ete3 import.
- `hyperiax/io/newick.py`:
* `read(source, *, schema=None, newick_format=1) -> Tree` — accepts a
Newick literal string or a path; returns a Tree whose schema always
includes `edge_length` (shape (), float32), plus any user-requested
extras. Node names live on `Topology.names`, not on tree.data.
* `write(tree, *, newick_format=1) -> str` — round-trip writer.
Patches in the root name manually because ete3 deliberately drops
it from every Newick output format (an ete3 quirk).
* `_ete_to_bfs_arrays` — BFS-traverses an ete3 tree to produce a
hyperiax-style `(parents, names, edge_lengths)` triple. BFS order
guarantees `parents[i] < i`, which is what Topology.from_parents
requires.
* `_require_ete3` — gates the import behind a helpful error message
pointing at the `[io]` extra.
ete3 is imported lazily inside `read`/`write`, so `import hyperiax.io`
works in a core-only install (verified by checking `sys.modules` after
the import).
Tests (T-12 + supporting): 12 in tests/test_io_newick.py with
`pytest.importorskip("ete3")` so the suite skips gracefully without
the extra. Covers bit-for-bit round-trips (simple unnamed and complex
named cases), edge-length and name extraction, BFS-ordering invariant,
extra schema fields, file-path reading, write-without-edge_length guard,
and a smoke test that a Newick-parsed tree feeds straight into an
@hx.up sweep.
104 tests total, all green.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
L2 recipe: the edge-length-weighted phylogenetic mean estimator, ported
from the legacy `PhyloMeanModel` (which subclassed `UpReducer` with
`{up, transform}` and a manual `reductions={...}` dict) into a single
`@hx.up` SweepFn.
- `hyperiax/prebuilt/__init__.py` — exports `phylo_mean`.
- `hyperiax/prebuilt/phylo_mean.py` — single sweep that reads
`estimated_value` and `edge_length` from children and writes back the
weighted mean on each parent:
hat_x_p = sum_c (v_c / l_c) / sum_c (1 / l_c)
The trailing-shape broadcast (scalar edges vs vector values) is
handled by reshaping `edges` to (k, 1, ...) before division.
Currently equal-degree only — `children.estimated_value /
children.edge_length` relies on JAX broadcasting that the unequal
ChildrenAxis proxy doesn't yet expose. The ragged path will be
enabled once the proxy gains elementwise arithmetic or a `.gather()`
fallback.
Tests (T-13 + supporting): 8 in tests/test_prebuilt_phylo_mean.py.
- Hand-computed expected values on a 7-node tree with non-uniform edges
(root=25 / inner=15,35 from leaves [10,20,30,40] under
edge_lengths=[0,0.5,0.5,1,1,2,2]).
- Leaves untouched after the sweep.
- Uniform edges → root = simple leaf mean.
- Vector-valued `estimated_value` (broadcasting works).
- Random edges + values on a depth-3 binary tree match a pure-numpy
reference walk to 1e-5 tolerance.
- Composition under `@jax.jit`.
- End-to-end: Newick → tree → phylo_mean.
- Unequal-degree currently raises (documented limitation).
112 tests total, all green.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Two prebuilt modules covering most of T-14. Deliberately MVP — the
full SDE-BFFG variant with ODE-integrated backward filter + Brownian
bridge guiding is deferred to a follow-up because it more than doubles
the code surface and isn't needed for the T-14 numerical baseline.
What landed:
- `hyperiax/prebuilt/sde.py` — full port of the legacy `examples/SDE.py`:
factored `dot` / `solve` for (n·d,) flattened state on tensor-product
diffusivity, uniform-step `dts`, and `forward` (Euler-Maruyama via
`lax.scan`).
- `hyperiax/prebuilt/_gaussian_density.py` — Gaussian density helpers
(standard, precision, canonical, unnormalized) shared between the
current Gaussian BFFG and the future SDE-BFFG.
- `hyperiax/prebuilt/bffg_gaussian.py`:
* `gaussian_up(n, a, d=1)` — Gaussian BFFG backward filter as a single
@hx.up sweep. Per-edge math is the closed-form precision-update
`H_0 = (I + H_T·var·a)⁻¹·H_T`, `F_0 = (I + H_T·var·a)⁻¹·F_T`,
lifted verbatim from the legacy `Gaussian_up`. The summed children
(F_0, H_0) become the parent's (F_T, H_T); c_T is recomputed via
logphi_can on the fused canonical params.
* `init_gaussian_leaves(tree, leaf_values, obs_var, *, n, d=1)` —
seed the leaves with iid-Gaussian observation precision I_n / τ².
* `gaussian_down_unconditional(sigma)` — simplest forward sampling:
`child = parent + sqrt(edge_length)·σ(parent)·noise`.
Deferred (will arrive in a follow-up commit):
- `bffg_sde.py` — needs `backward_filter` (ODE on (H, F, c)) +
`forward_guided` (guided SDE bridge) ports; ~150 lines plus heavy
numerical testing.
- `gaussian_down_conditional` — needs the up-pass results
(c_0, F_0, H_0) carried as node fields, which means extending the
up sweep's writes set.
Tests (14 new, 126 total):
- SDE: `dot`/`solve` round-trip; `dts` sums to T; `forward` with zero
noise matches the exact explicit-Euler reference on OU; zero-drift
Brownian variance ~ T (Monte Carlo); the σ-branch and the a-branch
agree when a = σσᵀ; jit composes.
- Gaussian BFFG up: 3-node root posterior matches the closed-form
`(H, F) = sum 1/(τ²+l_iσ²), sum y_i/(τ²+l_iσ²)`; posterior mean
matches F/H; under outer @jax.jit and `lax.scan`.
- Cross-check: on a star tree with τ²→0 and σ²=1, BFFG root posterior
mean equals phylo_mean (single-level equivalence; multi-level diverges
because BFFG correctly propagates posterior precision and phylo_mean
doesn't).
- Unconditional down: zero noise keeps parent value; depth-3 Brownian
diffusion produces leaf variance ~ depth.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Two L2 modules covering the kernel + LDDMM-dynamics half of the legacy shape pipeline. Pure-math; no tree machinery touched. - `prebuilt/shape_kernels.py` — verbatim port of `examples/shape.py`'s Laplace family (K_0..K_4) + Gaussian RBF, plus the matching `g_K0` / `g_K1` correlators. Each takes pairwise-difference tensors `(n, n, d)` and returns the `(n, n)` Gram matrix. `fibonacci_sphere` for ~uniform sphere sampling (pure numpy); `mesh_sphere` requires `trimesh` via the `[prebuilt-shape]` extra (lazy-imported with a clean error message pointing at the extra). - `prebuilt/lddmm.py` — `lddmm_drift()` (zero drift, free landmark Brownian motion) and `lddmm_covariance(kernel, *, n, d)` (kernel- Gram tensor-product covariance). These weren't in the legacy `examples/shape.py` — the notebooks defined them ad-hoc — collecting them here as the obvious LDDMM building blocks. They plug straight into `sde.forward` via the `a=...` parameter (no separate σ constructor needed; the forward path falls back to a cholesky of `a`). Tests (18 new, 144 total): - Shape kernels: r-floor at coincident landmarks; all Laplace kernels → α at r=0, Gaussian → α/2; Gram symmetry on real landmarks; monotone decay in distance; `g_K0` matches the closed form at r=1; fibonacci sphere returns unit vectors with near-zero mean position; `mesh_sphere` hits exact icosphere subdivision counts (target=42 → 42 verts); missing-trimesh path raises with a message pointing at `prebuilt-shape`. - LDDMM: drift is zero; covariance shape (n, n), symmetric, PD on well-separated landmarks; diagonal equals α; zero-noise integration is constant; nearly-coincident landmarks trace highly correlated trajectories (corr > 0.99); jit composes. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Final stage: enforce the L1/L2 dependency boundary in tests and bring
the CI workflows in line with the v3 setup (uv-managed, Python 3.11+,
no docs).
- `tests/test_core_dependencies.py` (T-16):
* Parses every `core/*.py` via `ast` (regex matched docstring text).
* Asserts no file imports any of matplotlib / scipy / tqdm / ete3 /
ete4 / flax / trimesh / plotly.
* Whitelist sanity-check: every absolute import in `core/` must
resolve to jax, numpy, hyperiax, or a stdlib module (detected by
sys.base_prefix membership of the spec origin).
* Manually verified to fire on a planted `import trimesh` in
`core/errors.py`.
- `.github/workflows/pytest.yaml`: rewritten to use
astral-sh/setup-uv@v4 + `uv sync --all-extras --group dev` + `uv run
pytest`. Triggers extended to cover both `main` and `v3` branches.
- `.github/workflows/pypi_push.yaml`: switched from
`python setup.py sdist` (setup.py was deleted in Stage 0) to
`uv build`, which produces both sdist and wheel; `fetch-depth: 0`
added so setuptools-scm can read the full git history for the
version tag.
- `.github/workflows/build_docs.yaml`: deleted. The docs notebook tree
was removed in Stage 0 and docs are out of scope for v3; reintroduce
later when docs come back online.
153 tests total, all green (+9 from T-16).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…l down
Completes the BFFG variant that was deferred in Stage 8. Closed-form
only (no ODE-integrated linear-drift path): handles the standard
``b = 0`` / ``β = B = 0`` case which covers Brownian motion and any
SDE whose auxiliary diffusivity is linearly interpolated between
``tildea0`` and ``tildeaT`` per edge.
- `prebuilt/bffg_sde.py`:
* `backward_filter(dts, params, c_T, v_T, F_T, H_T, tildea0, tildeaT)`
— closed-form per-edge filter. The integrated ``tildea`` over the
edge gives a tractable Φ-inverse ``I + H_T · integrated``, from
which ``H_0 = Φ_inv⁻¹ H_T`` and ``F_0 = Φ_inv⁻¹ F_T``. ``c_0``
follows the legacy correction
``c_T + ½ v_Tᵀ(H_T - H_0)v_T - ½ log|det Φ_inv(0)|``.
* `forward_guided(x0, dts, dWs, b, sigma, params, a, F_T, H_T,
tildea0, tildeaT)` — guided-bridge SDE via lax.scan. Drift gets the
bridge correction ``a·(F(t) - H(t)·x)·dt``; ``logpsi`` accumulates
the van der Meulen et al. correction terms.
* `sde_up(n_steps, a)` — up sweep wrapping `backward_filter`. Reads
``edge_length, v_0, c_T, v_T, F_T, H_T`` from each child, sums
``c_0, F_0, H_0`` over children, writes back to the parent's
``c_T, F_T, H_T`` and computes ``v_T = H_T⁻¹ F_T``.
* `sde_down_unconditional(n_steps, b, sigma, *, a=None)` — forward
Euler-Maruyama on each edge; child's ``value`` is the full
trajectory of shape (n_steps+1, n·d).
* `sde_down_conditional(n_steps, b, sigma, a)` — forward-guided
bridge using the up-sweep's ``(F_T, H_T)``; writes ``value`` and
``logpsi`` per node.
* `propagate_v_T_to_v_0()` — small down sweep that copies each
parent's posterior mean ``v_T`` to its children's linearization
point ``v_0`` (the root's ``v_0`` must be initialized separately,
typically by ``init_sde_leaves(..., root_value=...)``).
* `init_sde_leaves(...)` — analogue of `init_gaussian_leaves` that
also seeds ``v_T`` at leaves and (optionally) ``v_0, v_T``
everywhere from a root prior.
Tests (11, total 164):
- `backward_filter` Brownian (tildea=I) ≡ Gaussian closed-form
``H_0 = H_T / (1 + H_T·T)``, ``F_0 = F_T / (1 + H_T·T)``.
- `forward_guided` Brownian: ``logpsi → 0`` when ``a == tildea`` and
``b == 0``; zero-noise trajectory relaxes toward ``H⁻¹F`` (bridge
target).
- `sde_up` bit-for-bit equals `gaussian_up` on a 3-node Brownian tree;
root posterior matches the closed-form
``Σ 1/(τ²+l_iσ²)`` / ``Σ y_i/(τ²+l_iσ²)``.
- `propagate_v_T_to_v_0` copies parent ``v_T`` to children's ``v_0``.
- `sde_down_unconditional` zero noise → constant trajectory; unit-σ
Brownian leaf-terminal variance ≈ edge_length (Monte Carlo).
- End-to-end pipeline (up → propagate → conditional down) runs and
produces finite values.
- jit + `lax.scan` composition.
Deferred for a follow-up: the ODE-integrated path for time-varying
linear drift `B(t), β(t)` (uses `jax.experimental.ode.odeint`); the
matching `gaussian_down_conditional` (needs `c_0, F_0, H_0` carried as
node fields, which would extend `gaussian_up`'s writes set).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Finishes the two pieces that Stage 8 + the SDE follow-up deliberately deferred. Every BFFG variant from the legacy `examples/ABFFG.py` now has a home in the new prebuilt layer. ODE-integrated BFFG (general linear B, β): - `bffg_sde.backward_filter` and `forward_guided` now dispatch internally on `B`/`β`. The closed-form path is unchanged; the new ODE path integrates the (H, F, c) system via `jax.experimental.ode.odeint` with τ = T − t, unpacks the per-step series `F_t`, `H_t` for use by `forward_guided`, and supports d > 1 via tensor-product reshape (`F` flat (n·d,) ↔ matrix (n, d); `c` per d-column; trace term broadcast across d). Uniform `dts` assumed. - `bffg_sde.forward_guided` ODE path indexes the precomputed `F_t`, `H_t` arrays at each Euler step and adds the linear-drift term `tildeb(t, x) = β(t) + B(t)·x` to the `logpsi` correction. - `sde_up(n_steps, a, *, B=None, beta=None)` and `sde_down_conditional(n_steps, b, σ, a, *, B=None, beta=None)` now forward to the matching path. The conditional down re-runs the ODE filter inline per edge to recover `F_t`/`H_t` (avoids extending the up sweep's writes set; one extra `odeint` per edge per sample). - A small `_bridge_step_body` helper factors the shared Euler-step body between the closed-form and ODE branches of `forward_guided`. gaussian_down_conditional: - New sweep in `bffg_gaussian`. For each non-root node, re-derives the per-edge filter intermediates `(c_0, F_0, H_0)` from the node's own posterior `(c_T, F_T, H_T)` and `edge_length` using the same math as `gaussian_up` (linearization at `v_T = solve(H_T, F_T)`), then draws the conditional Gaussian sample given the parent's current value and computes the per-edge `logw` correction `logphi_H(v_T_post, x, (I + H_T·Σ_T)⁻¹·H_T) − logU(x, c_0, F_0, H_0)`. - No extension to `gaussian_up`'s writes set — the schema growth that was blocking this in the previous round is sidestepped by inline re-derivation. Costs one extra (n, n) solve per non-root node per sample. Tests (+11; 175 total, all green): ODE path (6): - `backward_filter` with `B=0, β=0` reproduces closed-form `H_0`, `F_0`, `c_0` to 1e-4. - `backward_filter` ODE per-step series ends at the boundary conditions (`H_t[-1] == H_T`). - `forward_guided` ODE with `B=0, β=0` matches closed-form trajectory and `logpsi` to 1e-3. - `sde_up` ODE matches closed-form sweep on a 3-node Brownian tree. - Full up → propagate → conditional-down pipeline with damped `B(t) = -0.5·I` runs and leaves land near observations. - ODE path composes under outer `@jax.jit`. `gaussian_down_conditional` (5): - Zero-noise + low-`obs_var` recovers leaf observations to 5e-2. - Determinism under zero noise (same input → same output). - Random noise produces finite values + finite `logw`. - Outer `@jax.jit` composition. - End-to-end up → conditional-down inside `jax.lax.scan`. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Stage A of the diffrax migration: only the ODE-integrated backward
filter is replaced — the SDE forward/forward_guided paths stay on
hand-rolled lax.scan because their pre-sampled-dW data model
(MCMC Crank-Nicolson on `tree.data['noise']`) is fundamentally at
odds with diffrax's BrownianPath abstraction.
The substitution:
jax.experimental.ode.odeint → diffrax.diffeqsolve(
diffrax.ODETerm(...),
diffrax.Tsit5(),
stepsize_controller=
diffrax.PIDController(rtol=1e-7, atol=1e-9),
...)
Same packed state layout, same τ = T − t convention, same return shape.
The vector_field is rewritten to diffrax's `(t, y, args)` signature.
Concretely:
- `pyproject.toml`: new `[prebuilt-bffg]` extra pinning `diffrax>=0.6`.
Closed-form path still has zero non-jax dependencies — only the ODE
branch needs the extra.
- `prebuilt/bffg_sde.py::_backward_filter_ode`: diffrax import is lazy,
raising a clean `ImportError` pointing at the extra when missing
(same pattern as ete3 in `io.newick` and trimesh in `shape_kernels`).
- `tests/test_core_dependencies.py`: diffrax and equinox added to the
forbidden-in-core list so a future accidental import gets caught.
- `tests/test_prebuilt_bffg_sde.py`: tolerances tightened by 2-3 orders
of magnitude (1e-3/1e-4 → 1e-5/1e-6) reflecting the integrator
precision boost. Added a missing-diffrax regression test via
`monkeypatch` on `builtins.__import__`.
- `README.md`: documents the new extra.
Measured numerical agreement between ODE (B=β=0) and closed-form on the
1D Brownian smoke case:
H_0 abs err : 1.8e-7 (was ~1e-4 with odeint)
F_0 abs err : 1.8e-7
c_0 abs err : 6.0e-7
forward_guided trajectory & logpsi abs err: 0.0 (downstream of better F_t/H_t)
178 tests total, all green.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Restructure the docs/source/notebooks/ progression so the reader-facing narrative is built around inferential goals rather than around the names of the underlying algorithms: - 03 phylo_mean: closed-form point estimate (unchanged content; updated cross-references in the recap). - 04 phylo_bayesian (rename of 04_gaussian_bffg): the same ancestral-state problem as 03, upgraded to a full Bayesian posterior. Gaussian BFFG is the closed-form smoother that delivers it; the Bayesian framing is the headline. Hyperparameters are taken as known here. - 05 gaussian_mcmc (new placeholder): joint MCMC over latent states and hyperparameters under linear-Gaussian transitions. BFFG sits inside the kernel as an exact marginal-likelihood evaluator. - 06 gaussian_mle (new placeholder): the MLE / MAP counterpart of 05 — jax.grad through the BFFG marginal likelihood plus optax. A v3-specific capability that the v2 mutable-tree design could not support. - 07 sde_mcmc (rename of 05_sde_bffg_bridges): same MCMC machinery as 05 but with non-linear SDE transitions. BFFG becomes an approximation that Metropolis corrects via logpsi. The 2x2 conceptual matrix is now Gaussian vs SDE on one axis and Bayesian (MCMC) vs MLE (grad) on the other; the SDE-MLE square is deferred as future work. Cross-references in notebooks 01-03 are updated to point at the new filenames. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The previous logw formula compared a normalized Gaussian density at the parent's sampled value to logU at the same point, evaluating two quantities that don't match Theorem 14 of van der Meulen & Sommer (2025). Rewrite as log w(x) = log φ(H⁻¹F; μ(x), Q(x)+H⁻¹) − log φ(H⁻¹F; Φx+β, Q̃+H⁻¹) with μ(x)=x, Φ=I, β=0, and Q̃ = ℓ·a(v_T) the auxiliary linearised at the canonical posterior mean (matching gaussian_up). In the pure linear-Gaussian limit Q(x) = Q̃ so the two densities coincide and sum(logw) ≡ 0 — the paper's p.16 remark, now verified at machine precision. State-dependent a produces a genuine non-zero correction. Drop the obsolete (c_0, F_0, H_0) intermediates and the now-unused logphi_H / logU imports. The sampling code itself is unchanged. Add three regression tests: linear case sum(logw)=0 exactly; single-edge hand-computed Theorem 14 matches the prebuilt to atol=1e-6; nonlinear a gives sum(logw) with std > 0.1. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…tion
Replace the placeholder with a complete (23 cells) notebook framing BFFG
as a guided proposal for parameter inference, not just a smoother.
Structure:
1. Setup + forward-simulate ground truth (depth-4 binary, 31 nodes,
16 leaves; x_root = 0 fixed)
2. Closed-form marginal log p(y|θ) via leaf joint Gaussian
3. BFFG-guided forward map (init_gaussian_leaves → gaussian_up →
gaussian_down_conditional) returning (x, sum(logw))
4. Empirical Theorem 14 collapse: sum(logw) ≡ 0 across 500 z draws
5. BFFG-guided MCMC kernel: pCN on z + RW on log θ, target =
log g_r(0;θ) + sum(logw) + log π(log θ); RW_SCALE tuned to ~46%
6. Multi-chain joblib (loky) + hand-written Gelman-Rubin
7. Triple verification: BFFG-MCMC histogram, grid-marginalised
analytic posterior, ground truth
8. Nonlinear teaser: state-dependent a → sum(logw) acquires real
spread, previews 07
9. Recap + references
End-to-end run verified: pCN acceptance = 1.0000 (Theorem 14 signature),
R̂ ≈ 1.001, BFFG-MCMC means/medians/95% CIs match analytic grid to 3
decimal places for both σ² and τ².
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…iable tree
Replace the placeholder with a 22-cell notebook that demonstrates
hyperparameter MLE for the same linear-Gaussian model as notebook 05,
this time via maximum-likelihood instead of MCMC.
The route is the EM identity
∇_θ log p(y | θ) = E_{x ~ p(x|y, θ)}[∇_θ log p(x, y | θ)]
with `gaussian_down_conditional` as an *exact* posterior sampler in the
linear-Gaussian regime (Theorem 14 collapse, verified in 05).
`stop_gradient` on the BFFG sample is what turns the autodiff into the
EM gradient instead of the (different) reparameterisation gradient —
one line, but load-bearing.
Structure:
1. Setup + simulate (same seed as 05 → same data)
2. Complete-data log-likelihood: per-edge prior + per-leaf obs
3. BFFG-as-posterior-sampler
4. EM score estimator with stop_gradient
5. Gradient parity vs closed-form marginal log-lik (M=500: |diff| < 0.1)
6. Adam training loop (LR=0.05, 400 steps), loss + trajectory plots
7. EM-MLE vs closed-form MLE vs grid argmin — all three agree to 1%
8. Bonus: jax.grad w.r.t. edge_length — autodiff traces the whole tree
9. Sidebar: MAP is +1 line of log-prior
End-to-end verified: EM-MLE recovers (σ²=0.41, τ²=0.22), matching the
closed-form MLE (0.41, 0.20) and grid argmin (0.40, 0.21) within MC
noise. The shared offset from truth (0.5, 0.1) is the finite-sample MLE
bias on 16 observations, not the estimator.
Add `optax>=0.2` to the notebooks extras (with the rest of the
previously-uncommitted notebooks block) and refresh uv.lock.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
prebuilt/: - Merge bffg_gaussian.py + bffg_sde.py into a single bffg.py. - Drop _gaussian_density.py: inline logphi via jax.scipy.stats.multivariate_normal.logpdf; canonical_leaf_messages becomes a private helper in bffg.py. - Delete lddmm.py and shape_kernels.py: dead in code, tests, and notebooks. Removes the prebuilt-shape extra (only served mesh_sphere). API: - Rename symmetric_topology arg height -> depth across tests and the docs landing page. Env: - Pin equinox<0.13.8 in prebuilt-bffg: 0.13.8 ships a broken wheel referencing a missing equinox.internal._loop submodule. Docs: - Bring in the sphinx scaffold (conf.py, index.rst, api/index.rst, Makefile) and per-prebuilt autosummary entries; drop the LDDMM / Shape-kernels rubrics. - Refresh notebooks 01-07 against the v3 API. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- extras: merge prebuilt-bffg+prebuilt-mcmc into a single `prebuilt`, rename `notebooks` to `notebook`, drop the standalone `docs` extra and roll sphinx tooling into the `dev` group - pin the equinox/jaxtyping/wadler-lindig/optimistix chain under `prebuilt` so `uv sync` no longer resolves a broken diffrax env - setuptools-scm: tag_regex + fallback_version for the v3 line - add ruff lint+format config and a pre-commit pipeline (ruff, whitespace, eof, yaml/toml, merge-conflict, large-file) - reformat the whole repo through ruff format as the new baseline - new `.github/workflows/docs.yaml`: strict sphinx build on push/PR, GitHub Pages deploy on push to v3 - rename root `makefile` → `Makefile`; move `docs/report/` out of the package tree and gitignore it - README + CLAUDE.md: drop the dead plan-doc link, fix the wrong `prebuilt-shape` extra, document the new install commands
- requirements/*.txt + environment.yml for non-uv install paths
- docs/Makefile: route sphinx-build through uv run (macOS PATH fix)
- prebuilt/bffg: refactor
- core/{dispatch,sweep}: updates
- prebuilt/{__init__,sde}: tweaks
- notebooks 01/05/06/07: refreshed outputs
- tests: expand bffg_sde + sweep_up
- README, .gitignore updates
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…onsolidation
Packaging
- Drop the prebuilt / io extras (diffrax and jax_tqdm are no longer used);
the core lib imports on bare jax + numpy. Simplify the sync target to
`uv sync --group dev`. Delete legacy requirements/ and environment.yml.
Layout
- Move Newick parsing into core/builders.py and delete hyperiax/io/.
- Add hyperiax/utils/ (pure-JAX RK4 / Euler-Maruyama solvers), used by bffg
in place of diffrax.
- Delete hyperiax/prebuilt/{mcmc,sde}.py; NumPyro is now the recommended
MCMC path (see notebooks 05 / 06).
Core dispatch
- One up-sweep path: _up_dispatch walks the segment layout (pbuckets) for
any topology. The equal-degree dense-vmap path is gone, and Topology
drops gather_child_idx / level_non_leaf_indices with it.
- Add Children.map(fn) — a segment-preserving per-child vmap so a non-linear
per-child transform feeds the same reduction surface as a plain field.
ChildrenAxis.flat is now exposed; writes_children supported on any topology.
BFFG rewrite (prebuilt/bffg.py)
- Rename gaussian_* -> discrete_*, sde_* -> continuous_*.
- Discrete backward filter implements the full linear-Gaussian auxiliary
(Phi x + beta, Q) per Theorem 14 §6.1; discrete forward-guiding weight
matches Theorem 14.3.
- Continuous backward filter caches the per-edge (H, F) trajectory via
writes_children; vertex message lives in new prec_v / ptnl_v fields.
- Per-edge linearisation anchors (Algorithm 3 §7.1): each node stores
anchor (+ anchor_pa on continuous edges); discrete_refine_anchor /
continuous_refine_anchor refine to the BFFG posterior mean.
- Two-anchor sigma-tilde interpolation along continuous edges (BF + FG
consistent).
- Track canonical-message log_norm so tree.log_norm[root] is the marginal
log-evidence.
phylo_mean migrated to children.map; now runs on ragged trees.
Tests
- Consolidate into one test_sweep_up.py covering both topologies (the
unified path makes them indistinguishable to user code).
- Drop the legacy gaussian_/sde_ BFFG tests; drop mcmc/sde tests; rename
test_io_newick -> test_newick; add test_utils_{ode,sde}; adjust
test_topology for the removed fields; flip test_prebuilt_phylo_mean's
unequal-degree-raises check to an unequal-degree-correctness check.
CI + docs scaffold updated.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- hyperiax/prebuilt/bffg.py: module-level docstring + Google-style Args/Returns/Notes for all 12 public symbols (discrete_/continuous_ schema, init_*_tree, *_bf_sweep, *_forward_sweep, *_fg_sweep, *_refine_anchor). No semantic changes — autodoc-friendly polish. - docs/source/api/index.rst: add discrete_refine_anchor and continuous_refine_anchor. - Rename docs/source/notebooks/07_sde_mcmc.ipynb -> 07_sde_bffg.ipynb; update the toctree (notebooks/index.rst) and the inline references in notebooks 03 / 04 / 05 / 06. Verified: pytest -q (164 passed), ruff check hyperiax/ tests/ (clean), make -C docs html SPHINXOPTS=-W (build succeeded). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The repo was tracking both `Makefile` (uppercase) and `makefile` (lowercase) as separate paths even though they point to the same physical file on the case-insensitive default macOS / Windows filesystems. On Linux CI this manifested as a duplicate entry, and locally it caused `git status` to permanently show `M Makefile` because the uppercase index entry held stale content while the lowercase entry tracked the up-to-date file. Keep only `makefile` (lowercase) — it already has the current `uv sync --group dev` content from the v3 release prep. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
b04864d
into
ComputationalEvolutionaryMorphometry:main
3 checks passed
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Purpose
This is the migration PR for merging the rewritten Hyperiax codebase into
main.The rewrite centers the public API on three JAX-friendly primitives:
Topologyfor immutable, hashable rooted-tree structure.Treefor immutable per-node JAX arrays with schema validation.@hx.upand@hx.downfor explicitTree -> Treemessage passing.The goal is to make the package homepage, installation path, CI, docs, and release workflow match the current architecture after this PR lands on
main.Highlights
Core
ChildrenAxisproxy for child reductions andchildren.map(fn)for non-reduction child transforms.Treemutation helpers that return new trees.hyperiax.core.core/andutils/stay pure JAX/NumPy/stdlib;prebuilt/stays separate.Prebuilt
hyperiax.prebuilt.bffgexposes the currentdiscrete_*andcontinuous_*BFFG sweeps.(Phi x + beta, Q)and Theorem 14 weight.phylo_meanuseschildren.map(...)and supports ragged trees.Packaging, Docs, And Release
mainbranch, with development/v3 migration wording removed.pyproject.tomlanduv.lock.vtags before the newvtag exists upstream.release: publishedandworkflow_dispatch; it does not publish onpushtomain.v3.0.0tag has been moved to the current v3 HEAD (79c0e02) so setuptools-scm builds the migration package as3.0.0. Nov3.0.1tag is present locally or onorigin.Main Integration
upstream/main@ca90a5dintov3inde195a1.modify/deleteconflicts for legacy example files deleted by the rewrite:examples/SDE.pyexamples/mcmc_Gaussian_BFFG_shapes_SDEs_Kunita.ipynbexamples/surface.#72forward SDE fix is covered byhyperiax.utils.sde.EulerMaruyama.step, which evaluatesdiffusion(t, y, args)at the current statey; continuous BFFG sweeps call this solver.Validation
uv sync --group devuv run python -c "<README quick-start example>"- returned[1. 1.]uv run pre-commit run --files README.md .github/workflows/pypi_push.yamluv run pre-commit run check-yaml --files .github/workflows/pypi_push.yamluv run ruff check hyperiax/ tests/- passeduv run pytest -q-164 passed in 8.04suv buildatv3.0.0- builthyperiax-3.0.0.tar.gzandhyperiax-3.0.0-py3-none-any.whlSETUPTOOLS_SCM_PRETEND_VERSION_FOR_HYPERIAX=3.0.0 uv run make -C docs html SPHINXOPTS=-W- build succeeded79c0e02: Sphinx build passed; pytest passed; Pages deploy skipped for PR as expectedmergeable=trueSuggested Review Order
hyperiax/core/dispatch.py,hyperiax/core/views.py,hyperiax/core/tree.py,hyperiax/core/topology.py.hyperiax/prebuilt/bffg.py.README.md,.github/workflows/,pyproject.toml,docs/source/.tests/test_core_dependencies.py,tests/test_sweep_up.py,tests/test_regression.py.