Skip to content

Commit 517dae8

Browse files
authored
implement fitting cholesky factors (#5)
* implement fitting cholesky factors including filling and extracting tridiagonal matrices on gpu and corresponding rrules for Zygote * run Aqua only on version 1.11.2 * avoid testing CUDA code on machines without CUDA
1 parent 26997b9 commit 517dae8

File tree

10 files changed

+773
-21
lines changed

10 files changed

+773
-21
lines changed

Project.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,12 @@ authors = ["Thomas Wutzler <twutz@bgc-jena.mpg.de> and contributors"]
44
version = "1.0.0-DEV"
55

66
[deps]
7+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
8+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
79
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
810
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
11+
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
12+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
913
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1014
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1115
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
@@ -21,9 +25,13 @@ HybridVariationalInferenceLuxExt = "Lux"
2125
HybridVariationalInferenceSimpleChainsExt = "SimpleChains"
2226

2327
[compat]
28+
ChainRulesCore = "1.25"
29+
CUDA = "5.5.2"
2430
Combinatorics = "1.0.2"
2531
ComponentArrays = "0.15.19"
2632
Flux = "v0.15.2"
33+
GPUArraysCore = "0.1, 0.2"
34+
LinearAlgebra = "1.10.0"
2735
Lux = "1.4.2"
2836
Random = "1.10.0"
2937
SimpleChains = "0.4"

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
Extending Variational Inference (VI), an approximate bayesian inversion method,
1010
to hybrid models, i.e. models that combine mechanistic and machine-learning parts.
1111

12-
The model inversion, inferes parametric approximations of posterior density
12+
The model inversion, infers parametric approximations of posterior density
1313
of model parameters, by comparing model outputs to uncertain observations. At
1414
the same time, a machine learning model is fit that predicts parameters of these
1515
approximations by covariates.

dev/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3+
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
34
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
45
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
56
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
@@ -11,5 +12,6 @@ SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
1112
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
1213
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1314
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
15+
TransformVariables = "84d833dd-6860-57f9-a1a7-6da5db126cff"
1416
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
1517
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

dev/doubleMM.jl

Lines changed: 235 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -65,28 +65,245 @@ loss_g(ϕg_opt1, xM, g)
6565
scatterplot(vec(θMs_true), vec(loss_g(ϕg_opt1, xM, g)[2]))
6666
@test cor(vec(θMs_true), vec(loss_g(ϕg_opt1, xM, g)[2])) > 0.9
6767

68-
#----------- fit g and θP to y_o
69-
f = gen_hybridcase_PBmodel(case; scenario)
68+
tmpf = () -> begin
69+
#----------- fit g and θP to y_o
70+
# end2end inversion
71+
f = gen_hybridcase_PBmodel(case; scenario)
7072

71-
int_ϕθP = ComponentArrayInterpreter(CA.ComponentVector(
72-
ϕg = 1:length(ϕg0), θP = par_templates.θP))
73-
p = p0 = vcat(ϕg0, par_templates.θP .* 0.9); # slightly disturb θP_true
73+
int_ϕθP = ComponentArrayInterpreter(CA.ComponentVector(
74+
ϕg = 1:length(ϕg0), θP = par_templates.θP))
75+
p = p0 = vcat(ϕg0, par_templates.θP .* 0.9); # slightly disturb θP_true
7476

75-
# Pass the site-data for the batches as separate vectors wrapped in a tuple
76-
train_loader = MLUtils.DataLoader((xM, xP, y_o), batchsize = n_batch)
77+
# Pass the site-data for the batches as separate vectors wrapped in a tuple
78+
train_loader = MLUtils.DataLoader((xM, xP, y_o), batchsize = n_batch)
7779

78-
loss_gf = get_loss_gf(g, f, y_global_o, int_ϕθP)
79-
l1 = loss_gf(p0, train_loader.data...)[1]
80+
loss_gf = get_loss_gf(g, f, y_global_o, int_ϕθP)
81+
l1 = loss_gf(p0, train_loader.data...)[1]
8082

81-
optf = Optimization.OptimizationFunction((ϕ, data) -> loss_gf(ϕ, data...)[1],
83+
optf = Optimization.OptimizationFunction((ϕ, data) -> loss_gf(ϕ, data...)[1],
84+
Optimization.AutoZygote())
85+
optprob = OptimizationProblem(optf, p0, train_loader)
86+
87+
res = Optimization.solve(
88+
optprob, Adam(0.02), callback = callback_loss(100), maxiters = 1000);
89+
90+
l1, y_pred_global, y_pred, θMs = loss_gf(res.u, train_loader.data...)
91+
scatterplot(vec(θMs_true), vec(θMs))
92+
scatterplot(log.(vec(θMs_true)), log.(vec(θMs)))
93+
scatterplot(vec(y_pred), vec(y_o))
94+
hcat(par_templates.θP, int_ϕθP(res.u).θP)
95+
end
96+
97+
#---------- HADVI
98+
# TODO think about good general initializations
99+
coef_logσ2_logMs = [-5.769 -3.501; -0.01791 0.007951]
100+
logσ2_logP = CA.ComponentVector(r0=-8.997, K2=-5.893)
101+
mean_σ_o_MC = 0.006042
102+
103+
# correlation matrices
104+
ρsP = zeros(sum(1:(n_θP-1)))
105+
ρsM = zeros(sum(1:(n_θM-1)))
106+
107+
ϕunc = CA.ComponentVector(;
108+
logσ2_logP=logσ2_logP,
109+
coef_logσ2_logMs=coef_logσ2_logMs,
110+
ρsP,
111+
ρsM)
112+
int_unc = ComponentArrayInterpreter(ϕunc)
113+
114+
# for a conservative uncertainty assume σ2=1e-10 and no relationship with magnitude
115+
ϕunc0 = CA.ComponentVector(;
116+
logσ2_logP=fill(-10.0, n_θP),
117+
coef_logσ2_logMs=reduce(hcat, ([-10.0, 0.0] for _ in 1:n_θM)),
118+
ρsP,
119+
ρsM)
120+
121+
logσ2y = fill(2 * log(σ_o), size(y_o, 1))
122+
n_MC = 3
123+
124+
125+
#-------------- ADVI with g inside cost function
126+
using CUDA
127+
using TransformVariables
128+
129+
transPMs_batch = as(
130+
(P=as(Array, asℝ₊, n_θP),
131+
Ms=as(Array, asℝ₊, n_θM, n_batch)))
132+
transPMs_all = as(
133+
(P=as(Array, asℝ₊, n_θP),
134+
Ms=as(Array, asℝ₊, n_θM, n_site)))
135+
136+
ϕ_true = θ = CA.ComponentVector(;
137+
μP=θP_true,
138+
ϕg=ϕg_opt,
139+
unc=ϕunc);
140+
trans_gu = as(
141+
(μP=as(Array, asℝ₊, n_θP),
142+
ϕg=as(Array, n_ϕg),
143+
unc=as(Array, length(ϕunc))))
144+
trans_g = as(
145+
(μP=as(Array, asℝ₊, n_θP),
146+
ϕg=as(Array, n_ϕg)))
147+
148+
const int_PMs_batch = ComponentArrayInterpreter(CA.ComponentVector(; θP,
149+
θMs=CA.ComponentMatrix(
150+
zeros(n_θM, n_batch), first(CA.getaxes(θM)), CA.Axis(i=1:n_batch))))
151+
152+
interpreters = interpreters_g = map(get_concrete,(;
153+
μP_ϕg_unc=ComponentArrayInterpreter(ϕ_true),
154+
PMs=int_PMs_batch,
155+
unc=ComponentArrayInterpreter(ϕunc)
156+
))
157+
158+
ϕg_true_vec = CA.ComponentVector(
159+
TransformVariables.inverse(trans_gu, cv2NamedTuple(ϕ_true)))
160+
ϕcg_true = interpreters.μP_ϕg_unc(ϕg_true_vec)
161+
ϕ_ini = ζ = vcat(ϕcg_true[[:μP, :ϕg]] .* 1.2, ϕcg_true[[:unc]]);
162+
ϕ_ini0 = ζ = vcat(ϕcg_true[:μP] .* 0.0, SimpleChains.init_params(g), ϕunc0);
163+
164+
neg_elbo_transnorm_gf(rng, g, f, ϕcg_true, y_o[:, 1:n_batch], x_o[:, 1:n_batch],
165+
transPMs_batch, map(get_concrete, interpreters);
166+
n_MC=8, logσ2y)
167+
Zygote.gradient-> neg_elbo_transnorm_gf(
168+
rng, g, f, ϕ, y_o[:, 1:n_batch], x_o[:, 1:n_batch],
169+
transPMs_batch, interpreters; n_MC=8, logσ2y), ϕcg_true)
170+
171+
() -> begin
172+
train_loader = MLUtils.DataLoader((x_o, y_o), batchsize = n_batch)
173+
174+
optf = Optimization.OptimizationFunction((ζg, data) -> begin
175+
x_o, y_o = data
176+
neg_elbo_transnorm_gf(
177+
rng, g, f, ζg, y_o, x_o, transPMs_batch, map(get_concrete, interpreters_g); n_MC=5, logσ2y)
178+
end,
179+
Optimization.AutoZygote())
180+
optprob = Optimization.OptimizationProblem(optf, CA.getdata(ϕ_ini), train_loader);
181+
res = Optimization.solve(optprob, Optimisers.Adam(0.02), callback=callback_loss(50), maxiters=800);
182+
#optprob = Optimization.OptimizationProblem(optf, ϕ_ini0);
183+
#res = Optimization.solve(optprob, Adam(0.02), callback=callback_loss(50), maxiters=1_400);
184+
end
185+
186+
#using Lux
187+
ϕ = ϕcg_true |> gpu;
188+
x_o_gpu = x_o |> gpu;
189+
# y_o = y_o |> gpu
190+
# logσ2y = logσ2y |> gpu
191+
n_covar = size(x_o, 1)
192+
g_flux = Flux.Chain(
193+
# dense layer with bias that maps to 8 outputs and applies `tanh` activation
194+
Flux.Dense(n_covar => n_covar * 4, tanh),
195+
Flux.Dense(n_covar * 4 => n_covar * 4, logistic),
196+
# dense layer without bias that maps to n outputs and `identity` activation
197+
Flux.Dense(n_covar * 4 => n_θM, identity, bias=false),
198+
)
199+
() -> begin
200+
using Lux
201+
g_lux = Lux.Chain(
202+
# dense layer with bias that maps to 8 outputs and applies `tanh` activation
203+
Lux.Dense(n_covar => n_covar * 4, tanh),
204+
Lux.Dense(n_covar * 4 => n_covar * 4, logistic),
205+
# dense layer without bias that maps to n outputs and `identity` activation
206+
Lux.Dense(n_covar * 4 => n_θM, identity, use_bias=false),
207+
)
208+
ps, st = Lux.setup(Random.default_rng(), g_lux)
209+
ps_ca = CA.ComponentArray(ps) |> gpu
210+
st = st |> gpu
211+
g_luxs = StatefulLuxLayer{true}(g_lux, nothing, st)
212+
g_luxs(x_o_gpu[:, 1:n_batch], ps_ca)
213+
ax_g = CA.getaxes(ps_ca)
214+
g_luxs(x_o_gpu[:, 1:n_batch], CA.ComponentArray.ϕg, ax_g))
215+
interpreters = (interpreters..., ϕg = ComponentArrayInterpreter(ps_ca))
216+
ϕg = CA.ComponentArray.ϕg, ax_g)
217+
ϕgc = interpreters.ϕg.ϕg)
218+
g_gpu = g_luxs
219+
end
220+
g_gpu = g_flux
221+
222+
#Zygote.gradient(ϕg -> sum(g_gpu(x_o_gpu[:, 1:n_batch],ϕg)), ϕgc)
223+
# Zygote.gradient(ϕg -> sum(compute_g(g_gpu, x_o_gpu[:, 1:n_batch], ϕg, interpreters)), ϕ.ϕg)
224+
# Zygote.gradient(ϕ -> sum(tmp_gen1(g_gpu, x_o_gpu[:, 1:n_batch], ϕ, interpreters)), ϕ.ϕg)
225+
# Zygote.gradient(ϕ -> sum(tmp_gen2(g_gpu, x_o_gpu[:, 1:n_batch], ϕ, interpreters)), CA.getdata(ϕ))
226+
# Zygote.gradient(ϕ -> sum(tmp_gen2(g_gpu, x_o_gpu[:, 1:n_batch], ϕ, interpreters)), ϕ) |> cpu
227+
# Zygote.gradient(ϕ -> sum(tmp_gen3(g_gpu, x_o_gpu[:, 1:n_batch], ϕ, interpreters)), ϕ) |> cpu
228+
# Zygote.gradient(ϕ -> sum(tmp_gen4(g_gpu, x_o_gpu[:, 1:n_batch], ϕ, interpreters)[1]), ϕ) |> cpu
229+
# generate_ζ(rng, g_gpu, f, ϕ, x_o_gpu[:, 1:n_batch], interpreters)
230+
# Zygote.gradient(ϕ -> sum(generate_ζ(rng, g_gpu, f, ϕ, x_o_gpu[:, 1:n_batch], interpreters)[1]), ϕ) |> cpu
231+
# include(joinpath(@__DIR__, "uncNN", "elbo.jl")) # callback_loss
232+
# neg_elbo_transnorm_gf(rng, g_gpu, f, ϕ, y_o[:, 1:n_batch],
233+
# x_o_gpu[:, 1:n_batch], transPMs_batch, interpreters; logσ2y)
234+
# Zygote.gradient(ϕ -> sum(neg_elbo_transnorm_gf(rng, g_gpu, f, ϕ, y_o[:, 1:n_batch],
235+
# x_o_gpu[:, 1:n_batch], transPMs_batch, interpreters; logσ2y)[1]), ϕ) |> cpu
236+
237+
238+
fcost(ϕ) = neg_elbo_transnorm_gf(rng, g_gpu, f, ϕ, y_o[:, 1:n_batch],
239+
x_o_gpu[:, 1:n_batch], transPMs_batch, map(get_concrete, interpreters);
240+
n_MC=8, logσ2y = logσ2y)
241+
fcost(ϕ)
242+
gr = Zygote.gradient(fcost, ϕ) |> cpu;
243+
Zygote.gradient(fcost, CA.getdata(ϕ))
244+
245+
246+
train_loader = MLUtils.DataLoader((x_o_gpu, y_o), batchsize = n_batch)
247+
248+
optf = Optimization.OptimizationFunction((ζg, data) -> begin
249+
x_o, y_o = data
250+
neg_elbo_transnorm_gf(
251+
rng, g_gpu, f, ζg, y_o, x_o, transPMs_batch, map(get_concrete, interpreters_g); n_MC=5, logσ2y)
252+
end,
82253
Optimization.AutoZygote())
83-
optprob = OptimizationProblem(optf, p0, train_loader)
254+
optprob = Optimization.OptimizationProblem(optf, CA.getdata(ϕ_ini) |> gpu, train_loader);
255+
res = res_gpu = Optimization.solve(optprob, Optimisers.Adam(0.02), callback=callback_loss(50), maxiters=800);
256+
257+
ζ_VIc = interpreters_g.μP_ϕg_unc(res.u |> cpu)
258+
ζMs_VI = g(x_o, ζ_VIc.ϕg)
259+
ϕunc_VI = int_unc(ζ_VIc.unc)
260+
261+
hcat(θP_true, exp.(ζ_VIc.μP))
262+
plt = scatterplot(vec(θMs_true), vec(exp.(ζMs_VI)))
263+
#lineplot!(plt, 0.0, 1.1, identity)
264+
#
265+
hcat(ϕunc, ϕunc_VI) # need to compare to MC sample
266+
# hard to estimate for original very small theta's but otherwise good
267+
268+
# test predicting correct obs-uncertainty of predictive posterior
269+
n_sample_pred = 200
270+
intm_PMs_gen = ComponentArrayInterpreter(CA.ComponentVector(; θP,
271+
θMs=CA.ComponentMatrix(
272+
zeros(n_θM, n_site), first(CA.getaxes(θM)), CA.Axis(i=1:n_sample_pred))))
273+
274+
include(joinpath(@__DIR__, "uncNN", "elbo.jl")) # callback_loss
275+
ζs, _ = generate_ζ(rng, g, f, res.u |> cpu, x_o,
276+
(;interpreters..., PMs = intm_PMs_gen); n_MC=n_sample_pred)
277+
# ζ = ζs[:,1]
278+
θsc = stack-> CA.getdata(CA.ComponentVector(
279+
TransformVariables.transform(transPMs_all, ζ))), eachcol(ζs));
280+
y_pred = stack(map-> first(predict_y(ζ, f, transPMs_all)), eachcol(ζs)));
281+
282+
size(y_pred)
283+
σ_o_post = mapslices(std, y_pred; dims=3);
284+
#describe(σ_o_post)
285+
vcat(σ_o, mean_σ_o_MC, mean(σ_o_post), sqrt(mean(abs2, σ_o_post)))
286+
mean_y_pred = map(mean, eachslice(y_pred; dims=(1, 2)))
287+
#describe(mean_y_pred - y_o)
288+
histogram(vec(mean_y_pred - y_true)) # predictions centered around y_o (or y_true)
289+
290+
# look at θP, θM1 of first site
291+
intm = ComponentArrayInterpreter(int_θdoubleMM(1:length(int_θdoubleMM)), (n_sample_pred,))
292+
ζs1c = intm(ζs[1:length(int_θdoubleMM), :])
293+
vcat(θP_true, θM_true)
294+
histogram(exp.(ζs1c[:r0, :]))
295+
histogram(exp.(ζs1c[:K2, :]))
296+
histogram(exp.(ζs1c[:r1, :]))
297+
histogram(exp.(ζs1c[:K1, :]))
298+
# all parameters estimated to high (true not in cf bounds)
299+
scatterplot(ζs1c[:r1, :], ζs1c[:K1, :]) # r1 and K1 strongly correlated (from θM)
300+
scatterplot(ζs1c[:r0, :], ζs1c[:K2, :]) # r0 and K also correlated (from θP)
301+
scatterplot(ζs1c[:r0, :], ζs1c[:K1, :]) # no correlation (modeled independent)
302+
303+
# TODO compare distributions to MC sample
304+
305+
306+
307+
84308

85-
res = Optimization.solve(
86-
optprob, Adam(0.02), callback = callback_loss(100), maxiters = 1000);
87309

88-
l1, y_pred_global, y_pred, θMs = loss_gf(res.u, train_loader.data...)
89-
scatterplot(vec(θMs_true), vec(θMs))
90-
scatterplot(log.(vec(θMs_true)), log.(vec(θMs)))
91-
scatterplot(vec(y_pred), vec(y_o))
92-
hcat(par_templates.θP, int_ϕθP(res.u).θP)

src/HybridVariationalInference.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ using ComponentArrays: ComponentArrays as CA
44
using Random
55
using StatsBase # fit ZScoreTransform
66
using Combinatorics # gen_hybridcase_synthetic/combinations
7+
using GPUArraysCore
8+
using LinearAlgebra
9+
using CUDA
10+
using ChainRulesCore
711

812
export ComponentArrayInterpreter, flatten1, get_concrete
913
include("ComponentArrayInterpreter.jl")
@@ -25,6 +29,9 @@ include("gencovar.jl")
2529
export callback_loss
2630
include("util_opt.jl")
2731

32+
#export - all internal
33+
include("cholesky.jl")
34+
2835
export DoubleMM
2936
include("DoubleMM/DoubleMM.jl")
3037

0 commit comments

Comments
 (0)