Skip to content

Stochasticad constant rate#606

Open
RomanSahakyan03 wants to merge 7 commits into
SciML:masterfrom
RomanSahakyan03:stochasticad-constant-rate
Open

Stochasticad constant rate#606
RomanSahakyan03 wants to merge 7 commits into
SciML:masterfrom
RomanSahakyan03:stochasticad-constant-rate

Conversation

@RomanSahakyan03

@RomanSahakyan03 RomanSahakyan03 commented Jun 13, 2026

Copy link
Copy Markdown

Summary

Adds an optional StochasticAD extension that makes derivative_estimate work on expectations over jump-only ConstantRateJump (SSA) problems, following @ChrisRackauckas's suggestion to start with the constant-rate case.

New API (in ext/JumpProcessesStochasticADExt.jl, loaded only when StochasticAD + Distributions are present):

  • constant_rate_ssa_final_state(jprob, p; nmax, tspan, return_saturation=false) — the main, exact iterative SSA method, differentiable and supporting state-dependent rates.
  • poisson_count_final_state(jprob, p; tspan) — a closed-form helper for the narrow state-independent (homogeneous-Poisson) case, with guards against misuse.

Also: src/aggregators/ssajump.jl makes the SSA rate cache generic over the rate type; src/JumpProcesses.jl adds stubs + exports (no StochasticAD in src/); tests run in an isolated GROUP=StochasticAD environment.

Checklist

  • Appropriate tests were added
  • Any code changes were done in a way that does not break public API
  • All documentation related to code changes were updated
  • The new code follows the
    contributor guidelines, in particular the SciML Style Guide and
    COLPRAC.
  • Any new documentation only uses public API

Additional context

Why a reformulation (the key point). A naive Gillespie loop differentiated with StochasticAD returns 0 whenever the parameter acts through the event times — e.g. pure death, rate = μ·u, where a larger μ simply packs in more deaths before T. The while t < T termination branches on the primal time, so the event count (the parameter-dependent quantity) is frozen and the derivative is silently dropped.

constant_rate_ssa_final_state fixes this exactly by rewriting each "is there a next event before T?" as a tracked Bernoulli(1 - exp(-total_rate·(T-t))), run over a fixed-length loop, with the event time drawn from the truncated exponential, the reaction chosen by stick-breaking Bernoullis, and the state updated by multiplicative-select masking (no branching on a triple, no triple-valued indexing). This has the same trajectory distribution as the SSA (it is not a τ-leap approximation), recovers analytic gradients, and handles state-dependent rates and multiple channels.

Validation (vs. closed-form analytic expectations):

Test StochasticAD Analytic
pure death, d/dμ E[u(T)] -60.58 ± 0.43 -60.65
birth–death, d/dλ ≈ 0.81 0.86
birth–death, d/dμ ≈ -40 -41.1

On the generic rate cache. Per the request to let StochasticTriple rates flow through the existing Direct/SSAStepper path, build_jump_aggregation now sizes cur_rates by the inferred rate type instead of typeof(t) (via Base.promote_op, so it never executes the rate at build time; Float64 behavior is unchanged and the existing Constant Rate/SSA tests pass). Honest caveat: this change alone does not produce correct gradients — the stock while t < T boundary still drops the event-count derivative. Correct gradients come from the reformulation above.

Scope / limitations (this first step). ConstantRateJump only (not VariableRateJump); jump-only problems (no continuous drift); additive affects (the net change is inferred from affect! and checked); a fixed nmax bound on the number of events (with an optional saturation flag). The earlier fixed-grid VariableRateJump/PDMP approach (#596) is intentionally set aside in favor of this narrower, exact constant-rate first step.

Dependency note. StochasticAD is an optional [weakdeps] extension, and its tests run in a dedicated GROUP=StochasticAD environment (no OrdinaryDiffEq) so they don't perturb the main test suite. (The currently-resolvable StochasticAD is fairly old — happy to discuss whether that's a concern for taking this as a dependency.)

   Implements StochasticAD-compatible gradients for jump-only ConstantRateJump processes, following Chris's "start with the constant-rate SSA" direction.

ext/JumpProcessesStochasticADExt.jl (loaded with StochasticAD + Distributions):
   - constant_rate_ssa_final_state: exact iterative SSA, differentiable. Reformulated as a fixed-length Bernoulli-per-event loop (occurs ~ Bernoulli(1-exp(-rate*(T-t))), truncated-exponential time, stick-breaking channel choice, multiplicative-select
     freezing) so the event-count derivative survives -- a naive `while t<T` SSA drops it. Supports state-dependent rates; rates recomputed after each event.
   - poisson_count_final_state: closed-form helper for the state-independent homogeneous-Poisson case (guarded against state-dependent misuse).

   src/aggregators/ssajump.jl: make the SSA rate cache generic over the rate type (via return-type inference) instead of forcing Float64, so StochasticTriple rates pass through the existing Direct/SSAStepper rate-cache path (Float64 unchanged). This alone does not yield gradients -- the stock while-t<T boundary still drops the event-count derivative; the reformulation above is what gives correct gradients.

   src/JumpProcesses.jl: stubs + exports (no StochasticAD in src/).
   test/stochasticad_tests.jl: pure-death + birth-death analytic-gradient tests, Poisson baseline, guards, and a generic-rate-cache test, in an isolated GROUP=StochasticAD env.
@RomanSahakyan03

Copy link
Copy Markdown
Author

Hi @ChrisRackauckas and @isaacsas, please let me know when you get a chance to review the changes. Thanks!

Turn the prototype differentiable SSA implementation into a first-class solve algorithm, with the generic rate cache now used by both the standard Direct path and the bounded StochasticAD-compatible path.

API:

* Add BoundedSSA(; nmax), a StochasticAD-compatible SSA solve algorithm.
* Add the BoundedSSA struct and export it from src/SSA_stepper.jl.
* Add the corresponding __solve implementation in the extension.
* Use a dedicated integrator so the bounded path is independent of SSAStepper's Float64 time handling and never compares t < T.
* Rename constant_rate_ssa_final_state to bounded_ssa_final_state.
* Add saturation_probability for choosing and diagnosing nmax. The truncation bias is exactly P(N > nmax), so it is measured rather than hidden.
* Drop poisson_count_final_state, since the bounded path subsumes the previous presampling approach.

Direct aggregator refactor:

* Factor the raw per-channel rate fill from time_to_next_jump into fill_cur_rates!.
* Reuse fill_cur_rates! in both the stock Direct/DirectFW path and the bounded loop.
* This lets StochasticTriple rates pass through the existing rate machinery.
* The refactor is behavior-preserving: Direct and DirectFW were verified bit-identical across seeds.

Scope:

* Support is limited to ConstantRateJump for now.
* MassActionJump is deferred because evalrxrate is not triple-generic: it has a ::R return assertion and an order > 1 specpop <= 0 boolean.
* MassActionJump rate constants also flow through param_mapper(p).
* Unsupported cases now fail with an informative error.

Tests:

* Rewrite stochasticad_tests.jl for the new BoundedSSA API.
* Test solve-path differentiation through BoundedSSA.
* Compare primal means against SSAStepper.
* Add fill_cur_rates_regression.jl to guard the direct.jl refactor.
* Cover mass-action, multi-channel, fwrapper, and exact Direct == DirectFW behavior.
* Fix runtests activate_stochasticad_env to use a cwd-independent path.
The grouped-tests.yml refactor (SciML#604) drives groups from test/test_groups.toml, which did not carry over the StochasticAD group the old Tests.yml matrix declared -- so CI was not running the BoundedSSA tests. runtests.jl already dispatches GROUP=StochasticAD into the isolated env (no special runner needed).
@isaacsas

Copy link
Copy Markdown
Member

@RomanSahakyan03 thanks!

I'd prefer if you didn't change the Direct code path or helper functions, and just added dispatches or separate helpers for the StochasticAD version. The Direct version has been extensively benchmarked, and I wouldn't want to risk its performance, so changes there would need a lot of benchmarking to be accepted.

Are you able to setup your solver to give full paths and allow evaluation of those paths at set times, or to at least support (maybe require?) saveat and return the solution at fixed times? This would be a lot more useful for users than just providing the solution at a fixed terminal time, and with a bias due to capping the number of steps.

If capping the number of steps is really needed, could you instead assume a maximum total propensity bound? This would allow you to pre-sample the number of events and event firing times, and to just have a null event to account for the difference between the bound and the current total propensity (keeping everything StochasticAD compatible). While it is still not going to be possible to provide such a bound for most chemical systems, for systems where the amount of each chemical is rigorously bounded such a bound would generally hold, so it seems like a more generally applicable solver that avoids the bias in having a nmax parameter.

@RomanSahakyan03

Copy link
Copy Markdown
Author

Thanks @isaacsas, that makes sense.

I’ll remove the changes to the Direct path from this PR and keep the StochasticAD implementation separate.

For the output, I'll make it return all path and make solver to work with saveat.

The bounded-propensity idea also sounds better than nmax. I’ll look into replacing the current capped-step approach with a thinning/uniformization version using a user-provided total propensity bound, with null events for the remaining rate. That should avoid the nmax bias when a valid bound is available.

So I'll update the PR in that direction: separate StochasticAD code path, saveat support, and bounded-propensity thinning instead of the current nmax-based version.

Reworked per review (isaacsas): Direct path reverted to upstream (untouched); StochasticAD is a separate DiffEqBase.solve dispatch + bounded_ssa_path helper. Uniformization/thinning against a constant total-propensity bound (rate_bound) with null events -- unbiased, iterates (rates recomputed per event), StochasticAD-compatible (parameter-free candidate times, no boolean on triple time). Full path / saveat support reusing the package's _process_saveat (same as SimpleTauLeaping); sol(t) via ConstantInterpolation. Drops the prior nmax approach, generic rate cache, saturation_probability, and the vestigial regression test. Tests 16/16; Direct regression 30/30. MassActionJump deferred.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants