From 05c5bd07641de66c9da0de9c0b30581592500906 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Mon, 25 Sep 2023 09:30:01 +0000 Subject: [PATCH 1/4] test: add tests for cumulative integrals --- test/sampled_tests.jl | 71 +++++++++++++++++++++++++------------------ 1 file changed, 42 insertions(+), 29 deletions(-) diff --git a/test/sampled_tests.jl b/test/sampled_tests.jl index 8a77c63f..4db1a675 100644 --- a/test/sampled_tests.jl +++ b/test/sampled_tests.jl @@ -9,21 +9,31 @@ using Integrals, Test grid2 = [lb; sort(grid2); ub] exact_sols = [1 / 6 * (ub^6 - lb^6), sin(ub) - sin(lb)] + exact_sols_cumulative = [[ + [1 / 6 * (x^6 - lb^6) for x in grid], + [sin(x) - sin(lb) for x in grid], + ] for grid in [grid1, grid2]] for method in [TrapezoidalRule] # Simpson's later - for grid in [grid1, grid2] + for (j, grid) in enumerate([grid1, grid2]) for (i, f) in enumerate([x -> x^5, x -> cos(x)]) exact = exact_sols[i] + exact_cum = exact_sols_cumulative[j] # single dimensional y y = f.(grid) prob = SampledIntegralProblem(y, grid) error = solve(prob, method()).u .- exact - @test all(error .< 10^-4) + error_cum = solve(prob, method(); cumulative = true).u .- exact_cum[i] + @test error < 10^-4 + @test all(error_cum .< 10^-2) # along dim=2 y = f.([grid grid]') prob = SampledIntegralProblem(y, grid; dim = 2) error = solve(prob, method()).u .- exact + error_cum = solve(prob, method(); cumulative = true).u .- + [exact_cum[i] exact_cum[i]]' @test all(error .< 10^-4) + @test all(error_cum .< 10^-2) end end end @@ -33,38 +43,41 @@ end x = 0.0:0.1:1.0 y = sin.(x) - prob = SampledIntegralProblem(y, x) - alg = TrapezoidalRule() - - cache = init(prob, alg) - sol1 = solve!(cache) - - @test sol1 == solve(prob, alg) - - cache.y = cos.(x) # use .= to update in-place - sol2 = solve!(cache) - - @test sol2 == solve(SampledIntegralProblem(cache.y, cache.x), alg) + function test_interface(x, y, cumulative) + prob = SampledIntegralProblem(y, x) + alg = TrapezoidalRule() - cache.x = 0.0:0.2:2.0 - cache.y = sin.(cache.x) - sol3 = solve!(cache) + cache = init(prob, alg; cumulative) + sol1 = solve!(cache) + @test sol1 == solve(prob, alg; cumulative) - @test sol3 == solve(SampledIntegralProblem(cache.y, cache.x), alg) + cache.y = cos.(x) # use .= to update in-place + sol2 = solve!(cache) + @test sol2 == solve(SampledIntegralProblem(cache.y, cache.x), alg; cumulative) - x = 0.0:0.1:1.0 - y = sin.(x) .* cos.(x') - - prob = SampledIntegralProblem(y, x) - alg = TrapezoidalRule() + cache.x = 0.0:0.2:2.0 + cache.y = sin.(cache.x) + sol3 = solve!(cache) + @test sol3 == solve(SampledIntegralProblem(cache.y, cache.x), alg; cumulative) - cache = init(prob, alg) - sol1 = solve!(cache) + x = 0.0:0.1:1.0 + y = sin.(x) .* cos.(x') - @test sol1 == solve(prob, alg) + prob = SampledIntegralProblem(y, x) + alg = TrapezoidalRule() - cache.dim = 1 - sol2 = solve!(cache) + cache = init(prob, alg; cumulative) + sol1 = solve!(cache) + @test sol1 == solve(prob, alg; cumulative) - @test sol2 == solve(SampledIntegralProblem(y, x, dim = 1), alg) + cache.dim = 1 + sol2 = solve!(cache) + @test sol2 == solve(SampledIntegralProblem(y, x, dim = 1), alg; cumulative) + end + @testset "Total Integral" begin + test_interface(x, y, false) + end + @testset "Cumulative Integral" begin + test_interface(x, y, true) + end end From 86866989022c9af7739f2e3579d894d435acc378 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Mon, 25 Sep 2023 09:30:36 +0000 Subject: [PATCH 2/4] feat: add `cumulative` kwarg which returns cumulative integral for solving `SampledIntegralProblem` --- src/common.jl | 13 ++++++++----- src/sampled.jl | 28 +++++++++++++++++++++++++--- 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/src/common.jl b/src/common.jl index 01e6ca6a..dab0d383 100644 --- a/src/common.jl +++ b/src/common.jl @@ -102,6 +102,7 @@ mutable struct SampledIntegralCache{Y, X, D, PK, A, K, Tc} prob_kwargs::PK alg::A kwargs::K + cumulative::Bool isfresh::Bool # state of whether weights have been calculated cacheval::Tc # store alg weights here end @@ -114,10 +115,10 @@ function Base.setproperty!(cache::SampledIntegralCache, name::Symbol, x) end function SciMLBase.init(prob::SampledIntegralProblem, - alg::SciMLBase.AbstractIntegralAlgorithm; + alg::SciMLBase.AbstractIntegralAlgorithm; cumulative = false, kwargs...) NamedTuple(kwargs) == NamedTuple() || - throw(ArgumentError("There are no keyword arguments allowed to `solve`")) + throw(ArgumentError("There are no keyword arguments allowed to `solve` except `cumulative`")) cacheval = init_cacheval(alg, prob) isfresh = true @@ -128,6 +129,7 @@ function SciMLBase.init(prob::SampledIntegralProblem, prob.kwargs, alg, kwargs, + cumulative, isfresh, cacheval) end @@ -139,12 +141,13 @@ solve(prob::SampledIntegralProblem, alg::SciMLBase.AbstractIntegralAlgorithm; kw ## Keyword Arguments -There are no keyword arguments used to solve `SampledIntegralProblem`s +- cumulative: Boolean value to indicate if it should return cumulative integral, i.e., a vector of integral values where at every sampled point, it corresponds to the integral from the beginning to that point. +Default value is `false`. """ function SciMLBase.solve(prob::SampledIntegralProblem, - alg::SciMLBase.AbstractIntegralAlgorithm; + alg::SciMLBase.AbstractIntegralAlgorithm; cumulative = false, kwargs...) - solve!(init(prob, alg; kwargs...)) + solve!(init(prob, alg; cumulative, kwargs...)) end function SciMLBase.solve!(cache::SampledIntegralCache) diff --git a/src/sampled.jl b/src/sampled.jl index 947ad962..21f22ecf 100644 --- a/src/sampled.jl +++ b/src/sampled.jl @@ -25,27 +25,49 @@ _eachslice(data::AbstractArray{T, 1}; dims = ndims(data)) where {T} = data dimension(::Val{D}) where {D} = D dimension(D::Int) = D -function evalrule(data::AbstractArray, weights, dim) +function update_outs!(outs, out, idx) + if typeof(outs) <: AbstractVector + outs[idx] = out + elseif typeof(outs) <: AbstractMatrix + @views outs[:, idx] = out + end +end + +function evalrule(data::AbstractArray, weights, dim, cumulative) fw = zip(_eachslice(data, dims = dim), weights) next = iterate(fw) next === nothing && throw(ArgumentError("No points to integrate")) (f1, w1), state = next out = w1 * f1 + cumulative && begin + outs = zeros(eltype(out), size(data)) + idx = 1 + update_outs!(outs, out, idx) + idx += 1 + end next = iterate(fw, state) if isbits(out) while next !== nothing (fi, wi), state = next out += wi * fi + cumulative && begin + update_outs!(outs, out, idx) + idx += 1 + end next = iterate(fw, state) end else while next !== nothing (fi, wi), state = next out .+= wi .* fi + cumulative && begin + update_outs!(outs, out, idx) + idx += 1 + end next = iterate(fw, state) end end - return out + return cumulative ? outs : out end # can be reused for other sampled rules, which should implement find_weights(x, alg) @@ -67,7 +89,7 @@ function __solvebp_call(cache::SampledIntegralCache, cache.isfresh = false end weights = cache.cacheval - I = evalrule(data, weights, dim) + I = evalrule(data, weights, dim, cache.cumulative; kwargs...) prob = build_problem(cache) return SciMLBase.build_solution(prob, alg, I, err, retcode = ReturnCode.Success) end From 46a11e10c8a29578d00d27df01202700ffb1d117 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Sun, 1 Oct 2023 18:43:28 +0000 Subject: [PATCH 3/4] refactor: use `Val` for `cumulative` to make it type stable --- src/common.jl | 8 ++++---- src/sampled.jl | 10 +++++----- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/common.jl b/src/common.jl index dab0d383..63920a03 100644 --- a/src/common.jl +++ b/src/common.jl @@ -95,14 +95,14 @@ function __solvebp_call(cache::IntegralCache, args...; kwargs...) __solvebp_call(build_problem(cache), args...; kwargs...) end -mutable struct SampledIntegralCache{Y, X, D, PK, A, K, Tc} +mutable struct SampledIntegralCache{Y, X, D, PK, A, K, C, Tc} y::Y x::X dim::D prob_kwargs::PK alg::A kwargs::K - cumulative::Bool + cumulative::C isfresh::Bool # state of whether weights have been calculated cacheval::Tc # store alg weights here end @@ -115,7 +115,7 @@ function Base.setproperty!(cache::SampledIntegralCache, name::Symbol, x) end function SciMLBase.init(prob::SampledIntegralProblem, - alg::SciMLBase.AbstractIntegralAlgorithm; cumulative = false, + alg::SciMLBase.AbstractIntegralAlgorithm; cumulative = Val(false), kwargs...) NamedTuple(kwargs) == NamedTuple() || throw(ArgumentError("There are no keyword arguments allowed to `solve` except `cumulative`")) @@ -145,7 +145,7 @@ solve(prob::SampledIntegralProblem, alg::SciMLBase.AbstractIntegralAlgorithm; kw Default value is `false`. """ function SciMLBase.solve(prob::SampledIntegralProblem, - alg::SciMLBase.AbstractIntegralAlgorithm; cumulative = false, + alg::SciMLBase.AbstractIntegralAlgorithm; cumulative = Val(false), kwargs...) solve!(init(prob, alg; cumulative, kwargs...)) end diff --git a/src/sampled.jl b/src/sampled.jl index 21f22ecf..81c6d8f4 100644 --- a/src/sampled.jl +++ b/src/sampled.jl @@ -33,13 +33,13 @@ function update_outs!(outs, out, idx) end end -function evalrule(data::AbstractArray, weights, dim, cumulative) +function evalrule(data::AbstractArray, weights, dim, cumulative::Val{C}) where C fw = zip(_eachslice(data, dims = dim), weights) next = iterate(fw) next === nothing && throw(ArgumentError("No points to integrate")) (f1, w1), state = next out = w1 * f1 - cumulative && begin + C && begin outs = zeros(eltype(out), size(data)) idx = 1 update_outs!(outs, out, idx) @@ -50,7 +50,7 @@ function evalrule(data::AbstractArray, weights, dim, cumulative) while next !== nothing (fi, wi), state = next out += wi * fi - cumulative && begin + C && begin update_outs!(outs, out, idx) idx += 1 end @@ -60,14 +60,14 @@ function evalrule(data::AbstractArray, weights, dim, cumulative) while next !== nothing (fi, wi), state = next out .+= wi .* fi - cumulative && begin + C && begin update_outs!(outs, out, idx) idx += 1 end next = iterate(fw, state) end end - return cumulative ? outs : out + return C ? outs : out end # can be reused for other sampled rules, which should implement find_weights(x, alg) From 414833779ce7caa8c89387bc01bc6fcb478a7439 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Sun, 1 Oct 2023 18:44:14 +0000 Subject: [PATCH 4/4] test: update tests to use `Val` for `cumulative` --- test/sampled_tests.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/sampled_tests.jl b/test/sampled_tests.jl index 4db1a675..a9ca8cc2 100644 --- a/test/sampled_tests.jl +++ b/test/sampled_tests.jl @@ -22,7 +22,7 @@ using Integrals, Test y = f.(grid) prob = SampledIntegralProblem(y, grid) error = solve(prob, method()).u .- exact - error_cum = solve(prob, method(); cumulative = true).u .- exact_cum[i] + error_cum = solve(prob, method(); cumulative = Val(true)).u .- exact_cum[i] @test error < 10^-4 @test all(error_cum .< 10^-2) @@ -30,7 +30,7 @@ using Integrals, Test y = f.([grid grid]') prob = SampledIntegralProblem(y, grid; dim = 2) error = solve(prob, method()).u .- exact - error_cum = solve(prob, method(); cumulative = true).u .- + error_cum = solve(prob, method(); cumulative = Val(true)).u .- [exact_cum[i] exact_cum[i]]' @test all(error .< 10^-4) @test all(error_cum .< 10^-2) @@ -75,9 +75,9 @@ end @test sol2 == solve(SampledIntegralProblem(y, x, dim = 1), alg; cumulative) end @testset "Total Integral" begin - test_interface(x, y, false) + test_interface(x, y, Val(false)) end @testset "Cumulative Integral" begin - test_interface(x, y, true) + test_interface(x, y, Val(true)) end end