Skip to content

Sparse mode performance, SparseHist input dispatch, and low-memory --noHessian mode#129

Open
bendavid wants to merge 18 commits intoWMass:mainfrom
bendavid:sparsedev4
Open

Sparse mode performance, SparseHist input dispatch, and low-memory --noHessian mode#129
bendavid wants to merge 18 commits intoWMass:mainfrom
bendavid:sparsedev4

Conversation

@bendavid
Copy link
Copy Markdown
Collaborator

@bendavid bendavid commented Apr 9, 2026

Adds support for sparse histogram input in the TensorWriter.

Significant performance optimizations for sparse mode in
both the TensorWriter and in the Fitter.

Performance optimizations should also give a factor of ~2
improvement for dense mode in the Fitter for large models.

Reworks --noHessian into a true low-memory mode: the dense
[npar, npar] covariance matrix is no longer allocated, the
Fitter initialization no longer has quadratic hot spots on large
external Hessians, and the postfit edmval and POI/NOI
uncertainties are still reported via Hessian-free conjugate
gradient solves. On the jpsi calibration tensor (108k params,
329M-nnz external sparse Hessian) Fitter.__init__ drops from
~370 s to ~5 s and peak memory is reduced by ~94 GB.

Note that this depends on WMass/wums#25

Three related groups of commits:

Group A — TensorWriter sparse dispatch (7 commits)

  • add option to treat input systematic histograms as difference with respect to nominal
  • add test for sparse mode
  • Support scipy sparse array inputs in TensorWriter and add as_difference option
  • Add multi-systematic dispatch in add_systematic and use wums.SparseHist
  • Add external likelihood term (gradient + hessian) support
  • Add efficient SparseHist multi-systematic dispatch in TensorWriter
  • Speed up TensorWriter for large multi-systematic SparseHist workloads

Group B — Sparse fast path performance (5 commits)

Up to ~20× HVP speedup on the jpsi calibration tensor (76800 bins,
108334 params, 62M-nnz logk, 329M-nnz external sparse Hessian):
HVP 6380 → 320 ms, loss+grad 3010 → 160 ms.

  • inputdata, parsing: prep for sparse fast path with CSR matvec
    — canonicalize sparse index ordering at load time, pre-build a
    CSRSparseMatrix view of logk, add --hvpMethod and
    --noJitCompile CLI options.

  • fitter: dynamic loss/grad/HVP wrappers with jit_compile +
    hvpMethod
    — replace class-level @tf.function decorators with
    instance-level wrappers built dynamically in _make_tf_functions,
    so jit and HVP autodiff mode can be controlled per-fit. Note that
    fwdrev HVP is intentionally never jit-compiled because
    tf.autodiff.ForwardAccumulator does not propagate JVPs through
    XLA-compiled subgraphs.

  • fitter: sparse fast path uses CSR matmul, no dense
    [nbins, nproc]
    — reformulate the sparse branch of
    _compute_yields_noBBB to use tf_sparse_csr.matmul for the
    inner contraction logk @ theta (~8× faster per call than the
    equivalent gather + segment_sum) and never materialize the dense
    [nbins, nproc] grid in the NLL/grad/HVP path. Also forces
    jit_compile=False in sparse mode (CSR matmul has no XLA kernel)
    and falls back to revrev when fwdrev is requested in sparse mode.

  • fitter: external sparse Hessian via CSR matmul — switch the
    external sparse-Hessian likelihood term to use CSR matmul. The
    registered gradient of sm.matmul is itself a single
    sm.matmul, so reverse-over-reverse autodiff no longer
    rematerializes a 2D gather/scatter chain in the second-order
    tape. On the jpsi 329M-nnz prefit Hessian this was the dominant
    HVP cost.

  • rabbit_fit, setup.sh: enable XLA multi-threaded Eigen on CPU
    set XLA_FLAGS=--xla_cpu_multi_thread_eigen=true so XLA's CPU
    emitter uses Eigen's multi-threaded routines for the dense
    matmuls jit_compile=True generates. ~1.3× speedup on dense
    large-model HVP/loss+grad on a many-core system. Set both in
    setup.sh (for sourced shells) and at the very top of
    bin/rabbit_fit.py before any TF import (for direct invocation).

Group C — --noHessian low-memory mode (5 commits)

Reworks --noHessian so it no longer allocates the dense
covariance matrix while still producing edmval and POI/NOI
uncertainties. The pipeline now runs end-to-end at O(npar)
memory for jpsi-scale problems where the full cov is infeasible
(~94 GB for 108k parameters in float64).

  • fitter, rabbit_fit: skip dense cov allocation under
    --noHessian
    — split prefit_covariance into a vector form
    (prefit_variance) and a tf.linalg.LinearOperatorDiag
    wrapper; always allocate a length-npar var_prefit vector
    and only allocate the dense self.cov tf.Variable when the
    postfit Hessian will actually be computed. defaultassign,
    randomize_parameters, and load_fitresult all handle
    self.cov is None. The rabbit_fit CLI now explicitly rejects
    every flag that would need the postfit covariance
    (--doImpacts, --computeVariations, --saveHists without
    --noChi2, --computeHistErrors[PerProcess],
    --computeHistCov, --computeHistImpacts,
    --computeHistGaussianImpacts, --externalPostfit).

  • fitter: speed up Fitter.__init__ on large external sparse
    Hessians
    — two structural fixes that reduce
    Fitter.__init__ on the jpsi calibration tensor from ~370 s
    to ~20 s. Replace the per-parameter np.where lookup of
    external term parameter names (O(n²), ~150 s) with a single
    dict lookup. Detect already-canonical sparse-Hessian indices
    and skip np.lexsort (~54 s) when the input is already
    sorted.

  • unify sparse-Hessian IO path; sort at write time, drop
    reorder calls
    — the TensorWriter now sorts the external
    sparse-Hessian indices into canonical row-major order at
    write time (matching hlogk_sparse / hnorm_sparse), so the
    reader and Fitter can use the same makesparsetensor helper
    and drop their defensive tf.sparse.reorder calls.
    Fitter.__init__ drops further to ~5 s on jpsi.

  • fitter: Hessian-free CG solve for is_linear case under
    --noHessian
    — the purely-quadratic is_linear fast path in
    Fitter.minimize() used to build the dense Hessian and do a
    Cholesky solve. Under --noHessian it now solves the normal
    equation H @ dx = -grad via scipy.sparse.linalg.cg with a
    LinearOperator backed by loss_val_grad_hessp, touching
    only O(npar) memory.

  • fitter, rabbit_fit: edmval + POI/NOI uncertainties under
    --noHessian
    — compute edmval and the POI+NOI rows of the
    covariance matrix via Hessian-free CG solves (one for edmval,
    one per POI/NOI index). No dense Hessian or covariance is
    ever materialized. The POI+NOI diagonal entries populate the
    parms_variances vector passed to add_parms_hist; other
    nuisances keep NaN variances, signalling that their postfit
    uncertainty was not computed. Verified against the Cholesky
    path on the small test tensor: edmval and POI/NOI
    uncertainties match to full precision.

bendavid and others added 12 commits April 9, 2026 08:16
…ce option

Add `as_difference` parameter to `add_systematic` to interpret input histograms
as differences from nominal. Add full scipy sparse array support for `add_process`
and `add_systematic`: in sparse mode, norm is stored as flat CSR and logk is
computed only at nonzero positions, avoiding full-size dense intermediates.
Extend test_sparse_fit.py to cover all modes including scipy sparse inputs.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
add_systematic now detects extra axes in the input histogram beyond the
channel axes (or via an explicit syst_axes argument) and books one
systematic per bin combination on those extra axes, with auto-generated
names from the bin labels. Works for hist inputs as well as for SparseHist
inputs from wums, in both dense and sparse TensorWriter modes.

The local SparseHist implementation has been moved to wums.sparse_hist and
is re-exported here for convenience. SparseHist now always uses the with-flow
layout internally, and the writer extracts either the with-flow or no-flow
representation depending on the channel's flow setting.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
TensorWriter.add_external_likelihood_term accepts a 1D hist for the
gradient and a 2D hist (or wums.SparseHist) for the hessian, both indexed
by hist.axis.StrCategory axes whose bin labels identify the parameters.
Both grad and hess (when provided together) must use the same parameter
list in the same order; the matrix is indexed by a single parameter list.
Multiple terms can be added with distinct names. Sparse hessians via
SparseHist preserve sparsity through the writer and the fit.

The terms are serialized under an external_terms HDF5 group, loaded back
in FitInputData, and resolved against the full fit parameter list (POIs +
systs) at Fitter init. Fitter._compute_external_nll adds an additive
g^T x_sub + 0.5 x_sub^T H x_sub contribution to the NLL, fully
differentiable through TF autodiff so all existing loss_val_grad and
hessian methods pick it up automatically.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The generic _get_systematic_slices loop calls h[slice_dict] once per
combination on the extra (systematic) axes, which for SparseHist input
is O(nnz) per slice and prohibitively slow when there are many extra
bins (e.g. ~108k corparms over a ~31M nnz SparseHist would take hours).

Add a fast path that pre-extracts the with-flow flat representation
once, computes a linear systematic index from the extra-axis
coordinates, sorts globally, and then yields contiguous per-bin runs.
Empty combinations yield an empty SparseHist over the kept axes so the
caller can still book the corresponding systematic name (allowing it
to be constrained by an external term even when the template variation
is identically zero). This is O(nnz log nnz) total instead of O(nnz)
per slice, and supports both single and asymmetric (up/down) inputs.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Several independent optimizations to the writer + write() path. On a
realistic 2-channel jpsi calibration tensor with ~108k corparm
systematics and a 330M-nnz external hessian, total wall time drops
from ~4m30s to ~1m13s.

1. Vectorized SparseHist multi-syst dispatch in add_systematic.
   New _add_systematics_sparsehist_batched does all per-entry math
   (channel flat index, norm lookup, sign-flip-protected logk) once
   over the full ~25M-entry array, partitions by linear systematic
   index via a single argsort + searchsorted, and bulk-inserts
   per-syst (indices, values) directly into dict_logkavg /
   dict_logkavg_indices. Empty bin combinations still get an entry
   and a corresponding book_systematic call so they appear in the
   fit parameter list and can be constrained externally. Triggered
   when the input is a single SparseHist with extra axes plus
   mirror=True, as_difference=True, no add_to_data_covariance.
   Per-channel booking goes from ~93s to ~9s.

2. Pre-allocate sparse assembly buffers in write(). The previous
   loop grew norm_sparse_* and logk_sparse_* via np.ndarray.resize
   once per (channel, process, syst), which is O(N^2) total because
   each resize allocates a new buffer and copies all elements. A
   quick first pass over the dict structures now computes the total
   nnz so the buffers can be allocated once and filled in place.

3. Replace list.index() with a dict in get_groups,
   get_constraintweights, get_noiidxs. The old code did
   systs.index(name) once per group member, giving O(nsysts*nmembers)
   behaviour: with 108k systs all in a single corparms group this was
   the dominant cost of write(), eating ~75 seconds.

4. Skip the unnecessary to_flat_csr sort in
   add_external_likelihood_term. For SparseHist hess input, access
   _flat_indices/_values directly and recover (rows, cols) via
   np.divmod, instead of going through to_flat_csr(flow=False) which
   sorts ~330M entries we then never read in order. ~30s saved.

5. Switch h5py compression from gzip to Blosc2 LZ4 in
   h5pyutils_write. ~5x faster on integer arrays at slightly better
   compression ratios. h5pyutils_read imports hdf5plugin so the
   filter is registered for read-back.

6. Add a compress=True parameter to writeFlatInChunks and have
   writeSparse pass compress=False for the values payload of an
   explicitly sparse tensor. Densely packed nonzero floats from real
   physics tensors compress only ~4% at 5x the write cost, so the
   compression is pure overhead there. Index buffers continue to
   compress (~10x ratio with negligible overhead).

Also adds a regression test in test_multi_systematic.py that
constructs a multi-syst SparseHist and asserts the batched fast path
produces bit-identical hnorm/hlogk to per-syst manual booking, with
log_normal + as_difference=True and entries that exercise the
logkepsilon sign-flip fallback.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Three preparatory changes that the fitter changes in following
commits will rely on:

  * inputdata.py: in sparse mode, call tf.sparse.reorder on norm and
    logk at load time to canonicalize their indices into row-major
    order. The fitter sparse fast path reduces nonzero entries via
    row-keyed reductions, which want coalesced memory access on the
    sorted indices.

  * inputdata.py: pre-build a tf.linalg.sparse.CSRSparseMatrix view
    of logk so the fitter can use sm.matmul (a multi-threaded CSR
    kernel) for the inner contraction logk @ theta. SparseMatrixMatMul
    has no XLA kernel, so any tf.function calling it must be built
    with jit_compile=False; the fitter handles this in sparse mode.

  * parsing.py: add --hvpMethod {revrev,fwdrev} to choose the
    autodiff mode for the Hessian-vector product, and --noJitCompile
    to disable XLA jit_compile (on by default in dense mode).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Replace the class-level @tf.function decorators on loss_val,
loss_val_grad, and loss_val_grad_hessp_{fwdrev,revrev} with
instance-level wrappers built dynamically in _make_tf_functions()
at construction time. This lets jit_compile and the HVP autodiff
mode be controlled per-fit via --jitCompile / --hvpMethod without
class-level redefinition.

  * --jitCompile (on by default): wraps loss/grad and revrev HVP
    with tf.function(jit_compile=True). The fwdrev HVP wrapper is
    intentionally NOT jit-compiled because tf.autodiff.Forward-
    Accumulator does not propagate JVPs through XLA-compiled
    subgraphs (the JVP comes back as zero), regardless of inner/
    outer placement of jit_compile.

  * --hvpMethod {revrev,fwdrev}: selects which underlying HVP
    wrapper is bound to self.loss_val_grad_hessp.

The dynamic wrappers are also stripped and rebuilt in __deepcopy__,
since the FuncGraph state held by an already-traced tf.function
cannot be deepcopy'd. _compute_loss is collapsed to a one-liner
since its only job is to dispatch to _compute_nll.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Reformulate the sparse branch of _compute_yields_noBBB so that the
NLL/grad/HVP path never materializes the dense [nbinsfull, nproc]
intermediate, and uses tf.linalg.sparse's CSR SparseMatrixMatMul
for the dominant inner contraction logk @ theta. The CSR kernel is
multi-threaded and ~8x faster per call than the equivalent
gather + unsorted_segment_sum that the previous form lowered to
under TF on CPU.

Changes:

  * _compute_yields_noBBB takes a new compute_norm flag. The dense
    [nbinsfull, nproc] normcentral grid is only built when an
    external caller actually wants per-process yields, or when
    binByBinStat "full" mode needs them for the analytic beta
    solution. The NLL/grad/HVP path passes compute_norm=False.

  * Sparse branch: replace tf.sparse.sparse_dense_matmul(logk, ...)
    with tf_sparse_csr.matmul(logk_csr, ...) on the pre-built CSR
    view from inputdata.py.

  * Sparse branch: collapse to per-bin yields via
    tf.math.unsorted_segment_sum on the modified sparse values
    keyed by bin index, equivalent to but cheaper than
    tf.sparse.reduce_sum at this scale.

  * _compute_yields_with_beta plumbs need_norm correctly so the
    bbb-lite path doesn't pay for the dense materialization.

  * _expected_yield_noBBB explicitly passes compute_norm=False.

  * _make_tf_functions: SparseMatrixMatMul has no XLA kernel, so
    force jit_compile=False on all wrappers in sparse mode
    regardless of the user's --jitCompile setting.

  * _make_tf_functions: tf.autodiff.ForwardAccumulator cannot
    trace tangents through SparseMatrixMatMul (no JVP rule for the
    CSR variant), so when --hvpMethod=fwdrev is requested in sparse
    mode, fall back to revrev with a warning.

Profile on the jpsi calibration tensor (76800 bins, 108334 params,
62M-nnz logk): HVP per call drops from ~6400 ms to ~320 ms (~20x
speedup), loss+grad from ~3000 ms to ~160 ms.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Switch the external sparse-Hessian likelihood term to use
tf.linalg.sparse's CSR SparseMatrixMatMul instead of an element-wise
gather-based 0.5 x^T H x form. The CSR matmul kernel is multi-
threaded, and crucially its registered gradient is itself a single
sm.matmul call, so reverse-over-reverse autodiff no longer
rematerializes a 2D gather/scatter chain in the second-order tape.
On large external-Hessian problems this was the dominant HVP cost.

Changes:

  * Fitter.__init__ external_terms loop: replace the "hess_sparse"
    (rows, cols, vals) tuple with a "hess_csr" CSRSparseMatrix view
    of the canonically-sorted SparseTensor, built once per term.

  * _compute_external_nll: dispatch on "hess_csr" instead of
    "hess_sparse" and compute 0.5 * x_sub^T (H @ x_sub) via
    tf_sparse_csr.matmul.

Profile on the jpsi calibration tensor (329M-nnz prefit external
Hessian on 108332 of the 108334 fit parameters): the closed-form
external HVP path that previously dominated the second-order tape
collapses to a single CSR matvec per HVP call, contributing
negligibly to the per-call cost.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Set XLA_FLAGS=--xla_cpu_multi_thread_eigen=true so XLA's CPU emitter
uses Eigen's multi-threaded routines for the dense linear-algebra
ops generated by jit_compile=True. This is a free win on dense
fits with no downside on sparse mode (where the dominant ops have
no parallel CPU kernel anyway). Measured ~1.3x speedup on dense
large-model HVP and loss+grad on a many-core system:

  default                                 HVP 51.1 ms  L+G 31.2 ms
  --xla_cpu_multi_thread_eigen=true       HVP 39.1 ms  L+G 23.0 ms

The flag is set in two places:

  * setup.sh: exported when users source the rabbit setup script.
    Append-only so any user-set XLA_FLAGS survive.

  * bin/rabbit_fit.py: also set programmatically at the very top
    of the script (before any TF import) so users who launch
    rabbit_fit.py directly without sourcing setup.sh still get
    the speedup. Same append-only logic.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
bendavid and others added 2 commits April 9, 2026 17:08
The Fitter previously always allocated a dense [npar, npar]
covariance tf.Variable, regardless of whether the postfit Hessian
would actually be computed. For very large parameter counts this
is infeasible (~94 GB for 108k parameters in float64) and
prevents --noHessian from being usable as a low-memory mode.

Changes:

  * Fitter.__init__: read options.noHessian into self.compute_cov.
    Always allocate the new self.var_prefit vector tf.Variable
    (length npar). Only allocate the dense self.cov tf.Variable
    when compute_cov=True; otherwise self.cov is None.

  * prefit_covariance() is split into:
      - prefit_variance(unconstrained_err): returns the per-parameter
        variance vector
      - prefit_covariance(unconstrained_err): returns a
        tf.linalg.LinearOperatorDiag wrapping the variance vector,
        so callers that want a matrix-like interface get one without
        ever materializing the dense [npar, npar] form. Callers that
        actually need a dense tensor can call .to_dense().

  * defaultassign now updates var_prefit always and cov only when
    it exists.

  * randomize_parameters samples from var_prefit when cov is None
    (the existing diagonal fast path was already correct for the
    prefit case; only the source of the variances needed to change).

  * load_fitresult raises a clear error if an external covariance
    is provided when self.cov is None.

  * bin/rabbit_fit.py:
      - Prefit add_parms_hist reads ifitter.var_prefit instead of
        tf.linalg.diag_part(ifitter.cov), which would fail under
        --noHessian.
      - The --computeVariations prefit branch uses
        prefit_covariance(unconstrained_err=1.0).to_dense() to feed
        the temporary cov assign (since prefit_covariance now
        returns a LinearOperator).
      - The early --noHessian guard now rejects every flag that
        actually requires the postfit covariance: --doImpacts,
        --computeVariations, --saveHists (without --noChi2),
        --computeHistErrors[PerProcess], --computeHistCov,
        --computeHistImpacts, --computeHistGaussianImpacts, and
        --externalPostfit.

Verified on the small test tensor: under --noHessian fitter.cov
is None, var_prefit is a length-13 vector, and a plain fit
converges. The incompatible-flag combinations raise clean errors
with a single descriptive message.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Two structural fixes that reduce Fitter.__init__ from ~370 s to
~20 s on the jpsi calibration tensor (108k parameters, 329M-nnz
external sparse Hessian).

  * Replace per-parameter np.where lookup with a single dict.
    The old code did np.where(parms_str == p) for each of the
    external term's parameters against the full ~10^5-element
    parameter list — quadratic, ~150 s. Build a name->index
    dict once and look each parameter up in O(1).

  * Detect already-sorted sparse-Hessian indices and skip
    np.lexsort. tf.SparseTensor / sparse_tensor_to_csr_sparse_matrix
    require canonical row-major order. The TensorWriter does not
    guarantee this for sparse-Hessian external terms, but in
    practice the indices are often already sorted (e.g. when the
    source SparseHist has its _flat_indices in flat-index order
    and they get split via np.divmod(flat, n), which preserves the
    ordering). A single vectorized O(nnz) check skips the much
    slower np.lexsort (~54 s on 329M nnz) when the data is already
    canonical, falling back to lexsort otherwise.

The remaining ~13 s in Fitter.__init__ on the jpsi tensor is the
unavoidable cost of materializing the 329M-nnz arrays into TF
tensors (np.stack of the [nnz, 2] index buffer, tf.constant on
both index and value buffers, and the CSR conversion proper).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Collaborator

@davidwalter2 davidwalter2 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First bunch of comments

rabbit/fitter.py Outdated
# one common regularization strength parameter
self.tau = tf.Variable(1.0, trainable=True, name="tau", dtype=tf.float64)

# External likelihood terms (additive g^T x + 0.5 x^T H x contributions
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest to put this block into a standalone function that takes external_terms and dtype as argument and returns "external_terms" object. Suggest to put it in a new file e.g. external_likelihood.py

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok yes the makes sense actually

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in fad47bc -- the external-term construction now lives in a new rabbit/external_likelihood.py module, and Fitter.__init__ calls external_likelihood.build_tf_external_terms(self.indata.external_terms, self.parms, self.indata.dtype). The matching scalar evaluator and the h5 reader live in the same module.

)
parser.add_argument(
"--noJitCompile",
dest="jitCompile",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using the dest keyword can be confusing and so far we managed to do without it, could we keep that convention?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yah ok let me see if that can be avoided

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in fad47bc -- replaced --noJitCompile with a tri-state --jitCompile {auto,on,off} (default auto), no dest keyword needed.

rabbit/fitter.py Outdated
# SparseMatrixMatMul has no XLA kernel, so any tf.function that
# uses it (via _compute_yields_noBBB in sparse mode) cannot be
# jit-compiled. Force jit_compile off in sparse mode regardless
# of the user's --jitCompile setting.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no "--jitCompile" setting but the option is "--noJitCompile", maybe got confused by the dest keyword?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost certainly because it was --jitCompile before and then got swapped when the default changed.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in fad47bc -- the option is now --jitCompile auto|on|off and Fitter.__init__ reads it as a string. Backwards-compatible with True/False from programmatic callers.

rabbit/fitter.py Outdated
# uses it (via _compute_yields_noBBB in sparse mode) cannot be
# jit-compiled. Force jit_compile off in sparse mode regardless
# of the user's --jitCompile setting.
jit = self.jit_compile and not self.indata.sparse
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest to add a warning with something like"
"jit_compile set but input data is sparse and jit compilation will be disabled."

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is the default it probably doesn't make sense to have a warning. Might make sense to have a default Non/ "auto" option for jitCompile rather than just on/off

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in fad47bc -- the new tri-state --jitCompile defaults to auto, which silently disables jit in sparse mode. The warning fires only when the user explicitly passes --jitCompile on while running on sparse input, since at that point they've asked for something XLA can't actually do.


return ln, lc, lbeta, lpenalty, beta

def _compute_external_nll(self):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could also go in a standalone function, which would be more modular and reusable which is the direction I think we should go. Arguments would be terms, params.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in fad47bc -- moved into rabbit.external_likelihood.compute_external_nll(terms, x, dtype). Fitter._compute_external_nll is now a one-line dispatch.

"(forward-over-reverse, via tf.autodiff.ForwardAccumulator) is an alternative.",
)
parser.add_argument(
"--noJitCompile",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does the option interplay with --eager? Do we need both options or can this also not be controlled with --eager? At a minimum --eager should also trigger no jit compile

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if --eager is used I think it will implicitly skip jitCompile because it won't even process the tf;functions

These are two different things though
tf function by itself switches from eager to graph mode. jit_compile recompiles the instructions within the graph (allowing things like add-multiply fusion like what a C++ compiler would do). Otherwise the graph internally still executes the tf operations one by one as written.

So yes we need both options and I think the current behaviour should be fine.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No code change here. --eager and --jitCompile control different things (eager mode skips graph building entirely, jit_compile is XLA fusion within an existing graph) so they remain orthogonal. --eager continues to bypass jit by virtue of skipping the tf.function wrappers.

# Write external likelihood terms
if self.external_terms:
ext_group = f.create_group("external_terms")
create_dataset(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are the names needed? They are already stored as the keys of the ext_group groups added below

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure...

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see what you mean now. Yes probably they don't need to be stored seperately indeed.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in fad47bc -- the writer no longer creates hexternal_term_names; the reader now iterates the external_terms h5 group's subgroups directly via ext_group.items().

for s in f["hexternal_term_names"][...]
]
ext_group = f["external_terms"]
for tname in names:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the following work?

for tname, tg, in ext_group.items()

It would be more pythonic IMO and no need for storing the names separately

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think that should work

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in fad47bc -- the reader now uses for tname, tg in ext_group.items() (in rabbit.external_likelihood.read_external_terms_from_h5) and the writer-side names list is gone.


self.axis_procs = hist.axis.StrCategory(self.procs, name="processes")

# Load external likelihood terms (optional).
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could also maybe go into a standalone function, but not sure about that

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yah probably. The init of this is already not very modular so maybe can leave that for subsequent refactoring.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok actually given the similar pattern in the fitter init I agree this can be modularized.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in fad47bc -- moved into rabbit.external_likelihood.read_external_terms_from_h5(ext_group). FitInputData calls self.external_terms = read_external_terms_from_h5(f.get("external_terms")).

bendavid and others added 3 commits April 9, 2026 18:51
Three coupled changes that simplify the IO path for the external
sparse Hessian and trim the rest of Fitter.__init__ on large
problems.

  * tensorwriter.py: sort the external sparse-Hessian indices into
    canonical row-major order at write time, matching what the
    writer already does for the sparse logk and norm tensors. Use
    a single ravel_multi_index + argsort. Add a fast-path that
    detects when the input is already canonical via a vectorized
    O(nnz) check and skips the sort entirely (typical when the
    source is a SparseHist built from a scipy CSR / CSC, which
    iterates row-major by definition).

  * inputdata.py: hess_sparse is now read via the same
    makesparsetensor() helper used for the sparse norm and logk,
    yielding a tf.sparse.SparseTensor directly. The previous code
    manually unpacked the indices into (rows, cols, vals) tuple
    form which forced an unnecessary numpy roundtrip downstream.
    The defensive tf.sparse.reorder calls on norm and logk are
    also dropped: the writer already sorts these into canonical
    order, so the reorder was redundant.

  * fitter.py external term loop: receive the SparseTensor and
    feed it straight to tf_sparse_csr.CSRSparseMatrix without an
    additional reorder step (the writer already canonicalized).
    This drops the in-Python np.lexsort + np.stack + tf.constant
    roundtrip on the 329M-nnz jpsi external Hessian.

Effect on jpsi calibration tensor (108k params, 329M-nnz prefit
sparse Hessian) Fitter.__init__:

           pre-IO unification:  20.5 s
           post:                 5.3 s

The TensorWriter side is ~5 s slower than before the sort was
added (the unavoidable cost of validating canonical order on
329M nnz). For SparseHist inputs from scipy CSR/CSC the data is
always pre-sorted so the validation succeeds and no additional
sort is performed.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The is_linear fast path in Fitter.minimize() previously built the
full dense [npar, npar] Hessian via loss_val_grad_hess() and did
a Cholesky solve. That's incompatible with --noHessian mode, which
is supposed to avoid the O(npar^2) allocation entirely.

Add an alternative Hessian-free branch that solves the normal
equation H @ dx = -grad iteratively with scipy's conjugate gradient
solver, feeding it a LinearOperator backed by loss_val_grad_hessp.
For a purely quadratic NLL the Hessian is positive-definite and CG
converges to machine precision in at most npar iterations (far
fewer for well-conditioned problems). The Cholesky path is still
used when compute_cov is True, since it has the lower per-call
cost when allocating the dense Hessian is already acceptable.

Verified against the Cholesky path on a constructed linear test
(chisqFit + Ones POI model + normal systematics): converged
parameter values match exactly; only the postfit uncertainty
slots differ, which is expected because the noHessian run does
not compute the covariance.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Under --noHessian we previously left edmval and all postfit
parameter uncertainties as NaN because the dense covariance is
never allocated. Now compute both via Hessian-free conjugate
gradient solves of

    H v = grad        ->  edmval = 0.5 * grad^T v
    H c_i = e_i       ->  c_i is the i-th row of cov

using scipy.sparse.linalg.cg with a LinearOperator backed by
self.loss_val_grad_hessp. No dense Hessian or covariance is ever
materialized; memory stays O(npar) instead of O(npar^2).

The cov rows are computed only for the parameters the user cares
about at this point -- the POIs (indices [0, npoi)) and the NOIs
(npoi + indata.noiidxs). Their diagonal entries give the postfit
standard deviations, which are populated into the parms_variances
vector passed to add_parms_hist. Non-POI / non-NOI nuisances keep
NaN variances, signalling that the postfit covariance for those
parameters was not computed.

Verified on the small test tensor: --noHessian now reports the
same edmval (7.068e-18 vs 7.068e-18) and the same POI/NOI
uncertainties (sig: 0.01436 +/- 0.01436, slope_signal: 2.01328
+/- 2.01328) as the full Cholesky path.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@bendavid bendavid changed the title Sparse mode performance and SparseHist input dispatch Sparse mode performance, SparseHist input dispatch, and low-memory --noHessian mode Apr 10, 2026
Address PR review feedback by collecting the three external-term
helpers into a single dedicated module:

  * read_external_terms_from_h5(ext_group) -- decode the on-disk
    "external_terms" h5 group into a list of raw per-term dicts.
    Iterate ext_group.items() directly so the writer no longer
    needs to store a separate "external_term_names" list.

  * build_tf_external_terms(terms, parms, dtype) -- promote the
    raw dicts to tf-side dicts (resolved indices, tf.constant
    grad, CSRSparseMatrix Hessian). Used by Fitter.__init__.

  * compute_external_nll(terms, x, dtype) -- evaluate
    sum_i (g_i^T x_sub + 0.5 x_sub^T H_i x_sub). Used by
    Fitter._compute_external_nll.

FitInputData, the Fitter init external-term loop, and
Fitter._compute_external_nll all collapse to one-line dispatches
into this module. The tensorwriter no longer writes the
hexternal_term_names dataset since the reader iterates the h5
subgroups directly.

Also rework the --jitCompile CLI option per review: replace the
"--noJitCompile" boolean+dest hack with a tri-state
"--jitCompile {auto,on,off}" with auto as the default. The Fitter
resolves the string in _make_tf_functions: "auto" silently
enables jit in dense mode and disables it in sparse mode (where
the CSR matmul kernels have no XLA implementation), "on" forces
it (warning + falling back to off in sparse mode), "off" disables
it unconditionally. Backwards compatibility: True/False are
still accepted from programmatic callers.

Smoke tested all 12 combinations (sparse/dense x auto/on/off
x cov/--noHessian); all converge.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants