From 0bb71999e973a0c31a5b142feb2446c1d2210a6e Mon Sep 17 00:00:00 2001 From: Roman <116185879+RomanSahakyan03@users.noreply.github.com> Date: Fri, 12 Jun 2026 10:29:41 +0400 Subject: [PATCH 1/5] Add StochasticAD support for constant-rate jump processes and related tests --- .github/workflows/Tests.yml | 3 ++ Project.toml | 5 ++ ext/JumpProcessesStochasticADExt.jl | 75 +++++++++++++++++++++++++++++ src/JumpProcesses.jl | 6 +++ test/runtests.jl | 15 ++++++ test/stochasticad/Project.toml | 4 ++ test/stochasticad_tests.jl | 42 ++++++++++++++++ 7 files changed, 150 insertions(+) create mode 100644 ext/JumpProcessesStochasticADExt.jl create mode 100644 test/stochasticad/Project.toml create mode 100644 test/stochasticad_tests.jl diff --git a/.github/workflows/Tests.yml b/.github/workflows/Tests.yml index 3f223b3de..7c85736d7 100644 --- a/.github/workflows/Tests.yml +++ b/.github/workflows/Tests.yml @@ -30,9 +30,12 @@ jobs: - InterfaceI - InterfaceII - QA + - StochasticAD exclude: - version: "pre" group: QA + - version: "pre" + group: StochasticAD uses: "SciML/.github/.github/workflows/tests.yml@v1" with: julia-version: "${{ matrix.version }}" diff --git a/Project.toml b/Project.toml index 9b2401234..49087d49f 100644 --- a/Project.toml +++ b/Project.toml @@ -22,12 +22,15 @@ SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" [weakdeps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8" +StochasticAD = "e4facb34-4f7e-4bec-b153-e122c37934ac" [extensions] JumpProcessesKernelAbstractionsExt = ["Adapt", "KernelAbstractions"] JumpProcessesOrdinaryDiffEqCoreExt = "OrdinaryDiffEqCore" +JumpProcessesStochasticADExt = ["StochasticAD", "Distributions"] [compat] ADTypes = "1" @@ -37,6 +40,7 @@ ArrayInterface = "7.15" DataStructures = "0.18, 0.19" DiffEqBase = "6.192, 7" DiffEqCallbacks = "4.7" +Distributions = "0.25" DocStringExtensions = "0.9" ExplicitImports = "1" FastBroadcast = "0.3, 1" @@ -58,6 +62,7 @@ SciMLBase = "2.115, 3.1" StableRNGs = "1" StaticArrays = "1.9.8" Statistics = "1" +StochasticAD = "0.1" StochasticDiffEq = "6.82, 7" SymbolicIndexingInterface = "0.3.36" Test = "1" diff --git a/ext/JumpProcessesStochasticADExt.jl b/ext/JumpProcessesStochasticADExt.jl new file mode 100644 index 000000000..742028c0f --- /dev/null +++ b/ext/JumpProcessesStochasticADExt.jl @@ -0,0 +1,75 @@ +module JumpProcessesStochasticADExt + +# Optional StochasticAD support for JumpProcesses: differentiate expectations +# over CONSTANT-RATE jump processes. +# +# Constant-rate jumps with state-independent rates form a (sum of) homogeneous +# Poisson process(es), so the StochasticAD-friendly route -- which also matches +# the standard "pre-sample times and use `tstops`" trick -- is to sample the +# per-channel event count `Nₖ ~ Poisson(λₖ·ΔT)` directly. StochasticAD +# differentiates the Poisson sampler, so no jump-time rootfinding or +# Float64-typed propensity cache is involved. + +using JumpProcesses +using StochasticAD +using Distributions: Poisson + +# minimal integrator-like object so a jump's `affect!` can be applied to a +# scratch state to read off its (additive) net effect. +mutable struct ShimIntegrator{U, P, T} + u::U + p::P + t::T +end + +""" + constant_rate_final_state(jprob, p; tspan = jprob.prob.tspan) -> u(tf) + +Final state of a constant-rate jump process, computed in a way that composes with +StochasticAD's `derivative_estimate`/`stochastic_triple`. For each +`ConstantRateJump` in `jprob`, the event count over `tspan` is sampled as +`Nₖ ~ Poisson(λₖ · ΔT)` (which StochasticAD differentiates directly) and the +final state is `u0 + Σₖ Nₖ · Δuₖ`. + +Wrap in `derivative_estimate` to get gradients of an expectation: + +```julia +derivative_estimate(p0) do p + observable(constant_rate_final_state(jprob, p)) +end +``` + +Exactness conditions: + + - every `ConstantRateJump` rate must be **state-independent**, so `λₖ` is + constant over `[t0, tf]`. The rate is read once as `jump.rate(u0, p, t0)`. + (State-dependent rates couple the event count to `p` through the jump + *times*, which a fixed pre-sample cannot capture — out of scope here.) + - each `affect!` must apply a **constant additive** net change to the state. + `Δuₖ` is inferred by applying `jump.affect!` to a copy of `u0`. + +The differentiation parameter must enter through the `p` argument of the rate +functions (it is passed straight to `jump.rate(u0, p, t0)`). +""" +function JumpProcesses.constant_rate_final_state(jprob, p; tspan = jprob.prob.tspan) + u0 = jprob.prob.u0 + jumps = jprob.constant_jumps + t0, tf = tspan + ΔT = tf - t0 + n = length(u0) + + # net additive change per channel (Float64), inferred from the affect + Δ = map(jumps) do jump + ushim = collect(float.(u0)) + jump.affect!(ShimIntegrator(ushim, p, t0)) + ushim .- float.(u0) + end + + # per-channel event counts; the StochasticTriple flows in via the rate + N = map(jump -> rand(Poisson(jump.rate(u0, p, t0) * ΔT)), jumps) + + # u(tf) = u0 + Σₖ Nₖ Δₖ (order-independent for additive jumps) + return [float(u0[i]) + sum(N[k] * Δ[k][i] for k in eachindex(jumps)) for i in 1:n] +end + +end # module diff --git a/src/JumpProcesses.jl b/src/JumpProcesses.jl index 87c8c2c26..098fb97f4 100644 --- a/src/JumpProcesses.jl +++ b/src/JumpProcesses.jl @@ -112,6 +112,12 @@ include("aggregators/aggregated_api.jl") include("variable_rate.jl") export VariableRateAggregator, VR_FRM, VR_Direct, VR_DirectFW +# StochasticAD support. Stub; the method is provided by the package extension +# `ext/JumpProcessesStochasticADExt.jl`, which loads only when StochasticAD and +# Distributions are both available. No StochasticAD code lives in `src/`. +function constant_rate_final_state end +export constant_rate_final_state + """ Aggregator to indicate that individual jumps should also be handled via the leaping algorithm that is passed to solve. diff --git a/test/runtests.jl b/test/runtests.jl index 9a2dfc246..06f01c3c8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,6 +9,16 @@ function activate_gpu_env() Pkg.instantiate() end +# Isolated environment for the StochasticAD extension tests. StochasticAD pins +# old transitive deps (e.g. ForwardDiff 0.10) that conflict with the modern +# OrdinaryDiffEq stack, so it is kept out of the main test target and run here +# in its own project (no ODE solver needed -- the extension never calls `solve`). +function activate_stochasticad_env() + Pkg.activate("stochasticad") + Pkg.develop(PackageSpec(path = dirname(@__DIR__))) + Pkg.instantiate() +end + @time begin if GROUP == "QA" @time @safetestset "QA Tests" begin include("qa.jl") end @@ -63,6 +73,11 @@ end @time @safetestset "GPU Tau Leaping test" begin include("gpu/regular_jumps.jl") end end + if GROUP == "StochasticAD" + activate_stochasticad_env() + @time @safetestset "StochasticAD Extension Tests" begin include("stochasticad_tests.jl") end + end + if GROUP == "Correctness" activate_gpu_env() end diff --git a/test/stochasticad/Project.toml b/test/stochasticad/Project.toml new file mode 100644 index 000000000..a9253003c --- /dev/null +++ b/test/stochasticad/Project.toml @@ -0,0 +1,4 @@ +[deps] +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +StochasticAD = "e4facb34-4f7e-4bec-b153-e122c37934ac" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/stochasticad_tests.jl b/test/stochasticad_tests.jl new file mode 100644 index 000000000..7c2bf9b9e --- /dev/null +++ b/test/stochasticad_tests.jl @@ -0,0 +1,42 @@ +using JumpProcesses, StochasticAD, Distributions +using Statistics, Random, Test + +# Tests for the optional StochasticAD extension: differentiating expectations +# over constant-rate jump processes. Needs only JumpProcesses + StochasticAD + +# Distributions (no ODE solver), so it runs in its own isolated environment. + +@testset "constant-rate jump gradient" begin + # Two constant (state-independent) channels on a scalar count: + # birth rate λ1 (u[1] += 1), emigration rate λ2 (u[1] -= 1). + # u(T) = u0 + N1 - N2, Nk ~ Poisson(λk T) => E[u(T)] = u0 + (λ1 - λ2)T + # d/dλ1 E = T , d/dλ2 E = -T + T = 2.0 + j1 = ConstantRateJump((u, p, t) -> p[1], integ -> (integ.u[1] += 1; nothing)) + j2 = ConstantRateJump((u, p, t) -> p[2], integ -> (integ.u[1] -= 1; nothing)) + dprob = DiscreteProblem([10], (0.0, T)) + jprob = JumpProblem(dprob, Direct(), j1, j2) + + @test length(jprob.constant_jumps) == 2 + + obs = u -> u[1] + λ0 = [3.0, 1.0] + + for k in 1:2 + N = 2000 + s = Vector{Float64}(undef, N) + for i in 1:N + Random.seed!(i) + s[i] = derivative_estimate(λ0[k]) do λk + p = [j == k ? λk : oftype(λk, λ0[j]) for j in 1:2] + obs(constant_rate_final_state(jprob, p)) + end + end + target = k == 1 ? T : -T + @test isapprox(mean(s), target; atol = 0.05) + end + + # mean state is also recovered (E[u(T)] = 10 + (3-1)*2 = 14) + Random.seed!(1) + means = mean(constant_rate_final_state(jprob, λ0)[1] for _ in 1:4000) + @test isapprox(means, 14.0; atol = 0.5) +end From 697f635d754e6fcf3f1c4be647265dff9f6a1fda Mon Sep 17 00:00:00 2001 From: Roman <116185879+RomanSahakyan03@users.noreply.github.com> Date: Sun, 14 Jun 2026 01:03:06 +0400 Subject: [PATCH 2/5] Add StochasticAD support for constant-rate jump SSA differentiation Implements StochasticAD-compatible gradients for jump-only ConstantRateJump processes, following Chris's "start with the constant-rate SSA" direction. ext/JumpProcessesStochasticADExt.jl (loaded with StochasticAD + Distributions): - constant_rate_ssa_final_state: exact iterative SSA, differentiable. Reformulated as a fixed-length Bernoulli-per-event loop (occurs ~ Bernoulli(1-exp(-rate*(T-t))), truncated-exponential time, stick-breaking channel choice, multiplicative-select freezing) so the event-count derivative survives -- a naive `while t u(tf) + constant_rate_ssa_final_state(jprob, p; nmax, tspan = jprob.prob.tspan, + return_saturation = false) -Final state of a constant-rate jump process, computed in a way that composes with -StochasticAD's `derivative_estimate`/`stochastic_triple`. For each -`ConstantRateJump` in `jprob`, the event count over `tspan` is sampled as -`Nₖ ~ Poisson(λₖ · ΔT)` (which StochasticAD differentiates directly) and the -final state is `u0 + Σₖ Nₖ · Δuₖ`. +Final state at `tspan[2]` of a **jump-only** `ConstantRateJump` process, computed +so that StochasticAD's `derivative_estimate`/`stochastic_triple` give correct +gradients — **including state-dependent rates** such as `rate(u,p,t) = p[1]*u[1]`. -Wrap in `derivative_estimate` to get gradients of an expectation: +Wrap in `derivative_estimate` (one scalar partial at a time): ```julia -derivative_estimate(p0) do p - observable(constant_rate_final_state(jprob, p)) +derivative_estimate(p0[k]) do pk + pv = [j == k ? pk : oftype(pk, p0[j]) for j in eachindex(p0)] + observable(constant_rate_ssa_final_state(jprob, pv; nmax = 500)) end ``` -Exactness conditions: +# Method + +A naive Gillespie loop (`t += randexp()/rate; t < T || break`) drops the +event-count derivative, because `t < T` branches on the primal time. Instead this +runs a **fixed-length** loop of `nmax` potential events and replaces the +time-comparison with a tracked `Bernoulli`: + + - `p_occ = 1 - exp(-total_rate·(T - t))`, `occurs ~ Bernoulli(p_occ)` — the + parameter-dependent probability the next event lands before `T`; + - the event time is the **truncated** exponential `-log(1 - U·p_occ)/total`; + - the channel is chosen by **stick-breaking** conditional `Bernoulli`s (the last + channel deterministic), avoiding `Categorical` 0/0 and triple-valued indexing; + - state updates use **multiplicative-select masking** (`step = active*occurs`), + so once an event fails to occur the trajectory is frozen — no branching on a + triple-valued decision. + +This is **exact** (the trajectory distribution equals the SSA's), not a τ-leap +approximation. Rates are recomputed from the current (triple) state every event. + +# Arguments / scope + + - `nmax` (required): fixed upper bound on the number of events. The loop always + runs `nmax` steps; once the chain breaks the rest are masked. If the true + count can exceed `nmax` the result is biased — choose `nmax` large enough that + saturation is negligible (see `return_saturation`). + - `return_saturation = true` returns `(u, active)`; `active != 0` means the + trajectory still had `nmax` consecutive events (i.e. it may be truncated). + Use this on the primal (Float64 `p`) to estimate the saturation probability. + +# Limitations + + - `ConstantRateJump` only (not `VariableRateJump`); jump-only, no continuous drift. + - Additive affects only (`Δuₖ` inferred from `affect!`, checked for state-independence). + - The differentiation parameter must enter through the `p` argument of the rates. +""" +function JumpProcesses.constant_rate_ssa_final_state(jprob, p; nmax, + tspan = jprob.prob.tspan, return_saturation = false) + u0 = jprob.prob.u0 + jumps = jprob.constant_jumps + t0, tf = tspan + K = length(jumps) + n = length(u0) + + # additive net change per channel (Float64), checked for state-independence + Δ = [_additive_change(jumps[k], u0, p, t0) for k in 1:K] - - every `ConstantRateJump` rate must be **state-independent**, so `λₖ` is - constant over `[t0, tf]`. The rate is read once as `jump.rate(u0, p, t0)`. - (State-dependent rates couple the event count to `p` through the jump - *times*, which a fixed pre-sample cannot capture — out of scope here.) - - each `affect!` must apply a **constant additive** net change to the state. - `Δuₖ` is inferred by applying `jump.affect!` to a copy of `u0`. + z = 0 * sum(p) # triple zero + u = [float(u0[i]) + z for i in 1:n] # triple-typed state + t = z + active = 1 + z # 1 while the event chain is unbroken + + for _ in 1:nmax + rates = [jumps[k].rate(u, p, t) for k in 1:K] # state-dependent OK + total = sum(rates) + Δt = tf - t + pocc = 1 - exp(-total * Δt) # P(next event before T) + occurs = rand(Bernoulli(pocc)) + step = active * occurs + + # which channel: stick-breaking conditional Bernoullis + multiplicative + # select; last channel deterministic, suffix-sum denominator (in [0,1)). + notchosen = 1 + z + sel = [z for _ in 1:n] + for k in 1:K + if k < K + denom = sum(rates[j] for j in k:K) + 1e-300 + chose = rand(Bernoulli(rates[k] / denom)) + else + chose = 1 + z + end + take = notchosen * chose + sel = [sel[i] + take * Δ[k][i] for i in 1:n] + notchosen = notchosen * (1 - chose) + end + + U = rand() + τ = -log(1 - U * pocc) / (total + 1e-300) # truncated-exp time + t = t + step * τ + u = [u[i] + step * sel[i] for i in 1:n] + active = active * occurs + end + + return return_saturation ? (u, active) : u +end + +# =========================================================================== +# NARROW helper: state-independent homogeneous-Poisson closed form +# =========================================================================== -The differentiation parameter must enter through the `p` argument of the rate -functions (it is passed straight to `jump.rate(u0, p, t0)`). """ -function JumpProcesses.constant_rate_final_state(jprob, p; tspan = jprob.prob.tspan) + poisson_count_final_state(jprob, p; tspan = jprob.prob.tspan) -> u(tf) + +Closed-form final state for the special case of **state-independent** constant +rates (a sum of homogeneous Poisson processes): `Nₖ ~ Poisson(λₖ·ΔT)`, +`u(tf) = u0 + Σₖ Nₖ·Δuₖ`. StochasticAD differentiates the `Poisson` sampler +directly. + +This is **only** for state-independent homogeneous-Poisson additive jump systems. +It is **not** the iterative SSA path and does **not** cover state-dependent +`ConstantRateJump` rates — use [`constant_rate_ssa_final_state`](@ref) for those. +The rate is read once at `(u0, p, t0)`; a state-independence check guards misuse. +""" +function JumpProcesses.poisson_count_final_state(jprob, p; tspan = jprob.prob.tspan) u0 = jprob.prob.u0 jumps = jprob.constant_jumps t0, tf = tspan ΔT = tf - t0 n = length(u0) - # net additive change per channel (Float64), inferred from the affect - Δ = map(jumps) do jump - ushim = collect(float.(u0)) - jump.affect!(ShimIntegrator(ushim, p, t0)) - ushim .- float.(u0) + base_shift = float.(collect(u0)) .+ one(eltype(float.(collect(u0)))) + for jump in jumps + isapprox(_val(jump.rate(u0, p, t0)), _val(jump.rate(base_shift, p, t0))) || error( + "poisson_count_final_state requires state-INDEPENDENT rates; a rate " * + "changed with the state. Use constant_rate_ssa_final_state for " * + "state-dependent ConstantRateJumps.") end - # per-channel event counts; the StochasticTriple flows in via the rate + Δ = map(jump -> _additive_change(jump, u0, p, t0), jumps) N = map(jump -> rand(Poisson(jump.rate(u0, p, t0) * ΔT)), jumps) - - # u(tf) = u0 + Σₖ Nₖ Δₖ (order-independent for additive jumps) return [float(u0[i]) + sum(N[k] * Δ[k][i] for k in eachindex(jumps)) for i in 1:n] end diff --git a/src/JumpProcesses.jl b/src/JumpProcesses.jl index 098fb97f4..b796b49ec 100644 --- a/src/JumpProcesses.jl +++ b/src/JumpProcesses.jl @@ -112,11 +112,12 @@ include("aggregators/aggregated_api.jl") include("variable_rate.jl") export VariableRateAggregator, VR_FRM, VR_Direct, VR_DirectFW -# StochasticAD support. Stub; the method is provided by the package extension +# StochasticAD support. Stubs; methods are provided by the package extension # `ext/JumpProcessesStochasticADExt.jl`, which loads only when StochasticAD and # Distributions are both available. No StochasticAD code lives in `src/`. -function constant_rate_final_state end -export constant_rate_final_state +function constant_rate_ssa_final_state end +function poisson_count_final_state end +export constant_rate_ssa_final_state, poisson_count_final_state """ Aggregator to indicate that individual jumps should also be handled via the leaping diff --git a/src/aggregators/ssajump.jl b/src/aggregators/ssajump.jl index 90c260c97..f032222e5 100644 --- a/src/aggregators/ssajump.jl +++ b/src/aggregators/ssajump.jl @@ -110,6 +110,24 @@ Adds a `tstop` to the integrator at the next jump time. nothing end +# Element type for the SSA rate cache: the time type promoted with the *inferred* +# rate output type(s). Ordinary Float64 rates give Float64 (unchanged); a non-Float64 +# rate type (e.g. a StochasticAD StochasticTriple) is preserved rather than forced +# into Float64. We use return-type inference (`Base.promote_op`) rather than calling +# the rate, so this never executes the rate at build time (which could index a +# `NullParameters` `p`, have side effects, etc.) and falls back to the time type if +# inference is not a concrete type. +function ssa_rate_eltype(u, p, t, majumps, rates) + R = typeof(t) + if get_num_majumps(majumps) > 0 + R = promote_type(R, Base.promote_op(evalrxrate, typeof(u), Int, typeof(majumps))) + end + if !isempty(rates) + R = promote_type(R, Base.promote_op(first(rates), typeof(u), typeof(p), typeof(t))) + end + return isconcretetype(R) ? R : typeof(t) +end + """ build_jump_aggregation(jump_agg_type, u, p, t, end_time, ma_jumps, rates, affects!, save_positions, rng; kwargs...) @@ -127,14 +145,22 @@ function build_jump_aggregation(jump_agg_type, u, p, t, end_time, ma_jumps, rate Vector{Vector{Pair{Int, eltype(u)}}}()) end + # Rate-cache element type: the time type promoted with the actual rate output + # type(s). For ordinary Float64 rates this stays Float64 (behavior unchanged); + # if a rate returns a non-Float64 number (e.g. a StochasticAD StochasticTriple), + # the cache holds that type instead of forcing it into Float64, so such rates + # can pass through the existing SSA path. Plain promotion/conversion -- no + # StochasticAD dependency in src/. + RT = ssa_rate_eltype(u, p, t, majumps, rates) + # current jump rates, allows mass action rates and constant jumps - cur_rates = Vector{typeof(t)}(undef, get_num_majumps(majumps) + length(rates)) + cur_rates = Vector{RT}(undef, get_num_majumps(majumps) + length(rates)) - sum_rate = zero(typeof(t)) + sum_rate = convert(RT, zero(t)) next_jump = 0 - next_jump_time = typemax(typeof(t)) - jump_agg_type(next_jump, next_jump_time, end_time, cur_rates, sum_rate, - majumps, rates, affects!, save_positions, rng; kwargs...) + next_jump_time = convert(RT, typemax(typeof(t))) + jump_agg_type(next_jump, next_jump_time, convert(RT, end_time), cur_rates, + sum_rate, majumps, rates, affects!, save_positions, rng; kwargs...) end """ diff --git a/test/stochasticad_tests.jl b/test/stochasticad_tests.jl index 7c2bf9b9e..0e2812793 100644 --- a/test/stochasticad_tests.jl +++ b/test/stochasticad_tests.jl @@ -1,42 +1,106 @@ using JumpProcesses, StochasticAD, Distributions using Statistics, Random, Test -# Tests for the optional StochasticAD extension: differentiating expectations -# over constant-rate jump processes. Needs only JumpProcesses + StochasticAD + -# Distributions (no ODE solver), so it runs in its own isolated environment. - -@testset "constant-rate jump gradient" begin - # Two constant (state-independent) channels on a scalar count: - # birth rate λ1 (u[1] += 1), emigration rate λ2 (u[1] -= 1). - # u(T) = u0 + N1 - N2, Nk ~ Poisson(λk T) => E[u(T)] = u0 + (λ1 - λ2)T - # d/dλ1 E = T , d/dλ2 E = -T - T = 2.0 - j1 = ConstantRateJump((u, p, t) -> p[1], integ -> (integ.u[1] += 1; nothing)) - j2 = ConstantRateJump((u, p, t) -> p[2], integ -> (integ.u[1] -= 1; nothing)) - dprob = DiscreteProblem([10], (0.0, T)) - jprob = JumpProblem(dprob, Direct(), j1, j2) - - @test length(jprob.constant_jumps) == 2 - - obs = u -> u[1] - λ0 = [3.0, 1.0] - - for k in 1:2 - N = 2000 - s = Vector{Float64}(undef, N) - for i in 1:N - Random.seed!(i) - s[i] = derivative_estimate(λ0[k]) do λk - p = [j == k ? λk : oftype(λk, λ0[j]) for j in 1:2] - obs(constant_rate_final_state(jprob, p)) +# StochasticAD-compatible differentiation for jump-only ConstantRateJump SSA +# problems. Needs only JumpProcesses + StochasticAD + Distributions (no ODE +# solver), so it runs in its own isolated environment (GROUP=StochasticAD). + +# per-partial StochasticAD gradient with fixed seeds (reproducible) +function sad_partial(f, p0, k; N) + s = Vector{Float64}(undef, N) + for i in 1:N + Random.seed!(i) + s[i] = derivative_estimate(p0[k]) do pk + p = [j == k ? pk : oftype(pk, p0[j]) for j in eachindex(p0)] + f(p) + end + end + return mean(s), std(s) / sqrt(N) +end + +@testset "StochasticAD constant-rate jumps" begin + + # --- Test A: pure death, STATE-DEPENDENT rate (the case Chris cares about) --- + # rate = μ·u[1] ; E[u(T)] = u0·e^{-μT} ; d/dμ E[u(T)] = -T·u0·e^{-μT} + # A naive `while t p[1] * u[1], integ -> (integ.u[1] -= 1; nothing)) + jprob = JumpProblem(DiscreteProblem([u0], (0.0, T)), Direct(), death) + analytic = -T * u0 * exp(-μ0 * T) + g, se = sad_partial([μ0], 1; N = 10000) do p + constant_rate_ssa_final_state(jprob, p; nmax = 200)[1] + end + @test abs(g - analytic) < 4 * se # event-count derivative captured + @test abs(g) > 1.0 # explicitly NOT the zero a naive SSA gives + end + + # --- Test B: birth-death, multi-channel + state-dependent --- + # birth λ, death μ·u[1] ; E[u(T)] = λ/μ + (u0-λ/μ)e^{-μT} + @testset "birth-death (multi-channel, state-dependent)" begin + T, u0, λ0, μ0 = 1.0, 50, 10.0, 0.3 + birth = ConstantRateJump((u, p, t) -> p[1], integ -> (integ.u[1] += 1; nothing)) + death = ConstantRateJump((u, p, t) -> p[2] * u[1], integ -> (integ.u[1] -= 1; nothing)) + jprob = JumpProblem(DiscreteProblem([u0], (0.0, T)), Direct(), birth, death) + a, b = λ0 / μ0, exp(-μ0 * T) + analytic = [(1 - b) / μ0, -λ0 / μ0^2 * (1 - b) + (u0 - a) * (-T * b)] + for k in 1:2 + g, se = sad_partial([λ0, μ0], k; N = 10000) do p + constant_rate_ssa_final_state(jprob, p; nmax = 400)[1] end + @test abs(g - analytic[k]) < 4 * se end - target = k == 1 ? T : -T - @test isapprox(mean(s), target; atol = 0.05) end - # mean state is also recovered (E[u(T)] = 10 + (3-1)*2 = 14) - Random.seed!(1) - means = mean(constant_rate_final_state(jprob, λ0)[1] for _ in 1:4000) - @test isapprox(means, 14.0; atol = 0.5) + # --- Test C: homogeneous Poisson baseline (state-INDEPENDENT) --- + # two constant channels ; u(T) = u0 + N1 - N2 ; d/dλ1 = T, d/dλ2 = -T + @testset "homogeneous Poisson baseline" begin + T, u0 = 2.0, 10 + j1 = ConstantRateJump((u, p, t) -> p[1], integ -> (integ.u[1] += 1; nothing)) + j2 = ConstantRateJump((u, p, t) -> p[2], integ -> (integ.u[1] -= 1; nothing)) + jprob = JumpProblem(DiscreteProblem([u0], (0.0, T)), Direct(), j1, j2) + λ0 = [3.0, 1.0] + # closed-form helper (exact) + for k in 1:2 + g, _ = sad_partial(λ0, k; N = 3000) do p + poisson_count_final_state(jprob, p)[1] + end + @test isapprox(g, k == 1 ? T : -T; atol = 0.05) + end + # the iterative SSA method must agree with the baseline on this case too + for k in 1:2 + g, se = sad_partial(λ0, k; N = 10000) do p + constant_rate_ssa_final_state(jprob, p; nmax = 200)[1] + end + @test abs(g - (k == 1 ? T : -T)) < 4 * se + end + end + + # --- guards: misuse should error, not silently mislead --- + @testset "guards" begin + T = 1.0 + # state-dependent rate via the Poisson shortcut must error + death = ConstantRateJump((u, p, t) -> p[1] * u[1], integ -> (integ.u[1] -= 1; nothing)) + jp = JumpProblem(DiscreteProblem([10], (0.0, T)), Direct(), death) + @test_throws ErrorException poisson_count_final_state(jp, [0.5]) + end + + # --- generic rate cache (Chris's literal request): a triple-valued rate must + # pass through the existing Direct rate-cache without Float64(::StochasticTriple). + # SCOPE: this only makes the cache generic. It does NOT by itself give correct + # gradients — the stock `while t < T` event boundary still drops the event-count + # derivative, and a full stock solve would next hit the SSAStepper integrator's + # Float64 time. Correct gradients use constant_rate_ssa_final_state (above). + @testset "generic rate cache (no Float64(::StochasticTriple))" begin + st = stochastic_triple(identity, 0.5) + jump = ConstantRateJump((u, p, t) -> p[1] * u[1], integ -> (integ.u[1] -= 1; nothing)) + rates, affects = JumpProcesses.get_jump_info_tuples((jump,)) + agg = JumpProcesses.build_jump_aggregation( + JumpProcesses.DirectJumpAggregation, [10], [st], 0.0, 1.0, nothing, + rates, affects, (false, false), JumpProcesses.DEFAULT_RNG) + @test eltype(agg.cur_rates) <: StochasticAD.StochasticTriple # cache is generic, not Float64 + sr, _ = JumpProcesses.time_to_next_jump(agg, [10], [st], 0.0) + @test sr isa StochasticAD.StochasticTriple # filled without Float64 error + end end From e3f9a56e26984333a6efe081783a5b8b8a9465f4 Mon Sep 17 00:00:00 2001 From: Roman Sahakyan Date: Wed, 17 Jun 2026 15:20:31 +0400 Subject: [PATCH 3/5] Promote constant-rate StochasticAD support to BoundedSSA Turn the prototype differentiable SSA implementation into a first-class solve algorithm, with the generic rate cache now used by both the standard Direct path and the bounded StochasticAD-compatible path. API: * Add BoundedSSA(; nmax), a StochasticAD-compatible SSA solve algorithm. * Add the BoundedSSA struct and export it from src/SSA_stepper.jl. * Add the corresponding __solve implementation in the extension. * Use a dedicated integrator so the bounded path is independent of SSAStepper's Float64 time handling and never compares t < T. * Rename constant_rate_ssa_final_state to bounded_ssa_final_state. * Add saturation_probability for choosing and diagnosing nmax. The truncation bias is exactly P(N > nmax), so it is measured rather than hidden. * Drop poisson_count_final_state, since the bounded path subsumes the previous presampling approach. Direct aggregator refactor: * Factor the raw per-channel rate fill from time_to_next_jump into fill_cur_rates!. * Reuse fill_cur_rates! in both the stock Direct/DirectFW path and the bounded loop. * This lets StochasticTriple rates pass through the existing rate machinery. * The refactor is behavior-preserving: Direct and DirectFW were verified bit-identical across seeds. Scope: * Support is limited to ConstantRateJump for now. * MassActionJump is deferred because evalrxrate is not triple-generic: it has a ::R return assertion and an order > 1 specpop <= 0 boolean. * MassActionJump rate constants also flow through param_mapper(p). * Unsupported cases now fail with an informative error. Tests: * Rewrite stochasticad_tests.jl for the new BoundedSSA API. * Test solve-path differentiation through BoundedSSA. * Compare primal means against SSAStepper. * Add fill_cur_rates_regression.jl to guard the direct.jl refactor. * Cover mass-action, multi-channel, fwrapper, and exact Direct == DirectFW behavior. * Fix runtests activate_stochasticad_env to use a cwd-independent path. --- ext/JumpProcessesStochasticADExt.jl | 246 ++++++++++++++-------------- src/JumpProcesses.jl | 13 +- src/SSA_stepper.jl | 51 ++++++ src/aggregators/direct.jl | 75 ++++----- test/fill_cur_rates_regression.jl | 95 +++++++++++ test/runtests.jl | 5 +- test/stochasticad/Project.toml | 2 + test/stochasticad_tests.jl | 92 ++++++++--- 8 files changed, 392 insertions(+), 187 deletions(-) create mode 100644 test/fill_cur_rates_regression.jl diff --git a/ext/JumpProcessesStochasticADExt.jl b/ext/JumpProcessesStochasticADExt.jl index 1804fa9a8..69264b671 100644 --- a/ext/JumpProcessesStochasticADExt.jl +++ b/ext/JumpProcessesStochasticADExt.jl @@ -1,36 +1,41 @@ module JumpProcessesStochasticADExt # StochasticAD-compatible differentiation for jump-only `ConstantRateJump` SSA -# problems. Two clearly-scoped entry points: +# problems — the implementation behind the `BoundedSSA` algorithm and the +# `bounded_ssa_final_state` / `saturation_probability` entry points. # -# * `constant_rate_ssa_final_state` (main) — exact iterative SSA, supports -# state-dependent `ConstantRateJump` rates (rates recomputed after each -# event). Reformulated as a fixed-length, `Bernoulli`-per-event program so -# StochasticAD sees every discrete decision (a naive `while t < T` loop drops -# the event-count derivative). -# * `poisson_count_final_state` (narrow helper) — closed-form shortcut valid -# ONLY for state-independent (homogeneous-Poisson) additive jumps. +# Why this exists: the stock `solve(jprob, SSAStepper())` cannot be differentiated +# with StochasticAD. It decides the number of events with a +# `while integrator.t < integrator.tstop < end_time` loop — a boolean predicate on +# (triple-valued) time, which StochasticAD forbids by design — so the event-count +# derivative (the dominant term for state-dependent rates) is dropped and a +# state-dependent rate gives a gradient of 0. We instead run a fixed-length loop of +# at most `nmax` jump attempts, representing every discrete decision through +# stochastic primitives / masks rather than Julia branches. Exact up to +# `P(N > nmax)`. # -# Scope: jump-only `DiscreteProblem`s, additive affects, no continuous drift, no -# `VariableRateJump`, no `Tsit5`, no rootfinding. Not a replacement for the -# JumpProcesses aggregator internals — this reads `jump.rate`/`affect!` only. +# Scope: jump-only `DiscreteProblem`s with `ConstantRateJump`s (state-dependent +# rates OK) and additive affects. No `MassActionJump` (see `_constant_rate_channels` +# for why), no `VariableRateJump`, no continuous drift, no rootfinding. using JumpProcesses using StochasticAD -using Distributions: Bernoulli, Poisson +using Distributions: Bernoulli +using DiffEqBase +using Random -# minimal integrator-like object so a jump's `affect!` can be applied to a -# scratch state to read off its net effect. +# primal value of a (possibly triple) scalar. +_val(x) = x +_val(x::StochasticAD.StochasticTriple) = StochasticAD.value(x) + +# minimal integrator-like object so a jump's `affect!` can be applied to a scratch +# state to read off its net effect. mutable struct ShimIntegrator{U, P, T} u::U p::P t::T end -# primal value of a (possibly triple) scalar, for state-dependence checks. -_val(x) = x -_val(x::StochasticAD.StochasticTriple) = StochasticAD.value(x) - # apply `affect!` to a scratch copy of `ubase` and return the net state change. function _net_change(affect!, ubase, p, t0) u = collect(ubase) @@ -38,118 +43,111 @@ function _net_change(affect!, ubase, p, t0) return u .- ubase end -# infer a jump's net state change, verifying it is *additive* (the change is the -# same from two different base states). Additive affects are required: the final -# state is built as `u0 + Σ Nₖ Δuₖ` / by adding `Δuₖ` on each event. +# infer a jump's net state change, verifying it is *additive* (same change from two +# different base states). The final state is built by adding `Δuₖ` on each event, +# so non-additive (state-dependent) affects are out of scope. function _additive_change(jump, u0, p, t0) base = float.(collect(u0)) Δ = _net_change(jump.affect!, base, p, t0) Δ2 = _net_change(jump.affect!, base .+ one(eltype(base)), p, t0) isapprox(Δ, Δ2) || error( - "this method supports only additive affects (constant net state change), " * + "BoundedSSA supports only additive affects (a constant net state change), " * "but a jump's affect! gave a state-dependent change ($Δ vs $Δ2 from a " * "shifted state). Arbitrary mutating affects are out of scope.") return Δ end -# =========================================================================== -# MAIN: exact iterative constant-rate SSA, StochasticAD-differentiable -# =========================================================================== +# Resolve a JumpProblem into the per-channel data the bounded loop needs: +# `(rates_tuple, Δs)` where `rates_tuple` is the tuple of `ConstantRateJump` rate +# functions (fed to the shared `JumpProcesses.fill_cur_rates!`) and `Δs[k]` is +# channel `k`'s additive net state change. +function _constant_rate_channels(jprob, u0, p, t0) + maj = jprob.massaction_jump + (maj === nothing || JumpProcesses.get_num_majumps(maj) == 0) || error( + "BoundedSSA does not yet support MassActionJump. `evalrxrate` is not " * + "triple-generic: its `::R` return assertion pins the rate to the " * + "`scaled_rates` element type, and the order>1 branch tests `specpop <= 0` " * + "(a boolean on triple-valued species). Mass-action rate constants also " * + "flow through `param_mapper(p)`. Build the model with ConstantRateJumps, " * + "or track the mass-action follow-up.") + vj = jprob.variable_jumps + (vj === nothing || isempty(vj)) || error( + "BoundedSSA supports jump-only constant-rate problems only; it does not " * + "support VariableRateJumps.") + cjumps = jprob.constant_jumps + (cjumps === nothing || isempty(cjumps)) && error( + "BoundedSSA requires at least one ConstantRateJump.") + rates_tuple, _ = JumpProcesses.get_jump_info_tuples(cjumps) + Δs = [_additive_change(j, u0, p, t0) for j in cjumps] + return rates_tuple, Δs +end """ - constant_rate_ssa_final_state(jprob, p; nmax, tspan = jprob.prob.tspan, - return_saturation = false) + bounded_ssa_final_state(jprob, p; nmax, tspan = jprob.prob.tspan, + return_saturation = false) + +Final state at `tspan[2]` of a jump-only `ConstantRateJump` process, computed so +that StochasticAD's `derivative_estimate`/`stochastic_triple` give correct +gradients — including state-dependent rates such as `rate(u,p,t) = p[1]*u[1]`. +This is the differentiable core behind [`BoundedSSA`](@ref). -Final state at `tspan[2]` of a **jump-only** `ConstantRateJump` process, computed -so that StochasticAD's `derivative_estimate`/`stochastic_triple` give correct -gradients — **including state-dependent rates** such as `rate(u,p,t) = p[1]*u[1]`. +Per-channel rates are computed via the same `JumpProcesses.fill_cur_rates!` helper +the `Direct` aggregator uses, so a triple-valued rate passes through the existing +rate machinery. See [`BoundedSSA`](@ref) for the method and scope, and +[`saturation_probability`](@ref) for sizing `nmax`. -Wrap in `derivative_estimate` (one scalar partial at a time): +`return_saturation = true` returns `(u, active)`; a non-zero primal `active` means +the trajectory used all `nmax` events and may be truncated. ```julia derivative_estimate(p0[k]) do pk pv = [j == k ? pk : oftype(pk, p0[j]) for j in eachindex(p0)] - observable(constant_rate_ssa_final_state(jprob, pv; nmax = 500)) + bounded_ssa_final_state(jprob, pv; nmax = 500)[1] end ``` - -# Method - -A naive Gillespie loop (`t += randexp()/rate; t < T || break`) drops the -event-count derivative, because `t < T` branches on the primal time. Instead this -runs a **fixed-length** loop of `nmax` potential events and replaces the -time-comparison with a tracked `Bernoulli`: - - - `p_occ = 1 - exp(-total_rate·(T - t))`, `occurs ~ Bernoulli(p_occ)` — the - parameter-dependent probability the next event lands before `T`; - - the event time is the **truncated** exponential `-log(1 - U·p_occ)/total`; - - the channel is chosen by **stick-breaking** conditional `Bernoulli`s (the last - channel deterministic), avoiding `Categorical` 0/0 and triple-valued indexing; - - state updates use **multiplicative-select masking** (`step = active*occurs`), - so once an event fails to occur the trajectory is frozen — no branching on a - triple-valued decision. - -This is **exact** (the trajectory distribution equals the SSA's), not a τ-leap -approximation. Rates are recomputed from the current (triple) state every event. - -# Arguments / scope - - - `nmax` (required): fixed upper bound on the number of events. The loop always - runs `nmax` steps; once the chain breaks the rest are masked. If the true - count can exceed `nmax` the result is biased — choose `nmax` large enough that - saturation is negligible (see `return_saturation`). - - `return_saturation = true` returns `(u, active)`; `active != 0` means the - trajectory still had `nmax` consecutive events (i.e. it may be truncated). - Use this on the primal (Float64 `p`) to estimate the saturation probability. - -# Limitations - - - `ConstantRateJump` only (not `VariableRateJump`); jump-only, no continuous drift. - - Additive affects only (`Δuₖ` inferred from `affect!`, checked for state-independence). - - The differentiation parameter must enter through the `p` argument of the rates. """ -function JumpProcesses.constant_rate_ssa_final_state(jprob, p; nmax, +function JumpProcesses.bounded_ssa_final_state(jprob, p; nmax, tspan = jprob.prob.tspan, return_saturation = false) - u0 = jprob.prob.u0 - jumps = jprob.constant_jumps - t0, tf = tspan - K = length(jumps) - n = length(u0) + prob = jprob.prob + u0 = prob.u0 + t0, tf = first(tspan), last(tspan) - # additive net change per channel (Float64), checked for state-independence - Δ = [_additive_change(jumps[k], u0, p, t0) for k in 1:K] + rates_tuple, Δs = _constant_rate_channels(jprob, u0, p, t0) + K = length(Δs) + n = length(u0) - z = 0 * sum(p) # triple zero - u = [float(u0[i]) + z for i in 1:n] # triple-typed state - t = z - active = 1 + z # 1 while the event chain is unbroken + z = 0 * sum(p) # triple zero (value 0) carrying p's type + u = [float(u0[i]) + z for i in 1:n] # triple-typed state + t = float(t0) + z + active = 1 + z # 1 while the event chain is unbroken for _ in 1:nmax - rates = [jumps[k].rate(u, p, t) for k in 1:K] # state-dependent OK - total = sum(rates) - Δt = tf - t - pocc = 1 - exp(-total * Δt) # P(next event before T) + # raw per-channel rates via the shared aggregator helper (triples flow through) + cur = [z for _ in 1:K] + JumpProcesses.fill_cur_rates!(cur, u, p, t, nothing, rates_tuple) + total = sum(cur) + + Δt = tf - t + pocc = 1 - exp(-total * Δt) # P(next event before tf) occurs = rand(Bernoulli(pocc)) - step = active * occurs + step = active * occurs # which channel: stick-breaking conditional Bernoullis + multiplicative - # select; last channel deterministic, suffix-sum denominator (in [0,1)). + # select; last channel deterministic, suffix-sum denominator in [0, 1). notchosen = 1 + z sel = [z for _ in 1:n] - for k in 1:K - if k < K - denom = sum(rates[j] for j in k:K) + 1e-300 - chose = rand(Bernoulli(rates[k] / denom)) - else - chose = 1 + z - end + @inbounds for k in 1:K + chose = k < K ? + rand(Bernoulli(cur[k] / (sum(cur[j] for j in k:K) + 1e-300))) : + (1 + z) take = notchosen * chose - sel = [sel[i] + take * Δ[k][i] for i in 1:n] + Δk = Δs[k] + sel = [sel[i] + take * Δk[i] for i in 1:n] notchosen = notchosen * (1 - chose) end U = rand() - τ = -log(1 - U * pocc) / (total + 1e-300) # truncated-exp time + τ = -log(1 - U * pocc) / (total + 1e-300) # truncated-exponential time t = t + step * τ u = [u[i] + step * sel[i] for i in 1:n] active = active * occurs @@ -158,41 +156,41 @@ function JumpProcesses.constant_rate_ssa_final_state(jprob, p; nmax, return return_saturation ? (u, active) : u end -# =========================================================================== -# NARROW helper: state-independent homogeneous-Poisson closed form -# =========================================================================== - """ - poisson_count_final_state(jprob, p; tspan = jprob.prob.tspan) -> u(tf) + saturation_probability(jprob, p; nmax, tspan = jprob.prob.tspan, ntrials = 1000) -Closed-form final state for the special case of **state-independent** constant -rates (a sum of homogeneous Poisson processes): `Nₖ ~ Poisson(λₖ·ΔT)`, -`u(tf) = u0 + Σₖ Nₖ·Δuₖ`. StochasticAD differentiates the `Poisson` sampler -directly. - -This is **only** for state-independent homogeneous-Poisson additive jump systems. -It is **not** the iterative SSA path and does **not** cover state-dependent -`ConstantRateJump` rates — use [`constant_rate_ssa_final_state`](@ref) for those. -The rate is read once at `(u0, p, t0)`; a state-independence check guards misuse. +Monte-Carlo estimate of `P(N > nmax)` — the probability the process has more than +`nmax` events on `tspan`, i.e. the bias of the bounded SSA path. Call with +ordinary (`Float64`) parameters `p`; size `nmax` so this is negligible. """ -function JumpProcesses.poisson_count_final_state(jprob, p; tspan = jprob.prob.tspan) - u0 = jprob.prob.u0 - jumps = jprob.constant_jumps - t0, tf = tspan - ΔT = tf - t0 - n = length(u0) - - base_shift = float.(collect(u0)) .+ one(eltype(float.(collect(u0)))) - for jump in jumps - isapprox(_val(jump.rate(u0, p, t0)), _val(jump.rate(base_shift, p, t0))) || error( - "poisson_count_final_state requires state-INDEPENDENT rates; a rate " * - "changed with the state. Use constant_rate_ssa_final_state for " * - "state-dependent ConstantRateJumps.") +function JumpProcesses.saturation_probability(jprob, p; nmax, + tspan = jprob.prob.tspan, ntrials = 1000) + nsat = 0 + for _ in 1:ntrials + _, active = JumpProcesses.bounded_ssa_final_state(jprob, p; nmax, tspan, + return_saturation = true) + (_val(active) != 0) && (nsat += 1) end + return nsat / ntrials +end - Δ = map(jump -> _additive_change(jump, u0, p, t0), jumps) - N = map(jump -> rand(Poisson(jump.rate(u0, p, t0) * ΔT)), jumps) - return [float(u0[i]) + sum(N[k] * Δ[k][i] for k in eachindex(jumps)) for i in 1:n] +# solve(jprob, BoundedSSA(; nmax)): run the bounded path and return a minimal +# (start, end) solution. `sol.u[end]` is the differentiable final state. +function DiffEqBase.__solve(jprob::JumpProblem, alg::BoundedSSA; + seed = nothing, tspan = jprob.prob.tspan, kwargs...) + seed === nothing || Random.seed!(seed) + prob = jprob.prob + u_final = JumpProcesses.bounded_ssa_final_state(jprob, prob.p; nmax = alg.nmax, + tspan = tspan) + # promote u0 to the (possibly triple) final-state type without needing a + # convert(::StochasticTriple, ::Float64): multiply by a clean zero. + u0p = [u_final[i] * 0 + float(prob.u0[i]) for i in eachindex(prob.u0)] + ts = [float(first(tspan)), float(last(tspan))] + us = [u0p, u_final] + DiffEqBase.build_solution(prob, alg, ts, us; + calculate_error = false, + stats = DiffEqBase.Stats(0), + interp = DiffEqBase.ConstantInterpolation(ts, us)) end end # module diff --git a/src/JumpProcesses.jl b/src/JumpProcesses.jl index b796b49ec..cf7ea989b 100644 --- a/src/JumpProcesses.jl +++ b/src/JumpProcesses.jl @@ -114,10 +114,13 @@ export VariableRateAggregator, VR_FRM, VR_Direct, VR_DirectFW # StochasticAD support. Stubs; methods are provided by the package extension # `ext/JumpProcessesStochasticADExt.jl`, which loads only when StochasticAD and -# Distributions are both available. No StochasticAD code lives in `src/`. -function constant_rate_ssa_final_state end -function poisson_count_final_state end -export constant_rate_ssa_final_state, poisson_count_final_state +# Distributions are both available. No StochasticAD code lives in `src/`. The +# `BoundedSSA` algorithm struct itself lives in `src` (see SSA_stepper.jl) so it +# is always referenceable/documentable; only its `solve` implementation is in the +# extension. +function bounded_ssa_final_state end +function saturation_probability end +export bounded_ssa_final_state, saturation_probability """ Aggregator to indicate that individual jumps should also be handled via the leaping @@ -134,7 +137,7 @@ include("solve.jl") export init, solve, solve! include("SSA_stepper.jl") -export SSAStepper +export SSAStepper, BoundedSSA # leaping: include("simple_regular_solve.jl") diff --git a/src/SSA_stepper.jl b/src/SSA_stepper.jl index d7b6c4aef..35accf32a 100644 --- a/src/SSA_stepper.jl +++ b/src/SSA_stepper.jl @@ -59,6 +59,57 @@ for details. struct SSAStepper <: DiffEqBase.AbstractDEAlgorithm end SciMLBase.allows_late_binding_tstops(::SSAStepper) = true +""" + BoundedSSA(; nmax) + +A StochasticAD-compatible SSA algorithm for **jump-only** `ConstantRateJump` +`DiscreteProblem`s, enabling correct gradients via StochasticAD's +`derivative_estimate`/`stochastic_triple`. + +The stock `SSAStepper` cannot be differentiated with StochasticAD: it decides the +number of events with a `while integrator.t < integrator.tstop < end_time` loop, +i.e. a boolean predicate on (triple-valued) time, which StochasticAD forbids by +design — so the event-count derivative is dropped (a state-dependent rate yields +a gradient of `0`). `BoundedSSA` instead runs a **fixed-length loop of at most +`nmax` jump attempts**, replacing the data-dependent control flow with stochastic +primitives (a tracked `Bernoulli` for "is there a next event before the end +time?", a truncated exponential for the time, stick-breaking `Bernoulli`s for the +channel, and multiplicative-select masking to freeze the trajectory once the +event chain breaks). This is exact up to `P(N > nmax)`, the probability of more +than `nmax` events. + +With ordinary `Float64` parameters `solve(jprob, BoundedSSA(; nmax))` is just a +bounded SSA simulation; with StochasticAD triples it differentiates. + +# Keyword arguments + + - `nmax` (required): the fixed upper bound on the number of jump events. If the + true count can exceed `nmax` the result is biased; size it from + [`saturation_probability`](@ref). + +# Scope / limitations + + - `ConstantRateJump`s only (state-dependent rates are supported); jump-only, no + continuous drift, no `VariableRateJump`. `MassActionJump` is not yet supported + (`evalrxrate` is not triple-generic and mass-action rate constants flow + through `param_mapper`). + - Additive affects only (the net change is inferred from `affect!` and checked). + - The differentiation parameter must enter through `prob.p`. + - The implementation lives in the `JumpProcessesStochasticADExt` extension, so + `StochasticAD` and `Distributions` must both be loaded to `solve` with it. + +See also [`bounded_ssa_final_state`](@ref) (the differentiable core) and +[`saturation_probability`](@ref). +""" +struct BoundedSSA{N} <: DiffEqBase.AbstractDEAlgorithm + nmax::N +end +function BoundedSSA(; nmax = nothing) + nmax === nothing && error("BoundedSSA requires the keyword argument `nmax` " * + "(a fixed upper bound on the number of jump events).") + BoundedSSA{typeof(nmax)}(nmax) +end + """ $(TYPEDEF) diff --git a/src/aggregators/direct.jl b/src/aggregators/direct.jl index ab33dd842..47e1c0f5f 100644 --- a/src/aggregators/direct.jl +++ b/src/aggregators/direct.jl @@ -68,33 +68,47 @@ end ######################## SSA specific helper routines ######################## +# Fill `cur_rates` with the *raw* (non-cumulative) per-channel rates: mass-action +# rates first (indices `1:get_num_majumps(majumps)`), then constant-jump rates. +# Shared by `Direct`'s `time_to_next_jump` (which then forms the running +# cumulative sum used for channel sampling) and by the StochasticAD bounded SSA +# path (which uses the raw rates directly). Tuple rates use the type-stable +# recursive `fill_cur_rates`; function-wrapper rates use a plain loop. This is the +# one place per-channel rates are computed, so a generic (e.g. StochasticTriple) +# rate type flows through it unchanged. +@inline function fill_cur_rates!(cur_rates, u, p, t, majumps, rates::Tuple) + nma = get_num_majumps(majumps) + @inbounds for i in 1:nma + cur_rates[i] = evalrxrate(u, i, majumps) + end + isempty(rates) || fill_cur_rates(u, p, t, cur_rates, nma + 1, rates...) + nothing +end + +@inline function fill_cur_rates!(cur_rates, u, p, t, majumps, rates::AbstractArray) + nma = get_num_majumps(majumps) + @inbounds for i in 1:nma + cur_rates[i] = evalrxrate(u, i, majumps) + end + @inbounds for k in eachindex(rates) + cur_rates[nma + k] = rates[k](u, p, t) + end + nothing +end + # tuple-based constant jumps function time_to_next_jump(p::DirectJumpAggregation{T, S, F1}, u, params, t) where {T, S, F1 <: Tuple} - prev_rate = zero(t) - new_rate = zero(t) cur_rates = p.cur_rates + fill_cur_rates!(cur_rates, u, params, t, p.ma_jumps, p.rates) - # mass action rates - majumps = p.ma_jumps - idx = get_num_majumps(majumps) - @inbounds for i in 1:idx - new_rate = evalrxrate(u, i, majumps) - cur_rates[i] = add_fast(new_rate, prev_rate) + # form the running cumulative sum used by `generate_jumps!` for channel sampling + prev_rate = zero(t) + @inbounds for i in eachindex(cur_rates) + cur_rates[i] = add_fast(cur_rates[i], prev_rate) prev_rate = cur_rates[i] end - # constant jump rates - rates = p.rates - if !isempty(rates) - idx += 1 - fill_cur_rates(u, params, t, cur_rates, idx, rates...) - @inbounds for i in idx:length(cur_rates) - cur_rates[i] = add_fast(cur_rates[i], prev_rate) - prev_rate = cur_rates[i] - end - end - @inbounds sum_rate = cur_rates[end] sum_rate, randexp(p.rng) / sum_rate end @@ -113,29 +127,16 @@ end # function wrapper-based constant jumps function time_to_next_jump(p::DirectJumpAggregation{T, S, F1}, u, params, t) where {T, S, F1 <: AbstractArray} - prev_rate = zero(t) - new_rate = zero(t) cur_rates = p.cur_rates + fill_cur_rates!(cur_rates, u, params, t, p.ma_jumps, p.rates) - # mass action rates - majumps = p.ma_jumps - idx = get_num_majumps(majumps) - @inbounds for i in 1:idx - new_rate = evalrxrate(u, i, majumps) - cur_rates[i] = add_fast(new_rate, prev_rate) + # form the running cumulative sum used by `generate_jumps!` for channel sampling + prev_rate = zero(t) + @inbounds for i in eachindex(cur_rates) + cur_rates[i] = add_fast(cur_rates[i], prev_rate) prev_rate = cur_rates[i] end - # constant jump rates - idx += 1 - rates = p.rates - @inbounds for i in 1:length(p.rates) - new_rate = rates[i](u, params, t) - cur_rates[idx] = add_fast(new_rate, prev_rate) - prev_rate = cur_rates[idx] - idx += 1 - end - @inbounds sum_rate = cur_rates[end] sum_rate, randexp(p.rng) / sum_rate end diff --git a/test/fill_cur_rates_regression.jl b/test/fill_cur_rates_regression.jl new file mode 100644 index 000000000..7f84423ea --- /dev/null +++ b/test/fill_cur_rates_regression.jl @@ -0,0 +1,95 @@ +using JumpProcesses +using Statistics, Random, Test + +# Regression for the shared `fill_cur_rates!` refactor (src/aggregators/direct.jl): +# the raw rate-fill was factored out of `time_to_next_jump` so the StochasticAD +# bounded SSA path can reuse it. The stock Direct / DirectFW SSA path must be +# UNCHANGED by that. This covers mass-action rates (`evalrxrate`), multi-channel +# cumulative selection (`searchsortedfirst`), the mixed mass-action + constant +# ordering, and the function-wrapper (`DirectFW`) path. Uses only JumpProcesses +# (no OrdinaryDiffEq), so it runs in the isolated StochasticAD env alongside the +# change it guards. + +# final value of species 1 over `nT` independent, seeded trajectories +function final_species1(jprob, alg; nT, seedbase = 0) + xs = Vector{Float64}(undef, nT) + for i in 1:nT + xs[i] = solve(jprob, alg; seed = seedbase + i).u[end][1] + end + return xs +end + +@testset "fill_cur_rates! stock-path regression (Direct/DirectFW)" begin + T = 1.0 + + # --- mass-action, single channel: X --(μ)--> 0, rate μ·X ----------------- + # E[X(T)] = X0·e^{-μT} + @testset "mass-action pure death (Direct)" begin + X0, μ = 100, 0.5 + maj = MassActionJump([μ], [[1 => 1]], [[1 => -1]]) + jp = JumpProblem(DiscreteProblem([X0], (0.0, T)), Direct(), maj) + analytic = X0 * exp(-μ * T) + xs = final_species1(jp, SSAStepper(); nT = 5000) + @test abs(mean(xs) - analytic) < 4 * std(xs) / sqrt(length(xs)) + end + + # --- mass-action, multi-channel: A <-(k1,k2)-> B (both first order) ------- + # linear system => the mean is exact: A_ss = k2·N/(k1+k2), + # E[A(T)] = A_ss + (A0 - A_ss)·e^{-(k1+k2)T}. Exercises evalrxrate over two + # channels and the cumulative-rate channel selection. + @testset "mass-action reversible two-channel (Direct)" begin + A0, B0, k1, k2 = 80, 20, 0.8, 0.4 + N = A0 + B0 + maj = MassActionJump([k1, k2], + [[1 => 1], [2 => 1]], + [[1 => -1, 2 => 1], [1 => 1, 2 => -1]]) + jp = JumpProblem(DiscreteProblem([A0, B0], (0.0, T)), Direct(), maj) + Ass = k2 * N / (k1 + k2) + analytic = Ass + (A0 - Ass) * exp(-(k1 + k2) * T) + xs = final_species1(jp, SSAStepper(); nT = 5000) + @test abs(mean(xs) - analytic) < 4 * std(xs) / sqrt(length(xs)) + @test all(x -> 0 <= x <= N, xs) # conservation A+B=N keeps A in [0,N] + end + + # --- mixed mass-action + constant jump: birth (constant λ) + death (μ·X) -- + # fills mass-action rate first, then the constant-jump rate, in fill_cur_rates!. + # E[X(T)] = λ/μ + (X0 - λ/μ)·e^{-μT} + @testset "mixed mass-action death + constant birth (Direct)" begin + X0, λ, μ = 50, 10.0, 0.3 + birth = ConstantRateJump((u, p, t) -> λ, integ -> (integ.u[1] += 1; nothing)) + death = MassActionJump([μ], [[1 => 1]], [[1 => -1]]) + jp = JumpProblem(DiscreteProblem([X0], (0.0, T)), Direct(), birth, death) + analytic = λ / μ + (X0 - λ / μ) * exp(-μ * T) + xs = final_species1(jp, SSAStepper(); nT = 5000) + @test abs(mean(xs) - analytic) < 4 * std(xs) / sqrt(length(xs)) + end + + # --- same mixed model via the function-wrapper aggregator ----------------- + # exercises the AbstractArray fill_cur_rates! / time_to_next_jump methods. + @testset "mixed model via DirectFW (fwrapper path)" begin + X0, λ, μ = 50, 10.0, 0.3 + birth = ConstantRateJump((u, p, t) -> λ, integ -> (integ.u[1] += 1; nothing)) + death = MassActionJump([μ], [[1 => 1]], [[1 => -1]]) + jp = JumpProblem(DiscreteProblem([X0], (0.0, T)), DirectFW(), birth, death) + analytic = λ / μ + (X0 - λ / μ) * exp(-μ * T) + xs = final_species1(jp, SSAStepper(); nT = 5000) + @test abs(mean(xs) - analytic) < 4 * std(xs) / sqrt(length(xs)) + end + + # --- Direct and DirectFW must produce IDENTICAL trajectories -------------- + # same rates, same cumulative sums, same RNG draws => bit-identical paths. + # Deterministic check (no MC tolerance): the sharpest guard that both + # fill_cur_rates! methods compute the same per-channel and cumulative rates. + @testset "Direct == DirectFW (identical trajectories)" begin + X0, λ, μ = 40, 8.0, 0.5 + mkbirth() = ConstantRateJump((u, p, t) -> λ, integ -> (integ.u[1] += 1; nothing)) + mkdeath() = MassActionJump([μ], [[1 => 1]], [[1 => -1]]) + jp_d = JumpProblem(DiscreteProblem([X0], (0.0, T)), Direct(), mkbirth(), mkdeath()) + jp_f = JumpProblem(DiscreteProblem([X0], (0.0, T)), DirectFW(), mkbirth(), mkdeath()) + for i in 1:25 + sd = solve(jp_d, SSAStepper(); seed = i) + sf = solve(jp_f, SSAStepper(); seed = i) + @test sd.u[end] == sf.u[end] + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 06f01c3c8..011374567 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,7 +14,7 @@ end # OrdinaryDiffEq stack, so it is kept out of the main test target and run here # in its own project (no ODE solver needed -- the extension never calls `solve`). function activate_stochasticad_env() - Pkg.activate("stochasticad") + Pkg.activate(joinpath(@__DIR__, "stochasticad")) Pkg.develop(PackageSpec(path = dirname(@__DIR__))) Pkg.instantiate() end @@ -75,6 +75,9 @@ end if GROUP == "StochasticAD" activate_stochasticad_env() + # Guards the shared `fill_cur_rates!` refactor the bounded path depends on. + # JumpProcesses-only (no OrdinaryDiffEq), so it runs in this isolated env. + @time @safetestset "fill_cur_rates! Stock-Path Regression" begin include("fill_cur_rates_regression.jl") end @time @safetestset "StochasticAD Extension Tests" begin include("stochasticad_tests.jl") end end diff --git a/test/stochasticad/Project.toml b/test/stochasticad/Project.toml index a9253003c..3d752c661 100644 --- a/test/stochasticad/Project.toml +++ b/test/stochasticad/Project.toml @@ -1,4 +1,6 @@ [deps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StochasticAD = "e4facb34-4f7e-4bec-b153-e122c37934ac" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/stochasticad_tests.jl b/test/stochasticad_tests.jl index 0e2812793..0f18f4400 100644 --- a/test/stochasticad_tests.jl +++ b/test/stochasticad_tests.jl @@ -2,8 +2,9 @@ using JumpProcesses, StochasticAD, Distributions using Statistics, Random, Test # StochasticAD-compatible differentiation for jump-only ConstantRateJump SSA -# problems. Needs only JumpProcesses + StochasticAD + Distributions (no ODE -# solver), so it runs in its own isolated environment (GROUP=StochasticAD). +# problems (the BoundedSSA path). Needs only JumpProcesses + StochasticAD + +# Distributions (no ODE solver), so it runs in its own isolated environment +# (GROUP=StochasticAD). # per-partial StochasticAD gradient with fixed seeds (reproducible) function sad_partial(f, p0, k; N) @@ -18,9 +19,9 @@ function sad_partial(f, p0, k; N) return mean(s), std(s) / sqrt(N) end -@testset "StochasticAD constant-rate jumps" begin +@testset "StochasticAD constant-rate jumps (BoundedSSA)" begin - # --- Test A: pure death, STATE-DEPENDENT rate (the case Chris cares about) --- + # --- Test A: pure death, STATE-DEPENDENT rate (the case that matters) --- # rate = μ·u[1] ; E[u(T)] = u0·e^{-μT} ; d/dμ E[u(T)] = -T·u0·e^{-μT} # A naive `while t 1.0 # explicitly NOT the zero a naive SSA gives @@ -47,7 +48,7 @@ end analytic = [(1 - b) / μ0, -λ0 / μ0^2 * (1 - b) + (u0 - a) * (-T * b)] for k in 1:2 g, se = sad_partial([λ0, μ0], k; N = 10000) do p - constant_rate_ssa_final_state(jprob, p; nmax = 400)[1] + bounded_ssa_final_state(jprob, p; nmax = 400)[1] end @test abs(g - analytic[k]) < 4 * se end @@ -61,29 +62,80 @@ end j2 = ConstantRateJump((u, p, t) -> p[2], integ -> (integ.u[1] -= 1; nothing)) jprob = JumpProblem(DiscreteProblem([u0], (0.0, T)), Direct(), j1, j2) λ0 = [3.0, 1.0] - # closed-form helper (exact) - for k in 1:2 - g, _ = sad_partial(λ0, k; N = 3000) do p - poisson_count_final_state(jprob, p)[1] - end - @test isapprox(g, k == 1 ? T : -T; atol = 0.05) - end - # the iterative SSA method must agree with the baseline on this case too for k in 1:2 g, se = sad_partial(λ0, k; N = 10000) do p - constant_rate_ssa_final_state(jprob, p; nmax = 200)[1] + bounded_ssa_final_state(jprob, p; nmax = 200)[1] end @test abs(g - (k == 1 ? T : -T)) < 4 * se end end + # --- Test D: BoundedSSA solve path differentiates (native API) --- + # solve(jprob, BoundedSSA(; nmax)) must give the same gradient as the core. + # Constructs the JumpProblem with triple parameters, exercising the generic + # rate cache during construction and the BoundedSSA __solve. + @testset "BoundedSSA solve path" begin + T, u0, μ0 = 1.0, 100, 0.5 + death = ConstantRateJump((u, p, t) -> p[1] * u[1], integ -> (integ.u[1] -= 1; nothing)) + analytic = -T * u0 * exp(-μ0 * T) + # primal: solve returns a sensible final state + jp0 = JumpProblem(DiscreteProblem([u0], (0.0, T), [μ0]), Direct(), death) + sol = solve(jp0, BoundedSSA(; nmax = 200)) + @test length(sol.u[end]) == 1 + @test 0 <= sol.u[end][1] <= u0 + # differentiate through solve + g, se = sad_partial([μ0], 1; N = 4000) do p + jp = JumpProblem(DiscreteProblem([u0], (0.0, T), p), Direct(), death) + solve(jp, BoundedSSA(; nmax = 200)).u[end][1] + end + @test abs(g - analytic) < 4 * se + end + + # --- Test E: primal distribution matches the stock SSA --- + # at large nmax the bounded SSA's mean must match solve(jprob, SSAStepper()). + @testset "primal mean matches stock SSAStepper" begin + T, u0, μ0 = 1.0, 100, 0.4 + death = ConstantRateJump((u, p, t) -> p[1] * u[1], integ -> (integ.u[1] -= 1; nothing)) + jprob = JumpProblem(DiscreteProblem([u0], (0.0, T), [μ0]), Direct(), death) + nT = 4000 + bounded = Vector{Float64}(undef, nT) + stock = Vector{Float64}(undef, nT) + for i in 1:nT + Random.seed!(i) + bounded[i] = bounded_ssa_final_state(jprob, [μ0]; nmax = 400)[1] + Random.seed!(10_000 + i) + stock[i] = solve(jprob, SSAStepper()).u[end][1] + end + analytic = u0 * exp(-μ0 * T) + @test abs(mean(bounded) - analytic) < 4 * std(bounded) / sqrt(nT) + # bounded vs stock means agree within combined MC error + se = sqrt(var(bounded) / nT + var(stock) / nT) + @test abs(mean(bounded) - mean(stock)) < 4 * se + end + + # --- Test F: saturation diagnostic --- + @testset "saturation_probability" begin + T, u0, λ0 = 1.0, 0, 20.0 + birth = ConstantRateJump((u, p, t) -> p[1], integ -> (integ.u[1] += 1; nothing)) + jprob = JumpProblem(DiscreteProblem([u0], (0.0, T)), Direct(), birth) + # ~Poisson(20) events: nmax=5 is far too small, nmax=200 is ample + @test saturation_probability(jprob, [λ0]; nmax = 5, ntrials = 2000) > 0.5 + @test saturation_probability(jprob, [λ0]; nmax = 200, ntrials = 2000) < 0.01 + end + # --- guards: misuse should error, not silently mislead --- @testset "guards" begin T = 1.0 - # state-dependent rate via the Poisson shortcut must error - death = ConstantRateJump((u, p, t) -> p[1] * u[1], integ -> (integ.u[1] -= 1; nothing)) - jp = JumpProblem(DiscreteProblem([10], (0.0, T)), Direct(), death) - @test_throws ErrorException poisson_count_final_state(jp, [0.5]) + # MassActionJump not yet supported + majump = MassActionJump([0.5], [[1 => 1]], [[1 => -1]]) + jp_ma = JumpProblem(DiscreteProblem([10], (0.0, T), [0.5]), Direct(), majump) + @test_throws ErrorException bounded_ssa_final_state(jp_ma, [0.5]; nmax = 50) + # missing nmax on the algorithm + @test_throws ErrorException BoundedSSA() + # non-additive (state-dependent) affect + weird = ConstantRateJump((u, p, t) -> p[1], integ -> (integ.u[1] *= 2; nothing)) + jp_w = JumpProblem(DiscreteProblem([10], (0.0, T)), Direct(), weird) + @test_throws ErrorException bounded_ssa_final_state(jp_w, [0.5]; nmax = 50) end # --- generic rate cache (Chris's literal request): a triple-valued rate must @@ -91,7 +143,7 @@ end # SCOPE: this only makes the cache generic. It does NOT by itself give correct # gradients — the stock `while t < T` event boundary still drops the event-count # derivative, and a full stock solve would next hit the SSAStepper integrator's - # Float64 time. Correct gradients use constant_rate_ssa_final_state (above). + # Float64 time. Correct gradients use BoundedSSA / bounded_ssa_final_state. @testset "generic rate cache (no Float64(::StochasticTriple))" begin st = stochastic_triple(identity, 0.5) jump = ConstantRateJump((u, p, t) -> p[1] * u[1], integ -> (integ.u[1] -= 1; nothing)) From 2f611e6525dd6c3ae1e960a8b0541d0db23fbabe Mon Sep 17 00:00:00 2001 From: Roman Sahakyan Date: Wed, 17 Jun 2026 15:29:56 +0400 Subject: [PATCH 4/5] CI: register StochasticAD test group so grouped-tests runs it The grouped-tests.yml refactor (#604) drives groups from test/test_groups.toml, which did not carry over the StochasticAD group the old Tests.yml matrix declared -- so CI was not running the BoundedSSA tests. runtests.jl already dispatches GROUP=StochasticAD into the isolated env (no special runner needed). --- test/test_groups.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/test_groups.toml b/test/test_groups.toml index 443c5cbb1..620d9fdca 100644 --- a/test/test_groups.toml +++ b/test/test_groups.toml @@ -6,3 +6,6 @@ versions = ["lts", "1", "pre"] [QA] versions = ["lts", "1"] + +[StochasticAD] +versions = ["lts", "1"] From c222c6dd5856acad6185956c988680a2a64d08ed Mon Sep 17 00:00:00 2001 From: Roman Sahakyan Date: Sun, 21 Jun 2026 12:38:50 +0400 Subject: [PATCH 5/5] StochasticAD: bounded-propensity BoundedSSA, separate from Direct Reworked per review (isaacsas): Direct path reverted to upstream (untouched); StochasticAD is a separate DiffEqBase.solve dispatch + bounded_ssa_path helper. Uniformization/thinning against a constant total-propensity bound (rate_bound) with null events -- unbiased, iterates (rates recomputed per event), StochasticAD-compatible (parameter-free candidate times, no boolean on triple time). Full path / saveat support reusing the package's _process_saveat (same as SimpleTauLeaping); sol(t) via ConstantInterpolation. Drops the prior nmax approach, generic rate cache, saturation_probability, and the vestigial regression test. Tests 16/16; Direct regression 30/30. MassActionJump deferred. --- ext/JumpProcessesStochasticADExt.jl | 246 ++++++++++++++-------------- src/JumpProcesses.jl | 11 +- src/SSA_stepper.jl | 83 ++++++---- src/aggregators/direct.jl | 75 +++++---- src/aggregators/ssajump.jl | 36 +--- test/fill_cur_rates_regression.jl | 95 ----------- test/runtests.jl | 3 - test/stochasticad_tests.jl | 155 +++++++----------- 8 files changed, 286 insertions(+), 418 deletions(-) delete mode 100644 test/fill_cur_rates_regression.jl diff --git a/ext/JumpProcessesStochasticADExt.jl b/ext/JumpProcessesStochasticADExt.jl index 69264b671..9d7607e58 100644 --- a/ext/JumpProcessesStochasticADExt.jl +++ b/ext/JumpProcessesStochasticADExt.jl @@ -2,32 +2,35 @@ module JumpProcessesStochasticADExt # StochasticAD-compatible differentiation for jump-only `ConstantRateJump` SSA # problems — the implementation behind the `BoundedSSA` algorithm and the -# `bounded_ssa_final_state` / `saturation_probability` entry points. +# `bounded_ssa_path` entry point. # # Why this exists: the stock `solve(jprob, SSAStepper())` cannot be differentiated -# with StochasticAD. It decides the number of events with a -# `while integrator.t < integrator.tstop < end_time` loop — a boolean predicate on -# (triple-valued) time, which StochasticAD forbids by design — so the event-count -# derivative (the dominant term for state-dependent rates) is dropped and a -# state-dependent rate gives a gradient of 0. We instead run a fixed-length loop of -# at most `nmax` jump attempts, representing every discrete decision through -# stochastic primitives / masks rather than Julia branches. Exact up to -# `P(N > nmax)`. +# with StochasticAD. It advances time with a `while integrator.t < integrator.tstop +# < end_time` loop — a boolean predicate on (triple-valued) time, which StochasticAD +# forbids by design — so the event-count derivative (the dominant term for +# state-dependent rates) is dropped and a state-dependent rate gives a gradient of 0. +# +# Instead we use UNIFORMIZATION (thinning) against a constant total-propensity bound +# `Λ = rate_bound`: candidate event times are a homogeneous Poisson process of rate +# `Λ` (parameter-free, so the loop never branches on a triple and times stay +# Float64); at each candidate the event is accepted with a tracked +# `Bernoulli(total_rate(u)/Λ)` (else a null event), and the channel is chosen by +# stick-breaking. All parameter dependence flows through the accept/channel +# Bernoullis, so the gradient is captured; it is unbiased given a valid `Λ`, and +# `saveat` is exact because the candidate times are fixed Float64. +# +# This code is fully SEPARATE from the `Direct` aggregator: it reads `jump.rate` / +# `jump.affect!` only and never touches `time_to_next_jump` or the Direct rate cache. # # Scope: jump-only `DiscreteProblem`s with `ConstantRateJump`s (state-dependent -# rates OK) and additive affects. No `MassActionJump` (see `_constant_rate_channels` -# for why), no `VariableRateJump`, no continuous drift, no rootfinding. +# rates OK) and additive affects. No `MassActionJump`, no `VariableRateJump`. using JumpProcesses using StochasticAD -using Distributions: Bernoulli +using Distributions: Bernoulli, Poisson using DiffEqBase using Random -# primal value of a (possibly triple) scalar. -_val(x) = x -_val(x::StochasticAD.StochasticTriple) = StochasticAD.value(x) - # minimal integrator-like object so a jump's `affect!` can be applied to a scratch # state to read off its net effect. mutable struct ShimIntegrator{U, P, T} @@ -44,8 +47,8 @@ function _net_change(affect!, ubase, p, t0) end # infer a jump's net state change, verifying it is *additive* (same change from two -# different base states). The final state is built by adding `Δuₖ` on each event, -# so non-additive (state-dependent) affects are out of scope. +# different base states). Additive affects are required: the state is built by +# adding `Δuₖ` on each event. function _additive_change(jump, u0, p, t0) base = float.(collect(u0)) Δ = _net_change(jump.affect!, base, p, t0) @@ -57,140 +60,145 @@ function _additive_change(jump, u0, p, t0) return Δ end -# Resolve a JumpProblem into the per-channel data the bounded loop needs: -# `(rates_tuple, Δs)` where `rates_tuple` is the tuple of `ConstantRateJump` rate -# functions (fed to the shared `JumpProcesses.fill_cur_rates!`) and `Δs[k]` is -# channel `k`'s additive net state change. -function _constant_rate_channels(jprob, u0, p, t0) +function _check_supported(jprob) + jprob.prob isa DiscreteProblem || error( + "BoundedSSA only supports JumpProblems defined over DiscreteProblems " * + "(pure jumps, no continuous drift).") maj = jprob.massaction_jump (maj === nothing || JumpProcesses.get_num_majumps(maj) == 0) || error( - "BoundedSSA does not yet support MassActionJump. `evalrxrate` is not " * - "triple-generic: its `::R` return assertion pins the rate to the " * - "`scaled_rates` element type, and the order>1 branch tests `specpop <= 0` " * - "(a boolean on triple-valued species). Mass-action rate constants also " * - "flow through `param_mapper(p)`. Build the model with ConstantRateJumps, " * - "or track the mass-action follow-up.") + "BoundedSSA does not yet support MassActionJump; build the model with " * + "ConstantRateJumps.") vj = jprob.variable_jumps (vj === nothing || isempty(vj)) || error( "BoundedSSA supports jump-only constant-rate problems only; it does not " * "support VariableRateJumps.") - cjumps = jprob.constant_jumps - (cjumps === nothing || isempty(cjumps)) && error( + cj = jprob.constant_jumps + (cj === nothing || isempty(cj)) && error( "BoundedSSA requires at least one ConstantRateJump.") - rates_tuple, _ = JumpProcesses.get_jump_info_tuples(cjumps) - Δs = [_additive_change(j, u0, p, t0) for j in cjumps] - return rates_tuple, Δs + nothing end -""" - bounded_ssa_final_state(jprob, p; nmax, tspan = jprob.prob.tspan, - return_saturation = false) - -Final state at `tspan[2]` of a jump-only `ConstantRateJump` process, computed so -that StochasticAD's `derivative_estimate`/`stochastic_triple` give correct -gradients — including state-dependent rates such as `rate(u,p,t) = p[1]*u[1]`. -This is the differentiable core behind [`BoundedSSA`](@ref). - -Per-channel rates are computed via the same `JumpProcesses.fill_cur_rates!` helper -the `Direct` aggregator uses, so a triple-valued rate passes through the existing -rate machinery. See [`BoundedSSA`](@ref) for the method and scope, and -[`saturation_probability`](@ref) for sizing `nmax`. - -`return_saturation = true` returns `(u, active)`; a non-zero primal `active` means -the trajectory used all `nmax` events and may be truncated. - -```julia -derivative_estimate(p0[k]) do pk - pv = [j == k ? pk : oftype(pk, p0[j]) for j in eachindex(p0)] - bounded_ssa_final_state(jprob, pv; nmax = 500)[1] -end -``` -""" -function JumpProcesses.bounded_ssa_final_state(jprob, p; nmax, - tspan = jprob.prob.tspan, return_saturation = false) - prob = jprob.prob - u0 = prob.u0 +# Internal uniformization driver: returns `(tsave, usave)` at the resolved save +# schedule. Uses `JumpProcesses._process_saveat` (from src/simple_regular_solve.jl) +# for the interior save times + save_start/save_end flags, so saveat semantics match +# SimpleTauLeaping and the rest of the package. The save loop mirrors that solver's +# push idiom; all save-time comparisons are on parameter-free Float64 candidate times. +function _bounded_ssa(jprob, p, Λ, tspan, saveat, save_start, save_end) + _check_supported(jprob) + u0 = jprob.prob.u0 + jumps = jprob.constant_jumps t0, tf = first(tspan), last(tspan) + ΔT = tf - t0 + K = length(jumps) + n = length(u0) + + saveat_times, ss, se = JumpProcesses._process_saveat(saveat, (t0, tf), + save_start, save_end) + + Δ = [_additive_change(jumps[k], u0, p, t0) for k in 1:K] # Float64 net change/channel + z = 0 * sum(p) # triple zero carrying p's type + u = [float(u0[i]) + z for i in 1:n] + + tsave = typeof(t0)[] + usave = typeof(u)[] + if ss + push!(tsave, t0) + push!(usave, copy(u)) + end - rates_tuple, Δs = _constant_rate_channels(jprob, u0, p, t0) - K = length(Δs) - n = length(u0) - - z = 0 * sum(p) # triple zero (value 0) carrying p's type - u = [float(u0[i]) + z for i in 1:n] # triple-typed state - t = float(t0) + z - active = 1 + z # 1 while the event chain is unbroken - - for _ in 1:nmax - # raw per-channel rates via the shared aggregator helper (triples flow through) - cur = [z for _ in 1:K] - JumpProcesses.fill_cur_rates!(cur, u, p, t, nothing, rates_tuple) - total = sum(cur) + # candidate events ~ homogeneous Poisson(Λ) on [t0, tf]. PARAMETER-FREE: Λ is a + # constant, so M and the times carry no derivative and never branch on a triple. + M = rand(Poisson(Λ * ΔT)) + ctimes = sort!(t0 .+ ΔT .* rand(M)) + + save_idx = 1 + for m in 1:M + tm = @inbounds ctimes[m] + # record interior save times crossed before this candidate (Float64 compares) + while save_idx <= length(saveat_times) && @inbounds(saveat_times[save_idx]) < tm + push!(tsave, @inbounds saveat_times[save_idx]) + push!(usave, copy(u)) + save_idx += 1 + end - Δt = tf - t - pocc = 1 - exp(-total * Δt) # P(next event before tf) - occurs = rand(Bernoulli(pocc)) - step = active * occurs + rates = [jumps[k].rate(u, p, tm) for k in 1:K] # recomputed at current state + total = sum(rates) + accept = rand(Bernoulli(total / Λ)) # thinning: real vs null event - # which channel: stick-breaking conditional Bernoullis + multiplicative - # select; last channel deterministic, suffix-sum denominator in [0, 1). + # which channel: stick-breaking conditional Bernoullis (last deterministic) notchosen = 1 + z sel = [z for _ in 1:n] - @inbounds for k in 1:K + for k in 1:K chose = k < K ? - rand(Bernoulli(cur[k] / (sum(cur[j] for j in k:K) + 1e-300))) : + rand(Bernoulli(rates[k] / (sum(rates[j] for j in k:K) + 1e-300))) : (1 + z) take = notchosen * chose - Δk = Δs[k] - sel = [sel[i] + take * Δk[i] for i in 1:n] + sel = [sel[i] + take * Δ[k][i] for i in 1:n] notchosen = notchosen * (1 - chose) end - U = rand() - τ = -log(1 - U * pocc) / (total + 1e-300) # truncated-exponential time - t = t + step * τ - u = [u[i] + step * sel[i] for i in 1:n] - active = active * occurs + u = [u[i] + accept * sel[i] for i in 1:n] # apply only on a real event end - - return return_saturation ? (u, active) : u + while save_idx <= length(saveat_times) + push!(tsave, @inbounds saveat_times[save_idx]) + push!(usave, copy(u)) + save_idx += 1 + end + if se + push!(tsave, tf) + push!(usave, copy(u)) + end + return tsave, usave end """ - saturation_probability(jprob, p; nmax, tspan = jprob.prob.tspan, ntrials = 1000) + bounded_ssa_path(jprob, p; rate_bound, saveat = tf, save_start = nothing, + save_end = nothing, tspan = jprob.prob.tspan) + +Differentiable core behind [`BoundedSSA`](@ref). Simulates the jump-only +`ConstantRateJump` process by uniformization against the constant total-propensity +bound `rate_bound`, and returns the (StochasticAD-differentiable) state at each save +time as a `Vector` of state vectors. + +`saveat`/`save_start`/`save_end` follow the usual JumpProcesses conventions (via the +same `_process_saveat` as `SimpleTauLeaping`): `saveat` is a `Number` step or a +collection of times; endpoints are controlled by `save_start`/`save_end`. Wrap in +`derivative_estimate` for gradients, e.g. of the terminal state: -Monte-Carlo estimate of `P(N > nmax)` — the probability the process has more than -`nmax` events on `tspan`, i.e. the bias of the bounded SSA path. Call with -ordinary (`Float64`) parameters `p`; size `nmax` so this is negligible. +```julia +derivative_estimate(p0[k]) do pk + pv = [j == k ? pk : oftype(pk, p0[j]) for j in eachindex(p0)] + bounded_ssa_path(jprob, pv; rate_bound = Λ, saveat = [tf])[end][1] +end +``` + +See [`BoundedSSA`](@ref) for the method and the meaning/validity of `rate_bound`. """ -function JumpProcesses.saturation_probability(jprob, p; nmax, - tspan = jprob.prob.tspan, ntrials = 1000) - nsat = 0 - for _ in 1:ntrials - _, active = JumpProcesses.bounded_ssa_final_state(jprob, p; nmax, tspan, - return_saturation = true) - (_val(active) != 0) && (nsat += 1) - end - return nsat / ntrials +function JumpProcesses.bounded_ssa_path(jprob, p; rate_bound, + saveat = last(jprob.prob.tspan), save_start = nothing, save_end = nothing, + tspan = jprob.prob.tspan) + _, usave = _bounded_ssa(jprob, p, rate_bound, tspan, saveat, save_start, save_end) + return usave end -# solve(jprob, BoundedSSA(; nmax)): run the bounded path and return a minimal -# (start, end) solution. `sol.u[end]` is the differentiable final state. -function DiffEqBase.__solve(jprob::JumpProblem, alg::BoundedSSA; - seed = nothing, tspan = jprob.prob.tspan, kwargs...) +# solve(jprob, BoundedSSA(; rate_bound); saveat, save_start, save_end): run the +# uniformization path and return a solution whose `u[i]` is the differentiable state +# at `t[i]`. `sol(t)` works via piecewise-constant interpolation (as with SSAStepper). +# Defined as `solve` (like SimpleTauLeaping), since BoundedSSA is self-contained and +# does not use the integrator/init machinery. +function DiffEqBase.solve(jump_prob::JumpProblem, alg::BoundedSSA; + seed = nothing, saveat = nothing, save_start = nothing, save_end = nothing, + tspan = jump_prob.prob.tspan, kwargs...) seed === nothing || Random.seed!(seed) - prob = jprob.prob - u_final = JumpProcesses.bounded_ssa_final_state(jprob, prob.p; nmax = alg.nmax, - tspan = tspan) - # promote u0 to the (possibly triple) final-state type without needing a - # convert(::StochasticTriple, ::Float64): multiply by a clean zero. - u0p = [u_final[i] * 0 + float(prob.u0[i]) for i in eachindex(prob.u0)] - ts = [float(first(tspan)), float(last(tspan))] - us = [u0p, u_final] + prob = jump_prob.prob + ts, us = _bounded_ssa(jump_prob, prob.p, alg.rate_bound, tspan, saveat, + save_start, save_end) DiffEqBase.build_solution(prob, alg, ts, us; + dense = true, + interp = DiffEqBase.ConstantInterpolation(ts, us), calculate_error = false, stats = DiffEqBase.Stats(0), - interp = DiffEqBase.ConstantInterpolation(ts, us)) + retcode = DiffEqBase.ReturnCode.Success) end end # module diff --git a/src/JumpProcesses.jl b/src/JumpProcesses.jl index cf7ea989b..b88ff4d55 100644 --- a/src/JumpProcesses.jl +++ b/src/JumpProcesses.jl @@ -112,15 +112,14 @@ include("aggregators/aggregated_api.jl") include("variable_rate.jl") export VariableRateAggregator, VR_FRM, VR_Direct, VR_DirectFW -# StochasticAD support. Stubs; methods are provided by the package extension +# StochasticAD support. Stub; the method is provided by the package extension # `ext/JumpProcessesStochasticADExt.jl`, which loads only when StochasticAD and # Distributions are both available. No StochasticAD code lives in `src/`. The # `BoundedSSA` algorithm struct itself lives in `src` (see SSA_stepper.jl) so it -# is always referenceable/documentable; only its `solve` implementation is in the -# extension. -function bounded_ssa_final_state end -function saturation_probability end -export bounded_ssa_final_state, saturation_probability +# is always referenceable/documentable; only its `solve` implementation (and this +# differentiable-path helper) are in the extension. +function bounded_ssa_path end +export bounded_ssa_path """ Aggregator to indicate that individual jumps should also be handled via the leaping diff --git a/src/SSA_stepper.jl b/src/SSA_stepper.jl index 35accf32a..797755a45 100644 --- a/src/SSA_stepper.jl +++ b/src/SSA_stepper.jl @@ -60,54 +60,71 @@ struct SSAStepper <: DiffEqBase.AbstractDEAlgorithm end SciMLBase.allows_late_binding_tstops(::SSAStepper) = true """ - BoundedSSA(; nmax) + BoundedSSA(; rate_bound) A StochasticAD-compatible SSA algorithm for **jump-only** `ConstantRateJump` -`DiscreteProblem`s, enabling correct gradients via StochasticAD's -`derivative_estimate`/`stochastic_triple`. - -The stock `SSAStepper` cannot be differentiated with StochasticAD: it decides the -number of events with a `while integrator.t < integrator.tstop < end_time` loop, -i.e. a boolean predicate on (triple-valued) time, which StochasticAD forbids by -design — so the event-count derivative is dropped (a state-dependent rate yields -a gradient of `0`). `BoundedSSA` instead runs a **fixed-length loop of at most -`nmax` jump attempts**, replacing the data-dependent control flow with stochastic -primitives (a tracked `Bernoulli` for "is there a next event before the end -time?", a truncated exponential for the time, stick-breaking `Bernoulli`s for the -channel, and multiplicative-select masking to freeze the trajectory once the -event chain breaks). This is exact up to `P(N > nmax)`, the probability of more -than `nmax` events. - -With ordinary `Float64` parameters `solve(jprob, BoundedSSA(; nmax))` is just a -bounded SSA simulation; with StochasticAD triples it differentiates. +`DiscreteProblem`s, giving correct gradients via StochasticAD's +`derivative_estimate`/`stochastic_triple` — with `saveat` support, so the whole +sampled path is differentiable, not only the terminal state. + +The stock `SSAStepper` cannot be differentiated with StochasticAD: it advances +time with a `while integrator.t < integrator.tstop < end_time` loop, i.e. a +boolean predicate on (triple-valued) time, which StochasticAD forbids by design — +so the event-count derivative is dropped (a state-dependent rate yields a gradient +of `0`). `BoundedSSA` instead uses **uniformization (thinning)** against a fixed +total-propensity bound `Λ = rate_bound`: + + - candidate event times form a homogeneous Poisson process of rate `Λ` on the + time span — these are **parameter-free**, so the loop never branches on a + triple and the times stay `Float64`; + - at each candidate the current total propensity `a(u)` is recomputed and the + event is *accepted* with a tracked `Bernoulli(a(u)/Λ)` (otherwise it is a + **null event** absorbing the slack `Λ - a(u)`); + - the firing channel is chosen by stick-breaking `Bernoulli`s. + +All parameter dependence flows through the accept / channel `Bernoulli`s, so the +gradient is captured. This is **unbiased** (no step cap) whenever `Λ` is a valid +bound, and `saveat` is exact because the candidate times are fixed `Float64`. + +With ordinary `Float64` parameters `solve(jprob, BoundedSSA(; rate_bound))` is an +ordinary (uniformization) SSA simulation; with StochasticAD triples it +differentiates. # Keyword arguments - - `nmax` (required): the fixed upper bound on the number of jump events. If the - true count can exceed `nmax` the result is biased; size it from - [`saturation_probability`](@ref). + - `rate_bound` (required): a constant `Λ` upper-bounding the **total** propensity + `Σₖ rateₖ(u, p, t)` over the whole trajectory (and over the parameter + perturbation). Valid for systems with rigorously bounded populations; a looser + bound only costs efficiency (more null events), not accuracy. If `Λ` is + violated the accept probability exceeds 1 and sampling errors — pick it with + margin. + +# `solve` options + + - `saveat`: times (a vector, or a `Number` step) at which to return the solution, + with `save_start`/`save_end` controlling the endpoints (same conventions as + `SimpleTauLeaping`, via `_process_saveat`); defaults to `[t0, tf]`. `sol.u[i]` is + the differentiable state at `sol.t[i]`, and `sol(t)` interpolates (piecewise + constant, as with `SSAStepper`). # Scope / limitations - - `ConstantRateJump`s only (state-dependent rates are supported); jump-only, no - continuous drift, no `VariableRateJump`. `MassActionJump` is not yet supported - (`evalrxrate` is not triple-generic and mass-action rate constants flow - through `param_mapper`). + - `ConstantRateJump`s only (state-dependent rates supported); jump-only, no + continuous drift, no `VariableRateJump`. `MassActionJump` is not yet supported. - Additive affects only (the net change is inferred from `affect!` and checked). - The differentiation parameter must enter through `prob.p`. - The implementation lives in the `JumpProcessesStochasticADExt` extension, so `StochasticAD` and `Distributions` must both be loaded to `solve` with it. -See also [`bounded_ssa_final_state`](@ref) (the differentiable core) and -[`saturation_probability`](@ref). +See also [`bounded_ssa_path`](@ref), the differentiable core this wraps. """ -struct BoundedSSA{N} <: DiffEqBase.AbstractDEAlgorithm - nmax::N +struct BoundedSSA{B} <: DiffEqBase.AbstractDEAlgorithm + rate_bound::B end -function BoundedSSA(; nmax = nothing) - nmax === nothing && error("BoundedSSA requires the keyword argument `nmax` " * - "(a fixed upper bound on the number of jump events).") - BoundedSSA{typeof(nmax)}(nmax) +function BoundedSSA(; rate_bound = nothing) + rate_bound === nothing && error("BoundedSSA requires the keyword argument " * + "`rate_bound` (a constant upper bound on the total propensity).") + BoundedSSA{typeof(rate_bound)}(rate_bound) end """ diff --git a/src/aggregators/direct.jl b/src/aggregators/direct.jl index 47e1c0f5f..ab33dd842 100644 --- a/src/aggregators/direct.jl +++ b/src/aggregators/direct.jl @@ -68,47 +68,33 @@ end ######################## SSA specific helper routines ######################## -# Fill `cur_rates` with the *raw* (non-cumulative) per-channel rates: mass-action -# rates first (indices `1:get_num_majumps(majumps)`), then constant-jump rates. -# Shared by `Direct`'s `time_to_next_jump` (which then forms the running -# cumulative sum used for channel sampling) and by the StochasticAD bounded SSA -# path (which uses the raw rates directly). Tuple rates use the type-stable -# recursive `fill_cur_rates`; function-wrapper rates use a plain loop. This is the -# one place per-channel rates are computed, so a generic (e.g. StochasticTriple) -# rate type flows through it unchanged. -@inline function fill_cur_rates!(cur_rates, u, p, t, majumps, rates::Tuple) - nma = get_num_majumps(majumps) - @inbounds for i in 1:nma - cur_rates[i] = evalrxrate(u, i, majumps) - end - isempty(rates) || fill_cur_rates(u, p, t, cur_rates, nma + 1, rates...) - nothing -end - -@inline function fill_cur_rates!(cur_rates, u, p, t, majumps, rates::AbstractArray) - nma = get_num_majumps(majumps) - @inbounds for i in 1:nma - cur_rates[i] = evalrxrate(u, i, majumps) - end - @inbounds for k in eachindex(rates) - cur_rates[nma + k] = rates[k](u, p, t) - end - nothing -end - # tuple-based constant jumps function time_to_next_jump(p::DirectJumpAggregation{T, S, F1}, u, params, t) where {T, S, F1 <: Tuple} + prev_rate = zero(t) + new_rate = zero(t) cur_rates = p.cur_rates - fill_cur_rates!(cur_rates, u, params, t, p.ma_jumps, p.rates) - # form the running cumulative sum used by `generate_jumps!` for channel sampling - prev_rate = zero(t) - @inbounds for i in eachindex(cur_rates) - cur_rates[i] = add_fast(cur_rates[i], prev_rate) + # mass action rates + majumps = p.ma_jumps + idx = get_num_majumps(majumps) + @inbounds for i in 1:idx + new_rate = evalrxrate(u, i, majumps) + cur_rates[i] = add_fast(new_rate, prev_rate) prev_rate = cur_rates[i] end + # constant jump rates + rates = p.rates + if !isempty(rates) + idx += 1 + fill_cur_rates(u, params, t, cur_rates, idx, rates...) + @inbounds for i in idx:length(cur_rates) + cur_rates[i] = add_fast(cur_rates[i], prev_rate) + prev_rate = cur_rates[i] + end + end + @inbounds sum_rate = cur_rates[end] sum_rate, randexp(p.rng) / sum_rate end @@ -127,16 +113,29 @@ end # function wrapper-based constant jumps function time_to_next_jump(p::DirectJumpAggregation{T, S, F1}, u, params, t) where {T, S, F1 <: AbstractArray} + prev_rate = zero(t) + new_rate = zero(t) cur_rates = p.cur_rates - fill_cur_rates!(cur_rates, u, params, t, p.ma_jumps, p.rates) - # form the running cumulative sum used by `generate_jumps!` for channel sampling - prev_rate = zero(t) - @inbounds for i in eachindex(cur_rates) - cur_rates[i] = add_fast(cur_rates[i], prev_rate) + # mass action rates + majumps = p.ma_jumps + idx = get_num_majumps(majumps) + @inbounds for i in 1:idx + new_rate = evalrxrate(u, i, majumps) + cur_rates[i] = add_fast(new_rate, prev_rate) prev_rate = cur_rates[i] end + # constant jump rates + idx += 1 + rates = p.rates + @inbounds for i in 1:length(p.rates) + new_rate = rates[i](u, params, t) + cur_rates[idx] = add_fast(new_rate, prev_rate) + prev_rate = cur_rates[idx] + idx += 1 + end + @inbounds sum_rate = cur_rates[end] sum_rate, randexp(p.rng) / sum_rate end diff --git a/src/aggregators/ssajump.jl b/src/aggregators/ssajump.jl index f032222e5..90c260c97 100644 --- a/src/aggregators/ssajump.jl +++ b/src/aggregators/ssajump.jl @@ -110,24 +110,6 @@ Adds a `tstop` to the integrator at the next jump time. nothing end -# Element type for the SSA rate cache: the time type promoted with the *inferred* -# rate output type(s). Ordinary Float64 rates give Float64 (unchanged); a non-Float64 -# rate type (e.g. a StochasticAD StochasticTriple) is preserved rather than forced -# into Float64. We use return-type inference (`Base.promote_op`) rather than calling -# the rate, so this never executes the rate at build time (which could index a -# `NullParameters` `p`, have side effects, etc.) and falls back to the time type if -# inference is not a concrete type. -function ssa_rate_eltype(u, p, t, majumps, rates) - R = typeof(t) - if get_num_majumps(majumps) > 0 - R = promote_type(R, Base.promote_op(evalrxrate, typeof(u), Int, typeof(majumps))) - end - if !isempty(rates) - R = promote_type(R, Base.promote_op(first(rates), typeof(u), typeof(p), typeof(t))) - end - return isconcretetype(R) ? R : typeof(t) -end - """ build_jump_aggregation(jump_agg_type, u, p, t, end_time, ma_jumps, rates, affects!, save_positions, rng; kwargs...) @@ -145,22 +127,14 @@ function build_jump_aggregation(jump_agg_type, u, p, t, end_time, ma_jumps, rate Vector{Vector{Pair{Int, eltype(u)}}}()) end - # Rate-cache element type: the time type promoted with the actual rate output - # type(s). For ordinary Float64 rates this stays Float64 (behavior unchanged); - # if a rate returns a non-Float64 number (e.g. a StochasticAD StochasticTriple), - # the cache holds that type instead of forcing it into Float64, so such rates - # can pass through the existing SSA path. Plain promotion/conversion -- no - # StochasticAD dependency in src/. - RT = ssa_rate_eltype(u, p, t, majumps, rates) - # current jump rates, allows mass action rates and constant jumps - cur_rates = Vector{RT}(undef, get_num_majumps(majumps) + length(rates)) + cur_rates = Vector{typeof(t)}(undef, get_num_majumps(majumps) + length(rates)) - sum_rate = convert(RT, zero(t)) + sum_rate = zero(typeof(t)) next_jump = 0 - next_jump_time = convert(RT, typemax(typeof(t))) - jump_agg_type(next_jump, next_jump_time, convert(RT, end_time), cur_rates, - sum_rate, majumps, rates, affects!, save_positions, rng; kwargs...) + next_jump_time = typemax(typeof(t)) + jump_agg_type(next_jump, next_jump_time, end_time, cur_rates, sum_rate, + majumps, rates, affects!, save_positions, rng; kwargs...) end """ diff --git a/test/fill_cur_rates_regression.jl b/test/fill_cur_rates_regression.jl deleted file mode 100644 index 7f84423ea..000000000 --- a/test/fill_cur_rates_regression.jl +++ /dev/null @@ -1,95 +0,0 @@ -using JumpProcesses -using Statistics, Random, Test - -# Regression for the shared `fill_cur_rates!` refactor (src/aggregators/direct.jl): -# the raw rate-fill was factored out of `time_to_next_jump` so the StochasticAD -# bounded SSA path can reuse it. The stock Direct / DirectFW SSA path must be -# UNCHANGED by that. This covers mass-action rates (`evalrxrate`), multi-channel -# cumulative selection (`searchsortedfirst`), the mixed mass-action + constant -# ordering, and the function-wrapper (`DirectFW`) path. Uses only JumpProcesses -# (no OrdinaryDiffEq), so it runs in the isolated StochasticAD env alongside the -# change it guards. - -# final value of species 1 over `nT` independent, seeded trajectories -function final_species1(jprob, alg; nT, seedbase = 0) - xs = Vector{Float64}(undef, nT) - for i in 1:nT - xs[i] = solve(jprob, alg; seed = seedbase + i).u[end][1] - end - return xs -end - -@testset "fill_cur_rates! stock-path regression (Direct/DirectFW)" begin - T = 1.0 - - # --- mass-action, single channel: X --(μ)--> 0, rate μ·X ----------------- - # E[X(T)] = X0·e^{-μT} - @testset "mass-action pure death (Direct)" begin - X0, μ = 100, 0.5 - maj = MassActionJump([μ], [[1 => 1]], [[1 => -1]]) - jp = JumpProblem(DiscreteProblem([X0], (0.0, T)), Direct(), maj) - analytic = X0 * exp(-μ * T) - xs = final_species1(jp, SSAStepper(); nT = 5000) - @test abs(mean(xs) - analytic) < 4 * std(xs) / sqrt(length(xs)) - end - - # --- mass-action, multi-channel: A <-(k1,k2)-> B (both first order) ------- - # linear system => the mean is exact: A_ss = k2·N/(k1+k2), - # E[A(T)] = A_ss + (A0 - A_ss)·e^{-(k1+k2)T}. Exercises evalrxrate over two - # channels and the cumulative-rate channel selection. - @testset "mass-action reversible two-channel (Direct)" begin - A0, B0, k1, k2 = 80, 20, 0.8, 0.4 - N = A0 + B0 - maj = MassActionJump([k1, k2], - [[1 => 1], [2 => 1]], - [[1 => -1, 2 => 1], [1 => 1, 2 => -1]]) - jp = JumpProblem(DiscreteProblem([A0, B0], (0.0, T)), Direct(), maj) - Ass = k2 * N / (k1 + k2) - analytic = Ass + (A0 - Ass) * exp(-(k1 + k2) * T) - xs = final_species1(jp, SSAStepper(); nT = 5000) - @test abs(mean(xs) - analytic) < 4 * std(xs) / sqrt(length(xs)) - @test all(x -> 0 <= x <= N, xs) # conservation A+B=N keeps A in [0,N] - end - - # --- mixed mass-action + constant jump: birth (constant λ) + death (μ·X) -- - # fills mass-action rate first, then the constant-jump rate, in fill_cur_rates!. - # E[X(T)] = λ/μ + (X0 - λ/μ)·e^{-μT} - @testset "mixed mass-action death + constant birth (Direct)" begin - X0, λ, μ = 50, 10.0, 0.3 - birth = ConstantRateJump((u, p, t) -> λ, integ -> (integ.u[1] += 1; nothing)) - death = MassActionJump([μ], [[1 => 1]], [[1 => -1]]) - jp = JumpProblem(DiscreteProblem([X0], (0.0, T)), Direct(), birth, death) - analytic = λ / μ + (X0 - λ / μ) * exp(-μ * T) - xs = final_species1(jp, SSAStepper(); nT = 5000) - @test abs(mean(xs) - analytic) < 4 * std(xs) / sqrt(length(xs)) - end - - # --- same mixed model via the function-wrapper aggregator ----------------- - # exercises the AbstractArray fill_cur_rates! / time_to_next_jump methods. - @testset "mixed model via DirectFW (fwrapper path)" begin - X0, λ, μ = 50, 10.0, 0.3 - birth = ConstantRateJump((u, p, t) -> λ, integ -> (integ.u[1] += 1; nothing)) - death = MassActionJump([μ], [[1 => 1]], [[1 => -1]]) - jp = JumpProblem(DiscreteProblem([X0], (0.0, T)), DirectFW(), birth, death) - analytic = λ / μ + (X0 - λ / μ) * exp(-μ * T) - xs = final_species1(jp, SSAStepper(); nT = 5000) - @test abs(mean(xs) - analytic) < 4 * std(xs) / sqrt(length(xs)) - end - - # --- Direct and DirectFW must produce IDENTICAL trajectories -------------- - # same rates, same cumulative sums, same RNG draws => bit-identical paths. - # Deterministic check (no MC tolerance): the sharpest guard that both - # fill_cur_rates! methods compute the same per-channel and cumulative rates. - @testset "Direct == DirectFW (identical trajectories)" begin - X0, λ, μ = 40, 8.0, 0.5 - mkbirth() = ConstantRateJump((u, p, t) -> λ, integ -> (integ.u[1] += 1; nothing)) - mkdeath() = MassActionJump([μ], [[1 => 1]], [[1 => -1]]) - jp_d = JumpProblem(DiscreteProblem([X0], (0.0, T)), Direct(), mkbirth(), mkdeath()) - jp_f = JumpProblem(DiscreteProblem([X0], (0.0, T)), DirectFW(), mkbirth(), mkdeath()) - for i in 1:25 - sd = solve(jp_d, SSAStepper(); seed = i) - sf = solve(jp_f, SSAStepper(); seed = i) - @test sd.u[end] == sf.u[end] - end - end -end diff --git a/test/runtests.jl b/test/runtests.jl index 011374567..b95d78225 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -75,9 +75,6 @@ end if GROUP == "StochasticAD" activate_stochasticad_env() - # Guards the shared `fill_cur_rates!` refactor the bounded path depends on. - # JumpProcesses-only (no OrdinaryDiffEq), so it runs in this isolated env. - @time @safetestset "fill_cur_rates! Stock-Path Regression" begin include("fill_cur_rates_regression.jl") end @time @safetestset "StochasticAD Extension Tests" begin include("stochasticad_tests.jl") end end diff --git a/test/stochasticad_tests.jl b/test/stochasticad_tests.jl index 0f18f4400..2dab8c016 100644 --- a/test/stochasticad_tests.jl +++ b/test/stochasticad_tests.jl @@ -1,10 +1,10 @@ using JumpProcesses, StochasticAD, Distributions using Statistics, Random, Test -# StochasticAD-compatible differentiation for jump-only ConstantRateJump SSA -# problems (the BoundedSSA path). Needs only JumpProcesses + StochasticAD + -# Distributions (no ODE solver), so it runs in its own isolated environment -# (GROUP=StochasticAD). +# Tests for the optional StochasticAD extension: differentiating expectations over +# jump-only ConstantRateJump processes via BoundedSSA (uniformization/thinning). +# Needs only JumpProcesses + StochasticAD + Distributions (no ODE solver), so it +# runs in its own isolated environment (GROUP=StochasticAD). # per-partial StochasticAD gradient with fixed seeds (reproducible) function sad_partial(f, p0, k; N) @@ -19,28 +19,33 @@ function sad_partial(f, p0, k; N) return mean(s), std(s) / sqrt(N) end -@testset "StochasticAD constant-rate jumps (BoundedSSA)" begin +# primal Monte-Carlo mean + standard error +function pmean(f; N) + s = [f() for _ in 1:N] + return mean(s), std(s) / sqrt(N) +end - # --- Test A: pure death, STATE-DEPENDENT rate (the case that matters) --- - # rate = μ·u[1] ; E[u(T)] = u0·e^{-μT} ; d/dμ E[u(T)] = -T·u0·e^{-μT} - # A naive `while t p[1] * u[1], integ -> (integ.u[1] -= 1; nothing)) jprob = JumpProblem(DiscreteProblem([u0], (0.0, T)), Direct(), death) analytic = -T * u0 * exp(-μ0 * T) g, se = sad_partial([μ0], 1; N = 10000) do p - bounded_ssa_final_state(jprob, p; nmax = 200)[1] + bounded_ssa_path(jprob, p; rate_bound = Λ, saveat = [T])[end][1] end @test abs(g - analytic) < 4 * se # event-count derivative captured @test abs(g) > 1.0 # explicitly NOT the zero a naive SSA gives end - # --- Test B: birth-death, multi-channel + state-dependent --- - # birth λ, death μ·u[1] ; E[u(T)] = λ/μ + (u0-λ/μ)e^{-μT} + # --- Test B: birth-death, multi-channel + state-dependent ------------------ + # birth λ, death μ·u ; E[u(T)] = λ/μ + (u0-λ/μ)e^{-μT} @testset "birth-death (multi-channel, state-dependent)" begin - T, u0, λ0, μ0 = 1.0, 50, 10.0, 0.3 + T, u0, λ0, μ0, Λ = 1.0, 50, 10.0, 0.3, 60.0 # total = λ + μ·u ≤ ~28 ≪ Λ birth = ConstantRateJump((u, p, t) -> p[1], integ -> (integ.u[1] += 1; nothing)) death = ConstantRateJump((u, p, t) -> p[2] * u[1], integ -> (integ.u[1] -= 1; nothing)) jprob = JumpProblem(DiscreteProblem([u0], (0.0, T)), Direct(), birth, death) @@ -48,111 +53,75 @@ end analytic = [(1 - b) / μ0, -λ0 / μ0^2 * (1 - b) + (u0 - a) * (-T * b)] for k in 1:2 g, se = sad_partial([λ0, μ0], k; N = 10000) do p - bounded_ssa_final_state(jprob, p; nmax = 400)[1] + bounded_ssa_path(jprob, p; rate_bound = Λ, saveat = [T])[end][1] end @test abs(g - analytic[k]) < 4 * se end end - # --- Test C: homogeneous Poisson baseline (state-INDEPENDENT) --- - # two constant channels ; u(T) = u0 + N1 - N2 ; d/dλ1 = T, d/dλ2 = -T - @testset "homogeneous Poisson baseline" begin - T, u0 = 2.0, 10 - j1 = ConstantRateJump((u, p, t) -> p[1], integ -> (integ.u[1] += 1; nothing)) - j2 = ConstantRateJump((u, p, t) -> p[2], integ -> (integ.u[1] -= 1; nothing)) - jprob = JumpProblem(DiscreteProblem([u0], (0.0, T)), Direct(), j1, j2) - λ0 = [3.0, 1.0] - for k in 1:2 - g, se = sad_partial(λ0, k; N = 10000) do p - bounded_ssa_final_state(jprob, p; nmax = 200)[1] + # --- Test C: saveat returns the path at intermediate times ----------------- + # pure death: E[u(s)] = u0·e^{-μs} at every save time s. + @testset "saveat path (intermediate times)" begin + T, u0, μ0, Λ = 1.0, 100, 0.5, 60.0 + death = ConstantRateJump((u, p, t) -> p[1] * u[1], integ -> (integ.u[1] -= 1; nothing)) + jprob = JumpProblem(DiscreteProblem([u0], (0.0, T)), Direct(), death) + sat = [0.0, 0.5, 1.0] + for (idx, s) in enumerate(sat) + m, se = pmean(N = 4000) do + bounded_ssa_path(jprob, [μ0]; rate_bound = Λ, saveat = sat)[idx][1] end - @test abs(g - (k == 1 ? T : -T)) < 4 * se + # `s = 0` is deterministic (u(0) = u0, se = 0); the `+ 1e-9` keeps the + # zero-variance point from failing `0 < 0`. + @test abs(m - u0 * exp(-μ0 * s)) <= 4 * se + 1e-9 end end - # --- Test D: BoundedSSA solve path differentiates (native API) --- - # solve(jprob, BoundedSSA(; nmax)) must give the same gradient as the core. - # Constructs the JumpProblem with triple parameters, exercising the generic - # rate cache during construction and the BoundedSSA __solve. - @testset "BoundedSSA solve path" begin - T, u0, μ0 = 1.0, 100, 0.5 + # --- Test D: primal mean matches the stock SSA ----------------------------- + @testset "primal mean matches stock SSAStepper" begin + T, u0, μ0, Λ = 1.0, 100, 0.4, 60.0 death = ConstantRateJump((u, p, t) -> p[1] * u[1], integ -> (integ.u[1] -= 1; nothing)) - analytic = -T * u0 * exp(-μ0 * T) - # primal: solve returns a sensible final state + jprob = JumpProblem(DiscreteProblem([u0], (0.0, T), [μ0]), Direct(), death) + mb, seb = pmean(N = 4000) do + bounded_ssa_path(jprob, [μ0]; rate_bound = Λ, saveat = [T])[end][1] + end + ms, ses = pmean(N = 4000) do + solve(jprob, SSAStepper()).u[end][1] + end + @test abs(mb - u0 * exp(-μ0 * T)) < 4 * seb + @test abs(mb - ms) < 4 * sqrt(seb^2 + ses^2) # agree within combined MC error + end + + # --- Test E: BoundedSSA solve path (native interface) ---------------------- + @testset "BoundedSSA solve path + saveat" begin + T, u0, μ0, Λ = 1.0, 100, 0.5, 60.0 + death = ConstantRateJump((u, p, t) -> p[1] * u[1], integ -> (integ.u[1] -= 1; nothing)) + # primal: solve returns the full path at saveat jp0 = JumpProblem(DiscreteProblem([u0], (0.0, T), [μ0]), Direct(), death) - sol = solve(jp0, BoundedSSA(; nmax = 200)) - @test length(sol.u[end]) == 1 + sol = solve(jp0, BoundedSSA(; rate_bound = Λ); saveat = [0.0, 0.5, 1.0]) + @test length(sol.u) == 3 + @test sol.u[1][1] == u0 # state at t0 is the initial condition @test 0 <= sol.u[end][1] <= u0 - # differentiate through solve + # differentiate through solve (constructs jprob with triple parameters) + analytic = -T * u0 * exp(-μ0 * T) g, se = sad_partial([μ0], 1; N = 4000) do p jp = JumpProblem(DiscreteProblem([u0], (0.0, T), p), Direct(), death) - solve(jp, BoundedSSA(; nmax = 200)).u[end][1] + solve(jp, BoundedSSA(; rate_bound = Λ); saveat = [T]).u[end][1] end @test abs(g - analytic) < 4 * se end - # --- Test E: primal distribution matches the stock SSA --- - # at large nmax the bounded SSA's mean must match solve(jprob, SSAStepper()). - @testset "primal mean matches stock SSAStepper" begin - T, u0, μ0 = 1.0, 100, 0.4 - death = ConstantRateJump((u, p, t) -> p[1] * u[1], integ -> (integ.u[1] -= 1; nothing)) - jprob = JumpProblem(DiscreteProblem([u0], (0.0, T), [μ0]), Direct(), death) - nT = 4000 - bounded = Vector{Float64}(undef, nT) - stock = Vector{Float64}(undef, nT) - for i in 1:nT - Random.seed!(i) - bounded[i] = bounded_ssa_final_state(jprob, [μ0]; nmax = 400)[1] - Random.seed!(10_000 + i) - stock[i] = solve(jprob, SSAStepper()).u[end][1] - end - analytic = u0 * exp(-μ0 * T) - @test abs(mean(bounded) - analytic) < 4 * std(bounded) / sqrt(nT) - # bounded vs stock means agree within combined MC error - se = sqrt(var(bounded) / nT + var(stock) / nT) - @test abs(mean(bounded) - mean(stock)) < 4 * se - end - - # --- Test F: saturation diagnostic --- - @testset "saturation_probability" begin - T, u0, λ0 = 1.0, 0, 20.0 - birth = ConstantRateJump((u, p, t) -> p[1], integ -> (integ.u[1] += 1; nothing)) - jprob = JumpProblem(DiscreteProblem([u0], (0.0, T)), Direct(), birth) - # ~Poisson(20) events: nmax=5 is far too small, nmax=200 is ample - @test saturation_probability(jprob, [λ0]; nmax = 5, ntrials = 2000) > 0.5 - @test saturation_probability(jprob, [λ0]; nmax = 200, ntrials = 2000) < 0.01 - end - - # --- guards: misuse should error, not silently mislead --- + # --- guards: misuse should error, not silently mislead --------------------- @testset "guards" begin T = 1.0 # MassActionJump not yet supported majump = MassActionJump([0.5], [[1 => 1]], [[1 => -1]]) jp_ma = JumpProblem(DiscreteProblem([10], (0.0, T), [0.5]), Direct(), majump) - @test_throws ErrorException bounded_ssa_final_state(jp_ma, [0.5]; nmax = 50) - # missing nmax on the algorithm + @test_throws ErrorException bounded_ssa_path(jp_ma, [0.5]; rate_bound = 10.0) + # missing rate_bound on the algorithm @test_throws ErrorException BoundedSSA() # non-additive (state-dependent) affect weird = ConstantRateJump((u, p, t) -> p[1], integ -> (integ.u[1] *= 2; nothing)) jp_w = JumpProblem(DiscreteProblem([10], (0.0, T)), Direct(), weird) - @test_throws ErrorException bounded_ssa_final_state(jp_w, [0.5]; nmax = 50) - end - - # --- generic rate cache (Chris's literal request): a triple-valued rate must - # pass through the existing Direct rate-cache without Float64(::StochasticTriple). - # SCOPE: this only makes the cache generic. It does NOT by itself give correct - # gradients — the stock `while t < T` event boundary still drops the event-count - # derivative, and a full stock solve would next hit the SSAStepper integrator's - # Float64 time. Correct gradients use BoundedSSA / bounded_ssa_final_state. - @testset "generic rate cache (no Float64(::StochasticTriple))" begin - st = stochastic_triple(identity, 0.5) - jump = ConstantRateJump((u, p, t) -> p[1] * u[1], integ -> (integ.u[1] -= 1; nothing)) - rates, affects = JumpProcesses.get_jump_info_tuples((jump,)) - agg = JumpProcesses.build_jump_aggregation( - JumpProcesses.DirectJumpAggregation, [10], [st], 0.0, 1.0, nothing, - rates, affects, (false, false), JumpProcesses.DEFAULT_RNG) - @test eltype(agg.cur_rates) <: StochasticAD.StochasticTriple # cache is generic, not Float64 - sr, _ = JumpProcesses.time_to_next_jump(agg, [10], [st], 0.0) - @test sr isa StochasticAD.StochasticTriple # filled without Float64 error + @test_throws ErrorException bounded_ssa_path(jp_w, [0.5]; rate_bound = 10.0) end end