Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/API/model_building.md
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ add_accumulations
noise_to_brownians
convert_system_indepvar
subset_tunables
respecialize
```

## Hybrid systems
Expand Down
3 changes: 2 additions & 1 deletion src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,8 @@ export isinput, isoutput, getbounds, hasbounds, getguess, hasguess, isdisturbanc
hasmisc, getmisc, state_priority,
subset_tunables
export liouville_transform, change_independent_variable, substitute_component,
add_accumulations, noise_to_brownians, Girsanov_transform, change_of_variables
add_accumulations, noise_to_brownians, Girsanov_transform, change_of_variables,
respecialize
export PDESystem
export Differential, expand_derivatives, @derivatives
export Equation, ConstrainedEquation
Expand Down
19 changes: 19 additions & 0 deletions src/systems/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ function SymbolicAffect(affect::SymbolicAffect; kwargs...)
end
SymbolicAffect(affect; kwargs...) = make_affect(affect; kwargs...)

function Symbolics.fast_substitute(aff::SymbolicAffect, rules)
substituter = Base.Fix2(fast_substitute, rules)
SymbolicAffect(map(substituter, aff.affect), map(substituter, aff.alg_eqs),
map(substituter, aff.discrete_parameters))
end

struct AffectSystem
"""The internal implicit discrete system whose equations are solved to obtain values after the affect."""
system::AbstractSystem
Expand All @@ -36,6 +42,19 @@ struct AffectSystem
discretes::Vector
end

function Symbolics.fast_substitute(aff::AffectSystem, rules)
substituter = Base.Fix2(fast_substitute, rules)
sys = aff.system
@set! sys.eqs = map(substituter, get_eqs(sys))
@set! sys.parameter_dependencies = map(substituter, get_parameter_dependencies(sys))
@set! sys.defaults = Dict([k => substituter(v) for (k, v) in defaults(sys)])
@set! sys.guesses = Dict([k => substituter(v) for (k, v) in guesses(sys)])
@set! sys.unknowns = map(substituter, get_unknowns(sys))
@set! sys.ps = map(substituter, get_ps(sys))
AffectSystem(sys, map(substituter, aff.unknowns),
map(substituter, aff.parameters), map(substituter, aff.discretes))
end

function AffectSystem(spec::SymbolicAffect; iv = nothing, alg_eqs = Equation[], kwargs...)
AffectSystem(spec.affect; alg_eqs = vcat(spec.alg_eqs, alg_eqs), iv,
discrete_parameters = spec.discrete_parameters, kwargs...)
Expand Down
140 changes: 140 additions & 0 deletions src/systems/diffeqs/basic_transformations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -706,3 +706,143 @@ function convert_system_indepvar(sys::System, t; name = nameof(sys))
@set! sys.var_to_name = var_to_name
return sys
end

"""
$(TYPEDSIGNATURES)

Shorthand for `respecialize(sys, []; all = true)`
"""
respecialize(sys::AbstractSystem) = respecialize(sys, []; all = true)

"""
$(TYPEDSIGNATURES)

Specialize nonnumeric parameters in `sys` by changing their symtype to a concrete type.
`mapping` is an iterable, where each element can be a parameter or a pair mapping a parameter
to a value. If the element is a parameter, it must have a default. Each specified parameter
is updated to have the symtype of the value associated with it (either in `mapping` or in
the defaults). This operation can only be performed on nonnumeric, non-array parameters. The
defaults of respecialized parameters are set to the associated values.

This operation can only be performed on `complete`d systems.

# Keyword arguments

- `all`: Specialize all nonnumeric parameters in the system. This will error if any such
parameter does not have a default.
"""
function respecialize(sys::AbstractSystem, mapping; all = false)
if !iscomplete(sys)
error("""
This operation can only be performed on completed systems. Use `complete(sys)` or
`mtkcompile(sys)`.
""")
end
if !is_split(sys)
error("""
This operation can only be performed on split systems. Use `complete(sys)` or
`mtkcompile(sys)` with the `split = true` keyword argument.
""")
end

new_ps = copy(get_ps(sys))
@set! sys.ps = new_ps

extras = []
if all
for x in filter(!is_variable_numeric, get_ps(sys))
if any(y -> isequal(x, y) || y isa Pair && isequal(x, y[1]), mapping) ||
symbolic_type(x) === ArraySymbolic() ||
iscall(x) && operation(x) === getindex
continue
end
push!(extras, x)
end
end
ps_to_specialize = Iterators.flatten((extras, mapping))

defs = copy(defaults(sys))
@set! sys.defaults = defs
final_defs = copy(defs)
evaluate_varmap!(final_defs, ps_to_specialize)

subrules = Dict()

for element in ps_to_specialize
if element isa Pair
k, v = element
else
k = element
v = get(final_defs, k, nothing)
@assert v !== nothing """
Parameter $k needs an associated value to be respecialized.
"""
@assert symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v) """
Parameter $k needs an associated value to be respecialized. Found symbolic \
default $v.
"""
end

k = unwrap(k)
T = typeof(v)

@assert !is_variable_numeric(k) """
Numeric types cannot be respecialized - tried to respecialize $k.
"""
@assert symbolic_type(k) !== ArraySymbolic() """
Cannot respecialize array symbolics - tried to respecialize $k.
"""
@assert !iscall(k) || operation(k) !== getindex """
Cannot respecialized scalarized array variables - tried to respecialize $k.
"""
idx = findfirst(isequal(k), get_ps(sys))
@assert idx !== nothing """
Parameter $k does not exist in the system.
"""

if iscall(k)
op = operation(k)
args = arguments(k)
new_p = SymbolicUtils.term(op, args...; type = T)
else
new_p = SymbolicUtils.Sym{T}(getname(k))
end

get_ps(sys)[idx] = new_p
defaults(sys)[new_p] = v
subrules[unwrap(k)] = unwrap(new_p)
end

substituter = Base.Fix2(fast_substitute, subrules)
@set! sys.eqs = map(substituter, get_eqs(sys))
@set! sys.observed = map(substituter, get_observed(sys))
@set! sys.initialization_eqs = map(substituter, get_initialization_eqs(sys))
if get_noise_eqs(sys) !== nothing
@set! sys.noise_eqs = map(substituter, get_noise_eqs(sys))
end
@set! sys.assertions = Dict([substituter(k) => v for (k, v) in assertions(sys)])
@set! sys.parameter_dependencies = map(substituter, get_parameter_dependencies(sys))
@set! sys.defaults = Dict([substituter(k) => substituter(v) for (k, v) in defaults(sys)])
@set! sys.guesses = Dict([k => substituter(v) for (k, v) in guesses(sys)])
@set! sys.continuous_events = map(get_continuous_events(sys)) do cev
SymbolicContinuousCallback(
map(substituter, cev.conditions), substituter(cev.affect),
substituter(cev.affect_neg), substituter(cev.initialize),
substituter(cev.finalize), cev.rootfind,
cev.reinitializealg, cev.zero_crossing_id)
end
@set! sys.discrete_events = map(get_discrete_events(sys)) do dev
SymbolicDiscreteCallback(map(substituter, dev.conditions), substituter(dev.affect),
substituter(dev.initialize), substituter(dev.finalize), dev.reinitializealg)
end
if get_schedule(sys) !== nothing
sched = get_schedule(sys)
@set! sys.schedule = Schedule(
sched.var_sccs, AnyDict(k => substituter(v) for (k, v) in sched.dummy_sub))
end
@set! sys.constraints = map(substituter, get_constraints(sys))
@set! sys.tstops = map(substituter, get_tstops(sys))
@set! sys.costs = Vector{Union{Real, BasicSymbolic}}(map(substituter, get_costs(sys)))
sys = complete(sys; split = is_split(sys))
return sys
end
6 changes: 6 additions & 0 deletions src/systems/imperative_affect.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ function ImperativeAffect(; f, kwargs...)
ImperativeAffect(f; kwargs...)
end

function Symbolics.fast_substitute(aff::ImperativeAffect, rules)
substituter = Base.Fix2(fast_substitute, rules)
ImperativeAffect(aff.f, map(substituter, aff.obs), aff.obs_syms,
map(substituter, aff.modified), aff.mod_syms, aff.ctx, aff.skip_checks)
end

function Base.show(io::IO, mfa::ImperativeAffect)
obs_vals = join(map((ob, nm) -> "$ob => $nm", mfa.obs, mfa.obs_syms), ", ")
mod_vals = join(map((md, nm) -> "$md => $nm", mfa.modified, mfa.mod_syms), ", ")
Expand Down
1 change: 1 addition & 0 deletions src/systems/problem_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1219,6 +1219,7 @@ with a constant value.
"""
function float_type_from_varmap(varmap, floatT = Bool)
for (k, v) in varmap
is_variable_floatingpoint(k) || continue
symbolic_type(v) == NotSymbolic() || continue
is_array_of_symbolics(v) && continue

Expand Down
21 changes: 21 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,27 @@ function is_floatingpoint_symtype(T::Type)
T <: AbstractArray && is_floatingpoint_symtype(eltype(T))
end

"""
$(TYPEDSIGNATURES)

Check if `sym` represents a symbolic number or array of numbers.
"""
function is_variable_numeric(sym)
sym = unwrap(sym)
T = symtype(sym)
is_numeric_symtype(T)
end

"""
$(TYPEDSIGNATURES)

Check if `T` is an appropriate symtype for a symbolic variable representing a number or
array of numbers.
"""
function is_numeric_symtype(T::Type)
return T <: Number || T <: AbstractArray && is_numeric_symtype(eltype(T))
end

"""
$(TYPEDSIGNATURES)

Expand Down
66 changes: 66 additions & 0 deletions test/basic_transformations.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using ModelingToolkit, OrdinaryDiffEq, DataInterpolations, DynamicQuantities, Test
using ModelingToolkitStandardLibrary.Blocks: RealInput, RealOutput
using SymbolicUtils: symtype

@independent_variables t
D = Differential(t)
Expand Down Expand Up @@ -328,3 +329,68 @@ end
D(x) ~ y]
@test issetequal(equations(asys), eqs)
end

abstract type AbstractFoo end

struct Bar <: AbstractFoo end
struct Baz <: AbstractFoo end

foofn(x) = 4
@register_symbolic foofn(x::AbstractFoo)

@testset "`respecialize`" begin
@parameters p::AbstractFoo p2(t)::AbstractFoo = p q[1:2]::AbstractFoo r
rp,
rp2 = let
only(@parameters p::Bar),
SymbolicUtils.term(operation(p2), arguments(p2)...; type = Baz)
end
@variables x(t) = 1.0
@named sys1 = System([D(x) ~ foofn(p) + foofn(p2) + x], t, [x], [p, p2, q, r])

@test_throws ["completed systems"] respecialize(sys1)
@test_throws ["completed systems"] respecialize(sys1, [])
@test_throws ["split systems"] respecialize(complete(sys1; split = false))
@test_throws ["split systems"] respecialize(complete(sys1; split = false), [])

sys = complete(sys1)

@test_throws ["Parameter p", "associated value"] respecialize(sys)
@test_throws ["Parameter p", "associated value"] respecialize(sys, [p])

@test_throws ["Parameter p2", "symbolic default"] respecialize(sys, [p2])

sys2 = respecialize(sys, [p => Bar()])
@test ModelingToolkit.iscomplete(sys2)
@test ModelingToolkit.is_split(sys2)
ps = ModelingToolkit.get_ps(sys2)
idx = findfirst(isequal(rp), ps)
@test defaults(sys2)[rp] == Bar()
@test symtype(ps[idx]) <: Bar
ic = ModelingToolkit.get_index_cache(sys2)
@test any(x -> x.type == Bar && x.length == 1, ic.nonnumeric_buffer_sizes)
prob = ODEProblem(sys2, [p2 => Bar(), q => [Bar(), Bar()], r => 1], (0.0, 1.0))
@test any(x -> x isa Vector{Bar} && length(x) == 1, prob.p.nonnumeric)

defaults(sys)[p2] = Baz()
sys2 = respecialize(sys, [p => Bar()]; all = true)
@test ModelingToolkit.iscomplete(sys2)
@test ModelingToolkit.is_split(sys2)
ps = ModelingToolkit.get_ps(sys2)
idx = findfirst(isequal(rp2), ps)
@test defaults(sys2)[rp2] == Baz()
@test symtype(ps[idx]) <: Baz
ic = ModelingToolkit.get_index_cache(sys2)
@test any(x -> x.type == Baz && x.length == 1, ic.nonnumeric_buffer_sizes)
delete!(defaults(sys), p2)
prob = ODEProblem(sys2, [q => [Bar(), Bar()], r => 1], (0.0, 1.0))
@test any(x -> x isa Vector{Bar} && length(x) == 1, prob.p.nonnumeric)
@test any(x -> x isa Vector{Baz} && length(x) == 1, prob.p.nonnumeric)

@test_throws ["Numeric types cannot be respecialized"] respecialize(sys, [r => 1])
@test_throws ["array symbolics"] respecialize(sys, [q => Bar[Bar(), Bar()]])
@test_throws ["scalarized array"] respecialize(sys, [q[1] => Bar()])

@parameters foo::AbstractFoo
@test_throws ["does not exist"] respecialize(sys, [foo => Bar()])
end
2 changes: 1 addition & 1 deletion test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1568,7 +1568,7 @@ end

cmd = `$(Base.julia_cmd()) --project=$(@__DIR__) -e $code`
proc = run(cmd, stdin, stdout, stderr; wait = false)
sleep(120)
sleep(180)
@test !process_running(proc)
kill(proc, Base.SIGKILL)
end
Expand Down
Loading