Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/OrdinaryDiffEqVerner/src/OrdinaryDiffEqVerner.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ PrecompileTools.@compile_workload begin
solver_list = nothing
end

export Vern6, Vern7, Vern8, Vern9
export Vern6, Vern7, Vern8, Vern9, RKV76IIa
export AutoVern6, AutoVern7, AutoVern8, AutoVern9

end
5 changes: 4 additions & 1 deletion lib/OrdinaryDiffEqVerner/src/alg_utils.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
isfsal(alg::Vern7) = false
isfsal(alg::Vern8) = false
isfsal(alg::Vern9) = false
isfsal(alg::RKV76IIa) = false

alg_order(alg::Vern6) = 6
alg_order(alg::Vern7) = 7
alg_order(alg::Vern8) = 8
alg_order(alg::Vern9) = 9
alg_order(alg::RKV76IIa) = 7

alg_stability_size(alg::Vern6) = 4.8553
alg_stability_size(alg::Vern7) = 4.6400
alg_stability_size(alg::Vern8) = 5.8641
alg_stability_size(alg::Vern9) = 4.4762
alg_stability_size(alg::RKV76IIa) = 4.910807773 # From the file: Real Stability Interval is nearly [ -4.910807773, 0]

SciMLBase.has_lazy_interpolation(alg::Union{Vern6, Vern7, Vern8, Vern9}) = true
SciMLBase.has_lazy_interpolation(alg::Union{Vern6, Vern7, Vern8, Vern9, RKV76IIa}) = true
26 changes: 26 additions & 0 deletions lib/OrdinaryDiffEqVerner/src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,29 @@ To gain access to stiff algorithms you might have to install additional librarie
such as `OrdinaryDiffEqRosenbrock`.
"""
AutoVern9(alg; lazy = true, kwargs...) = AutoAlgSwitch(Vern9(lazy = lazy), alg; kwargs...)


@doc explicit_rk_docstring(
"Verner's RKV76.IIa 7/6 method. Most efficient 10-stage conventional pair of orders 6 and 7 with interpolants.",
"RKV76IIa",
references = "@misc{verner2024rkv76iia,
title={RKV76.IIa - A 'most efficient' Runge--Kutta (10:7(6)) pair},
author={Verner, James H},
year={2024},
url={https://www.sfu.ca/~jverner/RKV76.IIa.Efficient.000003389335684.240711.FLOAT6040OnWeb}
}",
extra_keyword_description = """- `lazy`: determines if the lazy interpolant is used.
""",
extra_keyword_default = "lazy = true")
Base.@kwdef struct RKV76IIa{StageLimiter, StepLimiter, Thread} <:
OrdinaryDiffEqAdaptiveAlgorithm
stage_limiter!::StageLimiter = trivial_limiter!
step_limiter!::StepLimiter = trivial_limiter!
thread::Thread = False()
lazy::Bool = true
end
@truncate_stacktrace RKV76IIa 3
# for backwards compatibility
function RKV76IIa(stage_limiter!, step_limiter! = trivial_limiter!; lazy = true)
RKV76IIa(stage_limiter!, step_limiter!, False(), lazy)
end
8 changes: 8 additions & 0 deletions lib/OrdinaryDiffEqVerner/src/interp_func.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,11 @@ function SciMLBase.interp_summary(::Type{cacheType},
}}
dense ? "specialized 9th order lazy interpolation" : "1st order linear"
end

function SciMLBase.interp_summary(::Type{cacheType},
dense::Bool) where {
cacheType <:
Union{RKV76IIaCache, RKV76IIaConstantCache
}}
dense ? "specialized 7th order lazy interpolation" : "1st order linear"
end
66 changes: 66 additions & 0 deletions lib/OrdinaryDiffEqVerner/src/verner_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -269,3 +269,69 @@ function alg_cache(alg::Vern9, u, rate_prototype, ::Type{uEltypeNoUnits},
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
Vern9ConstantCache(alg.lazy)
end

@cache struct RKV76IIaCache{uType, rateType, uNoUnitsType, TabType, StageLimiter, StepLimiter,
Thread} <:
OrdinaryDiffEqMutableCache
u::uType
uprev::uType
k1::rateType
k2::rateType
k3::rateType
k4::rateType
k5::rateType
k6::rateType
k7::rateType
k8::rateType
k9::rateType
k10::rateType
utilde::uType
tmp::uType
rtmp::rateType
atmp::uNoUnitsType
tab::TabType
stage_limiter!::StageLimiter
step_limiter!::StepLimiter
thread::Thread
lazy::Bool
end

# fake values since non-FSAL method
get_fsalfirstlast(cache::RKV76IIaCache, u) = (nothing, nothing)

function alg_cache(alg::RKV76IIa, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
tab = RKV76IIaTableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits))
k1 = zero(rate_prototype)
k2 = zero(rate_prototype)
k3 = k2
k4 = zero(rate_prototype)
k5 = zero(rate_prototype)
k6 = zero(rate_prototype)
k7 = zero(rate_prototype)
k8 = k3
k9 = zero(rate_prototype)
k10 = k4
utilde = zero(u)
tmp = zero(u)
atmp = similar(u, uEltypeNoUnits)
recursivefill!(atmp, false)
rtmp = uEltypeNoUnits === eltype(u) ? utilde : zero(rate_prototype)
RKV76IIaCache(u, uprev, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, utilde, tmp, rtmp, atmp, tab,
alg.stage_limiter!, alg.step_limiter!, alg.thread, alg.lazy)
end

struct RKV76IIaConstantCache{TabType} <: OrdinaryDiffEqConstantCache
tab::TabType
lazy::Bool
end

function alg_cache(alg::RKV76IIa, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
tab = RKV76IIaTableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits))
RKV76IIaConstantCache(tab, alg.lazy)
end
126 changes: 126 additions & 0 deletions lib/OrdinaryDiffEqVerner/src/verner_rk_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1264,3 +1264,129 @@ end
end
return nothing
end

function initialize!(integrator, cache::RKV76IIaConstantCache)
integrator.fsalfirst = integrator.f(integrator.uprev, integrator.p, integrator.t) # Pre-start fsal
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
alg = unwrap_alg(integrator, false)
cache.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 10)
integrator.k = typeof(integrator.k)(undef, integrator.kshortsize)

# Avoid undefined entries if k is an array of arrays
integrator.fsallast = zero(integrator.fsalfirst)
integrator.k[1] = zero(integrator.fsalfirst)
@inbounds for i in 2:integrator.kshortsize-1
integrator.k[i] = zero(integrator.fsalfirst)
end
integrator.k[integrator.kshortsize] = zero(integrator.fsallast)
end

@muladd function perform_step!(integrator, cache::RKV76IIaConstantCache, repeat_step = false)
@unpack t, dt, uprev, u, f, p = integrator
@unpack c1, c2, c3, c4, c5, c6, c7, c8, c9, c10,
a21, a31, a32, a41, a42, a43, a51, a52, a53, a54,
a61, a62, a63, a64, a65, a71, a72, a73, a74, a75, a76,
a81, a82, a83, a84, a85, a86, a87,
a91, a92, a93, a94, a95, a96, a97, a98,
a101, a102, a103, a104, a105, a106, a107, a108, a109,
b1, b2, b3, b4, b5, b6, b7, b8, b9, b10,
bh1, bh2, bh3, bh4, bh5, bh6, bh7, bh8, bh9, bh10 = cache.tab

k1 = f(uprev, p, t)
k2 = f(uprev + dt * a21 * k1, p, t + c2 * dt)
k3 = f(uprev + dt * (a31 * k1 + a32 * k2), p, t + c3 * dt)
k4 = f(uprev + dt * (a41 * k1 + a42 * k2 + a43 * k3), p, t + c4 * dt)
k5 = f(uprev + dt * (a51 * k1 + a52 * k2 + a53 * k3 + a54 * k4), p, t + c5 * dt)
k6 = f(uprev + dt * (a61 * k1 + a62 * k2 + a63 * k3 + a64 * k4 + a65 * k5), p, t + c6 * dt)
k7 = f(uprev + dt * (a71 * k1 + a72 * k2 + a73 * k3 + a74 * k4 + a75 * k5 + a76 * k6), p, t + c7 * dt)
k8 = f(uprev + dt * (a81 * k1 + a82 * k2 + a83 * k3 + a84 * k4 + a85 * k5 + a86 * k6 + a87 * k7), p, t + c8 * dt)
k9 = f(uprev + dt * (a91 * k1 + a92 * k2 + a93 * k3 + a94 * k4 + a95 * k5 + a96 * k6 + a97 * k7 + a98 * k8), p, t + c9 * dt)
k10 = f(uprev + dt * (a101 * k1 + a102 * k2 + a103 * k3 + a104 * k4 + a105 * k5 + a106 * k6 + a107 * k7 + a108 * k8 + a109 * k9), p, t + c10 * dt)

u = uprev + dt * (b1 * k1 + b2 * k2 + b3 * k3 + b4 * k4 + b5 * k5 + b6 * k6 + b7 * k7 + b8 * k8 + b9 * k9 + b10 * k10)

OrdinaryDiffEqCore.increment_nf!(integrator.stats, 10)

if integrator.opts.adaptive
uhat = uprev + dt * (bh1 * k1 + bh2 * k2 + bh3 * k3 + bh4 * k4 + bh5 * k5 + bh6 * k6 + bh7 * k7 + bh8 * k8 + bh9 * k9 + bh10 * k10)
atmp = calculate_residuals(u .- uhat, uprev, u, integrator.opts.abstol,
integrator.opts.reltol, integrator.opts.internalnorm, t)
integrator.EEst = integrator.opts.internalnorm(atmp, t)
end

integrator.k[1] = k1
integrator.k[2] = k2
integrator.k[3] = k3
integrator.k[4] = k4
integrator.k[5] = k5
integrator.k[6] = k6
integrator.k[7] = k7
integrator.k[8] = k8
integrator.k[9] = k9
integrator.k[10] = k10

integrator.u = u
end

function initialize!(integrator, cache::RKV76IIaCache)
alg = unwrap_alg(integrator, false)
cache.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 10)
@unpack k = integrator
resize!(k, integrator.kshortsize)
k[1] = cache.k1
k[2] = cache.k2
k[3] = cache.k3
k[4] = cache.k4
k[5] = cache.k5
k[6] = cache.k6
k[7] = cache.k7
k[8] = cache.k8
k[9] = cache.k9
k[10] = cache.k10
end

@muladd function perform_step!(integrator, cache::RKV76IIaCache, repeat_step = false)
@unpack t, dt, uprev, u, f, p = integrator
@unpack c1, c2, c3, c4, c5, c6, c7, c8, c9, c10,
a21, a31, a32, a41, a42, a43, a51, a52, a53, a54,
a61, a62, a63, a64, a65, a71, a72, a73, a74, a75, a76,
a81, a82, a83, a84, a85, a86, a87,
a91, a92, a93, a94, a95, a96, a97, a98,
a101, a102, a103, a104, a105, a106, a107, a108, a109,
b1, b2, b3, b4, b5, b6, b7, b8, b9, b10,
bh1, bh2, bh3, bh4, bh5, bh6, bh7, bh8, bh9, bh10 = cache.tab
@unpack k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, utilde, tmp, atmp = cache
@unpack thread = cache

f(k1, uprev, p, t)
@.. broadcast=false thread=thread tmp = uprev + dt * a21 * k1
f(k2, tmp, p, t + c2 * dt)
@.. broadcast=false thread=thread tmp = uprev + dt * (a31 * k1 + a32 * k2)
f(k3, tmp, p, t + c3 * dt)
@.. broadcast=false thread=thread tmp = uprev + dt * (a41 * k1 + a42 * k2 + a43 * k3)
f(k4, tmp, p, t + c4 * dt)
@.. broadcast=false thread=thread tmp = uprev + dt * (a51 * k1 + a52 * k2 + a53 * k3 + a54 * k4)
f(k5, tmp, p, t + c5 * dt)
@.. broadcast=false thread=thread tmp = uprev + dt * (a61 * k1 + a62 * k2 + a63 * k3 + a64 * k4 + a65 * k5)
f(k6, tmp, p, t + c6 * dt)
@.. broadcast=false thread=thread tmp = uprev + dt * (a71 * k1 + a72 * k2 + a73 * k3 + a74 * k4 + a75 * k5 + a76 * k6)
f(k7, tmp, p, t + c7 * dt)
@.. broadcast=false thread=thread tmp = uprev + dt * (a81 * k1 + a82 * k2 + a83 * k3 + a84 * k4 + a85 * k5 + a86 * k6 + a87 * k7)
f(k8, tmp, p, t + c8 * dt)
@.. broadcast=false thread=thread tmp = uprev + dt * (a91 * k1 + a92 * k2 + a93 * k3 + a94 * k4 + a95 * k5 + a96 * k6 + a97 * k7 + a98 * k8)
f(k9, tmp, p, t + c9 * dt)
@.. broadcast=false thread=thread tmp = uprev + dt * (a101 * k1 + a102 * k2 + a103 * k3 + a104 * k4 + a105 * k5 + a106 * k6 + a107 * k7 + a108 * k8 + a109 * k9)
f(k10, tmp, p, t + c10 * dt)

@.. broadcast=false thread=thread u = uprev + dt * (b1 * k1 + b2 * k2 + b3 * k3 + b4 * k4 + b5 * k5 + b6 * k6 + b7 * k7 + b8 * k8 + b9 * k9 + b10 * k10)

OrdinaryDiffEqCore.increment_nf!(integrator.stats, 10)

if integrator.opts.adaptive
@.. broadcast=false thread=thread utilde = uprev + dt * (bh1 * k1 + bh2 * k2 + bh3 * k3 + bh4 * k4 + bh5 * k5 + bh6 * k6 + bh7 * k7 + bh8 * k8 + bh9 * k9 + bh10 * k10)
@.. broadcast=false thread=thread atmp = u - utilde
calculate_residuals!(atmp, atmp, uprev, u, integrator.opts.abstol,
integrator.opts.reltol, integrator.opts.internalnorm, t, thread)
integrator.EEst = integrator.opts.internalnorm(atmp, t)
end
end
Loading
Loading