Skip to content

Commit ac0caec

Browse files
Merge pull request #148 from Vaibhavdixit02/master
Make sample_u0 and save_idxs be compatible
2 parents 7b73e87 + d11a10c commit ac0caec

9 files changed

+91
-16
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DiffEqBayes"
22
uuid = "ebbdde9d-f333-5424-9be2-dbf1e9acfb5e"
33
authors = ["Vaibhavdixit02 <vaibhavyashdixit@gmail.com>"]
4-
version = "2.9.1"
4+
version = "2.10.0"
55

66
[deps]
77
ApproxBayes = "f5f396d3-230c-5e07-80e6-9fadf06146cc"

src/abc_inference.jl

+15-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,17 @@
1-
function createabcfunction(prob, t, distancefunction, alg; save_idxs = nothing, kwargs...)
1+
function createabcfunction(prob, t, distancefunction, alg; save_idxs = nothing, sample_u0 = false, kwargs...)
22
function simfunc(params, constants, data)
3-
sol = concrete_solve(STANDARD_PROB_GENERATOR(prob, params), alg; saveat = t, save_idxs = save_idxs, kwargs...)
3+
local u0
4+
if sample_u0
5+
u0 = save_idxs === nothing ? params[1:length(prob.u0)] : params[1:length(save_idxs)]
6+
if length(u0) < length(prob.u0)
7+
for i in length(u0):length(prob.u0)
8+
push!(u0,prob.u0[i])
9+
end
10+
end
11+
else
12+
u0 = prob.u0
13+
end
14+
sol = concrete_solve(STANDARD_PROB_GENERATOR(prob, params), alg, u0; saveat = t, save_idxs = save_idxs, kwargs...)
415
if size(sol, 2) < length(t)
516
return Inf,nothing
617
else
@@ -12,9 +23,9 @@ end
1223

1324
function abc_inference(prob::DiffEqBase.DEProblem, alg, t, data, priors; ϵ=0.001,
1425
distancefunction = euclidean, ABCalgorithm = ABCSMC, progress = false,
15-
num_samples = 500, maxiterations = 10^5, save_idxs = nothing, kwargs...)
26+
num_samples = 500, maxiterations = 10^5, save_idxs = nothing, sample_u0 = false, kwargs...)
1627

17-
abcsetup = ABCalgorithm(createabcfunction(prob, t, distancefunction, alg; save_idxs = save_idxs, kwargs...),
28+
abcsetup = ABCalgorithm(createabcfunction(prob, t, distancefunction, alg; save_idxs = save_idxs, sample_u0 = sample_u0, kwargs...),
1829
length(priors),
1930
ϵ,
2031
ApproxBayes.Prior(priors);

src/dynamichmc_inference.jl

+7-1
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,15 @@ function (P::DynamicHMCPosterior)(θ)
3636
@unpack algorithm, problem, data, t, parameter_priors = P
3737
@unpack σ_priors, solve_kwargs, sample_u0, save_idxs = P
3838
T = eltype(parameters)
39-
nu = length(problem.u0)
39+
nu = save_idxs == nothing ? length(problem.u0) : length(save_idxs)
4040
u0 = convert.(T, sample_u0 ? parameters[1:nu] : problem.u0)
4141
p = convert.(T, sample_u0 ? parameters[(nu + 1):end] : parameters)
42+
if length(u0) < length(problem.u0)
43+
# assumes u is ordered such that the observed variables are in the begining, consistent with ordered theta
44+
for i in length(u0):length(problem.u0)
45+
push!(u0, convert(T,problem.u0[i]))
46+
end
47+
end
4248
_saveat = t === nothing ? Float64[] : t
4349
sol = concrete_solve(problem, algorithm, u0, p; saveat = _saveat, save_idxs = save_idxs, solve_kwargs...)
4450
failure = size(sol, 2) < length(_saveat)

src/stan_inference.jl

+13-4
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@ function stan_inference(prob::DiffEqBase.DEProblem,t,data,priors = nothing;alg=:
5757
length_of_params = length(vars)
5858
if isnothing(diffeq_string)
5959
sys = first(ModelingToolkit.modelingtoolkitize(prob))
60-
length_of_parameter = length(sys.ps)
60+
length_of_parameter = length(sys.ps) + sample_u0 * length(save_idxs)
6161
else
62-
length_of_parameter = length(prob.p) + sample_u0 * length(prob.u0)
62+
length_of_parameter = length(prob.p) + sample_u0 * length(save_idxs)
6363
end
6464
if alg ==:rk45
6565
algorithm = "integrate_ode_rk45"
@@ -90,8 +90,17 @@ function stan_inference(prob::DiffEqBase.DEProblem,t,data,priors = nothing;alg=:
9090
priors_string = string(generate_priors(length_of_parameter,priors))
9191
stan_likelihood = stan_string(likelihood)
9292
if sample_u0
93-
nu = length(prob.u0)
94-
integral_string = "u_hat = $algorithm(sho, theta[1:$nu], t0, ts, theta[$(nu+1):$length_of_parameter], x_r, x_i, $reltol, $abstol, $maxiter);"
93+
nu = length(save_idxs)
94+
if nu < length(prob.u0)
95+
u0 = "{"
96+
for u_ in prob.u0[nu+1:length(prob.u0)]
97+
u0 = u0*string(u_)
98+
end
99+
u0 = u0*"}"
100+
integral_string = "u_hat = $algorithm(sho, append_array(theta[1:$nu],$u0), t0, ts, theta[$(nu+1):$length_of_parameter], x_r, x_i, $reltol, $abstol, $maxiter);"
101+
else
102+
integral_string = "u_hat = $algorithm(sho, theta[1:$nu], t0, ts, theta[$(nu+1):$length_of_parameter], x_r, x_i, $reltol, $abstol, $maxiter);"
103+
end
95104
else
96105
integral_string = "u_hat = $algorithm(sho, u0, t0, ts, theta, x_r, x_i, $reltol, $abstol, $maxiter);"
97106
end

src/turing_inference.jl

+8-2
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,15 @@ function turing_inference(
3030
for i in 1:length(likelihood_dist_priors)
3131
σ[i] ~ likelihood_dist_priors[i]
3232
end
33-
nu = length(prob.u0)
33+
nu = save_idxs === nothing ? length(prob.u0) : length(save_idxs)
3434
u0 = convert.(T, sample_u0 ? theta[1:nu] : prob.u0)
3535
p = convert.(T, sample_u0 ? theta[(nu + 1):end] : theta)
36+
if length(u0) < length(prob.u0)
37+
# assumes u is ordered such that the observed variables are in the begining, consistent with ordered theta
38+
for i in length(u0):length(prob.u0)
39+
push!(u0, convert(T,prob.u0[i]))
40+
end
41+
end
3642
_saveat = isnothing(t) ? Float64[] : t
3743
sol = concrete_solve(prob, alg, u0, p; saveat = _saveat, progress = progress, save_idxs = save_idxs, kwargs...)
3844
failure = size(sol, 2) < length(_saveat)
@@ -55,4 +61,4 @@ function turing_inference(
5561
model = mf(data)
5662
chn = sample(model, sampler, num_samples; progress = progress)
5763
return chn
58-
end
64+
end

test/abc.jl

+21-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using DiffEqBayes, OrdinaryDiffEq, ParameterizedFunctions, Distances, StatsBase, RecursiveArrayTools
2-
using Test
2+
using Test, Distributions
33

44
# One parameter case
55
f1 = @ode_def begin
@@ -20,15 +20,33 @@ bayesian_result = abc_inference(prob1,Tsit5(),t,data,priors;
2020

2121
@test mean(bayesian_result.parameters, weights(bayesian_result.weights)) 1.5 atol=0.1
2222

23+
priors = [Normal(1.,0.01),Normal(1.,0.01),Normal(1.5,0.01)]
24+
bayesian_result = abc_inference(prob1,Tsit5(),t,data,priors;
25+
num_samples=500= 0.001,sample_u0=true)
26+
27+
meanvals = mean(bayesian_result.parameters, weights(bayesian_result.weights), 1)
28+
@test meanvals[1] 1. atol=0.1
29+
@test meanvals[2] 1. atol=0.1
30+
@test meanvals[3] 1.5 atol=0.1
31+
2332
sol = solve(prob1,Tsit5(),save_idxs=[1])
2433
randomized = VectorOfArray([(sol(t[i]) + .01randn(1)) for i in 1:length(t)])
2534
data = convert(Array,randomized)
26-
35+
priors = [Normal(1.5,0.01)]
2736
bayesian_result = abc_inference(prob1,Tsit5(),t,data,priors;
2837
num_samples=500= 0.001,save_idxs=[1])
2938

3039
@test mean(bayesian_result.parameters, weights(bayesian_result.weights)) 1.5 atol=0.1
3140

41+
priors = [Normal(1.,0.01),Normal(1.5,0.01)]
42+
bayesian_result = abc_inference(prob1,Tsit5(),t,data,priors;
43+
num_samples=500= 0.001,sample_u0=true,save_idxs=[1])
44+
45+
meanvals = mean(bayesian_result.parameters, weights(bayesian_result.weights), 1)
46+
@test meanvals[1] 1. atol=0.1
47+
@test meanvals[2] 1.5 atol=0.1
48+
49+
3250
# custom distance-function
3351
weights_ = ones(size(data)) # weighted data
3452
for i = 1:3:length(data)
@@ -42,6 +60,7 @@ distfn = function (d1, d2)
4260
end
4361
return sqrt(d)
4462
end
63+
priors = [Normal(1.5,0.01)]
4564
bayesian_result = abc_inference(prob1,Tsit5(),t,data,priors;
4665
num_samples=500, ϵ = 0.001,
4766
distancefunction = distfn)

test/dynamicHMC.jl

+8
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,14 @@ bayesian_result = dynamichmc_inference(prob1, Tsit5(), t, data, (Normal(1.5, 1),
4242
as(Vector, asℝ₊, 1),mcmc_kwargs=mcmc_kwargs, save_idxs = [1])
4343
@test mean(p.parameters[1] for p in bayesian_result.posterior) p[1] atol = 0.1
4444

45+
priors = [Normal(1.,0.001),Normal(1.5,0.001)]
46+
mcmc_kwargs = (initialization = (q = zeros(2 + 1),), reporter = reporter)
47+
bayesian_result = dynamichmc_inference(prob1, Tsit5(), t, data, priors,
48+
as(Vector, asℝ₊, 2),mcmc_kwargs=mcmc_kwargs,save_idxs=[1],sample_u0=true)
49+
50+
@test mean(p.parameters[1] for p in bayesian_result.posterior) 1. atol = 0.1
51+
@test mean(p.parameters[2] for p in bayesian_result.posterior) 1.5 atol = 0.1
52+
4553
# With hand-code likelihood function
4654
weights_ = ones(size(data)) # weighted data
4755
for i = 1:3:length(data)

test/stan.jl

+10-1
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,22 @@ sdf = CmdStan.read_summary(bayesian_result.model)
3535
sol = solve(prob1,Tsit5(),save_idxs=[1])
3636
randomized = VectorOfArray([(sol(t[i]) + .01 * randn(1)) for i in 1:length(t)])
3737
data = convert(Array,randomized)
38-
38+
priors = [Truncated(Normal(1.5,0.1),0,2)]
3939
bayesian_result = stan_inference(prob1,t,data,priors;num_samples=300,
4040
num_warmup=500,likelihood=Normal,save_idxs=[1])
4141

4242
sdf = CmdStan.read_summary(bayesian_result.model)
4343
@test sdf[sdf.parameters .== :theta1, :mean][1] 1.5 atol=3e-1
4444

45+
46+
priors = [Normal(1.,0.01),Normal(1.5,0.01)]
47+
bayesian_result = stan_inference(prob1,t,data,priors;num_samples=300,
48+
num_warmup=500,likelihood=Normal,save_idxs=[1],sample_u0=true)
49+
50+
sdf = CmdStan.read_summary(bayesian_result.model)
51+
@test sdf[sdf.parameters .== :theta1, :mean][1] 1. atol=3e-1
52+
@test sdf[sdf.parameters .== :theta2, :mean][1] 1.5 atol=3e-1
53+
4554
println("Four parameter case")
4655
f1 = @ode_def begin
4756
dx = a*x - b*x*y

test/turing.jl

+8-1
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,19 @@ bayesian_result = turing_inference(prob1,Tsit5(),t,data,priors;num_samples=500,s
3232
sol = solve(prob1,Tsit5(),save_idxs=[1])
3333
randomized = VectorOfArray([(sol(t[i]) + .01 * randn(1)) for i in 1:length(t)])
3434
data = convert(Array,randomized)
35-
35+
priors = [Normal(1.5,0.01)]
3636
bayesian_result = turing_inference(prob1,Tsit5(),t,data,priors;num_samples=500,
3737
syms=[:a],save_idxs=[1])
3838

3939
@test mean(get(bayesian_result,:a)[1]) 1.5 atol=3e-1
4040

41+
priors = [Normal(1.,0.01),Normal(1.5,0.01)]
42+
bayesian_result = turing_inference(prob1,Tsit5(),t,data,priors;num_samples=500,sample_u0 =true,
43+
syms=[:u1,:a],save_idxs=[1])
44+
45+
@test mean(get(bayesian_result,:a)[1]) 1.5 atol=3e-1
46+
@test mean(get(bayesian_result,:u1)[1]) 1.0 atol=3e-1
47+
4148
println("Four parameter case")
4249
f2 = @ode_def begin
4350
dx = a*x - b*x*y

0 commit comments

Comments
 (0)