Skip to content

feat: implement isapprox for systems #3777

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,9 @@ function refreshed_metadata(meta::Base.ImmutableDict)
end
newmeta = Base.ImmutableDict(newmeta, k => v)
end
if !haskey(newmeta, MutableCacheKey)
newmeta = Base.ImmutableDict(newmeta, MutableCacheKey => MutableCacheT())
end
return newmeta
end

Expand Down
4 changes: 4 additions & 0 deletions src/systems/nonlinear/initializesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,7 @@ function SciMLBase.late_binding_update_u0_p(
end
newp = setp_oop(sys, syms)(newp, vals)
else
allsyms = nothing
# if `p` is not provided or is symbolic
p === missing || eltype(p) <: Pair || return newu0, newp
(newu0 === nothing || isempty(newu0)) && return newu0, newp
Expand All @@ -755,6 +756,9 @@ function SciMLBase.late_binding_update_u0_p(
if eltype(p) <: Pair
syms = []
vals = []
if allsyms === nothing
allsyms = all_symbols(sys)
end
for (k, v) in p
v === nothing && continue
(symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v)) || continue
Expand Down
51 changes: 50 additions & 1 deletion src/systems/system.jl
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = [];
end
metadata = meta
end
metadata = Base.ImmutableDict(metadata, MutableCacheKey => MutableCacheT())
metadata = refreshed_metadata(metadata)
System(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)), eqs, noise_eqs, jumps, constraints,
costs, consolidate, dvs, ps, brownians, iv, observed, Equation[],
var_to_name, name, description, defaults, guesses, systems, initialization_eqs,
Expand Down Expand Up @@ -1097,3 +1097,52 @@ function supports_initialization(sys::System)
return isempty(jumps(sys)) && _iszero(cost(sys)) &&
isempty(constraints(sys))
end

safe_eachrow(::Nothing) = nothing
safe_eachrow(x::AbstractArray) = eachrow(x)

safe_issetequal(::Nothing, ::Nothing) = true
safe_issetequal(::Nothing, x) = false
safe_issetequal(x, ::Nothing) = false
safe_issetequal(x, y) = issetequal(x, y)

"""
$(TYPEDSIGNATURES)

Check if two systems are about equal, to the extent that ModelingToolkit.jl supports. Note
that if this returns `true`, the systems are not guaranteed to be exactly equivalent
(unless `sysa === sysb`) but are highly likely to represent a similar mathematical problem.
If this returns `false`, the systems are very likely to be different.
"""
function Base.isapprox(sysa::System, sysb::System)
sysa === sysb && return true
return nameof(sysa) == nameof(sysb) &&
isequal(get_iv(sysa), get_iv(sysb)) &&
issetequal(get_eqs(sysa), get_eqs(sysb)) &&
safe_issetequal(
safe_eachrow(get_noise_eqs(sysa)), safe_eachrow(get_noise_eqs(sysb))) &&
issetequal(get_jumps(sysa), get_jumps(sysb)) &&
issetequal(get_constraints(sysa), get_constraints(sysb)) &&
issetequal(get_costs(sysa), get_costs(sysb)) &&
isequal(get_consolidate(sysa), get_consolidate(sysb)) &&
issetequal(get_unknowns(sysa), get_unknowns(sysb)) &&
issetequal(get_ps(sysa), get_ps(sysb)) &&
issetequal(get_brownians(sysa), get_brownians(sysb)) &&
issetequal(get_observed(sysa), get_observed(sysb)) &&
issetequal(get_parameter_dependencies(sysa), get_parameter_dependencies(sysb)) &&
isequal(get_description(sysa), get_description(sysb)) &&
isequal(get_defaults(sysa), get_defaults(sysb)) &&
isequal(get_guesses(sysa), get_guesses(sysb)) &&
issetequal(get_initialization_eqs(sysa), get_initialization_eqs(sysb)) &&
issetequal(get_continuous_events(sysa), get_continuous_events(sysb)) &&
issetequal(get_discrete_events(sysa), get_discrete_events(sysb)) &&
isequal(get_connector_type(sysa), get_connector_type(sysb)) &&
isequal(get_assertions(sysa), get_assertions(sysb)) &&
isequal(get_metadata(sysa), get_metadata(sysb)) &&
isequal(get_is_dde(sysa), get_is_dde(sysb)) &&
issetequal(get_tstops(sysa), get_tstops(sysb)) &&
safe_issetequal(get_ignored_connections(sysa), get_ignored_connections(sysb)) &&
isequal(get_is_initializesystem(sysa), get_is_initializesystem(sysb)) &&
isequal(get_is_discrete(sysa), get_is_discrete(sysb)) &&
isequal(get_isscheduled(sysa), get_isscheduled(sysb))
end
1 change: 1 addition & 0 deletions test/serialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ str = String(take!(io))

sys = include_string(@__MODULE__, str)
rc2 = expand_connections(rc_model)
@test isapprox(sys, rc2)
@test issetequal(equations(sys), equations(rc2))
@test issetequal(unknowns(sys), unknowns(rc2))
@test issetequal(parameters(sys), parameters(rc2))
Expand Down
Loading