diff --git a/Project.toml b/Project.toml index 9a369fc2bd..13fe56b275 100644 --- a/Project.toml +++ b/Project.toml @@ -26,6 +26,7 @@ DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf" DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821" EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56" ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" @@ -45,6 +46,7 @@ NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8" +PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" @@ -115,6 +117,7 @@ DynamicQuantities = "^0.11.2, 0.12, 0.13, 1" EnumX = "1.0.4" ExprTools = "0.1.10" FMI = "0.14" +FillArrays = "1.13.0" FindFirstFunctions = "1" ForwardDiff = "0.10.3" FunctionWrappers = "1.1" @@ -142,6 +145,7 @@ OrdinaryDiffEq = "6.82.0" OrdinaryDiffEqCore = "1.15.0" OrdinaryDiffEqDefault = "1.2" OrdinaryDiffEqNonlinearSolve = "1.5.0" +PreallocationTools = "0.4.27" PrecompileTools = "1" Pyomo = "0.1.0" REPL = "1" diff --git a/docs/src/API/codegen.md b/docs/src/API/codegen.md index 4f31405174..3a813038e5 100644 --- a/docs/src/API/codegen.md +++ b/docs/src/API/codegen.md @@ -23,6 +23,9 @@ ModelingToolkit.build_explicit_observed_function ModelingToolkit.generate_control_function ModelingToolkit.generate_update_A ModelingToolkit.generate_update_b +ModelingToolkit.generate_semiquadratic_functions +ModelingToolkit.generate_semiquadratic_jacobian +ModelingToolkit.get_semiquadratic_W_sparsity ``` For functions such as jacobian calculation which require symbolic computation, there diff --git a/docs/src/API/problems.md b/docs/src/API/problems.md index 72147a7e09..d8cb7e940c 100644 --- a/docs/src/API/problems.md +++ b/docs/src/API/problems.md @@ -17,6 +17,8 @@ SciMLBase.ODEFunction SciMLBase.ODEProblem SciMLBase.DAEFunction SciMLBase.DAEProblem +ModelingToolkit.SemilinearODEFunction +ModelingToolkit.SemilinearODEProblem SciMLBase.SDEFunction SciMLBase.SDEProblem SciMLBase.DDEFunction diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index d23d173b9a..52611bcc32 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -99,6 +99,9 @@ const DQ = DynamicQuantities import DifferentiationInterface as DI using ADTypes: AutoForwardDiff import SciMLPublic: @public +import PreallocationTools +import PreallocationTools: DiffCache +import FillArrays export @derivatives @@ -288,6 +291,7 @@ export IntervalNonlinearProblem export OptimizationProblem, constraints export SteadyStateProblem export JumpProblem +export SemilinearODEFunction, SemilinearODEProblem export alias_elimination, flatten export connect, domain_connect, @connector, Connection, AnalysisPoint, Flow, Stream, instream diff --git a/src/problems/docs.jl b/src/problems/docs.jl index 17bc2c83c6..0e043f68f2 100644 --- a/src/problems/docs.jl +++ b/src/problems/docs.jl @@ -1,3 +1,6 @@ +struct SemilinearODEFunction{iip, spec} end +struct SemilinearODEProblem{iip, spec} end + const U0_P_DOCS = """ The order of unknowns is determined by `unknowns(sys)`. If the system is split [`is_split`](@ref) create an [`MTKParameters`](@ref) object. Otherwise, a parameter vector. @@ -92,6 +95,15 @@ function problem_ctors(prob, istd) end end +function problem_ctors(prob::Type{<:SemilinearODEProblem}, istd) + @assert istd + """ + SciMLBase.$prob(sys::System, op, tspan::NTuple{2}; kwargs...) + SciMLBase.$prob{iip}(sys::System, op, tspan::NTuple{2}; kwargs...) + SciMLBase.$prob{iip, specialize}(sys::System, op, tspan::NTuple{2}; stiff_A = true, stiff_B = false, stiff_C = false, kwargs...) + """ +end + function prob_fun_common_kwargs(T, istd) return """ - `check_compatibility`: Whether to check if the given system `sys` contains all the @@ -103,7 +115,8 @@ function prob_fun_common_kwargs(T, istd) """ end -function problem_docstring(prob, func, istd; init = true, extra_body = "") +function problem_docstring(prob, func, istd; init = true, extra_body = "", + extra_kwargs = "", extra_kwargs_desc = "") if func isa DataType func = "`$func`" end @@ -127,8 +140,9 @@ function problem_docstring(prob, func, istd; init = true, extra_body = "") $PROBLEM_KWARGS $(istd ? TIME_DEPENDENT_PROBLEM_KWARGS : "") $(prob_fun_common_kwargs(prob, istd)) - + $(extra_kwargs) All other keyword arguments are forwarded to the $func constructor. + $(extra_kwargs_desc) $PROBLEM_INTERNALS_HEADER @@ -186,6 +200,32 @@ If the `System` has algebraic equations, like `x(t)^2 + y(t)^2`, the resulting `BVProblem` must be solved using BVDAE solvers, such as Ascher. """ +const SEMILINEAR_EXTRA_BODY = """ +This is a special form of an ODE which uses a `SplitFunction` internally. The equations are +separated into linear, quadratic and general terms and phrased as matrix operations. See +[`calculate_semiquadratic_form`](@ref) for information on how the equations are split. This +formulation allows leveraging split ODE solvers such as `KenCarp4` and is useful for systems +where the stiff and non-stiff terms can be separated out in such a manner. Typically the linear +part of the equations is the stiff part, but the keywords `stiff_A`, `stiff_B` and `stiff_C` can +be used to control which parts are considered as stiff. +""" + +const SEMILINEAR_A_B_C_KWARGS = """ +- `stiff_A`: Whether the linear part of the equations should be part of the stiff function + in the split form. Has no effect if the equations have no linear part. +- `stiff_B`: Whether the quadratic part of the equations should be part of the stiff + function in the split form. Has no effect if the equations have no quadratic part. +- `stiff_C`: Whether the non-linear non-quadratic part of the equations should be part of + the stiff function in the split form. Has no effect if the equations have no such + non-linear non-quadratic part. +""" + +const SEMILINEAR_A_B_C_CONSTRAINT = """ +Note that all three of `stiff_A`, `stiff_B`, `stiff_C` cannot be identical, and at least +two of `A`, `B`, `C` returned from [`calculate_semiquadratic_form`](@ref) must be +non-`nothing`. In other words, both of the functions in the split form must be non-empty. +""" + for (mod, prob, func, istd, kws) in [ (SciMLBase, :ODEProblem, ODEFunction, true, (;)), (SciMLBase, :SteadyStateProblem, ODEFunction, false, (;)), @@ -201,12 +241,23 @@ for (mod, prob, func, istd, kws) in [ (SciMLBase, :NonlinearProblem, NonlinearFunction, false, (;)), (SciMLBase, :NonlinearLeastSquaresProblem, NonlinearFunction, false, (;)), (SciMLBase, :SCCNonlinearProblem, NonlinearFunction, false, (; init = false)), - (SciMLBase, :OptimizationProblem, OptimizationFunction, false, (; init = false)) + (SciMLBase, :OptimizationProblem, OptimizationFunction, false, (; init = false)), + (ModelingToolkit, + :SemilinearODEProblem, + :SemilinearODEFunction, + true, + (; extra_body = SEMILINEAR_EXTRA_BODY, extra_kwargs = SEMILINEAR_A_B_C_KWARGS, + extra_kwargs_desc = SEMILINEAR_A_B_C_CONSTRAINT)) ] - @eval @doc problem_docstring($mod.$prob, $func, $istd) $mod.$prob + kwexpr = Expr(:parameters) + for (k, v) in pairs(kws) + push!(kwexpr.args, Expr(:kw, k, v)) + end + @eval @doc problem_docstring($kwexpr, $mod.$prob, $func, $istd) $mod.$prob end -function function_docstring(func, istd, optionals) +function function_docstring( + func, istd, optionals; extra_body = "", extra_kwargs = "", extra_kwargs_desc = "") return """ $func(sys::System; kwargs...) $func{iip}(sys::System; kwargs...) @@ -216,6 +267,8 @@ function function_docstring(func, istd, optionals) function should be in-place. `specialization` is a `SciMLBase.AbstractSpecalize` subtype indicating the level of specialization of the $func. + $(extra_body) + # Keyword arguments - `u0`: The `u0` vector for the corresponding problem, if available. Can be obtained @@ -232,8 +285,10 @@ function function_docstring(func, istd, optionals) sparse matrices. Also controls whether the mass matrix is sparse, wherever applicable. $(prob_fun_common_kwargs(func, istd)) $(process_optional_function_kwargs(optionals)) + $(extra_kwargs) All other keyword arguments are forwarded to the `$func` struct constructor. + $(extra_kwargs_desc) """ end @@ -324,20 +379,30 @@ function process_optional_function_kwargs(choices::Vector{Symbol}) join(map(Base.Fix1(getindex, OPTIONAL_FN_KWARGS_DICT), choices), "\n") end -for (mod, func, istd, optionals) in [ - (SciMLBase, :ODEFunction, true, [:jac, :tgrad]), - (SciMLBase, :ODEInputFunction, true, [:inputfn, :jac, :tgrad, :controljac]), - (SciMLBase, :DAEFunction, true, [:jac, :tgrad]), - (SciMLBase, :DDEFunction, true, Symbol[]), - (SciMLBase, :SDEFunction, true, [:jac, :tgrad]), - (SciMLBase, :SDDEFunction, true, Symbol[]), - (SciMLBase, :DiscreteFunction, true, Symbol[]), - (SciMLBase, :ImplicitDiscreteFunction, true, Symbol[]), - (SciMLBase, :NonlinearFunction, false, [:resid_prototype, :jac]), - (SciMLBase, :IntervalNonlinearFunction, false, Symbol[]), - (SciMLBase, :OptimizationFunction, false, [:jac, :grad, :hess, :cons_h, :cons_j]) +for (mod, func, istd, optionals, kws) in [ + (SciMLBase, :ODEFunction, true, [:jac, :tgrad], (;)), + (SciMLBase, :ODEInputFunction, true, [:inputfn, :jac, :tgrad, :controljac], (;)), + (SciMLBase, :DAEFunction, true, [:jac, :tgrad], (;)), + (SciMLBase, :DDEFunction, true, Symbol[], (;)), + (SciMLBase, :SDEFunction, true, [:jac, :tgrad], (;)), + (SciMLBase, :SDDEFunction, true, Symbol[], (;)), + (SciMLBase, :DiscreteFunction, true, Symbol[], (;)), + (SciMLBase, :ImplicitDiscreteFunction, true, Symbol[], (;)), + (SciMLBase, :NonlinearFunction, false, [:resid_prototype, :jac], (;)), + (SciMLBase, :IntervalNonlinearFunction, false, Symbol[], (;)), + (SciMLBase, :OptimizationFunction, false, [:jac, :grad, :hess, :cons_h, :cons_j], (;)), + (ModelingToolkit, + :SemilinearODEFunction, + true, + [:jac], + (; extra_body = SEMILINEAR_EXTRA_BODY, extra_kwargs = SEMILINEAR_A_B_C_KWARGS, + extra_kwargs_desc = SEMILINEAR_A_B_C_CONSTRAINT)) ] - @eval @doc function_docstring($mod.$func, $istd, $optionals) $mod.$func + kwexpr = Expr(:parameters) + for (k, v) in pairs(kws) + push!(kwexpr.args, Expr(:kw, k, v)) + end + @eval @doc function_docstring($kwexpr, $mod.$func, $istd, $optionals) $mod.$func end @doc """ diff --git a/src/problems/odeproblem.jl b/src/problems/odeproblem.jl index bc8b9cf701..87ee7af224 100644 --- a/src/problems/odeproblem.jl +++ b/src/problems/odeproblem.jl @@ -98,9 +98,159 @@ end maybe_codegen_scimlproblem(expression, SteadyStateProblem{iip}, args; kwargs...) end +@fallback_iip_specialize function SemilinearODEFunction{iip, specialize}( + sys::System; u0 = nothing, p = nothing, t = nothing, + semiquadratic_form = nothing, + stiff_A = true, stiff_B = false, stiff_C = false, + eval_expression = false, eval_module = @__MODULE__, + expression = Val{false}, sparse = false, check_compatibility = true, + jac = false, checkbounds = false, cse = true, initialization_data = nothing, + analytic = nothing, kwargs...) where {iip, specialize} + check_complete(sys, SemilinearODEFunction) + check_compatibility && check_compatible_system(SemilinearODEFunction, sys) + + if semiquadratic_form === nothing + semiquadratic_form = calculate_semiquadratic_form(sys; sparse) + sys = add_semiquadratic_parameters(sys, semiquadratic_form...) + end + + A, B, C = semiquadratic_form + M = calculate_massmatrix(sys) + _M = concrete_massmatrix(M; sparse, u0) + dvs = unknowns(sys) + + f1, f2 = generate_semiquadratic_functions( + sys, A, B, C; stiff_A, stiff_B, stiff_C, expression, wrap_gfw = Val{true}, + eval_expression, eval_module, kwargs...) + + if jac + Cjac = (C === nothing || !stiff_C) ? nothing : Symbolics.jacobian(C, dvs) + _jac = generate_semiquadratic_jacobian( + sys, A, B, C, Cjac; sparse, expression, + wrap_gfw = Val{true}, eval_expression, eval_module, kwargs...) + _W_sparsity = get_semiquadratic_W_sparsity( + sys, A, B, C, Cjac; stiff_A, stiff_B, stiff_C, mm = M) + W_prototype = calculate_W_prototype(_W_sparsity; u0, sparse) + else + _jac = nothing + W_prototype = nothing + end + + observedfun = ObservedFunctionCache( + sys; expression, steady_state = false, eval_expression, eval_module, checkbounds, cse) + + args = (; f1) + kwargs = (; jac = _jac, jac_prototype = W_prototype) + f1 = maybe_codegen_scimlfn(expression, ODEFunction{iip, specialize}, args; kwargs...) + + args = (; f1, f2) + kwargs = (; + sys = sys, + jac = _jac, + mass_matrix = _M, + jac_prototype = W_prototype, + observed = observedfun, + analytic, + initialization_data) + + return maybe_codegen_scimlfn( + expression, SplitFunction{iip, specialize}, args; kwargs...) +end + +@fallback_iip_specialize function SemilinearODEProblem{iip, spec}( + sys::System, op, tspan; check_compatibility = true, u0_eltype = nothing, + expression = Val{false}, callback = nothing, sparse = false, + stiff_A = true, stiff_B = false, stiff_C = false, jac = false, kwargs...) where { + iip, spec} + check_complete(sys, SemilinearODEProblem) + check_compatibility && check_compatible_system(SemilinearODEProblem, sys) + + A, B, C = semiquadratic_form = calculate_semiquadratic_form(sys; sparse) + eqs = equations(sys) + dvs = unknowns(sys) + + sys = add_semiquadratic_parameters(sys, A, B, C) + if A !== nothing + linear_matrix_param = unwrap(getproperty(sys, LINEAR_MATRIX_PARAM_NAME)) + else + linear_matrix_param = nothing + end + if B !== nothing + quadratic_forms = [unwrap(getproperty(sys, get_quadratic_form_name(i))) + for i in 1:length(eqs)] + diffcache_par = unwrap(getproperty(sys, DIFFCACHE_PARAM_NAME)) + else + quadratic_forms = diffcache_par = nothing + end + + op = to_varmap(op, dvs) + floatT = calculate_float_type(op, typeof(op)) + _u0_eltype = something(u0_eltype, floatT) + + guess = copy(guesses(sys)) + defs = copy(defaults(sys)) + if A !== nothing + guess[linear_matrix_param] = fill(NaN, size(A)) + defs[linear_matrix_param] = A + end + if B !== nothing + for (par, mat) in zip(quadratic_forms, B) + guess[par] = fill(NaN, size(mat)) + defs[par] = mat + end + cachelen = jac ? length(dvs) * length(eqs) : length(dvs) + defs[diffcache_par] = DiffCache(zeros(DiffEqBase.value(_u0_eltype), cachelen)) + end + @set! sys.guesses = guess + @set! sys.defaults = defs + + f, u0, p = process_SciMLProblem(SemilinearODEFunction{iip, spec}, sys, op; + t = tspan !== nothing ? tspan[1] : tspan, expression, check_compatibility, + semiquadratic_form, sparse, u0_eltype, stiff_A, stiff_B, stiff_C, jac, kwargs...) + + kwargs = process_kwargs(sys; expression, callback, kwargs...) + + args = (; f, u0, tspan, p) + maybe_codegen_scimlproblem(expression, SplitODEProblem{iip}, args; kwargs...) +end + +""" + $(TYPEDSIGNATURES) + +Add the necessary parameters for [`SemilinearODEProblem`](@ref) given the matrices +`A`, `B`, `C` returned from [`calculate_semiquadratic_form`](@ref). +""" +function add_semiquadratic_parameters(sys::System, A, B, C) + eqs = equations(sys) + n = length(eqs) + var_to_name = copy(get_var_to_name(sys)) + if B !== nothing + for i in eachindex(B) + B[i] === nothing && continue + par = get_quadratic_form_param((n, n), i) + var_to_name[get_quadratic_form_name(i)] = par + sys = with_additional_constant_parameter(sys, par) + end + par = get_diffcache_param(Float64) + var_to_name[DIFFCACHE_PARAM_NAME] = par + sys = with_additional_nonnumeric_parameter(sys, par) + end + if A !== nothing + par = get_linear_matrix_param((n, n)) + var_to_name[LINEAR_MATRIX_PARAM_NAME] = par + sys = with_additional_constant_parameter(sys, par) + end + @set! sys.var_to_name = var_to_name + if get_parent(sys) !== nothing + @set! sys.parent = add_semiquadratic_parameters(get_parent(sys), A, B, C) + end + return sys +end + function check_compatible_system( T::Union{Type{ODEFunction}, Type{ODEProblem}, Type{DAEFunction}, - Type{DAEProblem}, Type{SteadyStateProblem}}, + Type{DAEProblem}, Type{SteadyStateProblem}, Type{SemilinearODEFunction}, + Type{SemilinearODEProblem}}, sys::System) check_time_dependent(sys, T) check_not_dde(sys) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index bac04bf153..e10a309369 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -1186,6 +1186,10 @@ function is_array_of_symbolics(x) any(y -> symbolic_type(y) != NotSymbolic() || is_array_of_symbolics(y), x) end +function is_array_of_symbolics(x::SparseMatrixCSC) + return is_array_of_symbolics(nonzeros(x)) +end + function namespace_expr( O, sys, n = (sys === nothing ? nothing : nameof(sys)); ivs = sys === nothing ? nothing : independent_variables(sys)) diff --git a/src/systems/codegen.jl b/src/systems/codegen.jl index 4a68f935e8..7d9b0ec709 100644 --- a/src/systems/codegen.jl +++ b/src/systems/codegen.jl @@ -9,6 +9,12 @@ const GENERATE_X_KWARGS = """ $EVAL_EXPR_MOD_KWARGS """ +const EXPERIMENTAL_WARNING = """ +!!! warn + + This API is experimental and may change in a future non-breaking release. +""" + """ $(TYPEDSIGNATURES) @@ -1142,35 +1148,15 @@ Return matrix `A` and vector `b` such that the system `sys` can be represented a - `sparse`: return a sparse `A`. """ function calculate_A_b(sys::System; sparse = false) - rhss = [eq.rhs for eq in full_equations(sys)] + rhss = [-eq.rhs for eq in full_equations(sys)] dvs = unknowns(sys) - A = Matrix{Any}(undef, length(rhss), length(dvs)) - b = Vector{Any}(undef, length(rhss)) - for (i, rhs) in enumerate(rhss) - # mtkcompile makes this `0 ~ rhs` which typically ends up giving - # unknowns negative coefficients. If given the equations `A * x ~ b` - # it will simplify to `0 ~ b - A * x`. Thus this negation usually leads - # to more comprehensible user API. - resid = -rhs - for (j, var) in enumerate(dvs) - p, q, islinear = Symbolics.linear_expansion(resid, var) - if !islinear - throw(ArgumentError("System is not linear. Equation $((0 ~ rhs)) is not linear in unknown $var.")) - end - A[i, j] = p - resid = q - end - # negate beucause `resid` is the residual on the LHS - b[i] = -resid - end - - @assert all(Base.Fix1(isassigned, A), eachindex(A)) - @assert all(Base.Fix1(isassigned, A), eachindex(b)) - - if sparse - A = SparseArrays.sparse(A) + A, b = semilinear_form(rhss, dvs) + if !sparse + A = collect(A) end + A = unwrap.(A) + b = unwrap.(-b) return A, b end @@ -1217,3 +1203,468 @@ function generate_update_b(sys::System, b::AbstractVector; expression = Val{true return maybe_compile_function(expression, wrap_gfw, (1, 1, is_split(sys)), res; eval_expression, eval_module) end + +""" + $(TYPEDSIGNATURES) + +Represent the equations of the system `sys` in semiquadratic form. Returns 3 matrices +referred to as `A`, `B` and `C`. Hereon, `x` refers to the state vector. + +`A` contains coefficients for all linear terms in the equations. `A * x` is the linear +part of the RHS of the equations. `A` is `nothing` if none of the equations have +linear terms, in which case the corresponding term in the mathematical expression below +should be ignored. + +`B` is a vector of matrices where the `i`th matrix contains coefficients for the quadratic +terms in the `i`th equation's RHS. The quadratic terms in the `i`th equation are obtained +as `transpose(x) * B[i] * x`. Each matrix `B[i]` will be `nothing` if the `i`th equation +does not have quadratic term. in which case the corresponding term in the mathematical +expression below should be ignored. In case all `B[i]` are `nothing`, `B` will be +`nothing`. + +`C` is a vector of all non-linear and non-quadratic terms in the equations. `C` is +`nothing` if none of the equations have nonlinear and non-quadratic terms, in which case +the corresponding term in the mathematical expression below should be ignored. + +Mathematically, the right hand side of the `i`th equation is + +```math +\\mathtt{row}_i(\\mathbf{A})x + x^T(\\mathbf{B}_i)x + \\mathbf{C}_i +``` + +Note that any of `A`, `B` or `C` can be `nothing` if the coefficients/values are all zeros. + +## Keyword arguments + +- `sparse`: Return sparse matrices for `A`, `B` and `C`. +""" +function calculate_semiquadratic_form(sys::System; sparse = false) + rhss = [eq.rhs for eq in full_equations(sys)] + dvs = unknowns(sys) + A, B, x2, C = semiquadratic_form(rhss, dvs) + if nnz(B) == 0 + B = nothing + B2 = nothing + else + B2 = if sparse + Any[spzeros(Num, length(dvs), length(dvs)) for _ in 1:length(rhss)] + else + Any[zeros(Num, length(dvs), length(dvs)) for _ in 1:length(rhss)] + end + idxmap = vec([CartesianIndex(i, j) for j in 1:length(dvs) for i in 1:j]) + for (i, j, val) in zip(findnz(B)...) + B2[i][idxmap[j]] = val + end + for i in eachindex(B2) + if all(_iszero, B2[i]) + B2[i] = nothing + end + end + B2 = map(Broadcast.BroadcastFunction(unwrap), B2) + end + if nnz(A) == 0 + A = nothing + else + if !sparse + A = collect(A) + end + A = unwrap.(A) + end + if all(_iszero, C) + C = nothing + else + C = unwrap.(C) + end + + return A, B2, C +end + +const DIFFCACHE_PARAM_NAME = :__mtk_diffcache + +""" + $(TYPEDSIGNATURES) + +Return a symbolic variable representing a `PreallocationTools.DiffCache` with +floating-point type `T`. +""" +function get_diffcache_param(::Type{T}) where {T} + toconstant(Symbolics.variable( + DIFFCACHE_PARAM_NAME; T = DiffCache{Vector{T}, Vector{T}})) +end + +const LINEAR_MATRIX_PARAM_NAME = :linear_Aₘₜₖ + +""" + $(TYPEDSIGNATURES) + +Return a symbolic variable representing the `A` matrix returned from +[`calculate_semiquadratic_form`](@ref). +""" +function get_linear_matrix_param(size::NTuple{2, Int}) + m, n = size + unwrap(only(@constants $LINEAR_MATRIX_PARAM_NAME[1:m, 1:n])) +end + +""" + $(TYPEDSIGNATURES) + +Return the name of the `i`th matrix in `B` returned from +[`calculate_semiquadratic_form`](@ref). +""" +function get_quadratic_form_name(i::Int) + return Symbol(:quadratic_Bₘₜₖ_, i) +end + +""" + $(TYPEDSIGNATURES) + +Return a symbolic variable representing the `i`th matrix in `B` returned from +[`calculate_semiquadratic_form`](@ref). +""" +function get_quadratic_form_param(sz::NTuple{2, Int}, i::Int) + m, n = sz + name = get_quadratic_form_name(i) + unwrap(only(@constants $name[1:m, 1:n])) +end + +""" + $(TYPEDSIGNATURES) + +Return the parameter in `sys` corresponding to the one returned from +[`get_linear_matrix_param`](@ref), or `nothing` if `A === nothing`. +""" +function get_linear_matrix_param_from_sys(sys::System, A) + A === nothing && return nothing + return unwrap(getproperty(sys, LINEAR_MATRIX_PARAM_NAME)) +end + +""" + $(TYPEDSIGNATURES) + +Return the list of parameters in `sys` corresponding to the ones returned from +[`get_quadratic_form_param`](@ref) for each non-`nothing` matrix in `B`, or `nothing` if +`B === nothing`. If `B[i] === nothing`, the returned list will have `nothing` as the `i`th +entry. +""" +function get_quadratic_form_params_from_sys(sys::System, B) + B === nothing && return nothing + return map(eachindex(B)) do i + B[i] === nothing && return nothing + return unwrap(getproperty(sys, get_quadratic_form_name(i))) + end +end + +""" + $(TYPEDSIGNATURES) + +Return the parameter in `sys` corresponding to the one returned from +[`get_diffcache_param`](@ref). +""" +function get_diffcache_param_from_sys(sys::System, B) + B === nothing && return nothing + return unwrap(getproperty(sys, DIFFCACHE_PARAM_NAME)) +end + +""" + $(TYPEDSIGNATURES) + +Generate `f1` and `f2` for [`SemilinearODEFunction`](@ref) (internally represented as a +`SplitFunction`). `A`, `B`, `C` are the matrices returned from +[`calculate_semiquadratic_form`](@ref). This expects that the system has the necessary +extra parmameters added by [`add_semiquadratic_parameters`](@ref). + +## Keyword Arguments + +$SEMILINEAR_A_B_C_KWARGS +$GENERATE_X_KWARGS + +All other keyword arguments are forwarded to [`build_function_wrapper`](@ref). +$SEMILINEAR_A_B_C_CONSTRAINT + +$EXPERIMENTAL_WARNING +""" +function generate_semiquadratic_functions(sys::System, A, B, C; stiff_A = true, + stiff_B = false, stiff_C = false, expression = Val{true}, wrap_gfw = Val{false}, + eval_expression = false, eval_module = @__MODULE__, kwargs...) + if A === nothing && B === nothing + throw(ArgumentError("Cannot generate split form for the system - it has no linear or quadratic part.")) + end + + if (stiff_A || A === nothing) && (stiff_B || B === nothing) && stiff_C + throw(ArgumentError("All of `A`, `B` and `C` cannot be stiff at the same time.")) + end + if (!stiff_A || A === nothing) && (!stiff_B || B === nothing) && !stiff_C + throw(ArgumentError("All of `A`, `B` and `C` cannot be non-stiff at the same time.")) + end + linear_matrix_param = get_linear_matrix_param_from_sys(sys, A) + quadratic_forms = get_quadratic_form_params_from_sys(sys, B) + diffcache_par = get_diffcache_param_from_sys(sys, B) + eqs = equations(sys) + dvs = unknowns(sys) + ps = reorder_parameters(sys) + iv = get_iv(sys) + # Codegen is a bit manual, and we're manually creating an efficient IIP function. + # Since we explicitly provide Symbolics.DEFAULT_OUTSYM, the `u` is actually the second + # argument. + iip_x = generated_argument_name(2) + oop_x = generated_argument_name(1) + + ## iip + f1_iip_ir = Assignment[] + f2_iip_ir = Assignment[] + # C + if C !== nothing + C_ir = stiff_C ? f1_iip_ir : f2_iip_ir + push!(C_ir, Assignment(:__tmp_C, SetArray(false, Symbolics.DEFAULT_OUTSYM, C))) + end + # B + if B !== nothing + B_ir = stiff_B ? f1_iip_ir : f2_iip_ir + B_vals = map(eachindex(eqs)) do i + B[i] === nothing && return nothing + tmp_buf = term( + PreallocationTools.get_tmp, diffcache_par, Symbolics.DEFAULT_OUTSYM) + tmp_buf = term(view, tmp_buf, 1:length(dvs)) + + result = term(*, term(transpose, iip_x), :__tmp_B_1) + # if both write to the same buffer, don't overwrite + if stiff_B == stiff_C && C !== nothing + result = term(+, result, term(getindex, Symbolics.DEFAULT_OUTSYM, i)) + end + intermediates = [ + Assignment(:__tmp_B_buffer, tmp_buf), + Assignment(:__tmp_B_1, + term(mul!, :__tmp_B_buffer, + term(UpperTriangular, quadratic_forms[i]), iip_x)) + ] + return AtIndex(i, Let(intermediates, result)) + end + filter!(x -> x !== nothing, B_vals) + push!(B_ir, Assignment(:__tmp_B, SetArray(false, Symbolics.DEFAULT_OUTSYM, B_vals))) + end + # A + if A !== nothing + A_ir = stiff_A ? f1_iip_ir : f2_iip_ir + retain_old = stiff_A == stiff_B && B !== nothing || + stiff_A == stiff_C && C !== nothing + push!(A_ir, + Assignment(:__tmp_A, + term(mul!, Symbolics.DEFAULT_OUTSYM, + linear_matrix_param, iip_x, true, retain_old))) + end + ## oop + f1_terms = [] + f2_terms = [] + if A !== nothing + push!(stiff_A ? f1_terms : f2_terms, term(*, linear_matrix_param, oop_x)) + end + if B !== nothing + B_elems = map(eachindex(eqs)) do i + B[i] === nothing && return 0 + term( + *, term(transpose, oop_x), term(UpperTriangular, quadratic_forms[i]), oop_x) + end + push!(stiff_B ? f1_terms : f2_terms, MakeArray(B_elems, oop_x)) + end + if C !== nothing + push!(stiff_C ? f1_terms : f2_terms, MakeArray(C, oop_x)) + end + f1_expr = length(f1_terms) == 1 ? only(f1_terms) : term(+, f1_terms...) + f2_expr = length(f2_terms) == 1 ? only(f2_terms) : term(+, f2_terms...) + + f1_iip = build_function_wrapper( + sys, nothing, Symbolics.DEFAULT_OUTSYM, dvs, ps..., iv; p_start = 3, + extra_assignments = f1_iip_ir, expression = Val{true}, kwargs...) + f2_iip = build_function_wrapper( + sys, nothing, Symbolics.DEFAULT_OUTSYM, dvs, ps..., iv; p_start = 3, + extra_assignments = f2_iip_ir, expression = Val{true}, kwargs...) + f1_oop = build_function_wrapper( + sys, f1_expr, dvs, ps..., iv; expression = Val{true}, kwargs...) + f2_oop = build_function_wrapper( + sys, f2_expr, dvs, ps..., iv; expression = Val{true}, kwargs...) + + f1 = maybe_compile_function(expression, wrap_gfw, (2, 3, is_split(sys)), + (f1_oop, f1_iip); eval_expression, eval_module) + f2 = maybe_compile_function(expression, wrap_gfw, (2, 3, is_split(sys)), + (f2_oop, f2_iip); eval_expression, eval_module) + + return f1, f2 +end + +""" + $(TYPEDSIGNATURES) + +Generate the jacobian of `f1` for [`SemilinearODEFunction`](@ref) (internally represented as a +`SplitFunction`). `A`, `B`, `C` are the matrices returned from +[`calculate_semiquadratic_form`](@ref). `Cjac` is the jacobian of `C` with respect to the +unknowns of the system, or `nothing` if `C === nothing`. This expects that the system has the +necessary extra parmameters added by [`add_semiquadratic_parameters`](@ref). + +## Keyword Arguments + +$SEMILINEAR_A_B_C_KWARGS +$GENERATE_X_KWARGS + +All other keyword arguments are forwarded to [`build_function_wrapper`](@ref). +$SEMILINEAR_A_B_C_CONSTRAINT + +$EXPERIMENTAL_WARNING +""" +function generate_semiquadratic_jacobian( + sys::System, A, B, C, Cjac; sparse = false, stiff_A = true, stiff_B = false, + stiff_C = false, expression = Val{true}, wrap_gfw = Val{false}, + eval_expression = false, eval_module = @__MODULE__, kwargs...) + if sparse + error("Sparse analytical jacobians for split ODEs is not implemented.") + end + if A === nothing && B === nothing + throw(ArgumentError("Cannot generate split form for the system - it has no linear or quadratic part.")) + end + + if (stiff_A || A === nothing) && (stiff_B || B === nothing) && stiff_C + throw(ArgumentError("All of `A`, `B` and `C` cannot be stiff at the same time.")) + end + if (!stiff_A || A === nothing) && (!stiff_B || B === nothing) && !stiff_C + throw(ArgumentError("All of `A`, `B` and `C` cannot be non-stiff at the same time.")) + end + linear_matrix_param = get_linear_matrix_param_from_sys(sys, A) + quadratic_forms = get_quadratic_form_params_from_sys(sys, B) + diffcache_par = get_diffcache_param_from_sys(sys, B) + eqs = equations(sys) + dvs = unknowns(sys) + ps = reorder_parameters(sys) + iv = get_iv(sys) + M = length(eqs) + N = length(dvs) + # Codegen is a bit manual, and we're manually creating an efficient IIP function. + # Since we explicitly provide Symbolics.DEFAULT_OUTSYM, the `u` is actually the second + # argument. + iip_x = generated_argument_name(2) + oop_x = generated_argument_name(1) + + # iip + iip_ir = Assignment[] + if A !== nothing && stiff_A + push!(iip_ir, + Assignment( + :__A_jac, term(copyto!, Symbolics.DEFAULT_OUTSYM, linear_matrix_param))) + end + if B !== nothing && stiff_B + cachebuf_name = :__B_cache + cachelen = M * N + push!(iip_ir, + Assignment(cachebuf_name, + term(PreallocationTools.get_tmp, diffcache_par, Symbolics.DEFAULT_OUTSYM))) + push!(iip_ir, + Assignment( + cachebuf_name, term(reshape, term(view, cachebuf_name, 1:cachelen), M, N))) + for (i, quadpar) in enumerate(quadratic_forms) + B[i] === nothing && continue + coeffvar = Symbol(:__B_matrix_, i) + # B + B' + push!(iip_ir, Assignment(coeffvar, term(UpperTriangular, quadpar))) + push!(iip_ir, Assignment(:__tmp_B_1, term(copyto!, cachebuf_name, coeffvar))) + # mul! with scalar `B` does addition + push!(iip_ir, + Assignment(:__tmp_B_2, + term(mul!, cachebuf_name, true, term(transpose, coeffvar), true, true))) + # view the row of the jacobian this will write to + target_name = Symbol(:__jac_row_, i) + push!( + iip_ir, Assignment(target_name, term(view, Symbolics.DEFAULT_OUTSYM, i, :))) + # (B + B') * x, written directly to the jacobian. Retain the value in the jacobian if we've already + # written to it. + retain_old = A !== nothing && stiff_A + push!(iip_ir, + Assignment(:__tmp_B_3, + term(mul!, target_name, cachebuf_name, iip_x, true, retain_old))) + end + end + if C !== nothing && stiff_C + @assert Cjac !== nothing + @assert size(Cjac) == (M, N) + if A !== nothing && stiff_A || B !== nothing && stiff_B + idxs = map(eachindex(Cjac)) do idx + _iszero(Cjac[idx]) && return nothing + AtIndex( + idx, term(+, term(getindex, Symbolics.DEFAULT_OUTSYM, idx), Cjac[idx])) + end + filter!(x -> x !== nothing, idxs) + push!(iip_ir, + Assignment( + :__tmp_C, SetArray(false, Symbolics.DEFAULT_OUTSYM, idxs, false))) + end + end + + # oop + terms = [] + if A !== nothing && stiff_A + push!(terms, linear_matrix_param) + end + if B !== nothing && stiff_B + B_terms = map(eachindex(quadratic_forms)) do i + B[i] === nothing && return term(FillArrays.Falses, 1, N) + var = quadratic_forms[i] + var = term(UpperTriangular, var) + return term(*, term(+, var, term(transpose, var)), oop_x) + end + push!(terms, term(transpose, term(hcat, B_terms...))) + end + if C !== nothing && stiff_C + push!(terms, MakeArray(Cjac, oop_x)) + end + oop_expr = length(terms) == 1 ? only(terms) : term(+, terms...) + + j_iip = build_function_wrapper( + sys, nothing, Symbolics.DEFAULT_OUTSYM, dvs, ps..., iv; p_start = 3, + extra_assignments = iip_ir, expression = Val{true}, kwargs...) + j_oop, _ = build_function_wrapper( + sys, oop_expr, dvs, ps..., iv; expression = Val{true}, kwargs...) + return maybe_compile_function(expression, wrap_gfw, (2, 3, is_split(sys)), + (j_oop, j_iip); eval_expression, eval_module) +end + +""" + $(TYPEDSIGNATURES) + +Return the sparsity pattern of the jacobian of `f1` for [`SemilinearODEFunction`](@ref) +(internally represented as a `SplitFunction`). `A`, `B`, `C` are the matrices returned from +[`calculate_semiquadratic_form`](@ref). `Cjac` is the jacobian of `C` with respect to the +unknowns of the system, or `nothing` if `C === nothing`. This expects that the system has the +necessary extra parmameters added by [`add_semiquadratic_parameters`](@ref). + +## Keyword Arguments + +$SEMILINEAR_A_B_C_KWARGS +$GENERATE_X_KWARGS +- `mm`: The mass matrix of `sys`. + +$SEMILINEAR_A_B_C_CONSTRAINT + +$EXPERIMENTAL_WARNING +""" +function get_semiquadratic_W_sparsity( + sys::System, A, B, C, Cjac; stiff_A = true, stiff_B = false, + stiff_C = false, mm = calculate_massmatrix(sys)) + eqs = equations(sys) + dvs = unknowns(sys) + M = length(eqs) + N = length(dvs) + jac = spzeros(Num, M, N) + if stiff_A && A !== nothing + tmp = wrap.(A) + jac .+= tmp + end + if stiff_B && B !== nothing + for (i, mat) in enumerate(B) + mat === nothing && continue + jac[i, :] .+= (mat + transpose(mat)) * dvs + end + end + if stiff_C && C !== nothing + jac .+= Cjac + end + M_sparsity = mm isa UniformScaling ? sparse(I, M, N) : + SparseMatrixCSC{Bool, Int64}((!iszero).(mm)) + return (!_iszero).(jac) .| M_sparsity +end diff --git a/src/systems/codegen_utils.jl b/src/systems/codegen_utils.jl index dbbd7f85a8..d594a3902a 100644 --- a/src/systems/codegen_utils.jl +++ b/src/systems/codegen_utils.jl @@ -79,8 +79,8 @@ function array_variable_assignments(args...; argument_name = generated_argument_ # to help reduce allocations if first(idxs) < last(idxs) && vec(idxs) == first(idxs):last(idxs) idxs = first(idxs):last(idxs) - elseif vec(idxs) == last(idxs):-1:first(idxs) - idxs = last(idxs):-1:first(idxs) + elseif vec(idxs) == first(idxs):-1:last(idxs) + idxs = first(idxs):-1:last(idxs) else # Otherwise, turn the indexes into an `SArray` so they're stack-allocated idxs = SArray{Tuple{size(idxs)...}}(idxs) diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index 2ce1c7cffa..684cfd87df 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -63,6 +63,16 @@ struct IndexCache symbol_to_variable::Dict{Symbol, SymbolicParam} end +function Base.copy(ic::IndexCache) + IndexCache(copy(ic.unknown_idx), copy(ic.discrete_idx), copy(ic.callback_to_clocks), + copy(ic.tunable_idx), copy(ic.initials_idx), copy(ic.constant_idx), + copy(ic.nonnumeric_idx), copy(ic.observed_syms_to_timeseries), + copy(ic.dependent_pars_to_timeseries), copy(ic.discrete_buffer_sizes), + ic.tunable_buffer_size, ic.initials_buffer_size, + copy(ic.constant_buffer_sizes), copy(ic.nonnumeric_buffer_sizes), + copy(ic.symbol_to_variable)) +end + function IndexCache(sys::AbstractSystem) unks = unknowns(sys) unk_idxs = UnknownIndexMap() @@ -716,3 +726,55 @@ function subset_unknowns_observed( @set! ic.observed_syms_to_timeseries = observed_syms_to_timeseries return ic end + +function with_additional_constant_parameter(sys::AbstractSystem, par) + par = unwrap(par) + ps = copy(get_ps(sys)) + push!(ps, par) + @set! sys.ps = ps + is_split(sys) || return sys + + ic = copy(get_index_cache(sys)) + T = symtype(par) + bufidx = findfirst(buft -> buft.type == T, ic.constant_buffer_sizes) + if bufidx === nothing + push!(ic.constant_buffer_sizes, BufferTemplate(T, 1)) + bufidx = length(ic.constant_buffer_sizes) + idx_in_buf = 1 + else + buft = ic.constant_buffer_sizes[bufidx] + ic.constant_buffer_sizes[bufidx] = BufferTemplate(T, buft.length + 1) + idx_in_buf = buft.length + 1 + end + + ic.constant_idx[par] = ic.constant_idx[renamespace(sys, par)] = (bufidx, idx_in_buf) + @set! sys.index_cache = ic + + return sys +end + +function with_additional_nonnumeric_parameter(sys::AbstractSystem, par) + par = unwrap(par) + ps = copy(get_ps(sys)) + push!(ps, par) + @set! sys.ps = ps + is_split(sys) || return sys + + ic = copy(get_index_cache(sys)) + T = symtype(par) + bufidx = findfirst(buft -> buft.type == T, ic.nonnumeric_buffer_sizes) + if bufidx === nothing + push!(ic.nonnumeric_buffer_sizes, BufferTemplate(T, 1)) + bufidx = length(ic.nonnumeric_buffer_sizes) + idx_in_buf = 1 + else + buft = ic.nonnumeric_buffer_sizes[bufidx] + ic.nonnumeric_buffer_sizes[bufidx] = BufferTemplate(T, buft.length + 1) + idx_in_buf = buft.length + 1 + end + + ic.nonnumeric_idx[par] = ic.nonnumeric_idx[renamespace(sys, par)] = (bufidx, idx_in_buf) + @set! sys.index_cache = ic + + return sys +end diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index ecba542fd0..57a02dba0a 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -245,6 +245,13 @@ function recursive_unwrap(x::AbstractArray) symbolic_type(x) == ArraySymbolic() ? unwrap(x) : recursive_unwrap.(x) end +function recursive_unwrap(x::SparseMatrixCSC) + I, J, V = findnz(x) + V = recursive_unwrap(V) + m, n = size(x) + return sparse(I, J, V, m, n) +end + recursive_unwrap(x) = unwrap(x) function recursive_unwrap(x::AbstractDict) diff --git a/src/utils.jl b/src/utils.jl index e3e5e4de9f..89b269b13c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -387,6 +387,7 @@ function vars(exprs; op = Differential) vars!(Set(), unwrap(exprs); op) end end +vars(exprs::SparseMatrixCSC; op = Differential) = vars(nonzeros(exprs); op) vars(eq::Equation; op = Differential) = vars!(Set(), eq; op = op) function vars!(vars, eq::Equation; op = Differential) (vars!(vars, eq.lhs; op = op); vars!(vars, eq.rhs; op = op); vars) diff --git a/test/runtests.jl b/test/runtests.jl index 47230c9539..988961f4cc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -100,6 +100,7 @@ end @safetestset "Subsystem replacement" include("substitute_component.jl") @safetestset "Linearization Tests" include("linearize.jl") @safetestset "LinearProblem Tests" include("linearproblem.jl") + @safetestset "SemilinearODEProblem tests" include("semilinearodeproblem.jl") end end diff --git a/test/semilinearodeproblem.jl b/test/semilinearodeproblem.jl new file mode 100644 index 0000000000..024d8d4e58 --- /dev/null +++ b/test/semilinearodeproblem.jl @@ -0,0 +1,211 @@ +using ModelingToolkit +using OrdinaryDiffEq +using LinearAlgebra +using Test +using ModelingToolkit: t_nounits as t, D_nounits as D + +# from https://docs.sciml.ai/SciMLBenchmarksOutput/dev/AstroChem/nelson/ +@testset "Astrochem model" begin + function Nelson!(du, u, p, t) + T, Av, Go, n_H, shield = p + + # 1: H2 + du[1] = -1.2e-17 * u[1] + + n_H * (1.9e-6 * u[2] * u[3]) / (T^0.54) - + n_H * 4e-16 * u[1] * u[12] - + n_H * 7e-15 * u[1] * u[5] + + n_H * 1.7e-9 * u[10] * u[2] + + n_H * 2e-9 * u[2] * u[6] + + n_H * 2e-9 * u[2] * u[14] + + n_H * 8e-10 * u[2] * u[8] + sin(u[1]) / 1e20 # artificial nonlinear term for testing + + # 2: H3+ + du[2] = 1.2e-17 * u[1] + + n_H * (-1.9e-6 * u[3] * u[2]) / (T^0.54) - + n_H * 1.7e-9 * u[10] * u[2] - + n_H * 2e-9 * u[2] * u[6] - + n_H * 2e-9 * u[2] * u[14] - + n_H * 8e-10 * u[2] * u[8] + + # 3: e + du[3] = n_H * (-1.4e-10 * u[3] * u[12]) / (T^0.61) - + n_H * (3.8e-10 * u[13] * u[3]) / (T^0.65) - + n_H * (3.3e-5 * u[11] * u[3]) / T + + 1.2e-17 * u[1] - + n_H * (1.9e-6 * u[3] * u[2]) / (T^0.54) + + 6.8e-18 * u[4] - + n_H * (9e-11 * u[3] * u[5]) / (T^0.64) + + 3e-10 * Go * exp(-3 * Av) * u[6] + + n_H * 2e-9 * u[2] * u[13] + +2.0e-10 * Go * exp(-1.9 * Av) * u[14] + + # 4: He + du[4] = n_H * (9e-11 * u[3] * u[5]) / (T^0.64) - + 6.8e-18 * u[4] + + n_H * 7e-15 * u[1] * u[5] + + n_H * 1.6e-9 * u[10] * u[5] + + # 5: He+ + du[5] = 6.8e-18 * u[4] - + n_H * (9e-11 * u[3] * u[5]) / (T^0.64) - + n_H * 7e-15 * u[1] * u[5] - + n_H * 1.6e-9 * u[10] * u[5] + + # 6: C + du[6] = n_H * (1.4e-10 * u[3] * u[12]) / (T^0.61) - + n_H * 2e-9 * u[2] * u[6] - + n_H * 5.8e-12 * (T^0.5) * u[9] * u[6] + + 1e-9 * Go * exp(-1.5 * Av) * u[7] - + 3e-10 * Go * exp(-3 * Av) * u[6] + + 1e-10 * Go * exp(-3 * Av) * u[10] * shield + + # 7: CHx + du[7] = n_H * (-2e-10) * u[7] * u[8] + + n_H * 4e-16 * u[1] * u[12] + + n_H * 2e-9 * u[2] * u[6] - + 1e-9 * Go * u[7] * exp(-1.5 * Av) + + # 8: O + du[8] = n_H * (-2e-10) * u[7] * u[8] + + n_H * 1.6e-9 * u[10] * u[5] - + n_H * 8e-10 * u[2] * u[8] + + 5e-10 * Go * exp(-1.7 * Av) * u[9] + + 1e-10 * Go * exp(-3 * Av) * u[10] * shield + + # 9: OHx + du[9] = n_H * (-1e-9) * u[9] * u[12] + + n_H * 8e-10 * u[2] * u[8] - + n_H * 5.8e-12 * (T^0.5) * u[9] * u[6] - + 5e-10 * Go * exp(-1.7 * Av) * u[9] + + # 10: CO + du[10] = n_H * (3.3e-5 * u[11] * u[3]) / T + + n_H * 2e-10 * u[7] * u[8] - + n_H * 1.7e-9 * u[10] * u[2] - + n_H * 1.6e-9 * u[10] * u[5] + + n_H * 5.8e-12 * (T^0.5) * u[9] * u[6] - + 1e-10 * Go * exp(-3 * Av) * u[10] + + 1.5e-10 * Go * exp(-2.5 * Av) * u[11] * shield + + # 11: HCO+ + du[11] = n_H * (-3.3e-5 * u[11] * u[3]) / T + + n_H * 1e-9 * u[9] * u[12] + + n_H * 1.7e-9 * u[10] * u[2] - + 1.5e-10 * Go * exp(-2.5 * Av) * u[11] + + # 12: C+ + du[12] = n_H * (-1.4e-10 * u[3] * u[12]) / (T^0.61) - + n_H * 4e-16 * u[1] * u[12] - + n_H * 1e-9 * u[9] * u[12] + + n_H * 1.6e-9 * u[10] * u[5] + + 3e-10 * Go * exp(-3 * Av) * u[6] + + # 13: M+ + du[13] = n_H * (-3.8e-10 * u[13] * u[3]) / (T^0.65) + + n_H * 2e-9 * u[2] * u[14] + + 2.0e-10 * Go * exp(-1.9 * Av) * u[14] + + # 14: M + du[14] = n_H * (3.8e-10 * u[13] * u[3]) / (T^0.65) - + n_H * 2e-9 * u[2] * u[14] - + 2.0e-10 * Go * exp(-1.9 * Av) * u[14] + end + + # Set the Timespan, Parameters, and Initial Conditions + seconds_per_year = 3600 * 24 * 365 + tspan = (0.0, 30000 * seconds_per_year) # ~30 thousand yrs + + params = (10, # T + 2, # Av + 1.7, # Go + 611, # n_H + 1) # shield + + u0 = [0.5, # 1: H2 + 9.059e-9, # 2: H3+ + 2.0e-4, # 3: e + 0.1, # 4: He + 7.866e-7, # 5: He+ + 0.0, # 6: C + 0.0, # 7: CHx + 0.0004, # 8: O + 0.0, # 9: OHx + 0.0, # 10: CO + 0.0, # 11: HCO+ + 0.0002, # 12: C+ + 2.0e-7, # 13: M+ + 2.0e-7] # 14: M + + prob = ODEProblem(Nelson!, u0, tspan, params) + sys = mtkcompile(modelingtoolkitize(prob)) + A, B, C = ModelingToolkit.calculate_semiquadratic_form(sys) + @test A !== nothing + @test B !== nothing + @test C !== nothing + x = unknowns(sys) + linear_expr = A * x + linear_fun, = generate_custom_function(sys, linear_expr; expression = Val{false}) + quadratic_expr = reduce(vcat, [x' * _B * x for _B in B if _B !== nothing]) + quadratic_fun, = generate_custom_function(sys, quadratic_expr; expression = Val{false}) + nonlinear_expr = C + nonlinear_fun, = generate_custom_function(sys, nonlinear_expr; expression = Val{false}) + prob = ODEProblem(sys, nothing, tspan) + linear_val = linear_fun(prob.u0, prob.p, 0.0) + quadratic_val = quadratic_fun(prob.u0, prob.p, 0.0) + nonlinear_val = nonlinear_fun(prob.u0, prob.p, 0.0) + refsol = solve(prob, Vern9(); abstol = 1e-14, reltol = 1e-14) + + @testset "stiff_A: $stiff_A, stiff_B: $stiff_B, stiff_C: $stiff_C" for (stiff_A, stiff_B, stiff_C) in Iterators.product( + [false, true], [false, true], [false, true]) + kwargs = (; stiff_A, stiff_B, stiff_C) + if stiff_A == stiff_B == stiff_C + if stiff_A + @test_throws ["All of", "cannot be stiff"] SemilinearODEProblem( + sys, nothing, tspan; kwargs...) + @test_throws ["All of", "cannot be stiff"] SemilinearODEFunction( + sys; kwargs...) + else + @test_throws ["All of", "cannot be non-stiff"] SemilinearODEProblem( + sys, nothing, tspan; kwargs...) + @test_throws ["All of", "cannot be non-stiff"] SemilinearODEFunction( + sys; kwargs...) + end + continue + end + + reference_f1 = zeros(length(u0)) + reference_f2 = zeros(length(u0)) + mul!(stiff_A ? reference_f1 : reference_f2, I, linear_val, true, true) + mul!(stiff_B ? reference_f1 : reference_f2, I, quadratic_val, true, true) + mul!(stiff_C ? reference_f1 : reference_f2, I, nonlinear_val, true, true) + + @testset "Standard" begin + prob = SemilinearODEProblem(sys, nothing, tspan; kwargs...) + @test prob.f.f1(prob.u0, prob.p, 0.0)≈reference_f1 atol=1e-10 + @test prob.f.f2(prob.u0, prob.p, 0.0)≈reference_f2 atol=1e-10 + sol = solve(prob, KenCarp47()) + @test SciMLBase.successful_retcode(sol) + @test refsol(sol.t).u≈sol.u atol=1e-8 rtol=1e-8 + end + + @testset "Symbolic jacobian" begin + prob = SemilinearODEProblem(sys, nothing, tspan; jac = true, kwargs...) + @test prob.f.f1.jac !== nothing + sol = solve(prob, KenCarp47()) + @test SciMLBase.successful_retcode(sol) + @test refsol(sol.t).u≈sol.u atol=1e-8 rtol=1e-8 + end + + @testset "Sparse" begin + prob = SemilinearODEProblem(sys, nothing, tspan; sparse = true, kwargs...) + sol = solve(prob, KenCarp47()) + @test SciMLBase.successful_retcode(sol) + @test refsol(sol.t).u≈sol.u atol=1e-8 rtol=1e-8 + end + + @testset "Sparsejac" begin + @test_throws ["not implemented"] SemilinearODEProblem( + sys, nothing, tspan; jac = true, sparse = true, kwargs...) + end + end +end