diff --git a/src/common.jl b/src/common.jl index 01e6ca6a..63920a03 100644 --- a/src/common.jl +++ b/src/common.jl @@ -95,13 +95,14 @@ function __solvebp_call(cache::IntegralCache, args...; kwargs...) __solvebp_call(build_problem(cache), args...; kwargs...) end -mutable struct SampledIntegralCache{Y, X, D, PK, A, K, Tc} +mutable struct SampledIntegralCache{Y, X, D, PK, A, K, C, Tc} y::Y x::X dim::D prob_kwargs::PK alg::A kwargs::K + cumulative::C isfresh::Bool # state of whether weights have been calculated cacheval::Tc # store alg weights here end @@ -114,10 +115,10 @@ function Base.setproperty!(cache::SampledIntegralCache, name::Symbol, x) end function SciMLBase.init(prob::SampledIntegralProblem, - alg::SciMLBase.AbstractIntegralAlgorithm; + alg::SciMLBase.AbstractIntegralAlgorithm; cumulative = Val(false), kwargs...) NamedTuple(kwargs) == NamedTuple() || - throw(ArgumentError("There are no keyword arguments allowed to `solve`")) + throw(ArgumentError("There are no keyword arguments allowed to `solve` except `cumulative`")) cacheval = init_cacheval(alg, prob) isfresh = true @@ -128,6 +129,7 @@ function SciMLBase.init(prob::SampledIntegralProblem, prob.kwargs, alg, kwargs, + cumulative, isfresh, cacheval) end @@ -139,12 +141,13 @@ solve(prob::SampledIntegralProblem, alg::SciMLBase.AbstractIntegralAlgorithm; kw ## Keyword Arguments -There are no keyword arguments used to solve `SampledIntegralProblem`s +- cumulative: Boolean value to indicate if it should return cumulative integral, i.e., a vector of integral values where at every sampled point, it corresponds to the integral from the beginning to that point. +Default value is `false`. """ function SciMLBase.solve(prob::SampledIntegralProblem, - alg::SciMLBase.AbstractIntegralAlgorithm; + alg::SciMLBase.AbstractIntegralAlgorithm; cumulative = Val(false), kwargs...) - solve!(init(prob, alg; kwargs...)) + solve!(init(prob, alg; cumulative, kwargs...)) end function SciMLBase.solve!(cache::SampledIntegralCache) diff --git a/src/sampled.jl b/src/sampled.jl index 947ad962..81c6d8f4 100644 --- a/src/sampled.jl +++ b/src/sampled.jl @@ -25,27 +25,49 @@ _eachslice(data::AbstractArray{T, 1}; dims = ndims(data)) where {T} = data dimension(::Val{D}) where {D} = D dimension(D::Int) = D -function evalrule(data::AbstractArray, weights, dim) +function update_outs!(outs, out, idx) + if typeof(outs) <: AbstractVector + outs[idx] = out + elseif typeof(outs) <: AbstractMatrix + @views outs[:, idx] = out + end +end + +function evalrule(data::AbstractArray, weights, dim, cumulative::Val{C}) where C fw = zip(_eachslice(data, dims = dim), weights) next = iterate(fw) next === nothing && throw(ArgumentError("No points to integrate")) (f1, w1), state = next out = w1 * f1 + C && begin + outs = zeros(eltype(out), size(data)) + idx = 1 + update_outs!(outs, out, idx) + idx += 1 + end next = iterate(fw, state) if isbits(out) while next !== nothing (fi, wi), state = next out += wi * fi + C && begin + update_outs!(outs, out, idx) + idx += 1 + end next = iterate(fw, state) end else while next !== nothing (fi, wi), state = next out .+= wi .* fi + C && begin + update_outs!(outs, out, idx) + idx += 1 + end next = iterate(fw, state) end end - return out + return C ? outs : out end # can be reused for other sampled rules, which should implement find_weights(x, alg) @@ -67,7 +89,7 @@ function __solvebp_call(cache::SampledIntegralCache, cache.isfresh = false end weights = cache.cacheval - I = evalrule(data, weights, dim) + I = evalrule(data, weights, dim, cache.cumulative; kwargs...) prob = build_problem(cache) return SciMLBase.build_solution(prob, alg, I, err, retcode = ReturnCode.Success) end diff --git a/test/sampled_tests.jl b/test/sampled_tests.jl index 8a77c63f..a9ca8cc2 100644 --- a/test/sampled_tests.jl +++ b/test/sampled_tests.jl @@ -9,21 +9,31 @@ using Integrals, Test grid2 = [lb; sort(grid2); ub] exact_sols = [1 / 6 * (ub^6 - lb^6), sin(ub) - sin(lb)] + exact_sols_cumulative = [[ + [1 / 6 * (x^6 - lb^6) for x in grid], + [sin(x) - sin(lb) for x in grid], + ] for grid in [grid1, grid2]] for method in [TrapezoidalRule] # Simpson's later - for grid in [grid1, grid2] + for (j, grid) in enumerate([grid1, grid2]) for (i, f) in enumerate([x -> x^5, x -> cos(x)]) exact = exact_sols[i] + exact_cum = exact_sols_cumulative[j] # single dimensional y y = f.(grid) prob = SampledIntegralProblem(y, grid) error = solve(prob, method()).u .- exact - @test all(error .< 10^-4) + error_cum = solve(prob, method(); cumulative = Val(true)).u .- exact_cum[i] + @test error < 10^-4 + @test all(error_cum .< 10^-2) # along dim=2 y = f.([grid grid]') prob = SampledIntegralProblem(y, grid; dim = 2) error = solve(prob, method()).u .- exact + error_cum = solve(prob, method(); cumulative = Val(true)).u .- + [exact_cum[i] exact_cum[i]]' @test all(error .< 10^-4) + @test all(error_cum .< 10^-2) end end end @@ -33,38 +43,41 @@ end x = 0.0:0.1:1.0 y = sin.(x) - prob = SampledIntegralProblem(y, x) - alg = TrapezoidalRule() - - cache = init(prob, alg) - sol1 = solve!(cache) - - @test sol1 == solve(prob, alg) - - cache.y = cos.(x) # use .= to update in-place - sol2 = solve!(cache) - - @test sol2 == solve(SampledIntegralProblem(cache.y, cache.x), alg) + function test_interface(x, y, cumulative) + prob = SampledIntegralProblem(y, x) + alg = TrapezoidalRule() - cache.x = 0.0:0.2:2.0 - cache.y = sin.(cache.x) - sol3 = solve!(cache) + cache = init(prob, alg; cumulative) + sol1 = solve!(cache) + @test sol1 == solve(prob, alg; cumulative) - @test sol3 == solve(SampledIntegralProblem(cache.y, cache.x), alg) + cache.y = cos.(x) # use .= to update in-place + sol2 = solve!(cache) + @test sol2 == solve(SampledIntegralProblem(cache.y, cache.x), alg; cumulative) - x = 0.0:0.1:1.0 - y = sin.(x) .* cos.(x') - - prob = SampledIntegralProblem(y, x) - alg = TrapezoidalRule() + cache.x = 0.0:0.2:2.0 + cache.y = sin.(cache.x) + sol3 = solve!(cache) + @test sol3 == solve(SampledIntegralProblem(cache.y, cache.x), alg; cumulative) - cache = init(prob, alg) - sol1 = solve!(cache) + x = 0.0:0.1:1.0 + y = sin.(x) .* cos.(x') - @test sol1 == solve(prob, alg) + prob = SampledIntegralProblem(y, x) + alg = TrapezoidalRule() - cache.dim = 1 - sol2 = solve!(cache) + cache = init(prob, alg; cumulative) + sol1 = solve!(cache) + @test sol1 == solve(prob, alg; cumulative) - @test sol2 == solve(SampledIntegralProblem(y, x, dim = 1), alg) + cache.dim = 1 + sol2 = solve!(cache) + @test sol2 == solve(SampledIntegralProblem(y, x, dim = 1), alg; cumulative) + end + @testset "Total Integral" begin + test_interface(x, y, Val(false)) + end + @testset "Cumulative Integral" begin + test_interface(x, y, Val(true)) + end end