Skip to content

Add cumulative kwarg for solving SampledIntegralProblem #182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,14 @@ function __solvebp_call(cache::IntegralCache, args...; kwargs...)
__solvebp_call(build_problem(cache), args...; kwargs...)
end

mutable struct SampledIntegralCache{Y, X, D, PK, A, K, Tc}
mutable struct SampledIntegralCache{Y, X, D, PK, A, K, C, Tc}
y::Y
x::X
dim::D
prob_kwargs::PK
alg::A
kwargs::K
cumulative::C
isfresh::Bool # state of whether weights have been calculated
cacheval::Tc # store alg weights here
end
Expand All @@ -114,10 +115,10 @@ function Base.setproperty!(cache::SampledIntegralCache, name::Symbol, x)
end

function SciMLBase.init(prob::SampledIntegralProblem,
alg::SciMLBase.AbstractIntegralAlgorithm;
alg::SciMLBase.AbstractIntegralAlgorithm; cumulative = Val(false),
kwargs...)
NamedTuple(kwargs) == NamedTuple() ||
throw(ArgumentError("There are no keyword arguments allowed to `solve`"))
throw(ArgumentError("There are no keyword arguments allowed to `solve` except `cumulative`"))

cacheval = init_cacheval(alg, prob)
isfresh = true
Expand All @@ -128,6 +129,7 @@ function SciMLBase.init(prob::SampledIntegralProblem,
prob.kwargs,
alg,
kwargs,
cumulative,
isfresh,
cacheval)
end
Expand All @@ -139,12 +141,13 @@ solve(prob::SampledIntegralProblem, alg::SciMLBase.AbstractIntegralAlgorithm; kw

## Keyword Arguments

There are no keyword arguments used to solve `SampledIntegralProblem`s
- cumulative: Boolean value to indicate if it should return cumulative integral, i.e., a vector of integral values where at every sampled point, it corresponds to the integral from the beginning to that point.
Default value is `false`.
"""
function SciMLBase.solve(prob::SampledIntegralProblem,
alg::SciMLBase.AbstractIntegralAlgorithm;
alg::SciMLBase.AbstractIntegralAlgorithm; cumulative = Val(false),
kwargs...)
solve!(init(prob, alg; kwargs...))
solve!(init(prob, alg; cumulative, kwargs...))
end

function SciMLBase.solve!(cache::SampledIntegralCache)
Expand Down
28 changes: 25 additions & 3 deletions src/sampled.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,49 @@ _eachslice(data::AbstractArray{T, 1}; dims = ndims(data)) where {T} = data
dimension(::Val{D}) where {D} = D
dimension(D::Int) = D

function evalrule(data::AbstractArray, weights, dim)
function update_outs!(outs, out, idx)
if typeof(outs) <: AbstractVector
outs[idx] = out
elseif typeof(outs) <: AbstractMatrix
@views outs[:, idx] = out
end
end

function evalrule(data::AbstractArray, weights, dim, cumulative::Val{C}) where C
fw = zip(_eachslice(data, dims = dim), weights)
next = iterate(fw)
next === nothing && throw(ArgumentError("No points to integrate"))
(f1, w1), state = next
out = w1 * f1
C && begin
outs = zeros(eltype(out), size(data))
idx = 1
update_outs!(outs, out, idx)
idx += 1
end
next = iterate(fw, state)
if isbits(out)
while next !== nothing
(fi, wi), state = next
out += wi * fi
C && begin
update_outs!(outs, out, idx)
idx += 1
end
next = iterate(fw, state)
end
else
while next !== nothing
(fi, wi), state = next
out .+= wi .* fi
C && begin
update_outs!(outs, out, idx)
idx += 1
end
next = iterate(fw, state)
end
end
return out
return C ? outs : out
end

# can be reused for other sampled rules, which should implement find_weights(x, alg)
Expand All @@ -67,7 +89,7 @@ function __solvebp_call(cache::SampledIntegralCache,
cache.isfresh = false
end
weights = cache.cacheval
I = evalrule(data, weights, dim)
I = evalrule(data, weights, dim, cache.cumulative; kwargs...)
prob = build_problem(cache)
return SciMLBase.build_solution(prob, alg, I, err, retcode = ReturnCode.Success)
end
71 changes: 42 additions & 29 deletions test/sampled_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,31 @@ using Integrals, Test
grid2 = [lb; sort(grid2); ub]

exact_sols = [1 / 6 * (ub^6 - lb^6), sin(ub) - sin(lb)]
exact_sols_cumulative = [[
[1 / 6 * (x^6 - lb^6) for x in grid],
[sin(x) - sin(lb) for x in grid],
] for grid in [grid1, grid2]]
for method in [TrapezoidalRule] # Simpson's later
for grid in [grid1, grid2]
for (j, grid) in enumerate([grid1, grid2])
for (i, f) in enumerate([x -> x^5, x -> cos(x)])
exact = exact_sols[i]
exact_cum = exact_sols_cumulative[j]
# single dimensional y
y = f.(grid)
prob = SampledIntegralProblem(y, grid)
error = solve(prob, method()).u .- exact
@test all(error .< 10^-4)
error_cum = solve(prob, method(); cumulative = Val(true)).u .- exact_cum[i]
@test error < 10^-4
@test all(error_cum .< 10^-2)

# along dim=2
y = f.([grid grid]')
prob = SampledIntegralProblem(y, grid; dim = 2)
error = solve(prob, method()).u .- exact
error_cum = solve(prob, method(); cumulative = Val(true)).u .-
[exact_cum[i] exact_cum[i]]'
@test all(error .< 10^-4)
@test all(error_cum .< 10^-2)
end
end
end
Expand All @@ -33,38 +43,41 @@ end
x = 0.0:0.1:1.0
y = sin.(x)

prob = SampledIntegralProblem(y, x)
alg = TrapezoidalRule()

cache = init(prob, alg)
sol1 = solve!(cache)

@test sol1 == solve(prob, alg)

cache.y = cos.(x) # use .= to update in-place
sol2 = solve!(cache)

@test sol2 == solve(SampledIntegralProblem(cache.y, cache.x), alg)
function test_interface(x, y, cumulative)
prob = SampledIntegralProblem(y, x)
alg = TrapezoidalRule()

cache.x = 0.0:0.2:2.0
cache.y = sin.(cache.x)
sol3 = solve!(cache)
cache = init(prob, alg; cumulative)
sol1 = solve!(cache)
@test sol1 == solve(prob, alg; cumulative)

@test sol3 == solve(SampledIntegralProblem(cache.y, cache.x), alg)
cache.y = cos.(x) # use .= to update in-place
sol2 = solve!(cache)
@test sol2 == solve(SampledIntegralProblem(cache.y, cache.x), alg; cumulative)

x = 0.0:0.1:1.0
y = sin.(x) .* cos.(x')

prob = SampledIntegralProblem(y, x)
alg = TrapezoidalRule()
cache.x = 0.0:0.2:2.0
cache.y = sin.(cache.x)
sol3 = solve!(cache)
@test sol3 == solve(SampledIntegralProblem(cache.y, cache.x), alg; cumulative)

cache = init(prob, alg)
sol1 = solve!(cache)
x = 0.0:0.1:1.0
y = sin.(x) .* cos.(x')

@test sol1 == solve(prob, alg)
prob = SampledIntegralProblem(y, x)
alg = TrapezoidalRule()

cache.dim = 1
sol2 = solve!(cache)
cache = init(prob, alg; cumulative)
sol1 = solve!(cache)
@test sol1 == solve(prob, alg; cumulative)

@test sol2 == solve(SampledIntegralProblem(y, x, dim = 1), alg)
cache.dim = 1
sol2 = solve!(cache)
@test sol2 == solve(SampledIntegralProblem(y, x, dim = 1), alg; cumulative)
end
@testset "Total Integral" begin
test_interface(x, y, Val(false))
end
@testset "Cumulative Integral" begin
test_interface(x, y, Val(true))
end
end