Stochasticad constant rate#606
Conversation
Implements StochasticAD-compatible gradients for jump-only ConstantRateJump processes, following Chris's "start with the constant-rate SSA" direction.
ext/JumpProcessesStochasticADExt.jl (loaded with StochasticAD + Distributions):
- constant_rate_ssa_final_state: exact iterative SSA, differentiable. Reformulated as a fixed-length Bernoulli-per-event loop (occurs ~ Bernoulli(1-exp(-rate*(T-t))), truncated-exponential time, stick-breaking channel choice, multiplicative-select
freezing) so the event-count derivative survives -- a naive `while t<T` SSA drops it. Supports state-dependent rates; rates recomputed after each event.
- poisson_count_final_state: closed-form helper for the state-independent homogeneous-Poisson case (guarded against state-dependent misuse).
src/aggregators/ssajump.jl: make the SSA rate cache generic over the rate type (via return-type inference) instead of forcing Float64, so StochasticTriple rates pass through the existing Direct/SSAStepper rate-cache path (Float64 unchanged). This alone does not yield gradients -- the stock while-t<T boundary still drops the event-count derivative; the reformulation above is what gives correct gradients.
src/JumpProcesses.jl: stubs + exports (no StochasticAD in src/).
test/stochasticad_tests.jl: pure-death + birth-death analytic-gradient tests, Poisson baseline, guards, and a generic-rate-cache test, in an isolated GROUP=StochasticAD env.
|
Hi @ChrisRackauckas and @isaacsas, please let me know when you get a chance to review the changes. Thanks! |
Turn the prototype differentiable SSA implementation into a first-class solve algorithm, with the generic rate cache now used by both the standard Direct path and the bounded StochasticAD-compatible path. API: * Add BoundedSSA(; nmax), a StochasticAD-compatible SSA solve algorithm. * Add the BoundedSSA struct and export it from src/SSA_stepper.jl. * Add the corresponding __solve implementation in the extension. * Use a dedicated integrator so the bounded path is independent of SSAStepper's Float64 time handling and never compares t < T. * Rename constant_rate_ssa_final_state to bounded_ssa_final_state. * Add saturation_probability for choosing and diagnosing nmax. The truncation bias is exactly P(N > nmax), so it is measured rather than hidden. * Drop poisson_count_final_state, since the bounded path subsumes the previous presampling approach. Direct aggregator refactor: * Factor the raw per-channel rate fill from time_to_next_jump into fill_cur_rates!. * Reuse fill_cur_rates! in both the stock Direct/DirectFW path and the bounded loop. * This lets StochasticTriple rates pass through the existing rate machinery. * The refactor is behavior-preserving: Direct and DirectFW were verified bit-identical across seeds. Scope: * Support is limited to ConstantRateJump for now. * MassActionJump is deferred because evalrxrate is not triple-generic: it has a ::R return assertion and an order > 1 specpop <= 0 boolean. * MassActionJump rate constants also flow through param_mapper(p). * Unsupported cases now fail with an informative error. Tests: * Rewrite stochasticad_tests.jl for the new BoundedSSA API. * Test solve-path differentiation through BoundedSSA. * Compare primal means against SSAStepper. * Add fill_cur_rates_regression.jl to guard the direct.jl refactor. * Cover mass-action, multi-channel, fwrapper, and exact Direct == DirectFW behavior. * Fix runtests activate_stochasticad_env to use a cwd-independent path.
…Sahakyan03/JumpProcesses.jl into stochasticad-constant-rate
The grouped-tests.yml refactor (SciML#604) drives groups from test/test_groups.toml, which did not carry over the StochasticAD group the old Tests.yml matrix declared -- so CI was not running the BoundedSSA tests. runtests.jl already dispatches GROUP=StochasticAD into the isolated env (no special runner needed).
|
@RomanSahakyan03 thanks! I'd prefer if you didn't change the Are you able to setup your solver to give full paths and allow evaluation of those paths at set times, or to at least support (maybe require?) If capping the number of steps is really needed, could you instead assume a maximum total propensity bound? This would allow you to pre-sample the number of events and event firing times, and to just have a null event to account for the difference between the bound and the current total propensity (keeping everything StochasticAD compatible). While it is still not going to be possible to provide such a bound for most chemical systems, for systems where the amount of each chemical is rigorously bounded such a bound would generally hold, so it seems like a more generally applicable solver that avoids the bias in having a |
|
Thanks @isaacsas, that makes sense. I’ll remove the changes to the For the output, I'll make it return all path and make solver to work with The bounded-propensity idea also sounds better than So I'll update the PR in that direction: separate StochasticAD code path, |
Reworked per review (isaacsas): Direct path reverted to upstream (untouched); StochasticAD is a separate DiffEqBase.solve dispatch + bounded_ssa_path helper. Uniformization/thinning against a constant total-propensity bound (rate_bound) with null events -- unbiased, iterates (rates recomputed per event), StochasticAD-compatible (parameter-free candidate times, no boolean on triple time). Full path / saveat support reusing the package's _process_saveat (same as SimpleTauLeaping); sol(t) via ConstantInterpolation. Drops the prior nmax approach, generic rate cache, saturation_probability, and the vestigial regression test. Tests 16/16; Direct regression 30/30. MassActionJump deferred.
Summary
Adds an optional
StochasticADextension that makesderivative_estimatework on expectations over jump-onlyConstantRateJump(SSA) problems, following @ChrisRackauckas's suggestion to start with the constant-rate case.New API (in
ext/JumpProcessesStochasticADExt.jl, loaded only when StochasticAD + Distributions are present):constant_rate_ssa_final_state(jprob, p; nmax, tspan, return_saturation=false)— the main, exact iterative SSA method, differentiable and supporting state-dependent rates.poisson_count_final_state(jprob, p; tspan)— a closed-form helper for the narrow state-independent (homogeneous-Poisson) case, with guards against misuse.Also:
src/aggregators/ssajump.jlmakes the SSA rate cache generic over the rate type;src/JumpProcesses.jladds stubs + exports (no StochasticAD insrc/); tests run in an isolatedGROUP=StochasticADenvironment.Checklist
contributor guidelines, in particular the SciML Style Guide and
COLPRAC.
Additional context
Why a reformulation (the key point). A naive Gillespie loop differentiated with StochasticAD returns
0whenever the parameter acts through the event times — e.g. pure death,rate = μ·u, where a largerμsimply packs in more deaths beforeT. Thewhile t < Ttermination branches on the primal time, so the event count (the parameter-dependent quantity) is frozen and the derivative is silently dropped.constant_rate_ssa_final_statefixes this exactly by rewriting each "is there a next event beforeT?" as a trackedBernoulli(1 - exp(-total_rate·(T-t))), run over a fixed-length loop, with the event time drawn from the truncated exponential, the reaction chosen by stick-breakingBernoullis, and the state updated by multiplicative-select masking (no branching on a triple, no triple-valued indexing). This has the same trajectory distribution as the SSA (it is not a τ-leap approximation), recovers analytic gradients, and handles state-dependent rates and multiple channels.Validation (vs. closed-form analytic expectations):
d/dμ E[u(T)]-60.58 ± 0.43-60.65d/dλ≈ 0.810.86d/dμ≈ -40-41.1On the generic rate cache. Per the request to let
StochasticTriplerates flow through the existingDirect/SSAStepperpath,build_jump_aggregationnow sizescur_ratesby the inferred rate type instead oftypeof(t)(viaBase.promote_op, so it never executes the rate at build time;Float64behavior is unchanged and the existingConstant Rate/SSA tests pass). Honest caveat: this change alone does not produce correct gradients — the stockwhile t < Tboundary still drops the event-count derivative. Correct gradients come from the reformulation above.Scope / limitations (this first step).
ConstantRateJumponly (notVariableRateJump); jump-only problems (no continuous drift); additive affects (the net change is inferred fromaffect!and checked); a fixednmaxbound on the number of events (with an optional saturation flag). The earlier fixed-gridVariableRateJump/PDMP approach (#596) is intentionally set aside in favor of this narrower, exact constant-rate first step.Dependency note.
StochasticADis an optional[weakdeps]extension, and its tests run in a dedicatedGROUP=StochasticADenvironment (no OrdinaryDiffEq) so they don't perturb the main test suite. (The currently-resolvableStochasticADis fairly old — happy to discuss whether that's a concern for taking this as a dependency.)