diff --git a/lib/OrdinaryDiffEqVerner/src/OrdinaryDiffEqVerner.jl b/lib/OrdinaryDiffEqVerner/src/OrdinaryDiffEqVerner.jl index 389227688c..4fd7a60f84 100644 --- a/lib/OrdinaryDiffEqVerner/src/OrdinaryDiffEqVerner.jl +++ b/lib/OrdinaryDiffEqVerner/src/OrdinaryDiffEqVerner.jl @@ -80,7 +80,7 @@ PrecompileTools.@compile_workload begin solver_list = nothing end -export Vern6, Vern7, Vern8, Vern9 +export Vern6, Vern7, Vern8, Vern9, RKV76IIa export AutoVern6, AutoVern7, AutoVern8, AutoVern9 end diff --git a/lib/OrdinaryDiffEqVerner/src/alg_utils.jl b/lib/OrdinaryDiffEqVerner/src/alg_utils.jl index f38d2ec768..02e315d370 100644 --- a/lib/OrdinaryDiffEqVerner/src/alg_utils.jl +++ b/lib/OrdinaryDiffEqVerner/src/alg_utils.jl @@ -1,15 +1,18 @@ isfsal(alg::Vern7) = false isfsal(alg::Vern8) = false isfsal(alg::Vern9) = false +isfsal(alg::RKV76IIa) = false alg_order(alg::Vern6) = 6 alg_order(alg::Vern7) = 7 alg_order(alg::Vern8) = 8 alg_order(alg::Vern9) = 9 +alg_order(alg::RKV76IIa) = 7 alg_stability_size(alg::Vern6) = 4.8553 alg_stability_size(alg::Vern7) = 4.6400 alg_stability_size(alg::Vern8) = 5.8641 alg_stability_size(alg::Vern9) = 4.4762 +alg_stability_size(alg::RKV76IIa) = 4.910807773 # From the file: Real Stability Interval is nearly [ -4.910807773, 0] -SciMLBase.has_lazy_interpolation(alg::Union{Vern6, Vern7, Vern8, Vern9}) = true +SciMLBase.has_lazy_interpolation(alg::Union{Vern6, Vern7, Vern8, Vern9, RKV76IIa}) = true diff --git a/lib/OrdinaryDiffEqVerner/src/algorithms.jl b/lib/OrdinaryDiffEqVerner/src/algorithms.jl index 19efb0d5bf..ec892e635a 100644 --- a/lib/OrdinaryDiffEqVerner/src/algorithms.jl +++ b/lib/OrdinaryDiffEqVerner/src/algorithms.jl @@ -153,3 +153,29 @@ To gain access to stiff algorithms you might have to install additional librarie such as `OrdinaryDiffEqRosenbrock`. """ AutoVern9(alg; lazy = true, kwargs...) = AutoAlgSwitch(Vern9(lazy = lazy), alg; kwargs...) + + +@doc explicit_rk_docstring( + "Verner's RKV76.IIa 7/6 method. Most efficient 10-stage conventional pair of orders 6 and 7 with interpolants.", + "RKV76IIa", + references = "@misc{verner2024rkv76iia, + title={RKV76.IIa - A 'most efficient' Runge--Kutta (10:7(6)) pair}, + author={Verner, James H}, + year={2024}, + url={https://www.sfu.ca/~jverner/RKV76.IIa.Efficient.000003389335684.240711.FLOAT6040OnWeb} + }", + extra_keyword_description = """- `lazy`: determines if the lazy interpolant is used. + """, + extra_keyword_default = "lazy = true") +Base.@kwdef struct RKV76IIa{StageLimiter, StepLimiter, Thread} <: + OrdinaryDiffEqAdaptiveAlgorithm + stage_limiter!::StageLimiter = trivial_limiter! + step_limiter!::StepLimiter = trivial_limiter! + thread::Thread = False() + lazy::Bool = true +end +@truncate_stacktrace RKV76IIa 3 +# for backwards compatibility +function RKV76IIa(stage_limiter!, step_limiter! = trivial_limiter!; lazy = true) + RKV76IIa(stage_limiter!, step_limiter!, False(), lazy) +end diff --git a/lib/OrdinaryDiffEqVerner/src/interp_func.jl b/lib/OrdinaryDiffEqVerner/src/interp_func.jl index 1cca0adba3..bc292b1eac 100644 --- a/lib/OrdinaryDiffEqVerner/src/interp_func.jl +++ b/lib/OrdinaryDiffEqVerner/src/interp_func.jl @@ -29,3 +29,11 @@ function SciMLBase.interp_summary(::Type{cacheType}, }} dense ? "specialized 9th order lazy interpolation" : "1st order linear" end + +function SciMLBase.interp_summary(::Type{cacheType}, + dense::Bool) where { + cacheType <: + Union{RKV76IIaCache, RKV76IIaConstantCache +}} + dense ? "specialized 7th order lazy interpolation" : "1st order linear" +end diff --git a/lib/OrdinaryDiffEqVerner/src/verner_caches.jl b/lib/OrdinaryDiffEqVerner/src/verner_caches.jl index 9a3444d101..5734464a31 100644 --- a/lib/OrdinaryDiffEqVerner/src/verner_caches.jl +++ b/lib/OrdinaryDiffEqVerner/src/verner_caches.jl @@ -269,3 +269,69 @@ function alg_cache(alg::Vern9, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} Vern9ConstantCache(alg.lazy) end + +@cache struct RKV76IIaCache{uType, rateType, uNoUnitsType, TabType, StageLimiter, StepLimiter, + Thread} <: + OrdinaryDiffEqMutableCache + u::uType + uprev::uType + k1::rateType + k2::rateType + k3::rateType + k4::rateType + k5::rateType + k6::rateType + k7::rateType + k8::rateType + k9::rateType + k10::rateType + utilde::uType + tmp::uType + rtmp::rateType + atmp::uNoUnitsType + tab::TabType + stage_limiter!::StageLimiter + step_limiter!::StepLimiter + thread::Thread + lazy::Bool +end + +# fake values since non-FSAL method +get_fsalfirstlast(cache::RKV76IIaCache, u) = (nothing, nothing) + +function alg_cache(alg::RKV76IIa, u, rate_prototype, ::Type{uEltypeNoUnits}, + ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, + dt, reltol, p, calck, + ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} + tab = RKV76IIaTableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + k1 = zero(rate_prototype) + k2 = zero(rate_prototype) + k3 = k2 + k4 = zero(rate_prototype) + k5 = zero(rate_prototype) + k6 = zero(rate_prototype) + k7 = zero(rate_prototype) + k8 = k3 + k9 = zero(rate_prototype) + k10 = k4 + utilde = zero(u) + tmp = zero(u) + atmp = similar(u, uEltypeNoUnits) + recursivefill!(atmp, false) + rtmp = uEltypeNoUnits === eltype(u) ? utilde : zero(rate_prototype) + RKV76IIaCache(u, uprev, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, utilde, tmp, rtmp, atmp, tab, + alg.stage_limiter!, alg.step_limiter!, alg.thread, alg.lazy) +end + +struct RKV76IIaConstantCache{TabType} <: OrdinaryDiffEqConstantCache + tab::TabType + lazy::Bool +end + +function alg_cache(alg::RKV76IIa, u, rate_prototype, ::Type{uEltypeNoUnits}, + ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, + dt, reltol, p, calck, + ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} + tab = RKV76IIaTableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + RKV76IIaConstantCache(tab, alg.lazy) +end diff --git a/lib/OrdinaryDiffEqVerner/src/verner_rk_perform_step.jl b/lib/OrdinaryDiffEqVerner/src/verner_rk_perform_step.jl index 898be5af5d..1a7d3ac1b9 100644 --- a/lib/OrdinaryDiffEqVerner/src/verner_rk_perform_step.jl +++ b/lib/OrdinaryDiffEqVerner/src/verner_rk_perform_step.jl @@ -1264,3 +1264,129 @@ end end return nothing end + +function initialize!(integrator, cache::RKV76IIaConstantCache) + integrator.fsalfirst = integrator.f(integrator.uprev, integrator.p, integrator.t) # Pre-start fsal + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) + alg = unwrap_alg(integrator, false) + cache.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 10) + integrator.k = typeof(integrator.k)(undef, integrator.kshortsize) + + # Avoid undefined entries if k is an array of arrays + integrator.fsallast = zero(integrator.fsalfirst) + integrator.k[1] = zero(integrator.fsalfirst) + @inbounds for i in 2:integrator.kshortsize-1 + integrator.k[i] = zero(integrator.fsalfirst) + end + integrator.k[integrator.kshortsize] = zero(integrator.fsallast) +end + +@muladd function perform_step!(integrator, cache::RKV76IIaConstantCache, repeat_step = false) + @unpack t, dt, uprev, u, f, p = integrator + @unpack c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, + a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, + a61, a62, a63, a64, a65, a71, a72, a73, a74, a75, a76, + a81, a82, a83, a84, a85, a86, a87, + a91, a92, a93, a94, a95, a96, a97, a98, + a101, a102, a103, a104, a105, a106, a107, a108, a109, + b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, + bh1, bh2, bh3, bh4, bh5, bh6, bh7, bh8, bh9, bh10 = cache.tab + + k1 = f(uprev, p, t) + k2 = f(uprev + dt * a21 * k1, p, t + c2 * dt) + k3 = f(uprev + dt * (a31 * k1 + a32 * k2), p, t + c3 * dt) + k4 = f(uprev + dt * (a41 * k1 + a42 * k2 + a43 * k3), p, t + c4 * dt) + k5 = f(uprev + dt * (a51 * k1 + a52 * k2 + a53 * k3 + a54 * k4), p, t + c5 * dt) + k6 = f(uprev + dt * (a61 * k1 + a62 * k2 + a63 * k3 + a64 * k4 + a65 * k5), p, t + c6 * dt) + k7 = f(uprev + dt * (a71 * k1 + a72 * k2 + a73 * k3 + a74 * k4 + a75 * k5 + a76 * k6), p, t + c7 * dt) + k8 = f(uprev + dt * (a81 * k1 + a82 * k2 + a83 * k3 + a84 * k4 + a85 * k5 + a86 * k6 + a87 * k7), p, t + c8 * dt) + k9 = f(uprev + dt * (a91 * k1 + a92 * k2 + a93 * k3 + a94 * k4 + a95 * k5 + a96 * k6 + a97 * k7 + a98 * k8), p, t + c9 * dt) + k10 = f(uprev + dt * (a101 * k1 + a102 * k2 + a103 * k3 + a104 * k4 + a105 * k5 + a106 * k6 + a107 * k7 + a108 * k8 + a109 * k9), p, t + c10 * dt) + + u = uprev + dt * (b1 * k1 + b2 * k2 + b3 * k3 + b4 * k4 + b5 * k5 + b6 * k6 + b7 * k7 + b8 * k8 + b9 * k9 + b10 * k10) + + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 10) + + if integrator.opts.adaptive + uhat = uprev + dt * (bh1 * k1 + bh2 * k2 + bh3 * k3 + bh4 * k4 + bh5 * k5 + bh6 * k6 + bh7 * k7 + bh8 * k8 + bh9 * k9 + bh10 * k10) + atmp = calculate_residuals(u .- uhat, uprev, u, integrator.opts.abstol, + integrator.opts.reltol, integrator.opts.internalnorm, t) + integrator.EEst = integrator.opts.internalnorm(atmp, t) + end + + integrator.k[1] = k1 + integrator.k[2] = k2 + integrator.k[3] = k3 + integrator.k[4] = k4 + integrator.k[5] = k5 + integrator.k[6] = k6 + integrator.k[7] = k7 + integrator.k[8] = k8 + integrator.k[9] = k9 + integrator.k[10] = k10 + + integrator.u = u +end + +function initialize!(integrator, cache::RKV76IIaCache) + alg = unwrap_alg(integrator, false) + cache.lazy ? (integrator.kshortsize = 10) : (integrator.kshortsize = 10) + @unpack k = integrator + resize!(k, integrator.kshortsize) + k[1] = cache.k1 + k[2] = cache.k2 + k[3] = cache.k3 + k[4] = cache.k4 + k[5] = cache.k5 + k[6] = cache.k6 + k[7] = cache.k7 + k[8] = cache.k8 + k[9] = cache.k9 + k[10] = cache.k10 +end + +@muladd function perform_step!(integrator, cache::RKV76IIaCache, repeat_step = false) + @unpack t, dt, uprev, u, f, p = integrator + @unpack c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, + a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, + a61, a62, a63, a64, a65, a71, a72, a73, a74, a75, a76, + a81, a82, a83, a84, a85, a86, a87, + a91, a92, a93, a94, a95, a96, a97, a98, + a101, a102, a103, a104, a105, a106, a107, a108, a109, + b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, + bh1, bh2, bh3, bh4, bh5, bh6, bh7, bh8, bh9, bh10 = cache.tab + @unpack k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, utilde, tmp, atmp = cache + @unpack thread = cache + + f(k1, uprev, p, t) + @.. broadcast=false thread=thread tmp = uprev + dt * a21 * k1 + f(k2, tmp, p, t + c2 * dt) + @.. broadcast=false thread=thread tmp = uprev + dt * (a31 * k1 + a32 * k2) + f(k3, tmp, p, t + c3 * dt) + @.. broadcast=false thread=thread tmp = uprev + dt * (a41 * k1 + a42 * k2 + a43 * k3) + f(k4, tmp, p, t + c4 * dt) + @.. broadcast=false thread=thread tmp = uprev + dt * (a51 * k1 + a52 * k2 + a53 * k3 + a54 * k4) + f(k5, tmp, p, t + c5 * dt) + @.. broadcast=false thread=thread tmp = uprev + dt * (a61 * k1 + a62 * k2 + a63 * k3 + a64 * k4 + a65 * k5) + f(k6, tmp, p, t + c6 * dt) + @.. broadcast=false thread=thread tmp = uprev + dt * (a71 * k1 + a72 * k2 + a73 * k3 + a74 * k4 + a75 * k5 + a76 * k6) + f(k7, tmp, p, t + c7 * dt) + @.. broadcast=false thread=thread tmp = uprev + dt * (a81 * k1 + a82 * k2 + a83 * k3 + a84 * k4 + a85 * k5 + a86 * k6 + a87 * k7) + f(k8, tmp, p, t + c8 * dt) + @.. broadcast=false thread=thread tmp = uprev + dt * (a91 * k1 + a92 * k2 + a93 * k3 + a94 * k4 + a95 * k5 + a96 * k6 + a97 * k7 + a98 * k8) + f(k9, tmp, p, t + c9 * dt) + @.. broadcast=false thread=thread tmp = uprev + dt * (a101 * k1 + a102 * k2 + a103 * k3 + a104 * k4 + a105 * k5 + a106 * k6 + a107 * k7 + a108 * k8 + a109 * k9) + f(k10, tmp, p, t + c10 * dt) + + @.. broadcast=false thread=thread u = uprev + dt * (b1 * k1 + b2 * k2 + b3 * k3 + b4 * k4 + b5 * k5 + b6 * k6 + b7 * k7 + b8 * k8 + b9 * k9 + b10 * k10) + + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 10) + + if integrator.opts.adaptive + @.. broadcast=false thread=thread utilde = uprev + dt * (bh1 * k1 + bh2 * k2 + bh3 * k3 + bh4 * k4 + bh5 * k5 + bh6 * k6 + bh7 * k7 + bh8 * k8 + bh9 * k9 + bh10 * k10) + @.. broadcast=false thread=thread atmp = u - utilde + calculate_residuals!(atmp, atmp, uprev, u, integrator.opts.abstol, + integrator.opts.reltol, integrator.opts.internalnorm, t, thread) + integrator.EEst = integrator.opts.internalnorm(atmp, t) + end +end diff --git a/lib/OrdinaryDiffEqVerner/src/verner_tableaus.jl b/lib/OrdinaryDiffEqVerner/src/verner_tableaus.jl index 5edc4f4a10..84f7212787 100644 --- a/lib/OrdinaryDiffEqVerner/src/verner_tableaus.jl +++ b/lib/OrdinaryDiffEqVerner/src/verner_tableaus.jl @@ -3892,3 +3892,178 @@ end btilde1, btilde8, btilde9, btilde10, btilde11, btilde12, btilde13, btilde14, btilde15, btilde16) end + +""" +RKV76.IIa - A 'most efficient' Runge--Kutta (10:7(6)) pair +From Verner's Website +""" +struct RKV76IIaTableau{T, T2} + c1::T2 + c2::T2 + c3::T2 + c4::T2 + c5::T2 + c6::T2 + c7::T2 + c8::T2 + c9::T2 + c10::T2 + a21::T + a31::T + a32::T + a41::T + a42::T + a43::T + a51::T + a52::T + a53::T + a54::T + a61::T + a62::T + a63::T + a64::T + a65::T + a71::T + a72::T + a73::T + a74::T + a75::T + a76::T + a81::T + a82::T + a83::T + a84::T + a85::T + a86::T + a87::T + a91::T + a92::T + a93::T + a94::T + a95::T + a96::T + a97::T + a98::T + a101::T + a102::T + a103::T + a104::T + a105::T + a106::T + a107::T + a108::T + a109::T + b1::T + b2::T + b3::T + b4::T + b5::T + b6::T + b7::T + b8::T + b9::T + b10::T + bh1::T + bh2::T + bh3::T + bh4::T + bh5::T + bh6::T + bh7::T + bh8::T + bh9::T + bh10::T +end + +function RKV76IIaTableau(T, T2) + c1 = convert(T2, BigFloat("0")) + c2 = convert(T2, BigFloat("0.069")) + c3 = convert(T2, BigFloat("0.118")) + c4 = convert(T2, BigFloat("0.177")) + c5 = convert(T2, BigFloat("0.501")) + c6 = convert(T2, BigFloat("0.7737799115305331003715765296862487670813")) + c7 = convert(T2, BigFloat("0.994")) + c8 = convert(T2, BigFloat("0.998")) + c9 = convert(T2, BigFloat("1")) + c10 = convert(T2, BigFloat("1")) + + # Butcher tableau A matrix + a21 = convert(T, BigFloat("0.069")) + a31 = convert(T, BigFloat("0.01710144927536231884057971014492753623188")) + a32 = convert(T, BigFloat("0.1008985507246376811594202898550724637681")) + a41 = convert(T, BigFloat("0.04425")) + a42 = convert(T, BigFloat("0")) + a43 = convert(T, BigFloat("0.13275")) + a51 = convert(T, BigFloat("0.7353445130709566216604424016087331226659")) + a52 = convert(T, BigFloat("0")) + a53 = convert(T, BigFloat("-2.830160657856937661591496696351623096811")) + a54 = convert(T, BigFloat("2.595816144785981039931054294742889974145")) + a61 = convert(T, BigFloat("-12.21580485360407974005910916471598682362")) + a62 = convert(T, BigFloat("0")) + a63 = convert(T, BigFloat("48.82665485823736062335980699373053427134")) + a64 = convert(T, BigFloat("-38.55615592319928364666616600329792491404")) + a65 = convert(T, BigFloat("2.719085830096535863737044703969626233400")) + a71 = convert(T, BigFloat("108.8614188704176574066699618897203578466")) + a72 = convert(T, BigFloat("0")) + a73 = convert(T, BigFloat("-432.4521181775777896358931629332707752654")) + a74 = convert(T, BigFloat("343.9115281800118289547200158889409233641")) + a75 = convert(T, BigFloat("-20.55041135925273709189369488701721016265")) + a76 = convert(T, BigFloat("1.223582486401040366396880041626704217305")) + a81 = convert(T, BigFloat("113.4755131883738522204615568160304033854")) + a82 = convert(T, BigFloat("0")) + a83 = convert(T, BigFloat("-450.8122021555997002820400438087344405365")) + a84 = convert(T, BigFloat("358.5132765190089889943579090008312808216")) + a85 = convert(T, BigFloat("-21.45046667648445540174055882443151176550")) + a86 = convert(T, BigFloat("1.274053318605952891766776667539031508649")) + a87 = convert(T, BigFloat("-0.002174193904638422805639851234763413667602")) + a91 = convert(T, BigFloat("115.6996223324232534824963925993127275021")) + a92 = convert(T, BigFloat("0")) + a93 = convert(T, BigFloat("-459.6635446100248030478961869239726305957")) + a94 = convert(T, BigFloat("365.5534717131745930309149378867953890507")) + a95 = convert(T, BigFloat("-21.88511586349784824146225495848432937529")) + a96 = convert(T, BigFloat("1.298718109698721459187976480852777474315")) + a97 = convert(T, BigFloat("-0.00005318700918481883515898878747322241917739")) + a98 = convert(T, BigFloat("-0.003098494764731864405706095716460833640254")) + a101 = convert(T, BigFloat("124.1543935612464600014576130437603883332")) + a102 = convert(T, BigFloat("0")) + a103 = convert(T, BigFloat("-493.2318713314597046194663569971348299332")) + a104 = convert(T, BigFloat("392.2086219315800762927575562172365337929")) + a105 = convert(T, BigFloat("-23.48641564290853341361596821616234280392")) + a106 = convert(T, BigFloat("1.362322948908907509911149920532561575254")) + a107 = convert(T, BigFloat("-0.007051467367205771043993968232310964220061")) + a108 = convert(T, BigFloat("0")) + a109 = convert(T, BigFloat("0")) + + # High order weights + b1 = convert(T, BigFloat("0.05163520172057869163393251056217968836723")) + b2 = convert(T, BigFloat("0")) + b3 = convert(T, BigFloat("0")) + b4 = convert(T, BigFloat("0.2767172535461648728769641534539952501983")) + b5 = convert(T, BigFloat("0.3374175285287150670818592701488271741753")) + b6 = convert(T, BigFloat("0.1884488267810967803491085059046161195540")) + b7 = convert(T, BigFloat("24.54134121634868026791753618430192161716")) + b8 = convert(T, BigFloat("-68.81190284469011946382716084194838780382")) + b9 = convert(T, BigFloat("44.41634281776488378396776021757684795437")) + b10 = convert(T, BigFloat("0")) + + # Low order weights + bh1 = convert(T, BigFloat("0.05089676583692947576073561095512200263213")) + bh2 = convert(T, BigFloat("0")) + bh3 = convert(T, BigFloat("0")) + bh4 = convert(T, BigFloat("0.2793777374763233901369432426263934138476")) + bh5 = convert(T, BigFloat("0.3281330142746535239936396881369403928344")) + bh6 = convert(T, BigFloat("0.224172121818615103358179483735013")) + bh7 = convert(T, BigFloat("0.7874574778015076584344903106189416715189")) + bh8 = convert(T, BigFloat("0")) + bh9 = convert(T, BigFloat("0")) + bh10 = convert(T, BigFloat("-0.6700371172080291516839883360724104817561")) + + RKV76IIaTableau(c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, + a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, + a61, a62, a63, a64, a65, a71, a72, a73, a74, a75, a76, + a81, a82, a83, a84, a85, a86, a87, + a91, a92, a93, a94, a95, a96, a97, a98, + a101, a102, a103, a104, a105, a106, a107, a108, a109, + b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, + bh1, bh2, bh3, bh4, bh5, bh6, bh7, bh8, bh9, bh10) +end diff --git a/lib/OrdinaryDiffEqVerner/test/rkv76iia_tests.jl b/lib/OrdinaryDiffEqVerner/test/rkv76iia_tests.jl new file mode 100644 index 0000000000..70e0a5da7e --- /dev/null +++ b/lib/OrdinaryDiffEqVerner/test/rkv76iia_tests.jl @@ -0,0 +1,161 @@ +using OrdinaryDiffEqVerner, OrdinaryDiffEqCore, DiffEqBase, Test +using LinearAlgebra +using OrdinaryDiffEqSSPRK, DiffEqDevTools, Test, Random +import OrdinaryDiffEqLowStorageRK +import ODEProblemLibrary: prob_ode_linear, prob_ode_2Dlinear, prob_ode_bigfloat2Dlinear +using Plots + +testTol = 0.3 +println(BigFloat(exp(-64.0))) +function print_solution_dts(prob, algorithms; abstol=1e-12, reltol=1e-12, show_solution=true, kwargs...) + println("\n" * "="^70) + println("Comparing dt usage across algorithms") + println("="^70) + + # println("\nProblem Details:") + # println(" Time span: ", prob.tspan) + # println(" Initial condition u0: ", prob.u0) + # println(" Problem type: ", typeof(prob.f).name.name) + + for alg in algorithms + alg_name = string(typeof(alg).name.name) + + try + sol = solve(prob, alg; abstol=abstol, reltol=reltol, kwargs...) + final_dt = sol.t[end] - sol.t[end-1] + t_end = sol.t[end] + + println("\nAlgorithm: $alg_name") + if show_solution + println(" Solution at t=$t_end: ", sol[end]) + end + println(" Final dt: ", final_dt) + catch e + println("\nAlgorithm: $alg_name") + println(" ERROR: ", e) + end + end + + println("\n" * "="^70) +end + +# ODE function definitions +f_1 = (u, p, t) -> cos(t) +prob_ode_sin = ODEProblem(ODEFunction(f_1; analytic = (u0, p, t) -> sin(t)), 0.0, (0.0, 1.0)) + +f_1 = (du, u, p, t) -> du[1] = cos(t) +prob_ode_sin_inplace = ODEProblem(ODEFunction(f_1; analytic = (u0, p, t) -> [sin(t)]), [0.0], + (0.0, 1.0)) + +f_2 = (u, p, t) -> sin(u) +prob_ode_nonlinear = ODEProblem( + ODEFunction(f_2; + analytic = (u0, p, t) -> 2 * acot(exp(-t) * + cot(0.5))), 1.0, + (0.0, 0.5)) + +f_2 = (du, u, p, t) -> du[1] = sin(u[1]) +prob_ode_nonlinear_inplace = ODEProblem( + ODEFunction(f_2; + analytic = (u0, p, t) -> [ + 2 * acot(exp(-t) * cot(0.5)) + ]), + [1.0], (0.0, 0.5)) + +f_ssp = (u, p, t) -> begin + sin(10t) * u * (1 - u) +end +test_problem_ssp = ODEProblem(f_ssp, 0.1, (0.0, 8.0)) +test_problem_ssp_long = ODEProblem(f_ssp, 0.1, (0.0, 1.e3)) + +function f!(du, u, p, t) + du[1] = -u[1] +end + +function f(u, p, t) + -u +end + +test_problems_only_time = [prob_ode_sin, prob_ode_sin_inplace] + + +t_end=1.0 +alg = OrdinaryDiffEqSSPRK.SSPRK22() +t_end = 64.0 + +setprecision(256) +prob_oop = ODEProblem(f, 1.0, (0.0, t_end)) +println("**** exp(-64) ****") +algorithms = [RKV76IIa(), Vern7()] +print_solution_dts(prob_oop, algorithms; abstol=1e-40, reltol=1e-40) +print_solution_dts(prob_oop, [alg]; dt=OrdinaryDiffEqSSPRK.ssp_coefficient(alg), dense=false,abstol=1e-40, reltol=1e-40) +println("Expected value: ", exp(BigFloat(-t_end))) +println("**** exp(-64) ****") + + +println("***************** sin in and out of place *********************") +algorithms = [RKV76IIa(), Vern7()] +for prob in test_problems_only_time + print_solution_dts(prob, algorithms; abstol=1e-12, reltol=1e-12) + print_solution_dts(prob, [alg]; dt=OrdinaryDiffEqSSPRK.ssp_coefficient(alg), dense=false) +end +println("************** sin in and out of place ************************") + + +test_problems_nonlinear = [prob_ode_nonlinear, prob_ode_nonlinear_inplace] +t_end = 1.e3 + +sol_oop = solve(test_problem_ssp_long, RKV76IIa(), abstol=1e-12, reltol=1e-12) +println("***************** test_problem_ssp_long *********************") +algorithms = [RKV76IIa(), Vern7(), alg] +print_solution_dts(test_problem_ssp_long, algorithms; abstol=1e-12, reltol=1e-12) +print_solution_dts(test_problem_ssp_long, [alg]; dt=OrdinaryDiffEqSSPRK.ssp_coefficient(alg), dense=false) +println("************** test_problem_ssp_long ************************") + +t_end = 64.0 + +setprecision(256) +prob_oop = ODEProblem(f, 1.0, (0.0, t_end)) +println("**** exp(-64) ****") +algorithms = [RKV76IIa(), Vern7()] +print_solution_dts(prob_oop, algorithms; abstol=1e-40, reltol=1e-40) +print_solution_dts(prob_oop, [alg]; dt=OrdinaryDiffEqSSPRK.ssp_coefficient(alg), dense=false) +println("Expected value: ", exp(BigFloat(-t_end))) +println("**** exp(-64) ****") + + +println("*** Testing Convergence of diff algorithm *** ") + +#alg=Vern7() +alg=RKV76IIa() +# dts = BigFloat(1) ./ 2 .^ (1:6) +dts = [8, 6, 4, 2, 1, 0.5, 0.25, 0.125] + +errors = zeros(BigFloat, length(dts)) +println("Testing order 7 for RKV76IIa()") +for (i, dt) in enumerate(dts) + sol = solve(prob_oop, alg, dt=dt, adaptive=false) + # Use BigFloat for error calculation + errors[i] = abs(BigFloat(sol[end]) - exp(BigFloat(-t_end))) + println("Computed Solution ", sol[end], " for dt = ", dt, ", error = ", errors[i]) +end + +for i in 2:length(errors) + order = log(BigFloat(errors[i-1])/BigFloat(errors[i])) / log(2) + println("Order between dt=", dts[i-1], " and dt=", dts[i], ": ", order) +end + +plot( + dts, errors; + xscale = :log10, yscale = :log10, + marker = :o, linewidth = 2, + xlabel = "dt", ylabel = "Error", + title = "Convergence of RKV76IIa", + label = "Observed Error" +) +# Make reference line pass through the 0.125 dt point +ref_idx = findfirst(x -> x == 0.125, dts) +ref_errors = errors[ref_idx] * (BigFloat.(dts) ./ BigFloat(0.125)).^7 +plot!(float.(dts), ref_errors; linestyle = :dash, label = "Order 7 Reference") +display(current()) +savefig("convergence_rkv76iia.png") \ No newline at end of file diff --git a/lib/OrdinaryDiffEqVerner/test/runtests.jl b/lib/OrdinaryDiffEqVerner/test/runtests.jl index 5ed6ac4c1c..ca1529ea9d 100644 --- a/lib/OrdinaryDiffEqVerner/test/runtests.jl +++ b/lib/OrdinaryDiffEqVerner/test/runtests.jl @@ -4,5 +4,6 @@ using SafeTestsets if isempty(VERSION.prerelease) @time @safetestset "JET Tests" include("jet.jl") @time @safetestset "Aqua" include("qa.jl") + @time @safetestset "RKV76IIa Tests" include("rkv76iia_tests.jl") @time @safetestset "Allocation Tests" include("allocation_tests.jl") end \ No newline at end of file diff --git a/src/OrdinaryDiffEq.jl b/src/OrdinaryDiffEq.jl index 9091a8e16d..0da9691b46 100644 --- a/src/OrdinaryDiffEq.jl +++ b/src/OrdinaryDiffEq.jl @@ -128,8 +128,8 @@ export OrdinaryDiffEqRKN, Nystrom4, FineRKN4, FineRKN5, Nystrom4VelocityIndepend import OrdinaryDiffEqVerner: OrdinaryDiffEqVerner using OrdinaryDiffEqVerner: Vern6, Vern7, Vern8, Vern9, AutoVern6, AutoVern7, AutoVern8, - AutoVern9 -export OrdinaryDiffEqVerner, Vern6, Vern7, Vern8, Vern9 + AutoVern9, RKV76IIa +export OrdinaryDiffEqVerner, Vern6, Vern7, Vern8, Vern9, RKV76IIa import OrdinaryDiffEqHighOrderRK: OrdinaryDiffEqHighOrderRK using OrdinaryDiffEqHighOrderRK: TanYam7, DP8, PFRK87, TsitPap8