Skip to content

Commit 467fa26

Browse files
committed
move ML frameworks to extensions
1 parent 03c0f5b commit 467fa26

23 files changed

+955
-177
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@
55
/docs/Manifest*.toml
66
/docs/build/
77
test/Manifest*.toml
8+
dev/Manifest*.toml

Project.toml

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,31 @@ uuid = "a108c475-a4e2-4021-9a84-cfa7df242f64"
33
authors = ["Thomas Wutzler <twutz@bgc-jena.mpg.de> and contributors"]
44
version = "1.0.0-DEV"
55

6+
[weakdeps]
7+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
8+
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
9+
SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
10+
11+
[deps]
12+
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
13+
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
14+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
15+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
16+
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
17+
618
[compat]
7-
julia = "1.6.7"
19+
Combinatorics = "1.0.2"
20+
ComponentArrays = "0.15.19"
21+
Flux = "v0.15.2"
22+
Lux = "1.4.2"
23+
Random = "1.10.0"
24+
SimpleChains = "0.4"
25+
julia = "1.10"
26+
27+
[extensions]
28+
HybridVariationalInferenceSimpleChainsExt = "SimpleChains"
29+
HybridVariationalInferenceFluxExt = "Flux"
30+
HybridVariationalInferenceLuxExt = "Lux"
31+
32+
[workspace]
33+
projects = ["test", "docs"]

dev/Project.toml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
[deps]
2+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3+
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
4+
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
5+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
6+
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
7+
HybridVariationalInference = "a108c475-a4e2-4021-9a84-cfa7df242f64"
8+
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
9+
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
10+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
11+
SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
12+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
13+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
14+
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
15+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
16+
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
17+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

dev/doubleMM.jl

Lines changed: 331 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,331 @@
1+
using SimpleChains, BenchmarkTools, Static, OptimizationOptimisers
2+
import Zygote
3+
using StatsFuns: logistic
4+
using UnicodePlots
5+
using Distributions
6+
using StableRNGs
7+
using LinearAlgebra, StatsBase, Combinatorics
8+
using Random
9+
using MLUtils
10+
11+
using Test
12+
using HybridVariationalInference
13+
using StableRNGs
14+
using Random
15+
using Statistics
16+
using ComponentArrays: ComponentArrays as CA
17+
18+
using SimpleChains
19+
import Zygote
20+
21+
using OptimizationOptimisers
22+
23+
using UnicodePlots
24+
25+
const EX = HybridVariationalInference.DoubleMM
26+
27+
() -> begin
28+
#const PROJECT_ROOT = pkgdir(@__MODULE__)
29+
_project_dir = basename(@__DIR__) == "uncNN" ? dirname(@__DIR__) : @__DIR__
30+
include(joinpath(_project_dir, "uncNN", "ComponentArrayInterpreter.jl"))
31+
include(joinpath(_project_dir, "uncNN", "util.jl")) # flatten1
32+
end
33+
34+
const T = Float32
35+
rng = StableRNG(111)
36+
37+
const n_covar_pc = 2
38+
const n_covar = n_covar_pc + 3 # linear dependent
39+
const n_site = 10^n_covar_pc
40+
# n responses each per 200 observations
41+
n_batch = n_site
42+
43+
# moved to f_doubleMM
44+
#θP = θP_true = CA.ComponentVector(r0 = 0.3, K2=2.0)
45+
#θM = EX.θM = CA.ComponentVector(r1 = 0.5, K1 = 0.2)
46+
47+
const n_θP = length(EX.θP)
48+
const n_θM = length(EX.θM)
49+
50+
const int_θP = ComponentArrayInterpreter(EX.θP)
51+
const int_θM = ComponentArrayInterpreter(EX.θM)
52+
const int_θMs = ComponentArrayInterpreter(EX.θM, (n_batch,))
53+
const int_θPMs_flat = ComponentArrayInterpreter(P = n_θP, Ms = n_θM * n_batch)
54+
const int_θ = ComponentArrayInterpreter(CA.ComponentVector(;θP=EX.θP,θM=EX.θM))
55+
# moved to f_doubleMM
56+
# const int_θdoubleMM = ComponentArrayInterpreter(flatten1(CA.ComponentVector(;θP,θM)))
57+
# const S1 = [1.0, 1.0, 1.0, 0.3, 0.1]
58+
# const S2 = [1.0, 3.0, 5.0, 5.0, 5.0]
59+
θ = CA.getdata(vcat(EX.θP, EX.θM))
60+
61+
const int_θPMs = ComponentArrayInterpreter(CA.ComponentVector(;EX.θP,
62+
θMs=CA.ComponentMatrix(zeros(n_θM, n_batch), first(CA.getaxes(EX.θM)), CA.Axis(i=1:n_batch))))
63+
64+
f = EX.f_doubleMM
65+
66+
67+
# moved to f_doubleMM
68+
# gen_q(InteractionsCovCor)
69+
x_o, θMs_true0, g, q = EX.gen_q(
70+
rng, T, length(EX.θM), n_covar, n_site, n_θM);
71+
72+
# normalize to be distributed around the prescribed true values
73+
θMs_true = int_θMs(scale_centered_at(θMs_true0, EX.θM, 0.1))
74+
75+
extrema(θMs_true)
76+
histogram(vec(θMs_true[:r1,:]))
77+
histogram(vec(θMs_true[:K1,:]))
78+
79+
@test isapprox(vec(mean(CA.getdata(θMs_true); dims=2)), CA.getdata(EX.θM), rtol=0.02)
80+
@test isapprox(vec(std(CA.getdata(θMs_true); dims=2)), CA.getdata(EX.θM) .* 0.1, rtol=0.02)
81+
82+
# moved to f_doubleMM
83+
#applyf(f_double, θMs_true, stack(Iterators.repeated(CA.getdata(θP_true), size(θMs_true,2))))
84+
85+
y_true = applyf(f, θMs_true, EX.θP)
86+
σ_o = 0.01
87+
#σ_o = 0.002
88+
y_o = y_true .+ reshape(randn(length(y_true)), size(y_true)...) .* σ_o
89+
scatterplot(vec(y_true), vec(y_o))
90+
scatterplot(vec(log.(y_true)), vec(log.(y_o)))
91+
92+
# fit g to log(θ_true) ~ x_o
93+
ϕg = ϕg0 = SimpleChains.init_params(g);
94+
n_ϕg = length(ϕg)
95+
96+
97+
#----- fit g to θMs_true
98+
function loss_g(ϕg, x, g)
99+
ζMs = g(x, ϕg) # predict the log of the parameters
100+
θMs = exp.(ζMs)
101+
loss = sum(abs2, θMs .- θMs_true)
102+
return loss, θMs
103+
end
104+
loss_g(ϕg,x_o, g)
105+
Zygote.gradient(x-> loss_g(x, x_o, g)[1], ϕg);
106+
107+
optf = Optimization.OptimizationFunction((ϕg, p) -> loss_g(ϕg,x_o, g)[1],
108+
Optimization.AutoZygote())
109+
optprob = Optimization.OptimizationProblem(optf, ϕg0);
110+
res = Optimization.solve(optprob, Adam(0.02), callback=callback_loss(100), maxiters=500);
111+
112+
ϕg_opt1 = res.u;
113+
scatterplot(vec(θMs_true), vec(loss_g(ϕg_opt1, x_o, g)[2]))
114+
@test cor(vec(θMs_true), vec(loss_g(ϕg_opt1, x_o, g)[2])) > 0.9
115+
116+
#----------- fit q and θP to y_obs
117+
int_ϕθP = ComponentArrayInterpreter(CA.ComponentVector(ϕg=1:length(ϕg), θP=EX.θP))
118+
p = p0 = vcat(ϕg0, EX.θP .* 0.9); # slightly disturb θP_true
119+
#p = p0 = vcat(ϕg_opt1, θP_true .* 0.9); # slightly disturb θP_true
120+
p0c = int_ϕθP(p0);
121+
#gf(g,f_doubleMM, x_o, pc.ϕg, pc.θP)[1]
122+
123+
124+
k = 10
125+
# Pass the data for the batches as separate vectors wrapped in a tuple
126+
train_loader = MLUtils.DataLoader((x_o, y_o), batchsize = k)
127+
#l1 = loss_gf(p0, train_loader.data...)[1]
128+
129+
optf = Optimization.OptimizationFunction((ϕ, data) -> loss_gf(ϕ, data...)[1],
130+
Optimization.AutoZygote())
131+
optprob = OptimizationProblem(optf, p0, train_loader)using SimpleChains, BenchmarkTools, Static, OptimizationOptimisers
132+
import Zygote
133+
using StatsFuns: logistic
134+
using UnicodePlots
135+
using Distributions
136+
using StableRNGs
137+
using LinearAlgebra, StatsBase, Combinatorics
138+
using Random
139+
using MLUtils
140+
141+
using Test
142+
using HybridVariationalInference
143+
using StableRNGs
144+
using Random
145+
using ComponentArrays: ComponentArrays as CA
146+
147+
const EX = HybridVariationalInference.DoubleMM
148+
149+
#const PROJECT_ROOT = pkgdir(@__MODULE__)
150+
_project_dir = basename(@__DIR__) == "uncNN" ? dirname(@__DIR__) : @__DIR__
151+
include(joinpath(_project_dir, "uncNN", "ComponentArrayInterpreter.jl"))
152+
include(joinpath(_project_dir, "uncNN", "util.jl")) # flatten1
153+
154+
T = Float32
155+
rng = StableRNG(111)
156+
157+
const n_covar_pc = 2
158+
const n_covar = n_covar_pc + 3 # linear dependent
159+
const n_site = 10^n_covar_pc
160+
# n responses each per 200 observations
161+
n_batch = n_site
162+
163+
# moved to f_doubleMM
164+
#θP = θP_true = CA.ComponentVector(r0 = 0.3, K2=2.0)
165+
#θM = EX.θM = CA.ComponentVector(r1 = 0.5, K1 = 0.2)
166+
167+
const n_θP = length(EX.θP)
168+
const n_θM = length(EX.θM)
169+
170+
const int_θP = ComponentArrayInterpreter(EX.θP)
171+
const int_θM = ComponentArrayInterpreter(EX.θM)
172+
const int_θMs = ComponentArrayInterpreter(EX.θM, (n_batch,))
173+
const int_θPMs_flat = ComponentArrayInterpreter(P = n_θP, Ms = n_θM * n_batch)
174+
const int_θ = ComponentArrayInterpreter(CA.ComponentVector(;θP=EX.θP,θM=EX.θM))
175+
# moved to f_doubleMM
176+
# const int_θdoubleMM = ComponentArrayInterpreter(flatten1(CA.ComponentVector(;θP,θM)))
177+
# const S1 = [1.0, 1.0, 1.0, 0.3, 0.1]
178+
# const S2 = [1.0, 3.0, 5.0, 5.0, 5.0]
179+
θ = CA.getdata(vcat(EX.θP, EX.θM))
180+
181+
const int_θPMs = ComponentArrayInterpreter(CA.ComponentVector(;EX.θP,
182+
θMs=CA.ComponentMatrix(zeros(n_θM, n_batch), first(CA.getaxes(EX.θM)), CA.Axis(i=1:n_batch))))
183+
184+
f = EX.f_doubleMM
185+
186+
187+
# moved to f_doubleMM
188+
# gen_q(InteractionsCovCor)
189+
x_o, θMs_true0, g, q = EX.gen_q(
190+
rng, T, length(EX.θM), n_covar, n_site, n_θM);
191+
192+
# normalize to be distributed around the true values
193+
σ_θM = EX.θM .* 0.1 # 10% around expected
194+
dt = fit(ZScoreTransform, θMs_true0, dims=2)
195+
θMs_true0_scaled = StatsBase.transform(dt, θMs_true0)
196+
θMs_true = int_θMs(EX.θM .+ θMs_true0_scaled .* σ_θM)
197+
#map(mean, eachrow(θMs_true)), map(std, eachrow(θMs_true))
198+
#scatterplot(vec(θMs_true0), vec(θMs_true))
199+
#scatterplot(vec(θMs_true0), vec(θMs_true0_scaled))
200+
201+
extrema(θMs_true)
202+
histogram(vec(θMs_true))
203+
204+
# moved to f_doubleMM
205+
#applyf(f_double, θMs_true, stack(Iterators.repeated(CA.getdata(θP_true), size(θMs_true,2))))
206+
207+
y_true = applyf(f_doubleMM, θMs_true, θP_true)
208+
σ_o = 0.01
209+
#σ_o = 0.002
210+
y_o = y_true .+ reshape(randn(length(y_true)), size(y_true)...) .* σ_o
211+
scatterplot(vec(y_true), vec(y_o))
212+
scatterplot(vec(log.(y_true)), vec(log.(y_o)))
213+
214+
ϕg = ϕg0 = SimpleChains.init_params(g);
215+
n_ϕg = length(ϕg)
216+
ϕq = SimpleChains.init_params(q);
217+
#G = SimpleChains.alloc_threaded_grad(g);
218+
#@benchmark valgrad!($g, $mlpdloss, $x_o, $ϕg) # dropout active
219+
220+
#----- fit g to x_o and θMs_true
221+
function loss_g(ϕg, x, g)
222+
ζMs = g(x, ϕg) # predict the log of the parameters
223+
θMs = exp.(ζMs)
224+
loss = sum(abs2, θMs .- θMs_true)
225+
return loss, θMs
226+
end
227+
loss_g(ϕg,x_o, g)
228+
Zygote.gradient(x-> loss_g(x, x_o, g)[1], ϕg);
229+
230+
optf = Optimization.OptimizationFunction((ϕg, p) -> loss_g(ϕg,x_o, g)[1],
231+
Optimization.AutoZygote())
232+
optprob = Optimization.OptimizationProblem(optf, ϕg0);
233+
res = Optimization.solve(optprob, Adam(0.02), callback=callback_loss(100), maxiters=500);
234+
235+
ϕg_opt1 = res.u
236+
scatterplot(vec(θMs_true), vec(loss_g(ϕg_opt1, x_o, g)[2]))
237+
@test cor(vec(θMs_true), vec(loss_g(ϕg_opt1, x_o, g)[2])) > 0.9
238+
239+
#-------- fit g and θP to x_o and y_o
240+
int_ϕθP = ComponentArrayInterpreter(CA.ComponentVector(ϕg=1:length(ϕg), θP=θP_true))
241+
p = p0 = vcat(ϕg0, θP_true .* 0.9); # slightly disturb θP_true
242+
#p = p0 = vcat(ϕg_opt1, θP_true .* 0.9); # slightly disturb θP_true
243+
p0c = int_ϕθP(p0);
244+
#gf(g,f_doubleMM, x_o, pc.ϕg, pc.θP)[1]
245+
246+
247+
248+
249+
k = 10
250+
# Pass the data for the batches as separate vectors wrapped in a tuple
251+
train_loader = MLUtils.DataLoader((x_o, y_o), batchsize = k)
252+
#l1 = loss_gf(p0, train_loader.data...)[1]
253+
254+
optf = Optimization.OptimizationFunction((ϕ, data) -> loss_gf(ϕ, data...)[1],
255+
Optimization.AutoZygote())
256+
optprob = OptimizationProblem(optf, p0, train_loader)
257+
# caution: larger learning rate (of 0.02) or fewer iterations -> skewed θMs_pred ~ θMs_true
258+
res = Optimization.solve(optprob, Optimisers.Adam(0.02); callback=callback_loss(100), maxiters=2_000);
259+
#res = Optimization.solve(optprob, Optimisers.ADAM(0.02); callback=callback_loss(100), epochs=200);
260+
261+
262+
263+
() -> begin
264+
loss_gf(p0)[1]
265+
loss_gf(vcat(ϕg, θP_true))[1]
266+
loss_gf(vcat(ϕg_opt1, θP_true))[1]
267+
loss_gf(res.u)[1]
268+
269+
scatterplot(vec(loss_gf(res.u)[2]), vec(y_true))
270+
scatterplot(vec(loss_gf(res.u)[2]), vec(y_o))
271+
scatterplot(vec(y_true), vec(y_o))
272+
end
273+
274+
poptc = int_ϕθP(res.u);
275+
ϕg_opt, θP_opt = poptc.ϕg, poptc.θP;
276+
hcat(θP_true, θP_opt, p0c.θP)
277+
y_pred, θMs_pred = gf(g, f_doubleMM, x_o, ϕg_opt, θP_opt);
278+
() -> begin
279+
scatterplot(vec(y_pred), vec(y_o))
280+
scatterplot(vec(y_pred), vec(y_true) )
281+
282+
scatterplot(y_pred[1,:], y_true[1,:] )
283+
scatterplot(y_pred[2,:], y_true[2,:] )
284+
scatterplot(y_pred[1,:], y_o[1,:] )
285+
scatterplot(y_pred[2,:], y_o[2,:] )
286+
287+
plt = scatterplot(θMs_true[1,:],θMs_pred[1,:])
288+
plt = scatterplot(θMs_true[2,:],θMs_pred[2,:])
289+
end
290+
#vcat(θMs_true, θMs_pred)
291+
plt = scatterplot(vec(θMs_true), vec(θMs_pred))
292+
#lineplot!(plt, 0.0, 1.1, identity)
293+
294+
295+
# caution: larger learning rate (of 0.02) or fewer iterations -> skewed θMs_pred ~ θMs_true
296+
res = Optimization.solve(optprob, Optimisers.Adam(0.02); callback=callback_loss(100), maxiters=2_000);
297+
#res = Optimization.solve(optprob, Optimisers.ADAM(0.02); callback=callback_loss(100), epochs=200);
298+
299+
300+
301+
() -> begin
302+
loss_gf(p0)[1]
303+
loss_gf(vcat(ϕg, θP_true))[1]
304+
loss_gf(vcat(ϕg_opt1, θP_true))[1]
305+
loss_gf(res.u)[1]
306+
307+
scatterplot(vec(loss_gf(res.u)[2]), vec(y_true))
308+
scatterplot(vec(loss_gf(res.u)[2]), vec(y_o))
309+
scatterplot(vec(y_true), vec(y_o))
310+
end
311+
312+
poptc = int_ϕθP(res.u);
313+
ϕg_opt, θP_opt = poptc.ϕg, poptc.θP;
314+
hcat(θP_true, θP_opt, p0c.θP)
315+
y_pred, θMs_pred = gf(g, f_doubleMM, x_o, ϕg_opt, θP_opt);
316+
() -> begin
317+
scatterplot(vec(y_pred), vec(y_o))
318+
scatterplot(vec(y_pred), vec(y_true) )
319+
320+
scatterplot(y_pred[1,:], y_true[1,:] )
321+
scatterplot(y_pred[2,:], y_true[2,:] )
322+
scatterplot(y_pred[1,:], y_o[1,:] )
323+
scatterplot(y_pred[2,:], y_o[2,:] )
324+
325+
plt = scatterplot(θMs_true[1,:],θMs_pred[1,:])
326+
plt = scatterplot(θMs_true[2,:],θMs_pred[2,:])
327+
end
328+
#vcat(θMs_true, θMs_pred)
329+
plt = scatterplot(vec(θMs_true), vec(θMs_pred))
330+
#lineplot!(plt, 0.0, 1.1, identity)
331+

0 commit comments

Comments
 (0)