Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
name = "GeometricIntegratorsDiffEq"
uuid = "5a33fad7-5ce4-5983-9f5d-5f26ceab5c96"
version = "1.1.0"
version = "1.2.0"
authors = ["Chris Rackauckas <accounts@chrisrackauckas.com>"]

[deps]
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
GeometricIntegrators = "dcce2d33-59f6-5b8d-9047-0defad88ae06"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLLogging = "a6db7da4-7206-11f0-1eab-35f2a5dbe1d1"

[compat]
DiffEqBase = "6.62"
DiffEqBase = "6.62, 7"
ExplicitImports = "1.14.0"
GeometricIntegrators = "0.15, 0.16"
Reexport = "0.2, 1"
SciMLBase = "2"
SciMLBase = "2, 3"
SciMLLogging = "1"
julia = "1.10"

[extras]
Expand Down
3 changes: 2 additions & 1 deletion src/GeometricIntegratorsDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ module GeometricIntegratorsDiffEq

using Reexport: Reexport, @reexport
@reexport using DiffEqBase: DiffEqBase
using SciMLBase: SciMLBase, ReturnCode, check_keywords, warn_compat
using SciMLBase: SciMLBase, ReturnCode
using SciMLLogging: @SciMLMessage

using GeometricIntegrators: GeometricIntegrators, CrankNicolson, Crouzeix,
ExplicitEuler, ExplicitMidpoint, Gauss, Heun2, Heun3, ImplicitEuler,
Expand Down
70 changes: 53 additions & 17 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,46 @@ function DiffEqBase.__solve(
}
)

if verbose
warned = !isempty(kwargs) && check_keywords(alg, kwargs, warnlist)
if !(prob.f isa DiffEqBase.AbstractParameterizedFunction) && isstiff
if DiffEqBase.has_tgrad(prob.f)
@warn "Explicit t-gradient given to this stiff solver is ignored."
warned = true
end
if DiffEqBase.has_jac(prob.f)
@warn "Explicit Jacobian given to this stiff solver is ignored."
warned = true
end
# `verbose` may be a `Bool` (DiffEqBase v6) or a `DEVerbosity` / other
# `AbstractVerbositySpecifier` (DiffEqBase v7+). Route each warning through
# `@SciMLMessage` so a silent spec (e.g. `DEVerbosity(None())`) actually
# suppresses it — a `verbose !== false` guard can't, since a silent
# `DEVerbosity` is not `false`. The `:mismatched_input_output_type` toggle
# is the right DEVerbosity bucket for these: every message here reports an
# input the GeometricIntegrators solver cannot honor. For `Bool` verbose,
# `@SciMLMessage` ignores the toggle name and just maps true→WarnLevel,
# false→Silent.
warned = false
for (kw, val) in kwargs
if kw in warnlist && val !== nothing
@SciMLMessage(
string("The ", kw, " argument is ignored by ", alg, "."),
verbose, :mismatched_input_output_type
)
warned = true
end
warned && warn_compat()
end
if !(prob.f isa DiffEqBase.AbstractParameterizedFunction) && isstiff
if DiffEqBase.has_tgrad(prob.f)
@SciMLMessage(
"Explicit t-gradient given to this stiff solver is ignored.",
verbose, :mismatched_input_output_type
)
warned = true
end
if DiffEqBase.has_jac(prob.f)
@SciMLMessage(
"Explicit Jacobian given to this stiff solver is ignored.",
verbose, :mismatched_input_output_type
)
warned = true
end
end
if warned
@SciMLMessage(
"https://docs.sciml.ai/DiffEqDocs/stable/basics/compatibility_chart/",
verbose, :mismatched_input_output_type
)
end

if callback !== nothing
Expand Down Expand Up @@ -68,15 +95,20 @@ function DiffEqBase.__solve(
# Create function wrapper for GeometricIntegrators API
# GeometricIntegrators expects: v(v, t, q, params)
# DiffEqBase provides: f(du, u, p, t) for inplace or f(u, p, t) for out-of-place
# SciMLBase v3's default AutoSpecialize wraps f in a FunctionWrappersWrapper that
# only accepts the exact argument types captured at problem construction (e.g.
# `Matrix{Float64}`), so passing a `Base.ReshapedArray` to it errors. Unwrap the
# underlying user-defined function before invoking it.
raw_f = SciMLBase.unwrapped_f(prob.f.f)
if !isinplace && u isa AbstractArray
v! = (v, t, q, params) -> (v .= vec(prob.f(reshape(q, sizeu), p, t)); nothing)
v! = (v, t, q, params) -> (v .= vec(raw_f(reshape(q, sizeu), p, t)); nothing)
elseif !(u isa Vector{Float64})
v! = (
v, t, q,
params,
) -> (prob.f(reshape(v, sizeu), reshape(q, sizeu), p, t); nothing)
) -> (raw_f(reshape(v, sizeu), reshape(q, sizeu), p, t); nothing)
else
v! = (v, t, q, params) -> prob.f(v, q, p, t)
v! = (v, t, q, params) -> raw_f(v, q, p, t)
end

ode = GeometricIntegrators.ODEProblem(v!, prob.tspan, dt, vec(prob.u0))
Expand Down Expand Up @@ -111,17 +143,21 @@ function DiffEqBase.__solve(

v! = (v, t, q, p_state, params) -> (v .= p_state) # dq/dt = p

# Unwrap past SciMLBase v3 AutoSpecialize FunctionWrappers for the acceleration
# function so it accepts argument types other than those captured at construction.
raw_f1 = SciMLBase.unwrapped_f(prob.f.f1.f)

# Handle both inplace and out-of-place problems
if isinplace
f! = (
f_out, t, q, p_state,
params,
) -> (prob.f.f1.f(f_out, p_state, q, p, t); nothing) # dp/dt = f1(p, q)
) -> (raw_f1(f_out, p_state, q, p, t); nothing) # dp/dt = f1(p, q)
else
f! = (
f_out, t, q, p_state,
params,
) -> (f_out .= prob.f.f1.f(p_state, q, p, t); nothing)
) -> (f_out .= raw_f1(p_state, q, p, t); nothing)
end

pode = GeometricIntegrators.PODEProblem(
Expand Down
Loading