Skip to content

Commit 37226f4

Browse files
committed
add doubleMM testset
1 parent 214e530 commit 37226f4

File tree

7 files changed

+105
-17
lines changed

7 files changed

+105
-17
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66
/docs/build/
77
test/Manifest*.toml
88
dev/Manifest*.toml
9+
tmp/

dev/doubleMM.jl

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,3 @@
1-
() -> begin
2-
using SimpleChains, BenchmarkTools, Static, OptimizationOptimisers
3-
import Zygote
4-
using StatsFuns: logistic
5-
using UnicodePlots
6-
using Distributions
7-
using StableRNGs
8-
using LinearAlgebra, StatsBase, Combinatorics
9-
using Random
10-
end
11-
121
using Test
132
using HybridVariationalInference
143
using StableRNGs

ext/HybridVariationalInferenceLuxExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ function HVI.construct_LuxApplicator(m::Chain; device = gpu_device())
1616
st = st |> device
1717
stateful_layer = StatefulLuxLayer{true}(m, nothing, st)
1818
#stateful_layer(x_o_gpu[:, 1:n_site_batch], ps_ca)
19-
int_ϕ = ComponentArrayInterpreter(ps_ca)
19+
int_ϕ = get_concrete(ComponentArrayInterpreter(ps_ca))
2020
LuxApplicator(stateful_layer, int_ϕ)
2121
end
2222

test/Project.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
[deps]
22
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
33
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
4+
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
45
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
56
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
67
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
8+
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
9+
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
10+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
711
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
812
SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
913
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
14+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1015
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
1116
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
17+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1218
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ const GROUP = get(ENV, "GROUP", "All") # defined in in CI.yml
77
@time @safetestset "test_gencovar" include("test_gencovar.jl")
88
#@safetestset "test" include("test/test_SimpleChains.jl")
99
@time @safetestset "test_SimpleChains" include("test_SimpleChains.jl")
10+
#@safetestset "test" include("test/test_doubleMM.jl")
11+
@time @safetestset "test_doubleMM" include("test_doubleMM.jl")
12+
#
1013
#@safetestset "test" include("test/test_Flux.jl")
1114
@time @safetestset "test_Flux" include("test_Flux.jl")
1215
#@safetestset "test" include("test/test_Lux.jl")

test/test_doubleMM.jl

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
using Test
2+
using HybridVariationalInference
3+
using StableRNGs
4+
using Random
5+
using Statistics
6+
using ComponentArrays: ComponentArrays as CA
7+
8+
using SimpleChains
9+
using MLUtils
10+
import Zygote
11+
12+
using OptimizationOptimisers
13+
14+
const case = DoubleMM.DoubleMMCase()
15+
const MLengine = Val(nameof(SimpleChains))
16+
scenario = (:default,)
17+
18+
par_templates = get_hybridcase_par_templates(case; scenario)
19+
20+
(; n_covar, n_site, n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario)
21+
22+
rng = StableRNG(111)
23+
(; xM, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o
24+
) = gen_hybridcase_synthetic(case, rng; scenario);
25+
26+
@testset "gen_hybridcase_synthetic" begin
27+
@test isapprox(
28+
vec(mean(CA.getdata(θMs_true); dims = 2)), CA.getdata(par_templates.θM), rtol = 0.02)
29+
@test isapprox(vec(std(CA.getdata(θMs_true); dims = 2)),
30+
CA.getdata(par_templates.θM) .* 0.1, rtol = 0.02)
31+
32+
# test same results for same rng
33+
rng2 = StableRNG(111)
34+
gen2 = gen_hybridcase_synthetic(case, rng2; scenario);
35+
@test gen2.y_o == y_o
36+
end
37+
38+
@testset "loss_g" begin
39+
g, ϕg0 = gen_hybridcase_MLapplicator(case, MLengine; scenario);
40+
41+
function loss_g(ϕg, x, g)
42+
ζMs = g(x, ϕg) # predict the log of the parameters
43+
θMs = exp.(ζMs)
44+
loss = sum(abs2, θMs .- θMs_true)
45+
return loss, θMs
46+
end
47+
loss_g(ϕg0, xM, g)
48+
Zygote.gradient(x -> loss_g(x, xM, g)[1], ϕg0);
49+
50+
optf = Optimization.OptimizationFunction((ϕg, p) -> loss_g(ϕg, xM, g)[1],
51+
Optimization.AutoZygote())
52+
optprob = Optimization.OptimizationProblem(optf, ϕg0);
53+
res = Optimization.solve(optprob, Adam(0.02), callback = callback_loss(100), maxiters = 600);
54+
55+
ϕg_opt1 = res.u;
56+
pred = loss_g(ϕg_opt1, xM, g)
57+
θMs_pred = pred[2]
58+
#scatterplot(vec(θMs_true), vec(θMs_pred))
59+
@test cor(vec(θMs_true), vec(θMs_pred)) > 0.9
60+
end
61+
62+
@testset "loss_gf" begin
63+
#----------- fit g and θP to y_o
64+
g, ϕg0 = gen_hybridcase_MLapplicator(case, MLengine; scenario);
65+
f = gen_hybridcase_PBmodel(case; scenario)
66+
67+
int_ϕθP = ComponentArrayInterpreter(CA.ComponentVector(
68+
ϕg = 1:length(ϕg0), θP = par_templates.θP))
69+
p = p0 = vcat(ϕg0, par_templates.θP .* 0.8); # slightly disturb θP_true
70+
71+
# Pass the site-data for the batches as separate vectors wrapped in a tuple
72+
train_loader = MLUtils.DataLoader((xM, xP, y_o), batchsize = n_batch)
73+
74+
loss_gf = get_loss_gf(g, f, y_global_o, int_ϕθP)
75+
l1 = loss_gf(p0, train_loader.data...)[1]
76+
77+
optf = Optimization.OptimizationFunction((ϕ, data) -> loss_gf(ϕ, data...)[1],
78+
Optimization.AutoZygote())
79+
optprob = OptimizationProblem(optf, p0, train_loader)
80+
81+
res = Optimization.solve(
82+
optprob, Adam(0.02), callback = callback_loss(100), maxiters = 1000);
83+
84+
l1, y_pred_global, y_pred, θMs_pred = loss_gf(res.u, train_loader.data...)
85+
@test isapprox(par_templates.θP, int_ϕθP(res.u).θP, rtol = 0.11)
86+
@test cor(vec(θMs_true), vec(θMs_pred)) > 0.9
87+
88+
() -> begin
89+
scatterplot(vec(θMs_true), vec(θMs_pred))
90+
scatterplot(log.(vec(θMs_true)), log.(vec(θMs_pred)))
91+
scatterplot(vec(y_pred), vec(y_o))
92+
hcat(par_templates.θP, int_ϕθP(p0).θP, int_ϕθP(res.u).θP)
93+
end
94+
end

tmp/scratch.jl

Lines changed: 0 additions & 5 deletions
This file was deleted.

0 commit comments

Comments
 (0)