diff --git a/Project.toml b/Project.toml index 2689a8214..7e3d2f303 100644 --- a/Project.toml +++ b/Project.toml @@ -22,12 +22,15 @@ SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" [weakdeps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8" +StochasticAD = "e4facb34-4f7e-4bec-b153-e122c37934ac" [extensions] JumpProcessesKernelAbstractionsExt = ["Adapt", "KernelAbstractions"] JumpProcessesOrdinaryDiffEqCoreExt = "OrdinaryDiffEqCore" +JumpProcessesStochasticADExt = ["StochasticAD", "Distributions"] [compat] ADTypes = "1.22.0" @@ -58,6 +61,7 @@ SciMLBase = "3.18.0, 3.1" StableRNGs = "1" StaticArrays = "1.9.18" Statistics = "1" +StochasticAD = "0.1" StochasticDiffEq = "7.0.0, 7" SymbolicIndexingInterface = "0.3.48" Test = "1" diff --git a/ext/JumpProcessesStochasticADExt.jl b/ext/JumpProcessesStochasticADExt.jl new file mode 100644 index 000000000..9d7607e58 --- /dev/null +++ b/ext/JumpProcessesStochasticADExt.jl @@ -0,0 +1,204 @@ +module JumpProcessesStochasticADExt + +# StochasticAD-compatible differentiation for jump-only `ConstantRateJump` SSA +# problems — the implementation behind the `BoundedSSA` algorithm and the +# `bounded_ssa_path` entry point. +# +# Why this exists: the stock `solve(jprob, SSAStepper())` cannot be differentiated +# with StochasticAD. It advances time with a `while integrator.t < integrator.tstop +# < end_time` loop — a boolean predicate on (triple-valued) time, which StochasticAD +# forbids by design — so the event-count derivative (the dominant term for +# state-dependent rates) is dropped and a state-dependent rate gives a gradient of 0. +# +# Instead we use UNIFORMIZATION (thinning) against a constant total-propensity bound +# `Λ = rate_bound`: candidate event times are a homogeneous Poisson process of rate +# `Λ` (parameter-free, so the loop never branches on a triple and times stay +# Float64); at each candidate the event is accepted with a tracked +# `Bernoulli(total_rate(u)/Λ)` (else a null event), and the channel is chosen by +# stick-breaking. All parameter dependence flows through the accept/channel +# Bernoullis, so the gradient is captured; it is unbiased given a valid `Λ`, and +# `saveat` is exact because the candidate times are fixed Float64. +# +# This code is fully SEPARATE from the `Direct` aggregator: it reads `jump.rate` / +# `jump.affect!` only and never touches `time_to_next_jump` or the Direct rate cache. +# +# Scope: jump-only `DiscreteProblem`s with `ConstantRateJump`s (state-dependent +# rates OK) and additive affects. No `MassActionJump`, no `VariableRateJump`. + +using JumpProcesses +using StochasticAD +using Distributions: Bernoulli, Poisson +using DiffEqBase +using Random + +# minimal integrator-like object so a jump's `affect!` can be applied to a scratch +# state to read off its net effect. +mutable struct ShimIntegrator{U, P, T} + u::U + p::P + t::T +end + +# apply `affect!` to a scratch copy of `ubase` and return the net state change. +function _net_change(affect!, ubase, p, t0) + u = collect(ubase) + affect!(ShimIntegrator(u, p, t0)) + return u .- ubase +end + +# infer a jump's net state change, verifying it is *additive* (same change from two +# different base states). Additive affects are required: the state is built by +# adding `Δuₖ` on each event. +function _additive_change(jump, u0, p, t0) + base = float.(collect(u0)) + Δ = _net_change(jump.affect!, base, p, t0) + Δ2 = _net_change(jump.affect!, base .+ one(eltype(base)), p, t0) + isapprox(Δ, Δ2) || error( + "BoundedSSA supports only additive affects (a constant net state change), " * + "but a jump's affect! gave a state-dependent change ($Δ vs $Δ2 from a " * + "shifted state). Arbitrary mutating affects are out of scope.") + return Δ +end + +function _check_supported(jprob) + jprob.prob isa DiscreteProblem || error( + "BoundedSSA only supports JumpProblems defined over DiscreteProblems " * + "(pure jumps, no continuous drift).") + maj = jprob.massaction_jump + (maj === nothing || JumpProcesses.get_num_majumps(maj) == 0) || error( + "BoundedSSA does not yet support MassActionJump; build the model with " * + "ConstantRateJumps.") + vj = jprob.variable_jumps + (vj === nothing || isempty(vj)) || error( + "BoundedSSA supports jump-only constant-rate problems only; it does not " * + "support VariableRateJumps.") + cj = jprob.constant_jumps + (cj === nothing || isempty(cj)) && error( + "BoundedSSA requires at least one ConstantRateJump.") + nothing +end + +# Internal uniformization driver: returns `(tsave, usave)` at the resolved save +# schedule. Uses `JumpProcesses._process_saveat` (from src/simple_regular_solve.jl) +# for the interior save times + save_start/save_end flags, so saveat semantics match +# SimpleTauLeaping and the rest of the package. The save loop mirrors that solver's +# push idiom; all save-time comparisons are on parameter-free Float64 candidate times. +function _bounded_ssa(jprob, p, Λ, tspan, saveat, save_start, save_end) + _check_supported(jprob) + u0 = jprob.prob.u0 + jumps = jprob.constant_jumps + t0, tf = first(tspan), last(tspan) + ΔT = tf - t0 + K = length(jumps) + n = length(u0) + + saveat_times, ss, se = JumpProcesses._process_saveat(saveat, (t0, tf), + save_start, save_end) + + Δ = [_additive_change(jumps[k], u0, p, t0) for k in 1:K] # Float64 net change/channel + z = 0 * sum(p) # triple zero carrying p's type + u = [float(u0[i]) + z for i in 1:n] + + tsave = typeof(t0)[] + usave = typeof(u)[] + if ss + push!(tsave, t0) + push!(usave, copy(u)) + end + + # candidate events ~ homogeneous Poisson(Λ) on [t0, tf]. PARAMETER-FREE: Λ is a + # constant, so M and the times carry no derivative and never branch on a triple. + M = rand(Poisson(Λ * ΔT)) + ctimes = sort!(t0 .+ ΔT .* rand(M)) + + save_idx = 1 + for m in 1:M + tm = @inbounds ctimes[m] + # record interior save times crossed before this candidate (Float64 compares) + while save_idx <= length(saveat_times) && @inbounds(saveat_times[save_idx]) < tm + push!(tsave, @inbounds saveat_times[save_idx]) + push!(usave, copy(u)) + save_idx += 1 + end + + rates = [jumps[k].rate(u, p, tm) for k in 1:K] # recomputed at current state + total = sum(rates) + accept = rand(Bernoulli(total / Λ)) # thinning: real vs null event + + # which channel: stick-breaking conditional Bernoullis (last deterministic) + notchosen = 1 + z + sel = [z for _ in 1:n] + for k in 1:K + chose = k < K ? + rand(Bernoulli(rates[k] / (sum(rates[j] for j in k:K) + 1e-300))) : + (1 + z) + take = notchosen * chose + sel = [sel[i] + take * Δ[k][i] for i in 1:n] + notchosen = notchosen * (1 - chose) + end + + u = [u[i] + accept * sel[i] for i in 1:n] # apply only on a real event + end + while save_idx <= length(saveat_times) + push!(tsave, @inbounds saveat_times[save_idx]) + push!(usave, copy(u)) + save_idx += 1 + end + if se + push!(tsave, tf) + push!(usave, copy(u)) + end + return tsave, usave +end + +""" + bounded_ssa_path(jprob, p; rate_bound, saveat = tf, save_start = nothing, + save_end = nothing, tspan = jprob.prob.tspan) + +Differentiable core behind [`BoundedSSA`](@ref). Simulates the jump-only +`ConstantRateJump` process by uniformization against the constant total-propensity +bound `rate_bound`, and returns the (StochasticAD-differentiable) state at each save +time as a `Vector` of state vectors. + +`saveat`/`save_start`/`save_end` follow the usual JumpProcesses conventions (via the +same `_process_saveat` as `SimpleTauLeaping`): `saveat` is a `Number` step or a +collection of times; endpoints are controlled by `save_start`/`save_end`. Wrap in +`derivative_estimate` for gradients, e.g. of the terminal state: + +```julia +derivative_estimate(p0[k]) do pk + pv = [j == k ? pk : oftype(pk, p0[j]) for j in eachindex(p0)] + bounded_ssa_path(jprob, pv; rate_bound = Λ, saveat = [tf])[end][1] +end +``` + +See [`BoundedSSA`](@ref) for the method and the meaning/validity of `rate_bound`. +""" +function JumpProcesses.bounded_ssa_path(jprob, p; rate_bound, + saveat = last(jprob.prob.tspan), save_start = nothing, save_end = nothing, + tspan = jprob.prob.tspan) + _, usave = _bounded_ssa(jprob, p, rate_bound, tspan, saveat, save_start, save_end) + return usave +end + +# solve(jprob, BoundedSSA(; rate_bound); saveat, save_start, save_end): run the +# uniformization path and return a solution whose `u[i]` is the differentiable state +# at `t[i]`. `sol(t)` works via piecewise-constant interpolation (as with SSAStepper). +# Defined as `solve` (like SimpleTauLeaping), since BoundedSSA is self-contained and +# does not use the integrator/init machinery. +function DiffEqBase.solve(jump_prob::JumpProblem, alg::BoundedSSA; + seed = nothing, saveat = nothing, save_start = nothing, save_end = nothing, + tspan = jump_prob.prob.tspan, kwargs...) + seed === nothing || Random.seed!(seed) + prob = jump_prob.prob + ts, us = _bounded_ssa(jump_prob, prob.p, alg.rate_bound, tspan, saveat, + save_start, save_end) + DiffEqBase.build_solution(prob, alg, ts, us; + dense = true, + interp = DiffEqBase.ConstantInterpolation(ts, us), + calculate_error = false, + stats = DiffEqBase.Stats(0), + retcode = DiffEqBase.ReturnCode.Success) +end + +end # module diff --git a/src/JumpProcesses.jl b/src/JumpProcesses.jl index 87c8c2c26..b88ff4d55 100644 --- a/src/JumpProcesses.jl +++ b/src/JumpProcesses.jl @@ -112,6 +112,15 @@ include("aggregators/aggregated_api.jl") include("variable_rate.jl") export VariableRateAggregator, VR_FRM, VR_Direct, VR_DirectFW +# StochasticAD support. Stub; the method is provided by the package extension +# `ext/JumpProcessesStochasticADExt.jl`, which loads only when StochasticAD and +# Distributions are both available. No StochasticAD code lives in `src/`. The +# `BoundedSSA` algorithm struct itself lives in `src` (see SSA_stepper.jl) so it +# is always referenceable/documentable; only its `solve` implementation (and this +# differentiable-path helper) are in the extension. +function bounded_ssa_path end +export bounded_ssa_path + """ Aggregator to indicate that individual jumps should also be handled via the leaping algorithm that is passed to solve. @@ -127,7 +136,7 @@ include("solve.jl") export init, solve, solve! include("SSA_stepper.jl") -export SSAStepper +export SSAStepper, BoundedSSA # leaping: include("simple_regular_solve.jl") diff --git a/src/SSA_stepper.jl b/src/SSA_stepper.jl index d7b6c4aef..797755a45 100644 --- a/src/SSA_stepper.jl +++ b/src/SSA_stepper.jl @@ -59,6 +59,74 @@ for details. struct SSAStepper <: DiffEqBase.AbstractDEAlgorithm end SciMLBase.allows_late_binding_tstops(::SSAStepper) = true +""" + BoundedSSA(; rate_bound) + +A StochasticAD-compatible SSA algorithm for **jump-only** `ConstantRateJump` +`DiscreteProblem`s, giving correct gradients via StochasticAD's +`derivative_estimate`/`stochastic_triple` — with `saveat` support, so the whole +sampled path is differentiable, not only the terminal state. + +The stock `SSAStepper` cannot be differentiated with StochasticAD: it advances +time with a `while integrator.t < integrator.tstop < end_time` loop, i.e. a +boolean predicate on (triple-valued) time, which StochasticAD forbids by design — +so the event-count derivative is dropped (a state-dependent rate yields a gradient +of `0`). `BoundedSSA` instead uses **uniformization (thinning)** against a fixed +total-propensity bound `Λ = rate_bound`: + + - candidate event times form a homogeneous Poisson process of rate `Λ` on the + time span — these are **parameter-free**, so the loop never branches on a + triple and the times stay `Float64`; + - at each candidate the current total propensity `a(u)` is recomputed and the + event is *accepted* with a tracked `Bernoulli(a(u)/Λ)` (otherwise it is a + **null event** absorbing the slack `Λ - a(u)`); + - the firing channel is chosen by stick-breaking `Bernoulli`s. + +All parameter dependence flows through the accept / channel `Bernoulli`s, so the +gradient is captured. This is **unbiased** (no step cap) whenever `Λ` is a valid +bound, and `saveat` is exact because the candidate times are fixed `Float64`. + +With ordinary `Float64` parameters `solve(jprob, BoundedSSA(; rate_bound))` is an +ordinary (uniformization) SSA simulation; with StochasticAD triples it +differentiates. + +# Keyword arguments + + - `rate_bound` (required): a constant `Λ` upper-bounding the **total** propensity + `Σₖ rateₖ(u, p, t)` over the whole trajectory (and over the parameter + perturbation). Valid for systems with rigorously bounded populations; a looser + bound only costs efficiency (more null events), not accuracy. If `Λ` is + violated the accept probability exceeds 1 and sampling errors — pick it with + margin. + +# `solve` options + + - `saveat`: times (a vector, or a `Number` step) at which to return the solution, + with `save_start`/`save_end` controlling the endpoints (same conventions as + `SimpleTauLeaping`, via `_process_saveat`); defaults to `[t0, tf]`. `sol.u[i]` is + the differentiable state at `sol.t[i]`, and `sol(t)` interpolates (piecewise + constant, as with `SSAStepper`). + +# Scope / limitations + + - `ConstantRateJump`s only (state-dependent rates supported); jump-only, no + continuous drift, no `VariableRateJump`. `MassActionJump` is not yet supported. + - Additive affects only (the net change is inferred from `affect!` and checked). + - The differentiation parameter must enter through `prob.p`. + - The implementation lives in the `JumpProcessesStochasticADExt` extension, so + `StochasticAD` and `Distributions` must both be loaded to `solve` with it. + +See also [`bounded_ssa_path`](@ref), the differentiable core this wraps. +""" +struct BoundedSSA{B} <: DiffEqBase.AbstractDEAlgorithm + rate_bound::B +end +function BoundedSSA(; rate_bound = nothing) + rate_bound === nothing && error("BoundedSSA requires the keyword argument " * + "`rate_bound` (a constant upper bound on the total propensity).") + BoundedSSA{typeof(rate_bound)}(rate_bound) +end + """ $(TYPEDEF) diff --git a/test/runtests.jl b/test/runtests.jl index 9a2dfc246..b95d78225 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,6 +9,16 @@ function activate_gpu_env() Pkg.instantiate() end +# Isolated environment for the StochasticAD extension tests. StochasticAD pins +# old transitive deps (e.g. ForwardDiff 0.10) that conflict with the modern +# OrdinaryDiffEq stack, so it is kept out of the main test target and run here +# in its own project (no ODE solver needed -- the extension never calls `solve`). +function activate_stochasticad_env() + Pkg.activate(joinpath(@__DIR__, "stochasticad")) + Pkg.develop(PackageSpec(path = dirname(@__DIR__))) + Pkg.instantiate() +end + @time begin if GROUP == "QA" @time @safetestset "QA Tests" begin include("qa.jl") end @@ -63,6 +73,11 @@ end @time @safetestset "GPU Tau Leaping test" begin include("gpu/regular_jumps.jl") end end + if GROUP == "StochasticAD" + activate_stochasticad_env() + @time @safetestset "StochasticAD Extension Tests" begin include("stochasticad_tests.jl") end + end + if GROUP == "Correctness" activate_gpu_env() end diff --git a/test/stochasticad/Project.toml b/test/stochasticad/Project.toml new file mode 100644 index 000000000..3d752c661 --- /dev/null +++ b/test/stochasticad/Project.toml @@ -0,0 +1,6 @@ +[deps] +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +StochasticAD = "e4facb34-4f7e-4bec-b153-e122c37934ac" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/stochasticad_tests.jl b/test/stochasticad_tests.jl new file mode 100644 index 000000000..2dab8c016 --- /dev/null +++ b/test/stochasticad_tests.jl @@ -0,0 +1,127 @@ +using JumpProcesses, StochasticAD, Distributions +using Statistics, Random, Test + +# Tests for the optional StochasticAD extension: differentiating expectations over +# jump-only ConstantRateJump processes via BoundedSSA (uniformization/thinning). +# Needs only JumpProcesses + StochasticAD + Distributions (no ODE solver), so it +# runs in its own isolated environment (GROUP=StochasticAD). + +# per-partial StochasticAD gradient with fixed seeds (reproducible) +function sad_partial(f, p0, k; N) + s = Vector{Float64}(undef, N) + for i in 1:N + Random.seed!(i) + s[i] = derivative_estimate(p0[k]) do pk + p = [j == k ? pk : oftype(pk, p0[j]) for j in eachindex(p0)] + f(p) + end + end + return mean(s), std(s) / sqrt(N) +end + +# primal Monte-Carlo mean + standard error +function pmean(f; N) + s = [f() for _ in 1:N] + return mean(s), std(s) / sqrt(N) +end + +@testset "StochasticAD constant-rate jumps (BoundedSSA, uniformization)" begin + + # --- Test A: pure death, STATE-DEPENDENT rate (the case that matters) ------ + # rate = μ·u ; E[u(T)] = u0·e^{-μT} ; d/dμ E[u(T)] = -T·u0·e^{-μT} + # A naive `while t p[1] * u[1], integ -> (integ.u[1] -= 1; nothing)) + jprob = JumpProblem(DiscreteProblem([u0], (0.0, T)), Direct(), death) + analytic = -T * u0 * exp(-μ0 * T) + g, se = sad_partial([μ0], 1; N = 10000) do p + bounded_ssa_path(jprob, p; rate_bound = Λ, saveat = [T])[end][1] + end + @test abs(g - analytic) < 4 * se # event-count derivative captured + @test abs(g) > 1.0 # explicitly NOT the zero a naive SSA gives + end + + # --- Test B: birth-death, multi-channel + state-dependent ------------------ + # birth λ, death μ·u ; E[u(T)] = λ/μ + (u0-λ/μ)e^{-μT} + @testset "birth-death (multi-channel, state-dependent)" begin + T, u0, λ0, μ0, Λ = 1.0, 50, 10.0, 0.3, 60.0 # total = λ + μ·u ≤ ~28 ≪ Λ + birth = ConstantRateJump((u, p, t) -> p[1], integ -> (integ.u[1] += 1; nothing)) + death = ConstantRateJump((u, p, t) -> p[2] * u[1], integ -> (integ.u[1] -= 1; nothing)) + jprob = JumpProblem(DiscreteProblem([u0], (0.0, T)), Direct(), birth, death) + a, b = λ0 / μ0, exp(-μ0 * T) + analytic = [(1 - b) / μ0, -λ0 / μ0^2 * (1 - b) + (u0 - a) * (-T * b)] + for k in 1:2 + g, se = sad_partial([λ0, μ0], k; N = 10000) do p + bounded_ssa_path(jprob, p; rate_bound = Λ, saveat = [T])[end][1] + end + @test abs(g - analytic[k]) < 4 * se + end + end + + # --- Test C: saveat returns the path at intermediate times ----------------- + # pure death: E[u(s)] = u0·e^{-μs} at every save time s. + @testset "saveat path (intermediate times)" begin + T, u0, μ0, Λ = 1.0, 100, 0.5, 60.0 + death = ConstantRateJump((u, p, t) -> p[1] * u[1], integ -> (integ.u[1] -= 1; nothing)) + jprob = JumpProblem(DiscreteProblem([u0], (0.0, T)), Direct(), death) + sat = [0.0, 0.5, 1.0] + for (idx, s) in enumerate(sat) + m, se = pmean(N = 4000) do + bounded_ssa_path(jprob, [μ0]; rate_bound = Λ, saveat = sat)[idx][1] + end + # `s = 0` is deterministic (u(0) = u0, se = 0); the `+ 1e-9` keeps the + # zero-variance point from failing `0 < 0`. + @test abs(m - u0 * exp(-μ0 * s)) <= 4 * se + 1e-9 + end + end + + # --- Test D: primal mean matches the stock SSA ----------------------------- + @testset "primal mean matches stock SSAStepper" begin + T, u0, μ0, Λ = 1.0, 100, 0.4, 60.0 + death = ConstantRateJump((u, p, t) -> p[1] * u[1], integ -> (integ.u[1] -= 1; nothing)) + jprob = JumpProblem(DiscreteProblem([u0], (0.0, T), [μ0]), Direct(), death) + mb, seb = pmean(N = 4000) do + bounded_ssa_path(jprob, [μ0]; rate_bound = Λ, saveat = [T])[end][1] + end + ms, ses = pmean(N = 4000) do + solve(jprob, SSAStepper()).u[end][1] + end + @test abs(mb - u0 * exp(-μ0 * T)) < 4 * seb + @test abs(mb - ms) < 4 * sqrt(seb^2 + ses^2) # agree within combined MC error + end + + # --- Test E: BoundedSSA solve path (native interface) ---------------------- + @testset "BoundedSSA solve path + saveat" begin + T, u0, μ0, Λ = 1.0, 100, 0.5, 60.0 + death = ConstantRateJump((u, p, t) -> p[1] * u[1], integ -> (integ.u[1] -= 1; nothing)) + # primal: solve returns the full path at saveat + jp0 = JumpProblem(DiscreteProblem([u0], (0.0, T), [μ0]), Direct(), death) + sol = solve(jp0, BoundedSSA(; rate_bound = Λ); saveat = [0.0, 0.5, 1.0]) + @test length(sol.u) == 3 + @test sol.u[1][1] == u0 # state at t0 is the initial condition + @test 0 <= sol.u[end][1] <= u0 + # differentiate through solve (constructs jprob with triple parameters) + analytic = -T * u0 * exp(-μ0 * T) + g, se = sad_partial([μ0], 1; N = 4000) do p + jp = JumpProblem(DiscreteProblem([u0], (0.0, T), p), Direct(), death) + solve(jp, BoundedSSA(; rate_bound = Λ); saveat = [T]).u[end][1] + end + @test abs(g - analytic) < 4 * se + end + + # --- guards: misuse should error, not silently mislead --------------------- + @testset "guards" begin + T = 1.0 + # MassActionJump not yet supported + majump = MassActionJump([0.5], [[1 => 1]], [[1 => -1]]) + jp_ma = JumpProblem(DiscreteProblem([10], (0.0, T), [0.5]), Direct(), majump) + @test_throws ErrorException bounded_ssa_path(jp_ma, [0.5]; rate_bound = 10.0) + # missing rate_bound on the algorithm + @test_throws ErrorException BoundedSSA() + # non-additive (state-dependent) affect + weird = ConstantRateJump((u, p, t) -> p[1], integ -> (integ.u[1] *= 2; nothing)) + jp_w = JumpProblem(DiscreteProblem([10], (0.0, T)), Direct(), weird) + @test_throws ErrorException bounded_ssa_path(jp_w, [0.5]; rate_bound = 10.0) + end +end diff --git a/test/test_groups.toml b/test/test_groups.toml index 443c5cbb1..620d9fdca 100644 --- a/test/test_groups.toml +++ b/test/test_groups.toml @@ -6,3 +6,6 @@ versions = ["lts", "1", "pre"] [QA] versions = ["lts", "1"] + +[StochasticAD] +versions = ["lts", "1"]