Skip to content

feat: add SemilinearODEFunction and SemilinearODEProblem #3739

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821"
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
Expand All @@ -45,6 +46,7 @@ NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Expand Down Expand Up @@ -115,6 +117,7 @@ DynamicQuantities = "^0.11.2, 0.12, 0.13, 1"
EnumX = "1.0.4"
ExprTools = "0.1.10"
FMI = "0.14"
FillArrays = "1.13.0"
FindFirstFunctions = "1"
ForwardDiff = "0.10.3"
FunctionWrappers = "1.1"
Expand Down Expand Up @@ -142,6 +145,7 @@ OrdinaryDiffEq = "6.82.0"
OrdinaryDiffEqCore = "1.15.0"
OrdinaryDiffEqDefault = "1.2"
OrdinaryDiffEqNonlinearSolve = "1.5.0"
PreallocationTools = "0.4.27"
PrecompileTools = "1"
Pyomo = "0.1.0"
REPL = "1"
Expand Down
3 changes: 3 additions & 0 deletions docs/src/API/codegen.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ ModelingToolkit.build_explicit_observed_function
ModelingToolkit.generate_control_function
ModelingToolkit.generate_update_A
ModelingToolkit.generate_update_b
ModelingToolkit.generate_semiquadratic_functions
ModelingToolkit.generate_semiquadratic_jacobian
ModelingToolkit.get_semiquadratic_W_sparsity
```

For functions such as jacobian calculation which require symbolic computation, there
Expand Down
2 changes: 2 additions & 0 deletions docs/src/API/problems.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ SciMLBase.ODEFunction
SciMLBase.ODEProblem
SciMLBase.DAEFunction
SciMLBase.DAEProblem
ModelingToolkit.SemilinearODEFunction
ModelingToolkit.SemilinearODEProblem
SciMLBase.SDEFunction
SciMLBase.SDEProblem
SciMLBase.DDEFunction
Expand Down
4 changes: 4 additions & 0 deletions src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ const DQ = DynamicQuantities
import DifferentiationInterface as DI
using ADTypes: AutoForwardDiff
import SciMLPublic: @public
import PreallocationTools
import PreallocationTools: DiffCache
import FillArrays

export @derivatives

Expand Down Expand Up @@ -288,6 +291,7 @@ export IntervalNonlinearProblem
export OptimizationProblem, constraints
export SteadyStateProblem
export JumpProblem
export SemilinearODEFunction, SemilinearODEProblem
export alias_elimination, flatten
export connect, domain_connect, @connector, Connection, AnalysisPoint, Flow, Stream,
instream
Expand Down
101 changes: 83 additions & 18 deletions src/problems/docs.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
struct SemilinearODEFunction{iip, spec} end
struct SemilinearODEProblem{iip, spec} end

const U0_P_DOCS = """
The order of unknowns is determined by `unknowns(sys)`. If the system is split
[`is_split`](@ref) create an [`MTKParameters`](@ref) object. Otherwise, a parameter vector.
Expand Down Expand Up @@ -92,6 +95,15 @@ function problem_ctors(prob, istd)
end
end

function problem_ctors(prob::Type{<:SemilinearODEProblem}, istd)
@assert istd
"""
SciMLBase.$prob(sys::System, op, tspan::NTuple{2}; kwargs...)
SciMLBase.$prob{iip}(sys::System, op, tspan::NTuple{2}; kwargs...)
SciMLBase.$prob{iip, specialize}(sys::System, op, tspan::NTuple{2}; stiff_A = true, stiff_B = false, stiff_C = false, kwargs...)
"""
end

function prob_fun_common_kwargs(T, istd)
return """
- `check_compatibility`: Whether to check if the given system `sys` contains all the
Expand All @@ -103,7 +115,8 @@ function prob_fun_common_kwargs(T, istd)
"""
end

function problem_docstring(prob, func, istd; init = true, extra_body = "")
function problem_docstring(prob, func, istd; init = true, extra_body = "",
extra_kwargs = "", extra_kwargs_desc = "")
if func isa DataType
func = "`$func`"
end
Expand All @@ -127,8 +140,9 @@ function problem_docstring(prob, func, istd; init = true, extra_body = "")
$PROBLEM_KWARGS
$(istd ? TIME_DEPENDENT_PROBLEM_KWARGS : "")
$(prob_fun_common_kwargs(prob, istd))

$(extra_kwargs)
All other keyword arguments are forwarded to the $func constructor.
$(extra_kwargs_desc)

$PROBLEM_INTERNALS_HEADER

Expand Down Expand Up @@ -186,6 +200,32 @@ If the `System` has algebraic equations, like `x(t)^2 + y(t)^2`, the resulting
`BVProblem` must be solved using BVDAE solvers, such as Ascher.
"""

const SEMILINEAR_EXTRA_BODY = """
This is a special form of an ODE which uses a `SplitFunction` internally. The equations are
separated into linear, quadratic and general terms and phrased as matrix operations. See
[`calculate_semiquadratic_form`](@ref) for information on how the equations are split. This
formulation allows leveraging split ODE solvers such as `KenCarp4` and is useful for systems
where the stiff and non-stiff terms can be separated out in such a manner. Typically the linear
part of the equations is the stiff part, but the keywords `stiff_A`, `stiff_B` and `stiff_C` can
be used to control which parts are considered as stiff.
"""

const SEMILINEAR_A_B_C_KWARGS = """
- `stiff_A`: Whether the linear part of the equations should be part of the stiff function
in the split form. Has no effect if the equations have no linear part.
- `stiff_B`: Whether the quadratic part of the equations should be part of the stiff
function in the split form. Has no effect if the equations have no quadratic part.
- `stiff_C`: Whether the non-linear non-quadratic part of the equations should be part of
the stiff function in the split form. Has no effect if the equations have no such
non-linear non-quadratic part.
"""

const SEMILINEAR_A_B_C_CONSTRAINT = """
Note that all three of `stiff_A`, `stiff_B`, `stiff_C` cannot be identical, and at least
two of `A`, `B`, `C` returned from [`calculate_semiquadratic_form`](@ref) must be
non-`nothing`. In other words, both of the functions in the split form must be non-empty.
"""

for (mod, prob, func, istd, kws) in [
(SciMLBase, :ODEProblem, ODEFunction, true, (;)),
(SciMLBase, :SteadyStateProblem, ODEFunction, false, (;)),
Expand All @@ -201,12 +241,23 @@ for (mod, prob, func, istd, kws) in [
(SciMLBase, :NonlinearProblem, NonlinearFunction, false, (;)),
(SciMLBase, :NonlinearLeastSquaresProblem, NonlinearFunction, false, (;)),
(SciMLBase, :SCCNonlinearProblem, NonlinearFunction, false, (; init = false)),
(SciMLBase, :OptimizationProblem, OptimizationFunction, false, (; init = false))
(SciMLBase, :OptimizationProblem, OptimizationFunction, false, (; init = false)),
(ModelingToolkit,
:SemilinearODEProblem,
:SemilinearODEFunction,
true,
(; extra_body = SEMILINEAR_EXTRA_BODY, extra_kwargs = SEMILINEAR_A_B_C_KWARGS,
extra_kwargs_desc = SEMILINEAR_A_B_C_CONSTRAINT))
]
@eval @doc problem_docstring($mod.$prob, $func, $istd) $mod.$prob
kwexpr = Expr(:parameters)
for (k, v) in pairs(kws)
push!(kwexpr.args, Expr(:kw, k, v))
end
@eval @doc problem_docstring($kwexpr, $mod.$prob, $func, $istd) $mod.$prob
end

function function_docstring(func, istd, optionals)
function function_docstring(
func, istd, optionals; extra_body = "", extra_kwargs = "", extra_kwargs_desc = "")
return """
$func(sys::System; kwargs...)
$func{iip}(sys::System; kwargs...)
Expand All @@ -216,6 +267,8 @@ function function_docstring(func, istd, optionals)
function should be in-place. `specialization` is a `SciMLBase.AbstractSpecalize`
subtype indicating the level of specialization of the $func.

$(extra_body)

# Keyword arguments

- `u0`: The `u0` vector for the corresponding problem, if available. Can be obtained
Expand All @@ -232,8 +285,10 @@ function function_docstring(func, istd, optionals)
sparse matrices. Also controls whether the mass matrix is sparse, wherever applicable.
$(prob_fun_common_kwargs(func, istd))
$(process_optional_function_kwargs(optionals))
$(extra_kwargs)

All other keyword arguments are forwarded to the `$func` struct constructor.
$(extra_kwargs_desc)
"""
end

Expand Down Expand Up @@ -324,20 +379,30 @@ function process_optional_function_kwargs(choices::Vector{Symbol})
join(map(Base.Fix1(getindex, OPTIONAL_FN_KWARGS_DICT), choices), "\n")
end

for (mod, func, istd, optionals) in [
(SciMLBase, :ODEFunction, true, [:jac, :tgrad]),
(SciMLBase, :ODEInputFunction, true, [:inputfn, :jac, :tgrad, :controljac]),
(SciMLBase, :DAEFunction, true, [:jac, :tgrad]),
(SciMLBase, :DDEFunction, true, Symbol[]),
(SciMLBase, :SDEFunction, true, [:jac, :tgrad]),
(SciMLBase, :SDDEFunction, true, Symbol[]),
(SciMLBase, :DiscreteFunction, true, Symbol[]),
(SciMLBase, :ImplicitDiscreteFunction, true, Symbol[]),
(SciMLBase, :NonlinearFunction, false, [:resid_prototype, :jac]),
(SciMLBase, :IntervalNonlinearFunction, false, Symbol[]),
(SciMLBase, :OptimizationFunction, false, [:jac, :grad, :hess, :cons_h, :cons_j])
for (mod, func, istd, optionals, kws) in [
(SciMLBase, :ODEFunction, true, [:jac, :tgrad], (;)),
(SciMLBase, :ODEInputFunction, true, [:inputfn, :jac, :tgrad, :controljac], (;)),
(SciMLBase, :DAEFunction, true, [:jac, :tgrad], (;)),
(SciMLBase, :DDEFunction, true, Symbol[], (;)),
(SciMLBase, :SDEFunction, true, [:jac, :tgrad], (;)),
(SciMLBase, :SDDEFunction, true, Symbol[], (;)),
(SciMLBase, :DiscreteFunction, true, Symbol[], (;)),
(SciMLBase, :ImplicitDiscreteFunction, true, Symbol[], (;)),
(SciMLBase, :NonlinearFunction, false, [:resid_prototype, :jac], (;)),
(SciMLBase, :IntervalNonlinearFunction, false, Symbol[], (;)),
(SciMLBase, :OptimizationFunction, false, [:jac, :grad, :hess, :cons_h, :cons_j], (;)),
(ModelingToolkit,
:SemilinearODEFunction,
true,
[:jac],
(; extra_body = SEMILINEAR_EXTRA_BODY, extra_kwargs = SEMILINEAR_A_B_C_KWARGS,
extra_kwargs_desc = SEMILINEAR_A_B_C_CONSTRAINT))
]
@eval @doc function_docstring($mod.$func, $istd, $optionals) $mod.$func
kwexpr = Expr(:parameters)
for (k, v) in pairs(kws)
push!(kwexpr.args, Expr(:kw, k, v))
end
@eval @doc function_docstring($kwexpr, $mod.$func, $istd, $optionals) $mod.$func
end

@doc """
Expand Down
152 changes: 151 additions & 1 deletion src/problems/odeproblem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,159 @@ end
maybe_codegen_scimlproblem(expression, SteadyStateProblem{iip}, args; kwargs...)
end

@fallback_iip_specialize function SemilinearODEFunction{iip, specialize}(
sys::System; u0 = nothing, p = nothing, t = nothing,
semiquadratic_form = nothing,
stiff_A = true, stiff_B = false, stiff_C = false,
eval_expression = false, eval_module = @__MODULE__,
expression = Val{false}, sparse = false, check_compatibility = true,
jac = false, checkbounds = false, cse = true, initialization_data = nothing,
analytic = nothing, kwargs...) where {iip, specialize}
check_complete(sys, SemilinearODEFunction)
check_compatibility && check_compatible_system(SemilinearODEFunction, sys)

if semiquadratic_form === nothing
semiquadratic_form = calculate_semiquadratic_form(sys; sparse)
sys = add_semiquadratic_parameters(sys, semiquadratic_form...)
end

A, B, C = semiquadratic_form
M = calculate_massmatrix(sys)
_M = concrete_massmatrix(M; sparse, u0)
dvs = unknowns(sys)

f1, f2 = generate_semiquadratic_functions(
sys, A, B, C; stiff_A, stiff_B, stiff_C, expression, wrap_gfw = Val{true},
eval_expression, eval_module, kwargs...)

if jac
Cjac = (C === nothing || !stiff_C) ? nothing : Symbolics.jacobian(C, dvs)
_jac = generate_semiquadratic_jacobian(
sys, A, B, C, Cjac; sparse, expression,
wrap_gfw = Val{true}, eval_expression, eval_module, kwargs...)
_W_sparsity = get_semiquadratic_W_sparsity(
sys, A, B, C, Cjac; stiff_A, stiff_B, stiff_C, mm = M)
W_prototype = calculate_W_prototype(_W_sparsity; u0, sparse)
else
_jac = nothing
W_prototype = nothing
end

observedfun = ObservedFunctionCache(
sys; expression, steady_state = false, eval_expression, eval_module, checkbounds, cse)

args = (; f1)
kwargs = (; jac = _jac, jac_prototype = W_prototype)
f1 = maybe_codegen_scimlfn(expression, ODEFunction{iip, specialize}, args; kwargs...)

args = (; f1, f2)
kwargs = (;
sys = sys,
jac = _jac,
mass_matrix = _M,
jac_prototype = W_prototype,
observed = observedfun,
analytic,
initialization_data)

return maybe_codegen_scimlfn(
expression, SplitFunction{iip, specialize}, args; kwargs...)
end

@fallback_iip_specialize function SemilinearODEProblem{iip, spec}(
sys::System, op, tspan; check_compatibility = true, u0_eltype = nothing,
expression = Val{false}, callback = nothing, sparse = false,
stiff_A = true, stiff_B = false, stiff_C = false, jac = false, kwargs...) where {
iip, spec}
check_complete(sys, SemilinearODEProblem)
check_compatibility && check_compatible_system(SemilinearODEProblem, sys)

A, B, C = semiquadratic_form = calculate_semiquadratic_form(sys; sparse)
eqs = equations(sys)
dvs = unknowns(sys)

sys = add_semiquadratic_parameters(sys, A, B, C)
if A !== nothing
linear_matrix_param = unwrap(getproperty(sys, LINEAR_MATRIX_PARAM_NAME))
else
linear_matrix_param = nothing
end
if B !== nothing
quadratic_forms = [unwrap(getproperty(sys, get_quadratic_form_name(i)))
for i in 1:length(eqs)]
diffcache_par = unwrap(getproperty(sys, DIFFCACHE_PARAM_NAME))
else
quadratic_forms = diffcache_par = nothing
end

op = to_varmap(op, dvs)
floatT = calculate_float_type(op, typeof(op))
_u0_eltype = something(u0_eltype, floatT)

guess = copy(guesses(sys))
defs = copy(defaults(sys))
if A !== nothing
guess[linear_matrix_param] = fill(NaN, size(A))
defs[linear_matrix_param] = A
end
if B !== nothing
for (par, mat) in zip(quadratic_forms, B)
guess[par] = fill(NaN, size(mat))
defs[par] = mat
end
cachelen = jac ? length(dvs) * length(eqs) : length(dvs)
defs[diffcache_par] = DiffCache(zeros(DiffEqBase.value(_u0_eltype), cachelen))
end
@set! sys.guesses = guess
@set! sys.defaults = defs

f, u0, p = process_SciMLProblem(SemilinearODEFunction{iip, spec}, sys, op;
t = tspan !== nothing ? tspan[1] : tspan, expression, check_compatibility,
semiquadratic_form, sparse, u0_eltype, stiff_A, stiff_B, stiff_C, jac, kwargs...)

kwargs = process_kwargs(sys; expression, callback, kwargs...)

args = (; f, u0, tspan, p)
maybe_codegen_scimlproblem(expression, SplitODEProblem{iip}, args; kwargs...)
end

"""
$(TYPEDSIGNATURES)

Add the necessary parameters for [`SemilinearODEProblem`](@ref) given the matrices
`A`, `B`, `C` returned from [`calculate_semiquadratic_form`](@ref).
"""
function add_semiquadratic_parameters(sys::System, A, B, C)
eqs = equations(sys)
n = length(eqs)
var_to_name = copy(get_var_to_name(sys))
if B !== nothing
for i in eachindex(B)
B[i] === nothing && continue
par = get_quadratic_form_param((n, n), i)
var_to_name[get_quadratic_form_name(i)] = par
sys = with_additional_constant_parameter(sys, par)
end
par = get_diffcache_param(Float64)
var_to_name[DIFFCACHE_PARAM_NAME] = par
sys = with_additional_nonnumeric_parameter(sys, par)
end
if A !== nothing
par = get_linear_matrix_param((n, n))
var_to_name[LINEAR_MATRIX_PARAM_NAME] = par
sys = with_additional_constant_parameter(sys, par)
end
@set! sys.var_to_name = var_to_name
if get_parent(sys) !== nothing
@set! sys.parent = add_semiquadratic_parameters(get_parent(sys), A, B, C)
end
return sys
end

function check_compatible_system(
T::Union{Type{ODEFunction}, Type{ODEProblem}, Type{DAEFunction},
Type{DAEProblem}, Type{SteadyStateProblem}},
Type{DAEProblem}, Type{SteadyStateProblem}, Type{SemilinearODEFunction},
Type{SemilinearODEProblem}},
sys::System)
check_time_dependent(sys, T)
check_not_dde(sys)
Expand Down
Loading
Loading