Skip to content

Commit 214e530

Browse files
committed
reorganize AbstractHybridCase
1 parent 015e761 commit 214e530

File tree

6 files changed

+169
-124
lines changed

6 files changed

+169
-124
lines changed

dev/doubleMM.jl

Lines changed: 36 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -24,100 +24,80 @@ using OptimizationOptimisers
2424

2525
using UnicodePlots
2626

27-
const EX = HybridVariationalInference.DoubleMM
2827
const case = DoubleMM.DoubleMMCase()
2928
const MLengine = Val(nameof(SimpleChains))
30-
scenario=(:default,)
31-
29+
scenario = (:default,)
3230
rng = StableRNG(111)
3331

34-
(; n_covar_pc, n_covar, n_site, n_batch, n_θM, n_θP) = get_case_sizes(case; scenario)
32+
par_templates = get_hybridcase_par_templates(case; scenario)
33+
34+
(; n_covar, n_site, n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario)
3535

36-
# const int_θP = ComponentArrayInterpreter(EX.θP)
37-
# const int_θM = ComponentArrayInterpreter(EX.θM)
36+
# const int_θP = ComponentArrayInterpreter(par_templates.θP)
37+
# const int_θM = ComponentArrayInterpreter(par_templates.θM)
3838
# const int_θPMs_flat = ComponentArrayInterpreter(P = n_θP, Ms = n_θM * n_batch)
39-
# const int_θ = ComponentArrayInterpreter(CA.ComponentVector(;θP=EX.θP,θM=EX.θM))
39+
# const int_θ = ComponentArrayInterpreter(CA.ComponentVector(;θP=par_templates.θP,θM=par_templates.θM))
4040
# # moved to f_doubleMM
4141
# # const int_θdoubleMM = ComponentArrayInterpreter(flatten1(CA.ComponentVector(;θP,θM)))
4242
# # const S1 = [1.0, 1.0, 1.0, 0.3, 0.1]
4343
# # const S2 = [1.0, 3.0, 5.0, 5.0, 5.0]
44-
# θ = CA.getdata(vcat(EX.θP, EX.θM))
45-
46-
# const int_θPMs = ComponentArrayInterpreter(CA.ComponentVector(;EX.θP,
47-
# θMs=CA.ComponentMatrix(zeros(n_θM, n_batch), first(CA.getaxes(EX.θM)), CA.Axis(i=1:n_batch))))
44+
# θ = CA.getdata(vcat(par_templates.θP, par_templates.θM))
4845

49-
# moved to f_doubleMM
50-
# gen_q(InteractionsCovCor)
51-
x_o, θMs_true0 = gen_cov_pred(case, rng; scenario)
52-
# normalize to be distributed around the prescribed true values
53-
int_θMs_sites = ComponentArrayInterpreter(EX.θM, (n_site,))
54-
int_θMs_batch = ComponentArrayInterpreter(EX.θM, (n_batch,))
55-
θMs_true = int_θMs_sites(scale_centered_at(θMs_true0, EX.θM, 0.1));
46+
# const int_θPMs = ComponentArrayInterpreter(CA.ComponentVector(;par_templates.θP,
47+
# θMs=CA.ComponentMatrix(zeros(n_θM, n_batch), first(CA.getaxes(par_templates.θM)), CA.Axis(i=1:n_batch))))
5648

57-
@test isapprox(vec(mean(CA.getdata(θMs_true); dims=2)), CA.getdata(EX.θM), rtol=0.02)
58-
@test isapprox(vec(std(CA.getdata(θMs_true); dims=2)), CA.getdata(EX.θM) .* 0.1, rtol=0.02)
49+
(; xM, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o) = gen_hybridcase_synthetic(
50+
case, rng; scenario);
5951

52+
@test isapprox(
53+
vec(mean(CA.getdata(θMs_true); dims = 2)), CA.getdata(par_templates.θM), rtol = 0.02)
54+
@test isapprox(vec(std(CA.getdata(θMs_true); dims = 2)),
55+
CA.getdata(par_templates.θM) .* 0.1, rtol = 0.02)
6056

6157
#----- fit g to θMs_true
62-
g, ϕg0 = gen_g(case, MLengine; scenario)
63-
n_ϕg = length(ϕg0)
58+
g, ϕg0 = gen_hybridcase_MLapplicator(case, MLengine; scenario);
6459

6560
function loss_g(ϕg, x, g)
6661
ζMs = g(x, ϕg) # predict the log of the parameters
67-
θMs = exp.(ζMs)
62+
θMs = exp.(ζMs)
6863
loss = sum(abs2, θMs .- θMs_true)
6964
return loss, θMs
7065
end
71-
loss_g(ϕg0, x_o, g)
72-
Zygote.gradient(x-> loss_g(x, x_o, g)[1], ϕg0);
66+
loss_g(ϕg0, xM, g)
67+
Zygote.gradient(x -> loss_g(x, xM, g)[1], ϕg0);
7368

74-
optf = Optimization.OptimizationFunction((ϕg, p) -> loss_g(ϕg,x_o, g)[1],
69+
optf = Optimization.OptimizationFunction((ϕg, p) -> loss_g(ϕg, xM, g)[1],
7570
Optimization.AutoZygote())
7671
optprob = Optimization.OptimizationProblem(optf, ϕg0);
77-
res = Optimization.solve(optprob, Adam(0.02), callback=callback_loss(100), maxiters=600);
72+
res = Optimization.solve(optprob, Adam(0.02), callback = callback_loss(100), maxiters = 600);
7873

7974
ϕg_opt1 = res.u;
80-
loss_g(ϕg_opt1, x_o, g)
81-
scatterplot(vec(θMs_true), vec(loss_g(ϕg_opt1, x_o, g)[2]))
82-
@test cor(vec(θMs_true), vec(loss_g(ϕg_opt1, x_o, g)[2])) > 0.9
83-
84-
#----------- fit g and θP to y_obs
85-
f = gen_f(case; scenario)
86-
y_true = f(EX.θP, θMs_true, zip())[2]
87-
88-
σ_o = 0.01
89-
#σ_o = 0.002
90-
y_o = y_true .+ reshape(randn(length(y_true)), size(y_true)...) .* σ_o
91-
scatterplot(vec(y_true), vec(y_o))
92-
scatterplot(vec(log.(y_true)), vec(log.(y_o)))
93-
94-
# fit g to log(θ_true) ~ x_o
75+
loss_g(ϕg_opt1, xM, g)
76+
scatterplot(vec(θMs_true), vec(loss_g(ϕg_opt1, xM, g)[2]))
77+
@test cor(vec(θMs_true), vec(loss_g(ϕg_opt1, xM, g)[2])) > 0.9
9578

96-
int_ϕθP = ComponentArrayInterpreter(CA.ComponentVector(ϕg=1:length(ϕg0), θP=EX.θP))
97-
p = p0 = vcat(ϕg0, EX.θP .* 0.9); # slightly disturb θP_true
98-
# #p = p0 = vcat(ϕg_opt1, θP_true .* 0.9); # slightly disturb θP_true
99-
# p0c = int_ϕθP(p0);
100-
# #gf(g,f_doubleMM, x_o, pc.ϕg, pc.θP)[1]
79+
#----------- fit g and θP to y_o
80+
f = gen_hybridcase_PBmodel(case; scenario)
10181

82+
int_ϕθP = ComponentArrayInterpreter(CA.ComponentVector(
83+
ϕg = 1:length(ϕg0), θP = par_templates.θP))
84+
p = p0 = vcat(ϕg0, par_templates.θP .* 0.9); # slightly disturb θP_true
10285

103-
# Pass the data for the batches as separate vectors wrapped in a tuple
104-
train_loader = MLUtils.DataLoader((
105-
x_o,
106-
fill((), n_site), # xP
107-
y_o
108-
), batchsize = n_batch)
86+
# Pass the site-data for the batches as separate vectors wrapped in a tuple
87+
train_loader = MLUtils.DataLoader((xM, xP, y_o), batchsize = n_batch)
10988

110-
loss_gf = get_loss_gf(g, f, Float32[], int_ϕθP)
89+
loss_gf = get_loss_gf(g, f, y_global_o, int_ϕθP)
11190
l1 = loss_gf(p0, train_loader.data...)[1]
11291

11392
optf = Optimization.OptimizationFunction((ϕ, data) -> loss_gf(ϕ, data...)[1],
11493
Optimization.AutoZygote())
11594
optprob = OptimizationProblem(optf, p0, train_loader)
11695

117-
res = Optimization.solve(optprob, Adam(0.02), callback=callback_loss(100), maxiters=1000);
96+
res = Optimization.solve(
97+
optprob, Adam(0.02), callback = callback_loss(100), maxiters = 1000);
11898

11999
l1, y_pred_global, y_pred, θMs = loss_gf(res.u, train_loader.data...)
120100
scatterplot(vec(θMs_true), vec(θMs))
121101
scatterplot(log.(vec(θMs_true)), log.(vec(θMs)))
122102
scatterplot(vec(y_pred), vec(y_o))
123-
hcat(EX.θP, int_ϕθP(res.u).θP)
103+
hcat(par_templates.θP, int_ϕθP(res.u).θP)

ext/HybridVariationalInferenceSimpleChainsExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ HVI.construct_SimpleChainsApplicator(m::SimpleChain) = SimpleChainsApplicator(m)
1212

1313
HVI.apply_model(app::SimpleChainsApplicator, x, ϕ) = app.m(x, ϕ)
1414

15-
function HVI.gen_g(case::HVI.DoubleMM.DoubleMMCase, ::Val{:SimpleChains};
15+
function HVI.gen_hybridcase_MLapplicator(case::HVI.DoubleMM.DoubleMMCase, ::Val{:SimpleChains};
1616
scenario::NTuple=())
17-
(;n_covar, n_θM) = get_case_sizes(case; scenario)
18-
FloatType = get_case_FloatType(case; scenario)
17+
(;n_covar, n_θM) = get_hybridcase_sizes(case; scenario)
18+
FloatType = get_hybridcase_FloatType(case; scenario)
1919
n_out = n_θM
2020
is_using_dropout = :use_dropout scenario
2121
g_chain = if is_using_dropout

src/DoubleMM/f_doubleMM.jl

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,33 +16,58 @@ function f_doubleMM(θ::AbstractVector)
1616
return (y)
1717
end
1818

19-
function HybridVariationalInference.gen_f(::DoubleMMCase; scenario::NTuple = ())
20-
fsite = (θ, x_site) -> f_doubleMM(θ) # omit x_site drivers
21-
function f_doubleMM_with_global(θP::AbstractVector, θMs::AbstractMatrix, x)
22-
pred_sites = applyf(fsite, θMs, θP, x)
23-
pred_global = eltype(pred_sites)[]
24-
return pred_global, pred_sites
25-
end
19+
function HybridVariationalInference.get_hybridcase_par_templates(::DoubleMMCase; scenario::NTuple = ())
20+
(; θP, θM)
2621
end
2722

28-
function HybridVariationalInference.get_case_sizes(::DoubleMMCase; scenario = ())
23+
function HybridVariationalInference.get_hybridcase_sizes(::DoubleMMCase; scenario = ())
2924
n_covar_pc = 2
3025
n_covar = n_covar_pc + 3 # linear dependent
3126
n_site = 10^n_covar_pc
3227
n_batch = 10
3328
n_θM = length(θM)
3429
n_θP = length(θP)
35-
(; n_covar_pc, n_covar, n_site, n_batch, n_θM, n_θP)
30+
(; n_covar, n_site, n_batch, n_θM, n_θP)
3631
end
3732

38-
function HybridVariationalInference.get_case_FloatType(::DoubleMMCase; scenario)
33+
function HybridVariationalInference.gen_hybridcase_PBmodel(::DoubleMMCase; scenario::NTuple = ())
34+
fsite = (θ, x_site) -> f_doubleMM(θ) # omit x_site drivers
35+
function f_doubleMM_with_global(θP::AbstractVector, θMs::AbstractMatrix, x)
36+
pred_sites = applyf(fsite, θMs, θP, x)
37+
pred_global = eltype(pred_sites)[]
38+
return pred_global, pred_sites
39+
end
40+
end
41+
42+
function HybridVariationalInference.get_hybridcase_FloatType(::DoubleMMCase; scenario)
3943
return Float32
4044
end
4145

42-
function HybridVariationalInference.gen_cov_pred(case::DoubleMMCase, rng::AbstractRNG;
46+
function HybridVariationalInference.gen_hybridcase_synthetic(case::DoubleMMCase, rng::AbstractRNG;
4347
scenario = ())
44-
(; n_covar_pc, n_covar, n_site, n_batch, n_θM, n_θP) = get_case_sizes(case; scenario)
45-
FloatType = get_case_FloatType(case; scenario)
46-
gen_cov_pred(rng, FloatType, n_covar_pc, n_covar, n_site, n_θM;
48+
n_covar_pc = 2
49+
(; n_covar, n_site, n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario)
50+
FloatType = get_hybridcase_FloatType(case; scenario)
51+
xM, θMs_true0 = gen_cov_pred(rng, FloatType, n_covar_pc, n_covar, n_site, n_θM;
4752
rhodec = 8, is_using_dropout = false)
53+
int_θMs_sites = ComponentArrayInterpreter(θM, (n_site,))
54+
# normalize to be distributed around the prescribed true values
55+
θMs_true = int_θMs_sites(scale_centered_at(θMs_true0, θM, 0.1))
56+
f = gen_hybridcase_PBmodel(case; scenario)
57+
xP = fill((), n_site)
58+
y_global_true, y_true = f(θP, θMs_true, zip())
59+
σ_o = 0.01
60+
#σ_o = 0.002
61+
y_global_o = y_global_true .+ randn(rng, size(y_global_true)) .* σ_o
62+
y_o = y_true .+ randn(rng, size(y_true)) .* σ_o
63+
(;
64+
xM,
65+
θP_true = θP,
66+
θMs_true,
67+
xP,
68+
y_global_true,
69+
y_true,
70+
y_global_o,
71+
y_o,
72+
)
4873
end

src/HybridVariationalInference.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module HybridVariationalInference
33
using ComponentArrays: ComponentArrays as CA
44
using Random
55
using StatsBase # fit ZScoreTransform
6-
using Combinatorics # gen_cov_pred/combinations
6+
using Combinatorics # gen_hybridcase_synthetic/combinations
77

88
export ComponentArrayInterpreter, flatten1
99
include("ComponentArrayInterpreter.jl")
@@ -12,7 +12,10 @@ export AbstractModelApplicator, construct_SimpleChainsApplicator, construct_Flux
1212
construct_LuxApplicator
1313
include("ModelApplicator.jl")
1414

15-
export AbstractHybridCase, gen_g, gen_f, get_case_sizes, get_case_FloatType, gen_cov_pred
15+
export AbstractHybridCase, gen_hybridcase_MLapplicator, gen_hybridcase_PBmodel, get_hybridcase_sizes, get_hybridcase_FloatType, gen_hybridcase_synthetic,
16+
get_hybridcase_par_templates, gen_cov_pred
17+
include("hybrid_case.jl")
18+
1619
export applyf, gf, get_loss_gf
1720
include("gf.jl")
1821

src/gf.jl

Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,3 @@
1-
"""
2-
Type to dispatch constructing data and network structures
3-
for different cases of hybrid problem setups
4-
"""
5-
abstract type AbstractHybridCase end;
6-
7-
function get_case_sizes end
8-
9-
"""
10-
Determine the FloatType for given Case and scenario, defaults to Float32
11-
"""
12-
function get_case_FloatType(::AbstractHybridCase; scenario)
13-
return Float32
14-
end
15-
16-
function gen_cov_pred end
17-
18-
"""
19-
gen_g(::AbstractHybridCase, MLEngine, n_covar, n_out; scenario::NTuple=())
20-
21-
Construct the machine learning model fro given problem case and ML-Framework and
22-
scenario.
23-
24-
The MLEngine is a value type of a Symbol, usually the name of the module, e.g.
25-
`const MLengine = Val(nameof(SimpleChains))`.
26-
27-
returns a Tuple of
28-
- AbstractModelApplicator
29-
- initial parameter vector
30-
"""
31-
function gen_g end
32-
33-
"""
34-
gen_f(::AbstractHybridCase; scenario::NTuple=())
35-
36-
Construct the process-based model function
37-
`f(θP::AbstractVector, θMs::AbstractMatrix, x) -> (AbstractVector, AbstractMatrix)`
38-
with
39-
- θP: calibrated parameters that are constant across site
40-
- θMs: calibrated parameters that vary across sites, with a column for each site
41-
- x: drivers, indexed by site
42-
43-
returns a tuple of predictions with components
44-
- first, those that are constant across sites
45-
- second, those that vary across sites, with a column for each site
46-
"""
47-
function gen_f end
48-
49-
501
function applyf(f, θMs::AbstractMatrix, θP::AbstractVector, x)
512
# predict several sites with same physical parameters
523
yv = map(eachcol(θMs), x) do θM, x_site

src/hybrid_case.jl

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""
2+
Type to dispatch constructing data and network structures
3+
for different cases of hybrid problem setups
4+
5+
For a specific case, provide functions that specify details
6+
- get_hybridcase_par_templates
7+
- get_hybridcase_sizes
8+
- gen_hybridcase_MLapplicator
9+
- gen_hybridcase_PBmodel
10+
optionally
11+
- gen_hybridcase_synthetic
12+
- get_hybridcase_FloatType (if it shoudl differ from Float32)
13+
"""
14+
abstract type AbstractHybridCase end;
15+
16+
"""
17+
get_hybridcase_par_templates(::AbstractHybridCase; scenario)
18+
19+
Provide tuple of templates of ComponentVectors `θP` and `θM`.
20+
"""
21+
function get_hybridcase_par_templates end
22+
23+
"""
24+
get_hybridcase_par_templates(::AbstractHybridCase; scenario)
25+
26+
Provide a NamedTuple of number of
27+
- n_covar: covariates xM
28+
- n_site: all sites in the data
29+
- n_batch: sites in one minibatch during fitting
30+
- n_θM, n_θP: entries in parameter vectors
31+
"""
32+
function get_hybridcase_sizes end
33+
34+
"""
35+
gen_hybridcase_MLapplicator(::AbstractHybridCase, MLEngine, n_covar, n_out; scenario=())
36+
37+
Construct the machine learning model fro given problem case and ML-Framework and
38+
scenario.
39+
40+
The MLEngine is a value type of a Symbol, usually the name of the module, e.g.
41+
`const MLengine = Val(nameof(SimpleChains))`.
42+
43+
returns a Tuple of
44+
- AbstractModelApplicator
45+
- initial parameter vector
46+
"""
47+
function gen_hybridcase_MLapplicator end
48+
49+
"""
50+
gen_hybridcase_PBmodel(::AbstractHybridCase; scenario::NTuple=())
51+
52+
Construct the process-based model function
53+
`f(θP::AbstractVector, θMs::AbstractMatrix, x) -> (AbstractVector, AbstractMatrix)`
54+
with
55+
- θP: calibrated parameters that are constant across site
56+
- θMs: calibrated parameters that vary across sites, with a column for each site
57+
- x: drivers, indexed by site
58+
59+
returns a tuple of predictions with components
60+
- first, those that are constant across sites
61+
- second, those that vary across sites, with a column for each site
62+
"""
63+
function gen_hybridcase_PBmodel end
64+
65+
"""
66+
gen_hybridcase_synthetic(::AbstractHybridCase, rng; scenario)
67+
68+
Setup synthetic data, a NamedTuple of
69+
- xM: matrix of covariates, with one column per site
70+
- θP_true: vector global process-model parameters
71+
- θMs_true: matrix of site-varying process-model parameters, with
72+
- xP: Vector of process-model drivers, with an entry per site
73+
- y_global_true: vector of global observations
74+
- y_true: matrix of site-specific observations with one column per site
75+
- y_global_o, y_o: observations with added noise
76+
"""
77+
function gen_hybridcase_synthetic end
78+
79+
"""
80+
get_hybridcase_FloatType(::AbstractHybridCase; scenario)
81+
82+
Determine the FloatType for given Case and scenario, defaults to Float32
83+
"""
84+
function get_hybridcase_FloatType(::AbstractHybridCase; scenario)
85+
return Float32
86+
end

0 commit comments

Comments
 (0)