From b1d459265fbe956495199f177e3c80e66fb12980 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 21 Aug 2025 18:17:21 +0530 Subject: [PATCH 1/6] refactor: add `is_variable_numeric`, `is_numeric_symtype` --- src/utils.jl | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/utils.jl b/src/utils.jl index c6f3777c2b..11c6436bf3 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -838,6 +838,27 @@ function is_floatingpoint_symtype(T::Type) T <: AbstractArray && is_floatingpoint_symtype(eltype(T)) end +""" + $(TYPEDSIGNATURES) + +Check if `sym` represents a symbolic number or array of numbers. +""" +function is_variable_numeric(sym) + sym = unwrap(sym) + T = symtype(sym) + is_numeric_symtype(T) +end + +""" + $(TYPEDSIGNATURES) + +Check if `T` is an appropriate symtype for a symbolic variable representing a number or +array of numbers. +""" +function is_numeric_symtype(T::Type) + return T <: Number || T <: AbstractArray && is_numeric_symtype(eltype(T)) +end + """ $(TYPEDSIGNATURES) From 5cb05a5204657d22b1995a5b08b43e63937447b3 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 21 Aug 2025 20:14:14 +0530 Subject: [PATCH 2/6] feat: add `Symbolics.fast_substitute` for affects --- src/systems/callbacks.jl | 19 +++++++++++++++++++ src/systems/imperative_affect.jl | 6 ++++++ 2 files changed, 25 insertions(+) diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index f344f33d39..2d5ba6340d 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -25,6 +25,12 @@ function SymbolicAffect(affect::SymbolicAffect; kwargs...) end SymbolicAffect(affect; kwargs...) = make_affect(affect; kwargs...) +function Symbolics.fast_substitute(aff::SymbolicAffect, rules) + substituter = Base.Fix2(fast_substitute, rules) + SymbolicAffect(map(substituter, aff.affect), map(substituter, aff.alg_eqs), + map(substituter, aff.discrete_parameters)) +end + struct AffectSystem """The internal implicit discrete system whose equations are solved to obtain values after the affect.""" system::AbstractSystem @@ -36,6 +42,19 @@ struct AffectSystem discretes::Vector end +function Symbolics.fast_substitute(aff::AffectSystem, rules) + substituter = Base.Fix2(fast_substitute, rules) + sys = aff.system + @set! sys.eqs = map(substituter, get_eqs(sys)) + @set! sys.parameter_dependencies = map(substituter, get_parameter_dependencies(sys)) + @set! sys.defaults = Dict([k => substituter(v) for (k, v) in defaults(sys)]) + @set! sys.guesses = Dict([k => substituter(v) for (k, v) in guesses(sys)]) + @set! sys.unknowns = map(substituter, get_unknowns(sys)) + @set! sys.ps = map(substituter, get_ps(sys)) + AffectSystem(sys, map(substituter, aff.unknowns), + map(substituter, aff.parameters), map(substituter, aff.discretes)) +end + function AffectSystem(spec::SymbolicAffect; iv = nothing, alg_eqs = Equation[], kwargs...) AffectSystem(spec.affect; alg_eqs = vcat(spec.alg_eqs, alg_eqs), iv, discrete_parameters = spec.discrete_parameters, kwargs...) diff --git a/src/systems/imperative_affect.jl b/src/systems/imperative_affect.jl index f3d45e258a..1c43022f4b 100644 --- a/src/systems/imperative_affect.jl +++ b/src/systems/imperative_affect.jl @@ -67,6 +67,12 @@ function ImperativeAffect(; f, kwargs...) ImperativeAffect(f; kwargs...) end +function Symbolics.fast_substitute(aff::ImperativeAffect, rules) + substituter = Base.Fix2(fast_substitute, rules) + ImperativeAffect(aff.f, map(substituter, aff.obs), aff.obs_syms, + map(substituter, aff.modified), aff.mod_syms, aff.ctx, aff.skip_checks) +end + function Base.show(io::IO, mfa::ImperativeAffect) obs_vals = join(map((ob, nm) -> "$ob => $nm", mfa.obs, mfa.obs_syms), ", ") mod_vals = join(map((md, nm) -> "$md => $nm", mfa.modified, mfa.mod_syms), ", ") From 45afbc873bf200d37ecd120557d300ec8f5856d0 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 21 Aug 2025 20:14:27 +0530 Subject: [PATCH 3/6] fix: handle edge case in `float_type_from_varmap` --- src/systems/problem_utils.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index efcee3283e..2a1208586b 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -1219,6 +1219,7 @@ with a constant value. """ function float_type_from_varmap(varmap, floatT = Bool) for (k, v) in varmap + is_variable_floatingpoint(k) || continue symbolic_type(v) == NotSymbolic() || continue is_array_of_symbolics(v) && continue From c2c603a3a6204900f997f557fba538768d6b7815 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 21 Aug 2025 18:17:36 +0530 Subject: [PATCH 4/6] feat: add `respecialize` --- docs/src/API/model_building.md | 1 + src/ModelingToolkit.jl | 3 +- src/systems/diffeqs/basic_transformations.jl | 140 +++++++++++++++++++ 3 files changed, 143 insertions(+), 1 deletion(-) diff --git a/docs/src/API/model_building.md b/docs/src/API/model_building.md index c9d7c2f249..64ea81786f 100644 --- a/docs/src/API/model_building.md +++ b/docs/src/API/model_building.md @@ -227,6 +227,7 @@ add_accumulations noise_to_brownians convert_system_indepvar subset_tunables +respecialize ``` ## Hybrid systems diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index bb2ba0285d..8b3c4084c7 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -268,7 +268,8 @@ export isinput, isoutput, getbounds, hasbounds, getguess, hasguess, isdisturbanc hasmisc, getmisc, state_priority, subset_tunables export liouville_transform, change_independent_variable, substitute_component, - add_accumulations, noise_to_brownians, Girsanov_transform, change_of_variables + add_accumulations, noise_to_brownians, Girsanov_transform, change_of_variables, + respecialize export PDESystem export Differential, expand_derivatives, @derivatives export Equation, ConstrainedEquation diff --git a/src/systems/diffeqs/basic_transformations.jl b/src/systems/diffeqs/basic_transformations.jl index b8268e884e..16b06562a1 100644 --- a/src/systems/diffeqs/basic_transformations.jl +++ b/src/systems/diffeqs/basic_transformations.jl @@ -706,3 +706,143 @@ function convert_system_indepvar(sys::System, t; name = nameof(sys)) @set! sys.var_to_name = var_to_name return sys end + +""" + $(TYPEDSIGNATURES) + +Shorthand for `respecialize(sys, []; all = true)` +""" +respecialize(sys::AbstractSystem) = respecialize(sys, []; all = true) + +""" + $(TYPEDSIGNATURES) + +Specialize nonnumeric parameters in `sys` by changing their symtype to a concrete type. +`mapping` is an iterable, where each element can be a parameter or a pair mapping a parameter +to a value. If the element is a parameter, it must have a default. Each specified parameter +is updated to have the symtype of the value associated with it (either in `mapping` or in +the defaults). This operation can only be performed on nonnumeric, non-array parameters. The +defaults of respecialized parameters are set to the associated values. + +This operation can only be performed on `complete`d systems. + +# Keyword arguments + +- `all`: Specialize all nonnumeric parameters in the system. This will error if any such + parameter does not have a default. +""" +function respecialize(sys::AbstractSystem, mapping; all = false) + if !iscomplete(sys) + error(""" + This operation can only be performed on completed systems. Use `complete(sys)` or + `mtkcompile(sys)`. + """) + end + if !is_split(sys) + error(""" + This operation can only be performed on split systems. Use `complete(sys)` or + `mtkcompile(sys)` with the `split = true` keyword argument. + """) + end + + new_ps = copy(get_ps(sys)) + @set! sys.ps = new_ps + + extras = [] + if all + for x in filter(!is_variable_numeric, get_ps(sys)) + if any(y -> isequal(x, y) || y isa Pair && isequal(x, y[1]), mapping) || + symbolic_type(x) === ArraySymbolic() || + iscall(x) && operation(x) === getindex + continue + end + push!(extras, x) + end + end + ps_to_specialize = Iterators.flatten((extras, mapping)) + + defs = copy(defaults(sys)) + @set! sys.defaults = defs + final_defs = copy(defs) + evaluate_varmap!(final_defs, ps_to_specialize) + + subrules = Dict() + + for element in ps_to_specialize + if element isa Pair + k, v = element + else + k = element + v = get(final_defs, k, nothing) + @assert v !== nothing """ + Parameter $k needs an associated value to be respecialized. + """ + @assert symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v) """ + Parameter $k needs an associated value to be respecialized. Found symbolic \ + default $v. + """ + end + + k = unwrap(k) + T = typeof(v) + + @assert !is_variable_numeric(k) """ + Numeric types cannot be respecialized - tried to respecialize $k. + """ + @assert symbolic_type(k) !== ArraySymbolic() """ + Cannot respecialize array symbolics - tried to respecialize $k. + """ + @assert !iscall(k) || operation(k) !== getindex """ + Cannot respecialized scalarized array variables - tried to respecialize $k. + """ + idx = findfirst(isequal(k), get_ps(sys)) + @assert idx !== nothing """ + Parameter $k does not exist in the system. + """ + + if iscall(k) + op = operation(k) + args = arguments(k) + new_p = SymbolicUtils.term(op, args...; type = T) + else + new_p = SymbolicUtils.Sym{T}(getname(k)) + end + + get_ps(sys)[idx] = new_p + defaults(sys)[new_p] = v + subrules[unwrap(k)] = unwrap(new_p) + end + + substituter = Base.Fix2(fast_substitute, subrules) + @set! sys.eqs = map(substituter, get_eqs(sys)) + @set! sys.observed = map(substituter, get_observed(sys)) + @set! sys.initialization_eqs = map(substituter, get_initialization_eqs(sys)) + if get_noise_eqs(sys) !== nothing + @set! sys.noise_eqs = map(substituter, get_noise_eqs(sys)) + end + @set! sys.assertions = Dict([substituter(k) => v for (k, v) in assertions(sys)]) + @set! sys.parameter_dependencies = map(substituter, get_parameter_dependencies(sys)) + @set! sys.defaults = Dict([substituter(k) => substituter(v) for (k, v) in defaults(sys)]) + @set! sys.guesses = Dict([k => substituter(v) for (k, v) in guesses(sys)]) + @set! sys.continuous_events = map(get_continuous_events(sys)) do cev + SymbolicContinuousCallback( + map(substituter, cev.conditions), substituter(cev.affect), + substituter(cev.affect_neg), substituter(cev.initialize), + substituter(cev.finalize), cev.rootfind, + cev.reinitializealg, cev.zero_crossing_id) + end + @set! sys.discrete_events = map(get_discrete_events(sys)) do dev + SymbolicDiscreteCallback(map(substituter, dev.conditions), substituter(dev.affect), + substituter(dev.initialize), substituter(dev.finalize), dev.reinitializealg) + end + if get_schedule(sys) !== nothing + sched = get_schedule(sys) + @set! sys.schedule = Schedule( + sched.var_sccs, AnyDict(k => substituter(v) for (k, v) in sched.dummy_sub)) + end + @set! sys.constraints = map(substituter, get_constraints(sys)) + @set! sys.tstops = map(substituter, get_tstops(sys)) + @set! sys.costs = Vector{Union{Real, BasicSymbolic}}(map(substituter, get_costs(sys))) + sys = complete(sys; split = is_split(sys)) + return sys +end From 89324337da508bad62c87f109780c9c93562cb6e Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 21 Aug 2025 18:17:43 +0530 Subject: [PATCH 5/6] test: test `respecialize` --- test/basic_transformations.jl | 66 +++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/test/basic_transformations.jl b/test/basic_transformations.jl index bafb5cf9e2..dc9d71f300 100644 --- a/test/basic_transformations.jl +++ b/test/basic_transformations.jl @@ -1,5 +1,6 @@ using ModelingToolkit, OrdinaryDiffEq, DataInterpolations, DynamicQuantities, Test using ModelingToolkitStandardLibrary.Blocks: RealInput, RealOutput +using SymbolicUtils: symtype @independent_variables t D = Differential(t) @@ -328,3 +329,68 @@ end D(x) ~ y] @test issetequal(equations(asys), eqs) end + +abstract type AbstractFoo end + +struct Bar <: AbstractFoo end +struct Baz <: AbstractFoo end + +foofn(x) = 4 +@register_symbolic foofn(x::AbstractFoo) + +@testset "`respecialize`" begin + @parameters p::AbstractFoo p2(t)::AbstractFoo = p q[1:2]::AbstractFoo r + rp, + rp2 = let + only(@parameters p::Bar), + SymbolicUtils.term(operation(p2), arguments(p2)...; type = Baz) + end + @variables x(t) = 1.0 + @named sys1 = System([D(x) ~ foofn(p) + foofn(p2) + x], t, [x], [p, p2, q, r]) + + @test_throws ["completed systems"] respecialize(sys1) + @test_throws ["completed systems"] respecialize(sys1, []) + @test_throws ["split systems"] respecialize(complete(sys1; split = false)) + @test_throws ["split systems"] respecialize(complete(sys1; split = false), []) + + sys = complete(sys1) + + @test_throws ["Parameter p", "associated value"] respecialize(sys) + @test_throws ["Parameter p", "associated value"] respecialize(sys, [p]) + + @test_throws ["Parameter p2", "symbolic default"] respecialize(sys, [p2]) + + sys2 = respecialize(sys, [p => Bar()]) + @test ModelingToolkit.iscomplete(sys2) + @test ModelingToolkit.is_split(sys2) + ps = ModelingToolkit.get_ps(sys2) + idx = findfirst(isequal(rp), ps) + @test defaults(sys2)[rp] == Bar() + @test symtype(ps[idx]) <: Bar + ic = ModelingToolkit.get_index_cache(sys2) + @test any(x -> x.type == Bar && x.length == 1, ic.nonnumeric_buffer_sizes) + prob = ODEProblem(sys2, [p2 => Bar(), q => [Bar(), Bar()], r => 1], (0.0, 1.0)) + @test any(x -> x isa Vector{Bar} && length(x) == 1, prob.p.nonnumeric) + + defaults(sys)[p2] = Baz() + sys2 = respecialize(sys, [p => Bar()]; all = true) + @test ModelingToolkit.iscomplete(sys2) + @test ModelingToolkit.is_split(sys2) + ps = ModelingToolkit.get_ps(sys2) + idx = findfirst(isequal(rp2), ps) + @test defaults(sys2)[rp2] == Baz() + @test symtype(ps[idx]) <: Baz + ic = ModelingToolkit.get_index_cache(sys2) + @test any(x -> x.type == Baz && x.length == 1, ic.nonnumeric_buffer_sizes) + delete!(defaults(sys), p2) + prob = ODEProblem(sys2, [q => [Bar(), Bar()], r => 1], (0.0, 1.0)) + @test any(x -> x isa Vector{Bar} && length(x) == 1, prob.p.nonnumeric) + @test any(x -> x isa Vector{Baz} && length(x) == 1, prob.p.nonnumeric) + + @test_throws ["Numeric types cannot be respecialized"] respecialize(sys, [r => 1]) + @test_throws ["array symbolics"] respecialize(sys, [q => Bar[Bar(), Bar()]]) + @test_throws ["scalarized array"] respecialize(sys, [q[1] => Bar()]) + + @parameters foo::AbstractFoo + @test_throws ["does not exist"] respecialize(sys, [foo => Bar()]) +end From 9ed9aa758c759c8d49155679a915e8db0a88fcf1 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 25 Aug 2025 10:12:18 +0530 Subject: [PATCH 6/6] test: increase timeout for test --- test/odesystem.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/odesystem.jl b/test/odesystem.jl index 3b9693241c..c7247fb120 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -1568,7 +1568,7 @@ end cmd = `$(Base.julia_cmd()) --project=$(@__DIR__) -e $code` proc = run(cmd, stdin, stdout, stderr; wait = false) - sleep(120) + sleep(180) @test !process_running(proc) kill(proc, Base.SIGKILL) end