Skip to content

v3: migrate main to immutable Tree, unified sweeps, and BFFG refit#73

Merged
gefanyang merged 34 commits into
ComputationalEvolutionaryMorphometry:mainfrom
gefanyang:v3
Jun 2, 2026
Merged

v3: migrate main to immutable Tree, unified sweeps, and BFFG refit#73
gefanyang merged 34 commits into
ComputationalEvolutionaryMorphometry:mainfrom
gefanyang:v3

Conversation

@gefanyang

@gefanyang gefanyang commented May 29, 2026

Copy link
Copy Markdown
Collaborator

Purpose

This is the migration PR for merging the rewritten Hyperiax codebase into main.

The rewrite centers the public API on three JAX-friendly primitives:

  • Topology for immutable, hashable rooted-tree structure.
  • Tree for immutable per-node JAX arrays with schema validation.
  • Sweeps declared with @hx.up and @hx.down for explicit Tree -> Tree message passing.

The goal is to make the package homepage, installation path, CI, docs, and release workflow match the current architecture after this PR lands on main.

Highlights

Core

  • Single segment-based up-sweep dispatch path for equal- and unequal-degree trees.
  • ChildrenAxis proxy for child reductions and children.map(fn) for non-reduction child transforms.
  • Immutable Tree mutation helpers that return new trees.
  • Pure-Python Newick helpers in hyperiax.core.
  • L1/L2 layering enforced by tests: core/ and utils/ stay pure JAX/NumPy/stdlib; prebuilt/ stays separate.

Prebuilt

  • hyperiax.prebuilt.bffg exposes the current discrete_* and continuous_* BFFG sweeps.
  • Discrete BF/FG implements the full linear-Gaussian auxiliary (Phi x + beta, Q) and Theorem 14 weight.
  • Continuous BF/FG implements two-anchor auxiliary diffusion interpolation and the Theorem 23 / Remark 24 log-weight integrand.
  • phylo_mean uses children.map(...) and supports ragged trees.

Packaging, Docs, And Release

  • README is now a formal package homepage for the post-merge main branch, with development/v3 migration wording removed.
  • Requirements are consolidated into pyproject.toml and uv.lock.
  • Sphinx docs and tutorial notebooks 01-07 match the current API.
  • Docs CI handles upstream PR merge refs that can see legacy non-v tags before the new v tag exists upstream.
  • PyPI workflow is present and intentionally triggers only on release: published and workflow_dispatch; it does not publish on push to main.
  • The v3.0.0 tag has been moved to the current v3 HEAD (79c0e02) so setuptools-scm builds the migration package as 3.0.0. No v3.0.1 tag is present locally or on origin.

Main Integration

  • Merged upstream/main@ca90a5d into v3 in de195a1.
  • The only merge conflicts were modify/delete conflicts for legacy example files deleted by the rewrite:
    • examples/SDE.py
    • examples/mcmc_Gaussian_BFFG_shapes_SDEs_Kunita.ipynb
  • Resolution: keep the rewrite's deletion of the legacy examples/ surface.
  • The #72 forward SDE fix is covered by hyperiax.utils.sde.EulerMaruyama.step, which evaluates diffusion(t, y, args) at the current state y; continuous BFFG sweeps call this solver.

Validation

  • uv sync --group dev
  • uv run python -c "<README quick-start example>" - returned [1. 1.]
  • uv run pre-commit run --files README.md .github/workflows/pypi_push.yaml
  • uv run pre-commit run check-yaml --files .github/workflows/pypi_push.yaml
  • uv run ruff check hyperiax/ tests/ - passed
  • uv run pytest -q - 164 passed in 8.04s
  • uv build at v3.0.0 - built hyperiax-3.0.0.tar.gz and hyperiax-3.0.0-py3-none-any.whl
  • SETUPTOOLS_SCM_PRETEND_VERSION_FOR_HYPERIAX=3.0.0 uv run make -C docs html SPHINXOPTS=-W - build succeeded
  • GitHub Actions on 79c0e02: Sphinx build passed; pytest passed; Pages deploy skipped for PR as expected
  • PR is mergeable=true

Suggested Review Order

  1. Core invariants and dispatch: hyperiax/core/dispatch.py, hyperiax/core/views.py, hyperiax/core/tree.py, hyperiax/core/topology.py.
  2. BFFG implementation: hyperiax/prebuilt/bffg.py.
  3. Packaging, release, and docs entry points: README.md, .github/workflows/, pyproject.toml, docs/source/.
  4. Tests that enforce architecture and JIT behavior: tests/test_core_dependencies.py, tests/test_sweep_up.py, tests/test_regression.py.

gefanyang and others added 29 commits May 12, 2026 13:04
Delete the legacy core (`hyperiax/{tree,models,execution,mcmc,plotting}`),
all examples / benchmarks / docs notebooks / existing tests, and `setup.py`.
Stop tracking `_version.py` and `.DS_Store`.

Repackage for Python 3.11+ on uv:
- `pyproject.toml` slimmed to `jax / jaxlib / numpy`; ete3 moved under
  `[io]`; trimesh under `[prebuilt-shape]`; pytest into the PEP 735
  `dev` dependency group.
- `.python-version` pins 3.11.
- `uv.lock` committed for reproducible installs.
- `makefile` switched to `uv run pytest` / `uv sync --all-extras --group dev`.

`hyperiax/__init__.py` is gutted to a docstring placeholder until Stage 1
introduces `hyperiax.core`. README updated with uv-based dev setup
instructions and a banner pointing at the v3 plan.

Verified: `uv sync --group dev` succeeds and `uv run pytest` collects 0
tests cleanly.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Build the L1 core of the v3 refactor. All three types are frozen
dataclasses; Topology and Tree are JAX-pytree-registered so they
flow through jit/scan without any static-arg hack.

- `core/errors.py` — HyperiaxError, SchemaMismatch, MissingField,
  StructureMismatch.
- `core/schema.py` — FieldSpec + Schema. Schema is an ordered tuple
  sorted by name for deterministic hash; this is what stabilises the
  pytree structure across topologically-identical trees.
- `core/topology.py` — Topology with BFS-ordered parents and precomputed
  derived arrays (level_starts, masks, child_counts, equal-degree
  gather_child_idx, segment pbuckets/pbuckets_ref). All derived fields
  are np.ndarray (not jnp) so the object is Python-hashable and lives
  entirely in pytree aux_data. Hash is cached from parents.tobytes().
- `core/tree.py` — Tree composes Topology + Schema + dict[str, Array].
  `empty / set / set_at / update / drop` all return a new Tree; the
  data dict is never mutated in place. Tree is intentionally not
  hashable, so users cannot accidentally pass it as `static_argnames`.

Top-level `hyperiax` package re-exports the public surface.

Tests (T-1/T-2/T-3 + edge cases): 36 passing.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Bring the first sweep online. The minimal demo from the plan now runs:
build a symmetric binary tree, seed its leaves, run an up-sweep that
averages children, read the root — and wrap the whole thing in
`@jax.jit` without leaking tracers.

- `core/builders.py` — `symmetric_topology(height, degree)` for d≥1 and
  any h≥0; `from_parents` re-export for explicit constructions.
- `core/views.py` — `Node` / `Children` transient views. Attribute access
  (`node.value`, `children.value`) is the primary surface; item access
  works too. Missing fields produce a hyperiax-flavoured AttributeError
  pointing at `reads=`/`reads_children=` so the user can fix the
  declaration rather than chase a JAX KeyError.
- `core/sweep.py` — `SweepFn` (frozen, hashable, value-equal on (direction,
  fn identity, reads/writes tuples)) and the `@up(reads=..., reads_children=...,
  writes=...)` decorator. `writes` is required; cross-direction violations
  (`reads_parent` on @up, `reads_children` on @down) are rejected at
  construction.
- `core/dispatch.py` — `up_dispatch` validates the schema eagerly (so the
  traceback lands at the call site) and bails on unequal-degree (deferred
  to Stage 4). `_up_dispatch_jit` is `static_argnums=(0,)` on the SweepFn,
  walks parent levels deepest→root, gathers via `topo.gather_child_idx`,
  and vmaps the user function over the scope axis. The vmap is the key
  decision: from inside the user fn, `node.value` is `(*trailing,)` and
  `children.value` is `(k, *trailing)`, so `children.value.mean(0)` is
  naturally a reduction over the children axis. This is the same surface
  that Stage 4's `ChildrenAxis` proxy will present for unequal-degree
  trees.

Tests now cover T-4 (per-level numerics on a 15-node binary tree, leaf
mean recovery at the root, node-vs-children data interplay, params
threading) and T-5 (the same `SweepFn` on a structurally-identical fresh
tree hits the JIT cache — verified by counting `nonlocal` traces, not by
timing). 52 tests total, all green.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Walks root → leaves, one level at a time, gathering each node's parent
data via topo.parents and vmapping the user fn per-node. Unlike up,
this works on any topology — every non-root node has exactly one
parent, so no equal_degree fast path is needed.

- `core/views.py` — `Parent` subclasses Node with empty __slots__, just a
  different __repr__ for clarity in error messages and isinstance checks.
- `core/sweep.py` — `@down(reads=, reads_parent=, writes=)`. Signature
  symmetric with @up; same SweepFn primitive, different direction tag.
  `SweepFn.__call__` now dispatches to `down_dispatch` for `direction='down'`.
- `core/dispatch.py` — `down_dispatch` skips the equal_degree check
  (down works on any topology) and `_down_dispatch_jit` mirrors the up
  structure: iterate level 1..depth, slice the level contiguously,
  gather parents via topo.parents[ls:le], vmap the user fn over (node,
  parent, params).

Tests (T-6 + supporting): 12 new in tests/test_sweep_down.py covering
numerical propagation, multi-dim values, params threading, unequal-degree
topology (this is what makes down universally applicable), root-untouched
semantics, outer-jit composition, jit cache hits on identical topology,
the canonical up-then-down round trip (leaf mean broadcast to all nodes),
and the validation paths.

64 tests total, all green.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Bring ragged trees online without changing the user-facing surface.
The same `@hx.up` function that reduces `children.value.sum(0)` over a
symmetric binary tree now runs unchanged on a tree where one parent has
3 children and another has 2.

- `core/views.py` — `ChildrenAxis` is a slot-only proxy backed by a flat
  `(M_total, *trailing)` JAX array plus segment ids. It exposes
  `.sum/.prod/.max/.min/.mean(axis=0)`, each forwarding to the matching
  `jax.ops.segment_*` with `num_segments` held as a static Python int
  (R-2 in the plan: traced num_segments would yield poly-shape errors).
  Non-zero axes and any attempt to coerce the proxy to a dense array via
  `__array__` raise a hyperiax-flavoured error that tells the user to
  reduce first.
- `core/dispatch.py` — `up_dispatch` now branches on
  `tree.topology.equal_degree`. The new `_up_dispatch_unequal` walks
  parent levels deepest→root using `pbuckets`/`pbuckets_ref`, gathers
  parent rows by id, wraps each children field in a `ChildrenAxis`, and
  scatters writes into the unique parent ids. Mixed-depth leaves are
  handled naturally: a level-L node that is itself a leaf never appears
  in any `pbuckets_ref[L+1]`, so the dispatcher leaves its data alone.
  Down dispatch was already topology-agnostic and needs no changes; a
  small `_check_writes` helper unifies the writes-key validation across
  all three paths.

Tests (T-7 + supporting): 15 new in tests/test_sweep_up_unequal.py.
Covers sum/prod/max/min/mean on the canonical ragged tree, node+children
arithmetic, multi-dim trailing shapes, mixed-depth leaves (the genuinely
ragged case where root has degree 2 and one of its children has degree
3), proxy guards (non-zero axis, `jnp.asarray` coercion), outer-jit
composition, JIT cache hits, and an equal/unequal API-parity check
asserting the same `@up` function runs on both topology shapes.

79 tests total, all green.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The core of why the v3 rewrite exists. The legacy `OrderedExecutor` did
`tree.data = {...}` in place, which leaked tracers between calls of an
outer `@jax.jit` train_step (this is what made the user abandon
hyperiax for Flax+optax pipelines). With Tree now an immutable JAX
pytree and every dispatcher path pure, these scenarios just work.

tests/test_regression.py covers:

T-10 (outer-jit):
- 20 consecutive `@jax.jit`'d step calls — no leaked tracer
- multi-sweep pipeline (up → down → up) inside one jit
- bit-for-bit parity between eager and jit'd outputs

T-11 (lax.scan):
- user-fn trace count under a 100-iter scan equals that of a single
  call (body compiled once, not per iteration)
- structural check: the jaxpr of `scan(body, length=N)(tree)` has
  exactly one top-level `scan` primitive
- numerical equivalence between an eager Python loop and `lax.scan`
- scan composes with the unequal-degree (segment_sum) path too

jax.grad:
- `jax.grad` through an up-sweep matches the hand-computed derivative
- same for down-sweep
- `jax.value_and_grad` works on the ChildrenAxis (segment) path

Tree.update + jit:
- adding a new field with `.update()` then running a sweep on the
  extended schema works eagerly and under jit

Flax-style train_step:
- the exact pattern that broke the legacy code — params + grads +
  optimizer-style update wrapping a sweep — runs cleanly through 10
  jit'd iterations and produces sensible gradient descent on a toy loss.

92 tests total, all green.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
L2 module wrapping ete3 for Newick read/write. `hyperiax.io` is the
first non-core module, opt-in via `uv sync --extra io`.

- `hyperiax/io/__init__.py` — exposes `newick`; no eager ete3 import.
- `hyperiax/io/newick.py`:
  * `read(source, *, schema=None, newick_format=1) -> Tree` — accepts a
    Newick literal string or a path; returns a Tree whose schema always
    includes `edge_length` (shape (), float32), plus any user-requested
    extras. Node names live on `Topology.names`, not on tree.data.
  * `write(tree, *, newick_format=1) -> str` — round-trip writer.
    Patches in the root name manually because ete3 deliberately drops
    it from every Newick output format (an ete3 quirk).
  * `_ete_to_bfs_arrays` — BFS-traverses an ete3 tree to produce a
    hyperiax-style `(parents, names, edge_lengths)` triple. BFS order
    guarantees `parents[i] < i`, which is what Topology.from_parents
    requires.
  * `_require_ete3` — gates the import behind a helpful error message
    pointing at the `[io]` extra.

ete3 is imported lazily inside `read`/`write`, so `import hyperiax.io`
works in a core-only install (verified by checking `sys.modules` after
the import).

Tests (T-12 + supporting): 12 in tests/test_io_newick.py with
`pytest.importorskip("ete3")` so the suite skips gracefully without
the extra. Covers bit-for-bit round-trips (simple unnamed and complex
named cases), edge-length and name extraction, BFS-ordering invariant,
extra schema fields, file-path reading, write-without-edge_length guard,
and a smoke test that a Newick-parsed tree feeds straight into an
@hx.up sweep.

104 tests total, all green.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
L2 recipe: the edge-length-weighted phylogenetic mean estimator, ported
from the legacy `PhyloMeanModel` (which subclassed `UpReducer` with
`{up, transform}` and a manual `reductions={...}` dict) into a single
`@hx.up` SweepFn.

- `hyperiax/prebuilt/__init__.py` — exports `phylo_mean`.
- `hyperiax/prebuilt/phylo_mean.py` — single sweep that reads
  `estimated_value` and `edge_length` from children and writes back the
  weighted mean on each parent:

      hat_x_p = sum_c (v_c / l_c) / sum_c (1 / l_c)

  The trailing-shape broadcast (scalar edges vs vector values) is
  handled by reshaping `edges` to (k, 1, ...) before division.
  Currently equal-degree only — `children.estimated_value /
  children.edge_length` relies on JAX broadcasting that the unequal
  ChildrenAxis proxy doesn't yet expose. The ragged path will be
  enabled once the proxy gains elementwise arithmetic or a `.gather()`
  fallback.

Tests (T-13 + supporting): 8 in tests/test_prebuilt_phylo_mean.py.
- Hand-computed expected values on a 7-node tree with non-uniform edges
  (root=25 / inner=15,35 from leaves [10,20,30,40] under
  edge_lengths=[0,0.5,0.5,1,1,2,2]).
- Leaves untouched after the sweep.
- Uniform edges → root = simple leaf mean.
- Vector-valued `estimated_value` (broadcasting works).
- Random edges + values on a depth-3 binary tree match a pure-numpy
  reference walk to 1e-5 tolerance.
- Composition under `@jax.jit`.
- End-to-end: Newick → tree → phylo_mean.
- Unequal-degree currently raises (documented limitation).

112 tests total, all green.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Two prebuilt modules covering most of T-14. Deliberately MVP — the
full SDE-BFFG variant with ODE-integrated backward filter + Brownian
bridge guiding is deferred to a follow-up because it more than doubles
the code surface and isn't needed for the T-14 numerical baseline.

What landed:

- `hyperiax/prebuilt/sde.py` — full port of the legacy `examples/SDE.py`:
  factored `dot` / `solve` for (n·d,) flattened state on tensor-product
  diffusivity, uniform-step `dts`, and `forward` (Euler-Maruyama via
  `lax.scan`).
- `hyperiax/prebuilt/_gaussian_density.py` — Gaussian density helpers
  (standard, precision, canonical, unnormalized) shared between the
  current Gaussian BFFG and the future SDE-BFFG.
- `hyperiax/prebuilt/bffg_gaussian.py`:
    * `gaussian_up(n, a, d=1)` — Gaussian BFFG backward filter as a single
      @hx.up sweep. Per-edge math is the closed-form precision-update
      `H_0 = (I + H_T·var·a)⁻¹·H_T`, `F_0 = (I + H_T·var·a)⁻¹·F_T`,
      lifted verbatim from the legacy `Gaussian_up`. The summed children
      (F_0, H_0) become the parent's (F_T, H_T); c_T is recomputed via
      logphi_can on the fused canonical params.
    * `init_gaussian_leaves(tree, leaf_values, obs_var, *, n, d=1)` —
      seed the leaves with iid-Gaussian observation precision I_n / τ².
    * `gaussian_down_unconditional(sigma)` — simplest forward sampling:
      `child = parent + sqrt(edge_length)·σ(parent)·noise`.

Deferred (will arrive in a follow-up commit):
- `bffg_sde.py` — needs `backward_filter` (ODE on (H, F, c)) +
  `forward_guided` (guided SDE bridge) ports; ~150 lines plus heavy
  numerical testing.
- `gaussian_down_conditional` — needs the up-pass results
  (c_0, F_0, H_0) carried as node fields, which means extending the
  up sweep's writes set.

Tests (14 new, 126 total):
- SDE: `dot`/`solve` round-trip; `dts` sums to T; `forward` with zero
  noise matches the exact explicit-Euler reference on OU; zero-drift
  Brownian variance ~ T (Monte Carlo); the σ-branch and the a-branch
  agree when a = σσᵀ; jit composes.
- Gaussian BFFG up: 3-node root posterior matches the closed-form
  `(H, F) = sum 1/(τ²+l_iσ²), sum y_i/(τ²+l_iσ²)`; posterior mean
  matches F/H; under outer @jax.jit and `lax.scan`.
- Cross-check: on a star tree with τ²→0 and σ²=1, BFFG root posterior
  mean equals phylo_mean (single-level equivalence; multi-level diverges
  because BFFG correctly propagates posterior precision and phylo_mean
  doesn't).
- Unconditional down: zero noise keeps parent value; depth-3 Brownian
  diffusion produces leaf variance ~ depth.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Two L2 modules covering the kernel + LDDMM-dynamics half of the legacy
shape pipeline. Pure-math; no tree machinery touched.

- `prebuilt/shape_kernels.py` — verbatim port of `examples/shape.py`'s
  Laplace family (K_0..K_4) + Gaussian RBF, plus the matching `g_K0` /
  `g_K1` correlators. Each takes pairwise-difference tensors
  `(n, n, d)` and returns the `(n, n)` Gram matrix. `fibonacci_sphere`
  for ~uniform sphere sampling (pure numpy); `mesh_sphere` requires
  `trimesh` via the `[prebuilt-shape]` extra (lazy-imported with a
  clean error message pointing at the extra).

- `prebuilt/lddmm.py` — `lddmm_drift()` (zero drift, free landmark
  Brownian motion) and `lddmm_covariance(kernel, *, n, d)` (kernel-
  Gram tensor-product covariance). These weren't in the legacy
  `examples/shape.py` — the notebooks defined them ad-hoc — collecting
  them here as the obvious LDDMM building blocks. They plug straight
  into `sde.forward` via the `a=...` parameter (no separate σ
  constructor needed; the forward path falls back to a cholesky of `a`).

Tests (18 new, 144 total):
- Shape kernels: r-floor at coincident landmarks; all Laplace kernels →
  α at r=0, Gaussian → α/2; Gram symmetry on real landmarks; monotone
  decay in distance; `g_K0` matches the closed form at r=1; fibonacci
  sphere returns unit vectors with near-zero mean position; `mesh_sphere`
  hits exact icosphere subdivision counts (target=42 → 42 verts);
  missing-trimesh path raises with a message pointing at `prebuilt-shape`.
- LDDMM: drift is zero; covariance shape (n, n), symmetric, PD on
  well-separated landmarks; diagonal equals α; zero-noise integration
  is constant; nearly-coincident landmarks trace highly correlated
  trajectories (corr > 0.99); jit composes.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Final stage: enforce the L1/L2 dependency boundary in tests and bring
the CI workflows in line with the v3 setup (uv-managed, Python 3.11+,
no docs).

- `tests/test_core_dependencies.py` (T-16):
  * Parses every `core/*.py` via `ast` (regex matched docstring text).
  * Asserts no file imports any of matplotlib / scipy / tqdm / ete3 /
    ete4 / flax / trimesh / plotly.
  * Whitelist sanity-check: every absolute import in `core/` must
    resolve to jax, numpy, hyperiax, or a stdlib module (detected by
    sys.base_prefix membership of the spec origin).
  * Manually verified to fire on a planted `import trimesh` in
    `core/errors.py`.

- `.github/workflows/pytest.yaml`: rewritten to use
  astral-sh/setup-uv@v4 + `uv sync --all-extras --group dev` + `uv run
  pytest`. Triggers extended to cover both `main` and `v3` branches.

- `.github/workflows/pypi_push.yaml`: switched from
  `python setup.py sdist` (setup.py was deleted in Stage 0) to
  `uv build`, which produces both sdist and wheel; `fetch-depth: 0`
  added so setuptools-scm can read the full git history for the
  version tag.

- `.github/workflows/build_docs.yaml`: deleted. The docs notebook tree
  was removed in Stage 0 and docs are out of scope for v3; reintroduce
  later when docs come back online.

153 tests total, all green (+9 from T-16).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…l down

Completes the BFFG variant that was deferred in Stage 8. Closed-form
only (no ODE-integrated linear-drift path): handles the standard
``b = 0`` / ``β = B = 0`` case which covers Brownian motion and any
SDE whose auxiliary diffusivity is linearly interpolated between
``tildea0`` and ``tildeaT`` per edge.

- `prebuilt/bffg_sde.py`:
  * `backward_filter(dts, params, c_T, v_T, F_T, H_T, tildea0, tildeaT)`
    — closed-form per-edge filter. The integrated ``tildea`` over the
    edge gives a tractable Φ-inverse ``I + H_T · integrated``, from
    which ``H_0 = Φ_inv⁻¹ H_T`` and ``F_0 = Φ_inv⁻¹ F_T``. ``c_0``
    follows the legacy correction
    ``c_T + ½ v_Tᵀ(H_T - H_0)v_T - ½ log|det Φ_inv(0)|``.
  * `forward_guided(x0, dts, dWs, b, sigma, params, a, F_T, H_T,
    tildea0, tildeaT)` — guided-bridge SDE via lax.scan. Drift gets the
    bridge correction ``a·(F(t) - H(t)·x)·dt``; ``logpsi`` accumulates
    the van der Meulen et al. correction terms.
  * `sde_up(n_steps, a)` — up sweep wrapping `backward_filter`. Reads
    ``edge_length, v_0, c_T, v_T, F_T, H_T`` from each child, sums
    ``c_0, F_0, H_0`` over children, writes back to the parent's
    ``c_T, F_T, H_T`` and computes ``v_T = H_T⁻¹ F_T``.
  * `sde_down_unconditional(n_steps, b, sigma, *, a=None)` — forward
    Euler-Maruyama on each edge; child's ``value`` is the full
    trajectory of shape (n_steps+1, n·d).
  * `sde_down_conditional(n_steps, b, sigma, a)` — forward-guided
    bridge using the up-sweep's ``(F_T, H_T)``; writes ``value`` and
    ``logpsi`` per node.
  * `propagate_v_T_to_v_0()` — small down sweep that copies each
    parent's posterior mean ``v_T`` to its children's linearization
    point ``v_0`` (the root's ``v_0`` must be initialized separately,
    typically by ``init_sde_leaves(..., root_value=...)``).
  * `init_sde_leaves(...)` — analogue of `init_gaussian_leaves` that
    also seeds ``v_T`` at leaves and (optionally) ``v_0, v_T``
    everywhere from a root prior.

Tests (11, total 164):
- `backward_filter` Brownian (tildea=I) ≡ Gaussian closed-form
  ``H_0 = H_T / (1 + H_T·T)``, ``F_0 = F_T / (1 + H_T·T)``.
- `forward_guided` Brownian: ``logpsi → 0`` when ``a == tildea`` and
  ``b == 0``; zero-noise trajectory relaxes toward ``H⁻¹F`` (bridge
  target).
- `sde_up` bit-for-bit equals `gaussian_up` on a 3-node Brownian tree;
  root posterior matches the closed-form
  ``Σ 1/(τ²+l_iσ²)`` / ``Σ y_i/(τ²+l_iσ²)``.
- `propagate_v_T_to_v_0` copies parent ``v_T`` to children's ``v_0``.
- `sde_down_unconditional` zero noise → constant trajectory; unit-σ
  Brownian leaf-terminal variance ≈ edge_length (Monte Carlo).
- End-to-end pipeline (up → propagate → conditional down) runs and
  produces finite values.
- jit + `lax.scan` composition.

Deferred for a follow-up: the ODE-integrated path for time-varying
linear drift `B(t), β(t)` (uses `jax.experimental.ode.odeint`); the
matching `gaussian_down_conditional` (needs `c_0, F_0, H_0` carried as
node fields, which would extend `gaussian_up`'s writes set).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Finishes the two pieces that Stage 8 + the SDE follow-up deliberately
deferred. Every BFFG variant from the legacy `examples/ABFFG.py` now has
a home in the new prebuilt layer.

ODE-integrated BFFG (general linear B, β):

- `bffg_sde.backward_filter` and `forward_guided` now dispatch
  internally on `B`/`β`. The closed-form path is unchanged; the new
  ODE path integrates the (H, F, c) system via
  `jax.experimental.ode.odeint` with τ = T − t, unpacks the per-step
  series `F_t`, `H_t` for use by `forward_guided`, and supports d > 1
  via tensor-product reshape (`F` flat (n·d,) ↔ matrix (n, d); `c` per
  d-column; trace term broadcast across d). Uniform `dts` assumed.
- `bffg_sde.forward_guided` ODE path indexes the precomputed `F_t`,
  `H_t` arrays at each Euler step and adds the linear-drift term
  `tildeb(t, x) = β(t) + B(t)·x` to the `logpsi` correction.
- `sde_up(n_steps, a, *, B=None, beta=None)` and
  `sde_down_conditional(n_steps, b, σ, a, *, B=None, beta=None)` now
  forward to the matching path. The conditional down re-runs the ODE
  filter inline per edge to recover `F_t`/`H_t` (avoids extending the
  up sweep's writes set; one extra `odeint` per edge per sample).
- A small `_bridge_step_body` helper factors the shared Euler-step
  body between the closed-form and ODE branches of `forward_guided`.

gaussian_down_conditional:

- New sweep in `bffg_gaussian`. For each non-root node, re-derives the
  per-edge filter intermediates `(c_0, F_0, H_0)` from the node's own
  posterior `(c_T, F_T, H_T)` and `edge_length` using the same math as
  `gaussian_up` (linearization at `v_T = solve(H_T, F_T)`), then draws
  the conditional Gaussian sample given the parent's current value and
  computes the per-edge `logw` correction
  `logphi_H(v_T_post, x, (I + H_T·Σ_T)⁻¹·H_T) − logU(x, c_0, F_0, H_0)`.
- No extension to `gaussian_up`'s writes set — the schema growth that
  was blocking this in the previous round is sidestepped by inline
  re-derivation. Costs one extra (n, n) solve per non-root node per
  sample.

Tests (+11; 175 total, all green):

ODE path (6):
- `backward_filter` with `B=0, β=0` reproduces closed-form `H_0`, `F_0`,
  `c_0` to 1e-4.
- `backward_filter` ODE per-step series ends at the boundary
  conditions (`H_t[-1] == H_T`).
- `forward_guided` ODE with `B=0, β=0` matches closed-form trajectory
  and `logpsi` to 1e-3.
- `sde_up` ODE matches closed-form sweep on a 3-node Brownian tree.
- Full up → propagate → conditional-down pipeline with damped
  `B(t) = -0.5·I` runs and leaves land near observations.
- ODE path composes under outer `@jax.jit`.

`gaussian_down_conditional` (5):
- Zero-noise + low-`obs_var` recovers leaf observations to 5e-2.
- Determinism under zero noise (same input → same output).
- Random noise produces finite values + finite `logw`.
- Outer `@jax.jit` composition.
- End-to-end up → conditional-down inside `jax.lax.scan`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Stage A of the diffrax migration: only the ODE-integrated backward
filter is replaced — the SDE forward/forward_guided paths stay on
hand-rolled lax.scan because their pre-sampled-dW data model
(MCMC Crank-Nicolson on `tree.data['noise']`) is fundamentally at
odds with diffrax's BrownianPath abstraction.

The substitution:

  jax.experimental.ode.odeint  →  diffrax.diffeqsolve(
                                     diffrax.ODETerm(...),
                                     diffrax.Tsit5(),
                                     stepsize_controller=
                                       diffrax.PIDController(rtol=1e-7, atol=1e-9),
                                     ...)

Same packed state layout, same τ = T − t convention, same return shape.
The vector_field is rewritten to diffrax's `(t, y, args)` signature.

Concretely:
- `pyproject.toml`: new `[prebuilt-bffg]` extra pinning `diffrax>=0.6`.
  Closed-form path still has zero non-jax dependencies — only the ODE
  branch needs the extra.
- `prebuilt/bffg_sde.py::_backward_filter_ode`: diffrax import is lazy,
  raising a clean `ImportError` pointing at the extra when missing
  (same pattern as ete3 in `io.newick` and trimesh in `shape_kernels`).
- `tests/test_core_dependencies.py`: diffrax and equinox added to the
  forbidden-in-core list so a future accidental import gets caught.
- `tests/test_prebuilt_bffg_sde.py`: tolerances tightened by 2-3 orders
  of magnitude (1e-3/1e-4 → 1e-5/1e-6) reflecting the integrator
  precision boost. Added a missing-diffrax regression test via
  `monkeypatch` on `builtins.__import__`.
- `README.md`: documents the new extra.

Measured numerical agreement between ODE (B=β=0) and closed-form on the
1D Brownian smoke case:
  H_0 abs err  : 1.8e-7  (was ~1e-4 with odeint)
  F_0 abs err  : 1.8e-7
  c_0 abs err  : 6.0e-7
  forward_guided trajectory & logpsi abs err: 0.0 (downstream of better F_t/H_t)

178 tests total, all green.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Restructure the docs/source/notebooks/ progression so the reader-facing
narrative is built around inferential goals rather than around the names
of the underlying algorithms:

- 03 phylo_mean: closed-form point estimate (unchanged content; updated
  cross-references in the recap).
- 04 phylo_bayesian (rename of 04_gaussian_bffg): the same ancestral-state
  problem as 03, upgraded to a full Bayesian posterior. Gaussian BFFG is
  the closed-form smoother that delivers it; the Bayesian framing is the
  headline. Hyperparameters are taken as known here.
- 05 gaussian_mcmc (new placeholder): joint MCMC over latent states and
  hyperparameters under linear-Gaussian transitions. BFFG sits inside
  the kernel as an exact marginal-likelihood evaluator.
- 06 gaussian_mle (new placeholder): the MLE / MAP counterpart of 05 —
  jax.grad through the BFFG marginal likelihood plus optax. A v3-specific
  capability that the v2 mutable-tree design could not support.
- 07 sde_mcmc (rename of 05_sde_bffg_bridges): same MCMC machinery as 05
  but with non-linear SDE transitions. BFFG becomes an approximation that
  Metropolis corrects via logpsi.

The 2x2 conceptual matrix is now Gaussian vs SDE on one axis and
Bayesian (MCMC) vs MLE (grad) on the other; the SDE-MLE square is
deferred as future work.

Cross-references in notebooks 01-03 are updated to point at the new
filenames.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The previous logw formula compared a normalized Gaussian density at the
parent's sampled value to logU at the same point, evaluating two
quantities that don't match Theorem 14 of van der Meulen & Sommer (2025).
Rewrite as

  log w(x) = log φ(H⁻¹F; μ(x), Q(x)+H⁻¹) − log φ(H⁻¹F; Φx+β, Q̃+H⁻¹)

with μ(x)=x, Φ=I, β=0, and Q̃ = ℓ·a(v_T) the auxiliary linearised at
the canonical posterior mean (matching gaussian_up). In the pure
linear-Gaussian limit Q(x) = Q̃ so the two densities coincide and
sum(logw) ≡ 0 — the paper's p.16 remark, now verified at machine
precision. State-dependent a produces a genuine non-zero correction.

Drop the obsolete (c_0, F_0, H_0) intermediates and the now-unused
logphi_H / logU imports. The sampling code itself is unchanged.

Add three regression tests: linear case sum(logw)=0 exactly; single-edge
hand-computed Theorem 14 matches the prebuilt to atol=1e-6; nonlinear a
gives sum(logw) with std > 0.1.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…tion

Replace the placeholder with a complete (23 cells) notebook framing BFFG
as a guided proposal for parameter inference, not just a smoother.
Structure:

  1. Setup + forward-simulate ground truth (depth-4 binary, 31 nodes,
     16 leaves; x_root = 0 fixed)
  2. Closed-form marginal log p(y|θ) via leaf joint Gaussian
  3. BFFG-guided forward map (init_gaussian_leaves → gaussian_up →
     gaussian_down_conditional) returning (x, sum(logw))
  4. Empirical Theorem 14 collapse: sum(logw) ≡ 0 across 500 z draws
  5. BFFG-guided MCMC kernel: pCN on z + RW on log θ, target =
     log g_r(0;θ) + sum(logw) + log π(log θ); RW_SCALE tuned to ~46%
  6. Multi-chain joblib (loky) + hand-written Gelman-Rubin
  7. Triple verification: BFFG-MCMC histogram, grid-marginalised
     analytic posterior, ground truth
  8. Nonlinear teaser: state-dependent a → sum(logw) acquires real
     spread, previews 07
  9. Recap + references

End-to-end run verified: pCN acceptance = 1.0000 (Theorem 14 signature),
R̂ ≈ 1.001, BFFG-MCMC means/medians/95% CIs match analytic grid to 3
decimal places for both σ² and τ².

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…iable tree

Replace the placeholder with a 22-cell notebook that demonstrates
hyperparameter MLE for the same linear-Gaussian model as notebook 05,
this time via maximum-likelihood instead of MCMC.

The route is the EM identity

  ∇_θ log p(y | θ) = E_{x ~ p(x|y, θ)}[∇_θ log p(x, y | θ)]

with `gaussian_down_conditional` as an *exact* posterior sampler in the
linear-Gaussian regime (Theorem 14 collapse, verified in 05).
`stop_gradient` on the BFFG sample is what turns the autodiff into the
EM gradient instead of the (different) reparameterisation gradient —
one line, but load-bearing.

Structure:

  1. Setup + simulate (same seed as 05 → same data)
  2. Complete-data log-likelihood: per-edge prior + per-leaf obs
  3. BFFG-as-posterior-sampler
  4. EM score estimator with stop_gradient
  5. Gradient parity vs closed-form marginal log-lik (M=500: |diff| < 0.1)
  6. Adam training loop (LR=0.05, 400 steps), loss + trajectory plots
  7. EM-MLE vs closed-form MLE vs grid argmin — all three agree to 1%
  8. Bonus: jax.grad w.r.t. edge_length — autodiff traces the whole tree
  9. Sidebar: MAP is +1 line of log-prior

End-to-end verified: EM-MLE recovers (σ²=0.41, τ²=0.22), matching the
closed-form MLE (0.41, 0.20) and grid argmin (0.40, 0.21) within MC
noise. The shared offset from truth (0.5, 0.1) is the finite-sample MLE
bias on 16 observations, not the estimator.

Add `optax>=0.2` to the notebooks extras (with the rest of the
previously-uncommitted notebooks block) and refresh uv.lock.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
prebuilt/:
- Merge bffg_gaussian.py + bffg_sde.py into a single bffg.py.
- Drop _gaussian_density.py: inline logphi via
  jax.scipy.stats.multivariate_normal.logpdf; canonical_leaf_messages
  becomes a private helper in bffg.py.
- Delete lddmm.py and shape_kernels.py: dead in code, tests, and
  notebooks. Removes the prebuilt-shape extra (only served mesh_sphere).

API:
- Rename symmetric_topology arg height -> depth across tests and the
  docs landing page.

Env:
- Pin equinox<0.13.8 in prebuilt-bffg: 0.13.8 ships a broken wheel
  referencing a missing equinox.internal._loop submodule.

Docs:
- Bring in the sphinx scaffold (conf.py, index.rst, api/index.rst,
  Makefile) and per-prebuilt autosummary entries; drop the LDDMM /
  Shape-kernels rubrics.
- Refresh notebooks 01-07 against the v3 API.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- extras: merge prebuilt-bffg+prebuilt-mcmc into a single `prebuilt`,
  rename `notebooks` to `notebook`, drop the standalone `docs` extra
  and roll sphinx tooling into the `dev` group
- pin the equinox/jaxtyping/wadler-lindig/optimistix chain under
  `prebuilt` so `uv sync` no longer resolves a broken diffrax env
- setuptools-scm: tag_regex + fallback_version for the v3 line
- add ruff lint+format config and a pre-commit pipeline (ruff,
  whitespace, eof, yaml/toml, merge-conflict, large-file)
- reformat the whole repo through ruff format as the new baseline
- new `.github/workflows/docs.yaml`: strict sphinx build on push/PR,
  GitHub Pages deploy on push to v3
- rename root `makefile` → `Makefile`; move `docs/report/` out of
  the package tree and gitignore it
- README + CLAUDE.md: drop the dead plan-doc link, fix the wrong
  `prebuilt-shape` extra, document the new install commands
- requirements/*.txt + environment.yml for non-uv install paths
- docs/Makefile: route sphinx-build through uv run (macOS PATH fix)
- prebuilt/bffg: refactor
- core/{dispatch,sweep}: updates
- prebuilt/{__init__,sde}: tweaks
- notebooks 01/05/06/07: refreshed outputs
- tests: expand bffg_sde + sweep_up
- README, .gitignore updates

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…onsolidation

Packaging
- Drop the prebuilt / io extras (diffrax and jax_tqdm are no longer used);
  the core lib imports on bare jax + numpy. Simplify the sync target to
  `uv sync --group dev`. Delete legacy requirements/ and environment.yml.

Layout
- Move Newick parsing into core/builders.py and delete hyperiax/io/.
- Add hyperiax/utils/ (pure-JAX RK4 / Euler-Maruyama solvers), used by bffg
  in place of diffrax.
- Delete hyperiax/prebuilt/{mcmc,sde}.py; NumPyro is now the recommended
  MCMC path (see notebooks 05 / 06).

Core dispatch
- One up-sweep path: _up_dispatch walks the segment layout (pbuckets) for
  any topology. The equal-degree dense-vmap path is gone, and Topology
  drops gather_child_idx / level_non_leaf_indices with it.
- Add Children.map(fn) — a segment-preserving per-child vmap so a non-linear
  per-child transform feeds the same reduction surface as a plain field.
  ChildrenAxis.flat is now exposed; writes_children supported on any topology.

BFFG rewrite (prebuilt/bffg.py)
- Rename gaussian_* -> discrete_*, sde_* -> continuous_*.
- Discrete backward filter implements the full linear-Gaussian auxiliary
  (Phi x + beta, Q) per Theorem 14 §6.1; discrete forward-guiding weight
  matches Theorem 14.3.
- Continuous backward filter caches the per-edge (H, F) trajectory via
  writes_children; vertex message lives in new prec_v / ptnl_v fields.
- Per-edge linearisation anchors (Algorithm 3 §7.1): each node stores
  anchor (+ anchor_pa on continuous edges); discrete_refine_anchor /
  continuous_refine_anchor refine to the BFFG posterior mean.
- Two-anchor sigma-tilde interpolation along continuous edges (BF + FG
  consistent).
- Track canonical-message log_norm so tree.log_norm[root] is the marginal
  log-evidence.

phylo_mean migrated to children.map; now runs on ragged trees.

Tests
- Consolidate into one test_sweep_up.py covering both topologies (the
  unified path makes them indistinguishable to user code).
- Drop the legacy gaussian_/sde_ BFFG tests; drop mcmc/sde tests; rename
  test_io_newick -> test_newick; add test_utils_{ode,sde}; adjust
  test_topology for the removed fields; flip test_prebuilt_phylo_mean's
  unequal-degree-raises check to an unequal-degree-correctness check.

CI + docs scaffold updated.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- hyperiax/prebuilt/bffg.py: module-level docstring + Google-style
  Args/Returns/Notes for all 12 public symbols (discrete_/continuous_
  schema, init_*_tree, *_bf_sweep, *_forward_sweep, *_fg_sweep,
  *_refine_anchor). No semantic changes — autodoc-friendly polish.
- docs/source/api/index.rst: add discrete_refine_anchor and
  continuous_refine_anchor.
- Rename docs/source/notebooks/07_sde_mcmc.ipynb -> 07_sde_bffg.ipynb;
  update the toctree (notebooks/index.rst) and the inline references in
  notebooks 03 / 04 / 05 / 06.

Verified: pytest -q (164 passed), ruff check hyperiax/ tests/ (clean),
make -C docs html SPHINXOPTS=-W (build succeeded).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The repo was tracking both `Makefile` (uppercase) and `makefile` (lowercase)
as separate paths even though they point to the same physical file on the
case-insensitive default macOS / Windows filesystems. On Linux CI this
manifested as a duplicate entry, and locally it caused `git status` to
permanently show `M Makefile` because the uppercase index entry held stale
content while the lowercase entry tracked the up-to-date file.

Keep only `makefile` (lowercase) — it already has the current `uv sync
--group dev` content from the v3 release prep.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@gefanyang gefanyang changed the title v3: ground-up rewrite — immutable Tree, unified segment dispatch, BFFG refit v3: migrate main to immutable Tree, unified sweeps, and BFFG refit May 31, 2026
@gefanyang gefanyang merged commit b04864d into ComputationalEvolutionaryMorphometry:main Jun 2, 2026
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant