|
| 1 | +using Test |
| 2 | +using HybridVariationalInference |
| 3 | +using StableRNGs |
| 4 | +using Random |
| 5 | +using Statistics |
| 6 | +using ComponentArrays: ComponentArrays as CA |
| 7 | + |
| 8 | +using SimpleChains |
| 9 | +using MLUtils |
| 10 | +import Zygote |
| 11 | + |
| 12 | +using OptimizationOptimisers |
| 13 | + |
| 14 | +const case = DoubleMM.DoubleMMCase() |
| 15 | +const MLengine = Val(nameof(SimpleChains)) |
| 16 | +scenario = (:default,) |
| 17 | + |
| 18 | +par_templates = get_hybridcase_par_templates(case; scenario) |
| 19 | + |
| 20 | +(; n_covar, n_site, n_batch, n_θM, n_θP) = get_hybridcase_sizes(case; scenario) |
| 21 | + |
| 22 | +rng = StableRNG(111) |
| 23 | +(; xM, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o |
| 24 | +) = gen_hybridcase_synthetic(case, rng; scenario); |
| 25 | + |
| 26 | +@testset "gen_hybridcase_synthetic" begin |
| 27 | + @test isapprox( |
| 28 | + vec(mean(CA.getdata(θMs_true); dims = 2)), CA.getdata(par_templates.θM), rtol = 0.02) |
| 29 | + @test isapprox(vec(std(CA.getdata(θMs_true); dims = 2)), |
| 30 | + CA.getdata(par_templates.θM) .* 0.1, rtol = 0.02) |
| 31 | + |
| 32 | + # test same results for same rng |
| 33 | + rng2 = StableRNG(111) |
| 34 | + gen2 = gen_hybridcase_synthetic(case, rng2; scenario); |
| 35 | + @test gen2.y_o == y_o |
| 36 | +end |
| 37 | + |
| 38 | +@testset "loss_g" begin |
| 39 | + g, ϕg0 = gen_hybridcase_MLapplicator(case, MLengine; scenario); |
| 40 | + |
| 41 | + function loss_g(ϕg, x, g) |
| 42 | + ζMs = g(x, ϕg) # predict the log of the parameters |
| 43 | + θMs = exp.(ζMs) |
| 44 | + loss = sum(abs2, θMs .- θMs_true) |
| 45 | + return loss, θMs |
| 46 | + end |
| 47 | + loss_g(ϕg0, xM, g) |
| 48 | + Zygote.gradient(x -> loss_g(x, xM, g)[1], ϕg0); |
| 49 | + |
| 50 | + optf = Optimization.OptimizationFunction((ϕg, p) -> loss_g(ϕg, xM, g)[1], |
| 51 | + Optimization.AutoZygote()) |
| 52 | + optprob = Optimization.OptimizationProblem(optf, ϕg0); |
| 53 | + res = Optimization.solve(optprob, Adam(0.02), callback = callback_loss(100), maxiters = 600); |
| 54 | + |
| 55 | + ϕg_opt1 = res.u; |
| 56 | + pred = loss_g(ϕg_opt1, xM, g) |
| 57 | + θMs_pred = pred[2] |
| 58 | + #scatterplot(vec(θMs_true), vec(θMs_pred)) |
| 59 | + @test cor(vec(θMs_true), vec(θMs_pred)) > 0.9 |
| 60 | +end |
| 61 | + |
| 62 | +@testset "loss_gf" begin |
| 63 | + #----------- fit g and θP to y_o |
| 64 | + g, ϕg0 = gen_hybridcase_MLapplicator(case, MLengine; scenario); |
| 65 | + f = gen_hybridcase_PBmodel(case; scenario) |
| 66 | + |
| 67 | + int_ϕθP = ComponentArrayInterpreter(CA.ComponentVector( |
| 68 | + ϕg = 1:length(ϕg0), θP = par_templates.θP)) |
| 69 | + p = p0 = vcat(ϕg0, par_templates.θP .* 0.8); # slightly disturb θP_true |
| 70 | + |
| 71 | + # Pass the site-data for the batches as separate vectors wrapped in a tuple |
| 72 | + train_loader = MLUtils.DataLoader((xM, xP, y_o), batchsize = n_batch) |
| 73 | + |
| 74 | + loss_gf = get_loss_gf(g, f, y_global_o, int_ϕθP) |
| 75 | + l1 = loss_gf(p0, train_loader.data...)[1] |
| 76 | + |
| 77 | + optf = Optimization.OptimizationFunction((ϕ, data) -> loss_gf(ϕ, data...)[1], |
| 78 | + Optimization.AutoZygote()) |
| 79 | + optprob = OptimizationProblem(optf, p0, train_loader) |
| 80 | + |
| 81 | + res = Optimization.solve( |
| 82 | + optprob, Adam(0.02), callback = callback_loss(100), maxiters = 1000); |
| 83 | + |
| 84 | + l1, y_pred_global, y_pred, θMs_pred = loss_gf(res.u, train_loader.data...) |
| 85 | + @test isapprox(par_templates.θP, int_ϕθP(res.u).θP, rtol = 0.11) |
| 86 | + @test cor(vec(θMs_true), vec(θMs_pred)) > 0.9 |
| 87 | + |
| 88 | + () -> begin |
| 89 | + scatterplot(vec(θMs_true), vec(θMs_pred)) |
| 90 | + scatterplot(log.(vec(θMs_true)), log.(vec(θMs_pred))) |
| 91 | + scatterplot(vec(y_pred), vec(y_o)) |
| 92 | + hcat(par_templates.θP, int_ϕθP(p0).θP, int_ϕθP(res.u).θP) |
| 93 | + end |
| 94 | +end |
0 commit comments