@@ -65,28 +65,245 @@ loss_g(ϕg_opt1, xM, g)
65
65
scatterplot (vec (θMs_true), vec (loss_g (ϕg_opt1, xM, g)[2 ]))
66
66
@test cor (vec (θMs_true), vec (loss_g (ϕg_opt1, xM, g)[2 ])) > 0.9
67
67
68
- # ----------- fit g and θP to y_o
69
- f = gen_hybridcase_PBmodel (case; scenario)
68
+ tmpf = () -> begin
69
+ # ----------- fit g and θP to y_o
70
+ # end2end inversion
71
+ f = gen_hybridcase_PBmodel (case; scenario)
70
72
71
- int_ϕθP = ComponentArrayInterpreter (CA. ComponentVector (
72
- ϕg = 1 : length (ϕg0), θP = par_templates. θP))
73
- p = p0 = vcat (ϕg0, par_templates. θP .* 0.9 ); # slightly disturb θP_true
73
+ int_ϕθP = ComponentArrayInterpreter (CA. ComponentVector (
74
+ ϕg = 1 : length (ϕg0), θP = par_templates. θP))
75
+ p = p0 = vcat (ϕg0, par_templates. θP .* 0.9 ); # slightly disturb θP_true
74
76
75
- # Pass the site-data for the batches as separate vectors wrapped in a tuple
76
- train_loader = MLUtils. DataLoader ((xM, xP, y_o), batchsize = n_batch)
77
+ # Pass the site-data for the batches as separate vectors wrapped in a tuple
78
+ train_loader = MLUtils. DataLoader ((xM, xP, y_o), batchsize = n_batch)
77
79
78
- loss_gf = get_loss_gf (g, f, y_global_o, int_ϕθP)
79
- l1 = loss_gf (p0, train_loader. data... )[1 ]
80
+ loss_gf = get_loss_gf (g, f, y_global_o, int_ϕθP)
81
+ l1 = loss_gf (p0, train_loader. data... )[1 ]
80
82
81
- optf = Optimization. OptimizationFunction ((ϕ, data) -> loss_gf (ϕ, data... )[1 ],
83
+ optf = Optimization. OptimizationFunction ((ϕ, data) -> loss_gf (ϕ, data... )[1 ],
84
+ Optimization. AutoZygote ())
85
+ optprob = OptimizationProblem (optf, p0, train_loader)
86
+
87
+ res = Optimization. solve (
88
+ optprob, Adam (0.02 ), callback = callback_loss (100 ), maxiters = 1000 );
89
+
90
+ l1, y_pred_global, y_pred, θMs = loss_gf (res. u, train_loader. data... )
91
+ scatterplot (vec (θMs_true), vec (θMs))
92
+ scatterplot (log .(vec (θMs_true)), log .(vec (θMs)))
93
+ scatterplot (vec (y_pred), vec (y_o))
94
+ hcat (par_templates. θP, int_ϕθP (res. u). θP)
95
+ end
96
+
97
+ # ---------- HADVI
98
+ # TODO think about good general initializations
99
+ coef_logσ2_logMs = [- 5.769 - 3.501 ; - 0.01791 0.007951 ]
100
+ logσ2_logP = CA. ComponentVector (r0= - 8.997 , K2= - 5.893 )
101
+ mean_σ_o_MC = 0.006042
102
+
103
+ # correlation matrices
104
+ ρsP = zeros (sum (1 : (n_θP- 1 )))
105
+ ρsM = zeros (sum (1 : (n_θM- 1 )))
106
+
107
+ ϕunc = CA. ComponentVector (;
108
+ logσ2_logP= logσ2_logP,
109
+ coef_logσ2_logMs= coef_logσ2_logMs,
110
+ ρsP,
111
+ ρsM)
112
+ int_unc = ComponentArrayInterpreter (ϕunc)
113
+
114
+ # for a conservative uncertainty assume σ2=1e-10 and no relationship with magnitude
115
+ ϕunc0 = CA. ComponentVector (;
116
+ logσ2_logP= fill (- 10.0 , n_θP),
117
+ coef_logσ2_logMs= reduce (hcat, ([- 10.0 , 0.0 ] for _ in 1 : n_θM)),
118
+ ρsP,
119
+ ρsM)
120
+
121
+ logσ2y = fill (2 * log (σ_o), size (y_o, 1 ))
122
+ n_MC = 3
123
+
124
+
125
+ # -------------- ADVI with g inside cost function
126
+ using CUDA
127
+ using TransformVariables
128
+
129
+ transPMs_batch = as (
130
+ (P= as (Array, asℝ₊, n_θP),
131
+ Ms= as (Array, asℝ₊, n_θM, n_batch)))
132
+ transPMs_all = as (
133
+ (P= as (Array, asℝ₊, n_θP),
134
+ Ms= as (Array, asℝ₊, n_θM, n_site)))
135
+
136
+ ϕ_true = θ = CA. ComponentVector (;
137
+ μP= θP_true,
138
+ ϕg= ϕg_opt,
139
+ unc= ϕunc);
140
+ trans_gu = as (
141
+ (μP= as (Array, asℝ₊, n_θP),
142
+ ϕg= as (Array, n_ϕg),
143
+ unc= as (Array, length (ϕunc))))
144
+ trans_g = as (
145
+ (μP= as (Array, asℝ₊, n_θP),
146
+ ϕg= as (Array, n_ϕg)))
147
+
148
+ const int_PMs_batch = ComponentArrayInterpreter (CA. ComponentVector (; θP,
149
+ θMs= CA. ComponentMatrix (
150
+ zeros (n_θM, n_batch), first (CA. getaxes (θM)), CA. Axis (i= 1 : n_batch))))
151
+
152
+ interpreters = interpreters_g = map (get_concrete,(;
153
+ μP_ϕg_unc= ComponentArrayInterpreter (ϕ_true),
154
+ PMs= int_PMs_batch,
155
+ unc= ComponentArrayInterpreter (ϕunc)
156
+ ))
157
+
158
+ ϕg_true_vec = CA. ComponentVector (
159
+ TransformVariables. inverse (trans_gu, cv2NamedTuple (ϕ_true)))
160
+ ϕcg_true = interpreters. μP_ϕg_unc (ϕg_true_vec)
161
+ ϕ_ini = ζ = vcat (ϕcg_true[[:μP , :ϕg ]] .* 1.2 , ϕcg_true[[:unc ]]);
162
+ ϕ_ini0 = ζ = vcat (ϕcg_true[:μP ] .* 0.0 , SimpleChains. init_params (g), ϕunc0);
163
+
164
+ neg_elbo_transnorm_gf (rng, g, f, ϕcg_true, y_o[:, 1 : n_batch], x_o[:, 1 : n_batch],
165
+ transPMs_batch, map (get_concrete, interpreters);
166
+ n_MC= 8 , logσ2y)
167
+ Zygote. gradient (ϕ -> neg_elbo_transnorm_gf (
168
+ rng, g, f, ϕ, y_o[:, 1 : n_batch], x_o[:, 1 : n_batch],
169
+ transPMs_batch, interpreters; n_MC= 8 , logσ2y), ϕcg_true)
170
+
171
+ () -> begin
172
+ train_loader = MLUtils. DataLoader ((x_o, y_o), batchsize = n_batch)
173
+
174
+ optf = Optimization. OptimizationFunction ((ζg, data) -> begin
175
+ x_o, y_o = data
176
+ neg_elbo_transnorm_gf (
177
+ rng, g, f, ζg, y_o, x_o, transPMs_batch, map (get_concrete, interpreters_g); n_MC= 5 , logσ2y)
178
+ end ,
179
+ Optimization. AutoZygote ())
180
+ optprob = Optimization. OptimizationProblem (optf, CA. getdata (ϕ_ini), train_loader);
181
+ res = Optimization. solve (optprob, Optimisers. Adam (0.02 ), callback= callback_loss (50 ), maxiters= 800 );
182
+ # optprob = Optimization.OptimizationProblem(optf, ϕ_ini0);
183
+ # res = Optimization.solve(optprob, Adam(0.02), callback=callback_loss(50), maxiters=1_400);
184
+ end
185
+
186
+ # using Lux
187
+ ϕ = ϕcg_true |> gpu;
188
+ x_o_gpu = x_o |> gpu;
189
+ # y_o = y_o |> gpu
190
+ # logσ2y = logσ2y |> gpu
191
+ n_covar = size (x_o, 1 )
192
+ g_flux = Flux. Chain (
193
+ # dense layer with bias that maps to 8 outputs and applies `tanh` activation
194
+ Flux. Dense (n_covar => n_covar * 4 , tanh),
195
+ Flux. Dense (n_covar * 4 => n_covar * 4 , logistic),
196
+ # dense layer without bias that maps to n outputs and `identity` activation
197
+ Flux. Dense (n_covar * 4 => n_θM, identity, bias= false ),
198
+ )
199
+ () -> begin
200
+ using Lux
201
+ g_lux = Lux. Chain (
202
+ # dense layer with bias that maps to 8 outputs and applies `tanh` activation
203
+ Lux. Dense (n_covar => n_covar * 4 , tanh),
204
+ Lux. Dense (n_covar * 4 => n_covar * 4 , logistic),
205
+ # dense layer without bias that maps to n outputs and `identity` activation
206
+ Lux. Dense (n_covar * 4 => n_θM, identity, use_bias= false ),
207
+ )
208
+ ps, st = Lux. setup (Random. default_rng (), g_lux)
209
+ ps_ca = CA. ComponentArray (ps) |> gpu
210
+ st = st |> gpu
211
+ g_luxs = StatefulLuxLayer {true} (g_lux, nothing , st)
212
+ g_luxs (x_o_gpu[:, 1 : n_batch], ps_ca)
213
+ ax_g = CA. getaxes (ps_ca)
214
+ g_luxs (x_o_gpu[:, 1 : n_batch], CA. ComponentArray (ϕ. ϕg, ax_g))
215
+ interpreters = (interpreters... , ϕg = ComponentArrayInterpreter (ps_ca))
216
+ ϕg = CA. ComponentArray (ϕ. ϕg, ax_g)
217
+ ϕgc = interpreters. ϕg (ϕ. ϕg)
218
+ g_gpu = g_luxs
219
+ end
220
+ g_gpu = g_flux
221
+
222
+ # Zygote.gradient(ϕg -> sum(g_gpu(x_o_gpu[:, 1:n_batch],ϕg)), ϕgc)
223
+ # Zygote.gradient(ϕg -> sum(compute_g(g_gpu, x_o_gpu[:, 1:n_batch], ϕg, interpreters)), ϕ.ϕg)
224
+ # Zygote.gradient(ϕ -> sum(tmp_gen1(g_gpu, x_o_gpu[:, 1:n_batch], ϕ, interpreters)), ϕ.ϕg)
225
+ # Zygote.gradient(ϕ -> sum(tmp_gen2(g_gpu, x_o_gpu[:, 1:n_batch], ϕ, interpreters)), CA.getdata(ϕ))
226
+ # Zygote.gradient(ϕ -> sum(tmp_gen2(g_gpu, x_o_gpu[:, 1:n_batch], ϕ, interpreters)), ϕ) |> cpu
227
+ # Zygote.gradient(ϕ -> sum(tmp_gen3(g_gpu, x_o_gpu[:, 1:n_batch], ϕ, interpreters)), ϕ) |> cpu
228
+ # Zygote.gradient(ϕ -> sum(tmp_gen4(g_gpu, x_o_gpu[:, 1:n_batch], ϕ, interpreters)[1]), ϕ) |> cpu
229
+ # generate_ζ(rng, g_gpu, f, ϕ, x_o_gpu[:, 1:n_batch], interpreters)
230
+ # Zygote.gradient(ϕ -> sum(generate_ζ(rng, g_gpu, f, ϕ, x_o_gpu[:, 1:n_batch], interpreters)[1]), ϕ) |> cpu
231
+ # include(joinpath(@__DIR__, "uncNN", "elbo.jl")) # callback_loss
232
+ # neg_elbo_transnorm_gf(rng, g_gpu, f, ϕ, y_o[:, 1:n_batch],
233
+ # x_o_gpu[:, 1:n_batch], transPMs_batch, interpreters; logσ2y)
234
+ # Zygote.gradient(ϕ -> sum(neg_elbo_transnorm_gf(rng, g_gpu, f, ϕ, y_o[:, 1:n_batch],
235
+ # x_o_gpu[:, 1:n_batch], transPMs_batch, interpreters; logσ2y)[1]), ϕ) |> cpu
236
+
237
+
238
+ fcost (ϕ) = neg_elbo_transnorm_gf (rng, g_gpu, f, ϕ, y_o[:, 1 : n_batch],
239
+ x_o_gpu[:, 1 : n_batch], transPMs_batch, map (get_concrete, interpreters);
240
+ n_MC= 8 , logσ2y = logσ2y)
241
+ fcost (ϕ)
242
+ gr = Zygote. gradient (fcost, ϕ) |> cpu;
243
+ Zygote. gradient (fcost, CA. getdata (ϕ))
244
+
245
+
246
+ train_loader = MLUtils. DataLoader ((x_o_gpu, y_o), batchsize = n_batch)
247
+
248
+ optf = Optimization. OptimizationFunction ((ζg, data) -> begin
249
+ x_o, y_o = data
250
+ neg_elbo_transnorm_gf (
251
+ rng, g_gpu, f, ζg, y_o, x_o, transPMs_batch, map (get_concrete, interpreters_g); n_MC= 5 , logσ2y)
252
+ end ,
82
253
Optimization. AutoZygote ())
83
- optprob = OptimizationProblem (optf, p0, train_loader)
254
+ optprob = Optimization. OptimizationProblem (optf, CA. getdata (ϕ_ini) |> gpu, train_loader);
255
+ res = res_gpu = Optimization. solve (optprob, Optimisers. Adam (0.02 ), callback= callback_loss (50 ), maxiters= 800 );
256
+
257
+ ζ_VIc = interpreters_g. μP_ϕg_unc (res. u |> cpu)
258
+ ζMs_VI = g (x_o, ζ_VIc. ϕg)
259
+ ϕunc_VI = int_unc (ζ_VIc. unc)
260
+
261
+ hcat (θP_true, exp .(ζ_VIc. μP))
262
+ plt = scatterplot (vec (θMs_true), vec (exp .(ζMs_VI)))
263
+ # lineplot!(plt, 0.0, 1.1, identity)
264
+ #
265
+ hcat (ϕunc, ϕunc_VI) # need to compare to MC sample
266
+ # hard to estimate for original very small theta's but otherwise good
267
+
268
+ # test predicting correct obs-uncertainty of predictive posterior
269
+ n_sample_pred = 200
270
+ intm_PMs_gen = ComponentArrayInterpreter (CA. ComponentVector (; θP,
271
+ θMs= CA. ComponentMatrix (
272
+ zeros (n_θM, n_site), first (CA. getaxes (θM)), CA. Axis (i= 1 : n_sample_pred))))
273
+
274
+ include (joinpath (@__DIR__ , " uncNN" , " elbo.jl" )) # callback_loss
275
+ ζs, _ = generate_ζ (rng, g, f, res. u |> cpu, x_o,
276
+ (;interpreters... , PMs = intm_PMs_gen); n_MC= n_sample_pred)
277
+ # ζ = ζs[:,1]
278
+ θsc = stack (ζ -> CA. getdata (CA. ComponentVector (
279
+ TransformVariables. transform (transPMs_all, ζ))), eachcol (ζs));
280
+ y_pred = stack (map (ζ -> first (predict_y (ζ, f, transPMs_all)), eachcol (ζs)));
281
+
282
+ size (y_pred)
283
+ σ_o_post = mapslices (std, y_pred; dims= 3 );
284
+ # describe(σ_o_post)
285
+ vcat (σ_o, mean_σ_o_MC, mean (σ_o_post), sqrt (mean (abs2, σ_o_post)))
286
+ mean_y_pred = map (mean, eachslice (y_pred; dims= (1 , 2 )))
287
+ # describe(mean_y_pred - y_o)
288
+ histogram (vec (mean_y_pred - y_true)) # predictions centered around y_o (or y_true)
289
+
290
+ # look at θP, θM1 of first site
291
+ intm = ComponentArrayInterpreter (int_θdoubleMM (1 : length (int_θdoubleMM)), (n_sample_pred,))
292
+ ζs1c = intm (ζs[1 : length (int_θdoubleMM), :])
293
+ vcat (θP_true, θM_true)
294
+ histogram (exp .(ζs1c[:r0 , :]))
295
+ histogram (exp .(ζs1c[:K2 , :]))
296
+ histogram (exp .(ζs1c[:r1 , :]))
297
+ histogram (exp .(ζs1c[:K1 , :]))
298
+ # all parameters estimated to high (true not in cf bounds)
299
+ scatterplot (ζs1c[:r1 , :], ζs1c[:K1 , :]) # r1 and K1 strongly correlated (from θM)
300
+ scatterplot (ζs1c[:r0 , :], ζs1c[:K2 , :]) # r0 and K also correlated (from θP)
301
+ scatterplot (ζs1c[:r0 , :], ζs1c[:K1 , :]) # no correlation (modeled independent)
302
+
303
+ # TODO compare distributions to MC sample
304
+
305
+
306
+
307
+
84
308
85
- res = Optimization. solve (
86
- optprob, Adam (0.02 ), callback = callback_loss (100 ), maxiters = 1000 );
87
309
88
- l1, y_pred_global, y_pred, θMs = loss_gf (res. u, train_loader. data... )
89
- scatterplot (vec (θMs_true), vec (θMs))
90
- scatterplot (log .(vec (θMs_true)), log .(vec (θMs)))
91
- scatterplot (vec (y_pred), vec (y_o))
92
- hcat (par_templates. θP, int_ϕθP (res. u). θP)
0 commit comments