@@ -24,100 +24,80 @@ using OptimizationOptimisers
24
24
25
25
using UnicodePlots
26
26
27
- const EX = HybridVariationalInference. DoubleMM
28
27
const case = DoubleMM. DoubleMMCase ()
29
28
const MLengine = Val (nameof (SimpleChains))
30
- scenario= (:default ,)
31
-
29
+ scenario = (:default ,)
32
30
rng = StableRNG (111 )
33
31
34
- (; n_covar_pc, n_covar, n_site, n_batch, n_θM, n_θP) = get_case_sizes (case; scenario)
32
+ par_templates = get_hybridcase_par_templates (case; scenario)
33
+
34
+ (; n_covar, n_site, n_batch, n_θM, n_θP) = get_hybridcase_sizes (case; scenario)
35
35
36
- # const int_θP = ComponentArrayInterpreter(EX .θP)
37
- # const int_θM = ComponentArrayInterpreter(EX .θM)
36
+ # const int_θP = ComponentArrayInterpreter(par_templates .θP)
37
+ # const int_θM = ComponentArrayInterpreter(par_templates .θM)
38
38
# const int_θPMs_flat = ComponentArrayInterpreter(P = n_θP, Ms = n_θM * n_batch)
39
- # const int_θ = ComponentArrayInterpreter(CA.ComponentVector(;θP=EX .θP,θM=EX .θM))
39
+ # const int_θ = ComponentArrayInterpreter(CA.ComponentVector(;θP=par_templates .θP,θM=par_templates .θM))
40
40
# # moved to f_doubleMM
41
41
# # const int_θdoubleMM = ComponentArrayInterpreter(flatten1(CA.ComponentVector(;θP,θM)))
42
42
# # const S1 = [1.0, 1.0, 1.0, 0.3, 0.1]
43
43
# # const S2 = [1.0, 3.0, 5.0, 5.0, 5.0]
44
- # θ = CA.getdata(vcat(EX.θP, EX.θM))
45
-
46
- # const int_θPMs = ComponentArrayInterpreter(CA.ComponentVector(;EX.θP,
47
- # θMs=CA.ComponentMatrix(zeros(n_θM, n_batch), first(CA.getaxes(EX.θM)), CA.Axis(i=1:n_batch))))
44
+ # θ = CA.getdata(vcat(par_templates.θP, par_templates.θM))
48
45
49
- # moved to f_doubleMM
50
- # gen_q(InteractionsCovCor)
51
- x_o, θMs_true0 = gen_cov_pred (case, rng; scenario)
52
- # normalize to be distributed around the prescribed true values
53
- int_θMs_sites = ComponentArrayInterpreter (EX. θM, (n_site,))
54
- int_θMs_batch = ComponentArrayInterpreter (EX. θM, (n_batch,))
55
- θMs_true = int_θMs_sites (scale_centered_at (θMs_true0, EX. θM, 0.1 ));
46
+ # const int_θPMs = ComponentArrayInterpreter(CA.ComponentVector(;par_templates.θP,
47
+ # θMs=CA.ComponentMatrix(zeros(n_θM, n_batch), first(CA.getaxes(par_templates.θM)), CA.Axis(i=1:n_batch))))
56
48
57
- @test isapprox ( vec ( mean (CA . getdata (θMs_true); dims = 2 )), CA . getdata (EX . θM), rtol = 0.02 )
58
- @test isapprox ( vec ( std (CA . getdata (θMs_true); dims = 2 )), CA . getdata (EX . θM) .* 0.1 , rtol = 0.02 )
49
+ (; xM, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o) = gen_hybridcase_synthetic (
50
+ case, rng; scenario);
59
51
52
+ @test isapprox (
53
+ vec (mean (CA. getdata (θMs_true); dims = 2 )), CA. getdata (par_templates. θM), rtol = 0.02 )
54
+ @test isapprox (vec (std (CA. getdata (θMs_true); dims = 2 )),
55
+ CA. getdata (par_templates. θM) .* 0.1 , rtol = 0.02 )
60
56
61
57
# ----- fit g to θMs_true
62
- g, ϕg0 = gen_g (case, MLengine; scenario)
63
- n_ϕg = length (ϕg0)
58
+ g, ϕg0 = gen_hybridcase_MLapplicator (case, MLengine; scenario);
64
59
65
60
function loss_g (ϕg, x, g)
66
61
ζMs = g (x, ϕg) # predict the log of the parameters
67
- θMs = exp .(ζMs)
62
+ θMs = exp .(ζMs)
68
63
loss = sum (abs2, θMs .- θMs_true)
69
64
return loss, θMs
70
65
end
71
- loss_g (ϕg0, x_o , g)
72
- Zygote. gradient (x-> loss_g (x, x_o , g)[1 ], ϕg0);
66
+ loss_g (ϕg0, xM , g)
67
+ Zygote. gradient (x -> loss_g (x, xM , g)[1 ], ϕg0);
73
68
74
- optf = Optimization. OptimizationFunction ((ϕg, p) -> loss_g (ϕg,x_o , g)[1 ],
69
+ optf = Optimization. OptimizationFunction ((ϕg, p) -> loss_g (ϕg, xM , g)[1 ],
75
70
Optimization. AutoZygote ())
76
71
optprob = Optimization. OptimizationProblem (optf, ϕg0);
77
- res = Optimization. solve (optprob, Adam (0.02 ), callback= callback_loss (100 ), maxiters= 600 );
72
+ res = Optimization. solve (optprob, Adam (0.02 ), callback = callback_loss (100 ), maxiters = 600 );
78
73
79
74
ϕg_opt1 = res. u;
80
- loss_g (ϕg_opt1, x_o, g)
81
- scatterplot (vec (θMs_true), vec (loss_g (ϕg_opt1, x_o, g)[2 ]))
82
- @test cor (vec (θMs_true), vec (loss_g (ϕg_opt1, x_o, g)[2 ])) > 0.9
83
-
84
- # ----------- fit g and θP to y_obs
85
- f = gen_f (case; scenario)
86
- y_true = f (EX. θP, θMs_true, zip ())[2 ]
87
-
88
- σ_o = 0.01
89
- # σ_o = 0.002
90
- y_o = y_true .+ reshape (randn (length (y_true)), size (y_true)... ) .* σ_o
91
- scatterplot (vec (y_true), vec (y_o))
92
- scatterplot (vec (log .(y_true)), vec (log .(y_o)))
93
-
94
- # fit g to log(θ_true) ~ x_o
75
+ loss_g (ϕg_opt1, xM, g)
76
+ scatterplot (vec (θMs_true), vec (loss_g (ϕg_opt1, xM, g)[2 ]))
77
+ @test cor (vec (θMs_true), vec (loss_g (ϕg_opt1, xM, g)[2 ])) > 0.9
95
78
96
- int_ϕθP = ComponentArrayInterpreter (CA. ComponentVector (ϕg= 1 : length (ϕg0), θP= EX. θP))
97
- p = p0 = vcat (ϕg0, EX. θP .* 0.9 ); # slightly disturb θP_true
98
- # #p = p0 = vcat(ϕg_opt1, θP_true .* 0.9); # slightly disturb θP_true
99
- # p0c = int_ϕθP(p0);
100
- # #gf(g,f_doubleMM, x_o, pc.ϕg, pc.θP)[1]
79
+ # ----------- fit g and θP to y_o
80
+ f = gen_hybridcase_PBmodel (case; scenario)
101
81
82
+ int_ϕθP = ComponentArrayInterpreter (CA. ComponentVector (
83
+ ϕg = 1 : length (ϕg0), θP = par_templates. θP))
84
+ p = p0 = vcat (ϕg0, par_templates. θP .* 0.9 ); # slightly disturb θP_true
102
85
103
- # Pass the data for the batches as separate vectors wrapped in a tuple
104
- train_loader = MLUtils. DataLoader ((
105
- x_o,
106
- fill ((), n_site), # xP
107
- y_o
108
- ), batchsize = n_batch)
86
+ # Pass the site-data for the batches as separate vectors wrapped in a tuple
87
+ train_loader = MLUtils. DataLoader ((xM, xP, y_o), batchsize = n_batch)
109
88
110
- loss_gf = get_loss_gf (g, f, Float32[] , int_ϕθP)
89
+ loss_gf = get_loss_gf (g, f, y_global_o , int_ϕθP)
111
90
l1 = loss_gf (p0, train_loader. data... )[1 ]
112
91
113
92
optf = Optimization. OptimizationFunction ((ϕ, data) -> loss_gf (ϕ, data... )[1 ],
114
93
Optimization. AutoZygote ())
115
94
optprob = OptimizationProblem (optf, p0, train_loader)
116
95
117
- res = Optimization. solve (optprob, Adam (0.02 ), callback= callback_loss (100 ), maxiters= 1000 );
96
+ res = Optimization. solve (
97
+ optprob, Adam (0.02 ), callback = callback_loss (100 ), maxiters = 1000 );
118
98
119
99
l1, y_pred_global, y_pred, θMs = loss_gf (res. u, train_loader. data... )
120
100
scatterplot (vec (θMs_true), vec (θMs))
121
101
scatterplot (log .(vec (θMs_true)), log .(vec (θMs)))
122
102
scatterplot (vec (y_pred), vec (y_o))
123
- hcat (EX . θP, int_ϕθP (res. u). θP)
103
+ hcat (par_templates . θP, int_ϕθP (res. u). θP)
0 commit comments