Skip to content

Commit 77a86a5

Browse files
Merge pull request #3768 from AayushSabharwal/as/v9-complex-odeproblem
[v9] fix: enable support for complex ODEProblem again
2 parents 4a3d069 + 278fe05 commit 77a86a5

File tree

3 files changed

+33
-5
lines changed

3 files changed

+33
-5
lines changed

src/systems/index_cache.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,8 +388,8 @@ function IndexCache(sys::AbstractSystem)
388388
observed_syms_to_timeseries,
389389
dependent_pars_to_timeseries,
390390
disc_buffer_templates,
391-
BufferTemplate(Real, tunable_buffer_size),
392-
BufferTemplate(Real, initials_buffer_size),
391+
BufferTemplate(Number, tunable_buffer_size),
392+
BufferTemplate(Number, initials_buffer_size),
393393
const_buffer_sizes,
394394
nonnumeric_buffer_sizes,
395395
symbol_to_variable

src/systems/problem_utils.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -969,7 +969,7 @@ end
969969
$(TYPEDEF)
970970
971971
A callable struct to use as the `get_updated_u0` field of `InitializationMetadata`.
972-
Returns the value to use for the `u0` of the problem.
972+
Returns the value to use for the `u0` of the problem.
973973
974974
# Fields
975975
@@ -1175,7 +1175,7 @@ function float_type_from_varmap(varmap, floatT = Bool)
11751175

11761176
if v isa AbstractArray
11771177
floatT = promote_type(floatT, eltype(v))
1178-
elseif v isa Real
1178+
elseif v isa Number
11791179
floatT = promote_type(floatT, typeof(v))
11801180
end
11811181
end
@@ -1447,7 +1447,7 @@ function check_inputmap_keys(sys, u0map, pmap)
14471447
end
14481448

14491449
const BAD_KEY_MESSAGE = """
1450-
Undefined keys found in the parameter or initial condition maps. Check if symbolic variable names have been reassigned.
1450+
Undefined keys found in the parameter or initial condition maps. Check if symbolic variable names have been reassigned.
14511451
The following keys are invalid:
14521452
"""
14531453

test/complex.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using ModelingToolkit
2+
using OrdinaryDiffEq
23
using ModelingToolkit: t_nounits as t
34
using Test
45

@@ -14,3 +15,30 @@ using Test
1415
end
1516
@named mixed = ComplexModel()
1617
@test length(equations(mixed)) == 2
18+
19+
@testset "Complex ODEProblem" begin
20+
using ModelingToolkit: t_nounits as t, D_nounits as D
21+
22+
vars = @variables x(t) y(t) z(t)
23+
pars = @parameters a b
24+
25+
eqs = [
26+
D(x) ~ y - x,
27+
D(y) ~ -x * z + b * abs(z),
28+
D(z) ~ x * y - a
29+
]
30+
@named modlorenz = System(eqs, t)
31+
sys = structural_simplify(modlorenz)
32+
33+
ic = ModelingToolkit.get_index_cache(sys)
34+
@test ic.tunable_buffer_size.type == Number
35+
36+
u0 = ComplexF64[-4.0, 5.0, 0.0] .+ randn(ComplexF64, 3)
37+
p = ComplexF64[5.0, 0.1]
38+
dict = merge(Dict(unknowns(sys) .=> u0), Dict(parameters(sys) .=> p))
39+
prob = ODEProblem(sys, dict, (0.0, 1.0))
40+
41+
sol = solve(prob, Tsit5(), saveat = 0.1)
42+
43+
@test sol.u[1] isa Vector{ComplexF64}
44+
end

0 commit comments

Comments
 (0)