diff --git a/Project.toml b/Project.toml index 392eb9d..e4a390d 100644 --- a/Project.toml +++ b/Project.toml @@ -12,11 +12,13 @@ CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DistributionFits = "45214091-1ed4-4409-9bcf-fdb48a05e921" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +Missings = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" @@ -47,6 +49,7 @@ CommonSolve = "0.2.4" ComponentArrays = "0.15.19" DistributionFits = "0.3.9" Distributions = "0.25.117" +FillArrays = "1.13.0" Flux = "0.14, 0.15, 0.16" Functors = "0.4, 0.5" GPUArraysCore = "0.1, 0.2" @@ -54,6 +57,7 @@ LinearAlgebra = "1.10" Lux = "1.4.2" MLDataDevices = "1.5, 1.6" MLUtils = "0.4.5" +Missings = "1.2.0" Optimization = "3.19.3, 4" Random = "1.10.0" SimpleChains = "0.4" diff --git a/_typos.toml b/_typos.toml index 7ec6b95..e6619d1 100644 --- a/_typos.toml +++ b/_typos.toml @@ -4,3 +4,4 @@ extend-exclude = ["docs/src_stash/"] [default.extend-words] SOM = "SOM" negLogLik = "negLogLik" +Missings = "Missings" diff --git a/dev/doubleMM.jl b/dev/doubleMM.jl index b7849d3..3f1a813 100644 --- a/dev/doubleMM.jl +++ b/dev/doubleMM.jl @@ -32,7 +32,7 @@ cdev = gdev isa MLDataDevices.AbstractGPUDevice ? cpu_device() : identity #------ setup synthetic data and training data loader prob0_ = HybridProblem(DoubleMM.DoubleMMCase(); scenario); -(; xM, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc +(; xM, θP_true, θMs_true, xP, y_true, y_o, y_unc ) = gen_hybridproblem_synthetic(rng, DoubleMM.DoubleMMCase(); scenario); n_site, n_batch = get_hybridproblem_n_site_and_batch(prob0_; scenario) ζP_true, ζMs_true = log.(θP_true), log.(θMs_true) @@ -59,7 +59,7 @@ n_epoch = 80 maxiters = n_batches_in_epoch * n_epoch); # update the problem with optimized parameters prob0o = prob1o =probo; -y_pred_global, y_pred, θMs = gf(prob0o; scenario, is_inferred=Val(true)); + y_pred, θMs = gf(prob0o; scenario, is_inferred=Val(true)); # @descend_code_warntype gf(prob0o; scenario) #@usingany UnicodePlots plt = scatterplot(θMs_true'[:, 1], θMs[:, 1]); @@ -77,7 +77,7 @@ histogram(vec(y_pred) - vec(y_true)) # predictions centered around y_o (or y_tru (; ϕ, resopt) = solve(prob0o, solver1; scenario, rng, callback = callback_loss(20), maxiters = 400) prob1o = HybridProblem(prob0o; ϕg = cpu_ca(ϕ).ϕg, θP = cpu_ca(ϕ).θP) - y_pred_global, y_pred, θMs = gf(prob1o, xM, xP; scenario) + y_pred, θMs = gf(prob1o, xM, xP; scenario) scatterplot(θMs_true[1, :], θMs[1, :]) scatterplot(θMs_true[2, :], θMs[2, :]) prob1o.θP @@ -91,7 +91,7 @@ end (; ϕ, resopt) = solve(prob2, solver1; scenario, rng, callback = callback_loss(20), maxiters = 600) prob2o = HybridProblem(prob2; ϕg = collect(ϕ.ϕg), θP = ϕ.θP) - y_pred_global, y_pred, θMs = gf(prob2o, xM, xP) + y_pred, θMs = gf(prob2o, xM, xP) prob2o.θP end @@ -127,7 +127,7 @@ end (; ϕ, resopt) = solve(prob3, solver1; scenario, rng, callback = callback_loss(50), maxiters = 600) prob3o = HybridProblem(prob3; ϕg = cpu_ca(ϕ).ϕg, θP = cpu_ca(ϕ).θP) - y_pred_global, y_pred, θMs = gf(prob3o, xM, xP; scenario) + y_pred, θMs = gf(prob3o, xM, xP; scenario) scatterplot(θMs_true[2, :], θMs[2, :]) prob3o.θP scatterplot(vec(y_true), vec(y_pred)) @@ -173,7 +173,7 @@ solver_post = HybridPosteriorSolver(; alg = OptimizationOptimisers.Adam(0.01), n (y1, θsP1, θsMs1) = (y, θsP, θsMs); () -> begin # prediction with fitted parameters (should be smaller than mean) - y_pred_global, y_pred2, θMs = gf(prob1o, xM, xP; scenario) + y_pred2, θMs = gf(prob1o, xM, xP; scenario) scatterplot(θMs_true[1, :], θMs[1, :]) scatterplot(θMs_true[2, :], θMs[2, :]) hcat(θP_true, θP) # all parameters overestimated @@ -366,7 +366,7 @@ end # ζMs = invt.transM.(θMs_i) # _f = get_hybridproblem_PBmodel(probo; scenario) # y_site = map(eachcol(θPs), θMs_i) do θP, θM - # y_global, y = _f(θP, reshape(θM, (length(θM), 1)), xP[[i_site]]) + # y = _f(θP, reshape(θM, (length(θM), 1)), xP[[i_site]]) # y[:,1] # end |> stack nLs = get_hybridproblem_neg_logden_obs( diff --git a/docs/make.jl b/docs/make.jl index a2198bb..93bf49c 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -23,9 +23,9 @@ makedocs(; #"Test quarto markdown" => "tutorials/test1.md", ], "How to" => [ + ".. use GPU" => "tutorials/lux_gpu.md", ".. model independent parameters" => "tutorials/blocks_corr.md", ".. model site-global corr" => "tutorials/corr_site_global.md", - ".. use GPU" => "tutorials/lux_gpu.md", ], "Explanation" => [ #"Theory" => "explanation/theory_hvi.md", TODO activate when paper is published diff --git a/docs/src/tutorials/Project.toml b/docs/src/tutorials/Project.toml index 6a3bf83..59ffeee 100644 --- a/docs/src/tutorials/Project.toml +++ b/docs/src/tutorials/Project.toml @@ -8,6 +8,7 @@ DistributionFits = "45214091-1ed4-4409-9bcf-fdb48a05e921" HybridVariationalInference = "a108c475-a4e2-4021-9a84-cfa7df242f64" JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" PairPlots = "43a3c2be-4208-490b-832a-a21dcd55d7da" diff --git a/docs/src/tutorials/basic_cpu.md b/docs/src/tutorials/basic_cpu.md index 5a50ff7..e1a80c9 100644 --- a/docs/src/tutorials/basic_cpu.md +++ b/docs/src/tutorials/basic_cpu.md @@ -104,14 +104,14 @@ HVI is an approximate bayesian analysis and combines prior information on the parameters with the model and observed data. Here, we provide a wide prior by fitting a Lognormal distributions to -- the mean corresponding to the initial value provided above -- the 0.95-quantile 3 times the mean +- the mode corresponding to the initial value provided above +- the 0.95-quantile 3 times the mode using the `DistributionFits.jl` package. ``` julia θall = vcat(θP, θM) priors_dict = Dict{Symbol, Distribution}( - keys(θall) .=> fit.(LogNormal, θall, QuantilePoint.(θall .* 3, 0.95))) + keys(θall) .=> fit.(LogNormal, θall, QuantilePoint.(θall .* 3, 0.95), Val(:mode))) ``` ## Observations, model drivers and covariates diff --git a/docs/src/tutorials/basic_cpu.qmd b/docs/src/tutorials/basic_cpu.qmd index 89a77ce..37d8449 100644 --- a/docs/src/tutorials/basic_cpu.qmd +++ b/docs/src/tutorials/basic_cpu.qmd @@ -109,14 +109,14 @@ HVI is an approximate bayesian analysis and combines prior information on the parameters with the model and observed data. Here, we provide a wide prior by fitting a Lognormal distributions to -- the mean corresponding to the initial value provided above -- the 0.95-quantile 3 times the mean +- the mode corresponding to the initial value provided above +- the 0.95-quantile 3 times the mode using the `DistributionFits.jl` package. ```{julia} θall = vcat(θP, θM) priors_dict = Dict{Symbol, Distribution}( - keys(θall) .=> fit.(LogNormal, θall, QuantilePoint.(θall .* 3, 0.95))) + keys(θall) .=> fit.(LogNormal, θall, QuantilePoint.(θall .* 3, 0.95), Val(:mode))) ``` ## Observations, model drivers and covariates @@ -138,7 +138,7 @@ rng = StableRNG(111) #| echo: false #| eval: false () -> begin - (; xM, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc) = + (; xM, θP_true, θMs_true, xP, y_true, y_o, y_unc) = gen_hybridproblem_synthetic(rng, DoubleMM.DoubleMMCase(); scenario=Val((:omit_r0,))) end ``` diff --git a/docs/src/tutorials/blocks_corr_files/figure-commonmark/cell-11-output-1.png b/docs/src/tutorials/blocks_corr_files/figure-commonmark/cell-11-output-1.png index 6a60d64..deb0d1d 100644 Binary files a/docs/src/tutorials/blocks_corr_files/figure-commonmark/cell-11-output-1.png and b/docs/src/tutorials/blocks_corr_files/figure-commonmark/cell-11-output-1.png differ diff --git a/docs/src/tutorials/blocks_corr_files/figure-commonmark/cell-12-output-1.png b/docs/src/tutorials/blocks_corr_files/figure-commonmark/cell-12-output-1.png index 4d4bee7..7225faa 100644 Binary files a/docs/src/tutorials/blocks_corr_files/figure-commonmark/cell-12-output-1.png and b/docs/src/tutorials/blocks_corr_files/figure-commonmark/cell-12-output-1.png differ diff --git a/docs/src/tutorials/blocks_corr_files/figure-commonmark/cell-13-output-1.png b/docs/src/tutorials/blocks_corr_files/figure-commonmark/cell-13-output-1.png index 5c086ba..a5ab7fb 100644 Binary files a/docs/src/tutorials/blocks_corr_files/figure-commonmark/cell-13-output-1.png and b/docs/src/tutorials/blocks_corr_files/figure-commonmark/cell-13-output-1.png differ diff --git a/docs/src/tutorials/corr_site_global_files/figure-commonmark/cell-10-output-1.png b/docs/src/tutorials/corr_site_global_files/figure-commonmark/cell-10-output-1.png index 9aa44b3..9d2324e 100644 Binary files a/docs/src/tutorials/corr_site_global_files/figure-commonmark/cell-10-output-1.png and b/docs/src/tutorials/corr_site_global_files/figure-commonmark/cell-10-output-1.png differ diff --git a/docs/src/tutorials/corr_site_global_files/figure-commonmark/cell-11-output-1.png b/docs/src/tutorials/corr_site_global_files/figure-commonmark/cell-11-output-1.png index cd914fe..123a022 100644 Binary files a/docs/src/tutorials/corr_site_global_files/figure-commonmark/cell-11-output-1.png and b/docs/src/tutorials/corr_site_global_files/figure-commonmark/cell-11-output-1.png differ diff --git a/docs/src/tutorials/corr_site_global_files/figure-commonmark/cell-12-output-1.png b/docs/src/tutorials/corr_site_global_files/figure-commonmark/cell-12-output-1.png index 2ef2dac..11ba18f 100644 Binary files a/docs/src/tutorials/corr_site_global_files/figure-commonmark/cell-12-output-1.png and b/docs/src/tutorials/corr_site_global_files/figure-commonmark/cell-12-output-1.png differ diff --git a/docs/src/tutorials/corr_site_global_files/figure-commonmark/cell-9-output-1.png b/docs/src/tutorials/corr_site_global_files/figure-commonmark/cell-9-output-1.png index d10960c..ab8eb5c 100644 Binary files a/docs/src/tutorials/corr_site_global_files/figure-commonmark/cell-9-output-1.png and b/docs/src/tutorials/corr_site_global_files/figure-commonmark/cell-9-output-1.png differ diff --git a/docs/src/tutorials/inspect_results.md b/docs/src/tutorials/inspect_results.md index 4eb8b8f..0c41fe8 100644 --- a/docs/src/tutorials/inspect_results.md +++ b/docs/src/tutorials/inspect_results.md @@ -5,7 +5,7 @@ CurrentModule = HybridVariationalInference ``` -This tutorial leads you through querying relevant information from the +This tutorial leads you through extracting relevant information from the inversion results and to produce some typical plots. First load necessary packages. @@ -110,7 +110,7 @@ In addition to the uncertainty in parameters, we are also interested in the uncertainty of predictions, i.e. the predictive posterior. We cam either run the PBM for all the parameter samples that we obtained already, -using the [`AbstractModelApplicator`](@ref), or use [`predict_hvi`](@ref) which combines +using the [`AbstractPBMApplicator`](@ref), or use [`predict_hvi`](@ref) which combines sampling the posterior and predictive posterior and returns the additional `NamedTuple` entry `y`. diff --git a/docs/src/tutorials/inspect_results_files/figure-commonmark/cell-10-output-1.png b/docs/src/tutorials/inspect_results_files/figure-commonmark/cell-10-output-1.png index 32ef60b..75b8d1d 100644 Binary files a/docs/src/tutorials/inspect_results_files/figure-commonmark/cell-10-output-1.png and b/docs/src/tutorials/inspect_results_files/figure-commonmark/cell-10-output-1.png differ diff --git a/docs/src/tutorials/inspect_results_files/figure-commonmark/cell-13-output-1.png b/docs/src/tutorials/inspect_results_files/figure-commonmark/cell-13-output-1.png index f994f9e..88e0fff 100644 Binary files a/docs/src/tutorials/inspect_results_files/figure-commonmark/cell-13-output-1.png and b/docs/src/tutorials/inspect_results_files/figure-commonmark/cell-13-output-1.png differ diff --git a/docs/src/tutorials/inspect_results_files/figure-commonmark/cell-8-output-1.png b/docs/src/tutorials/inspect_results_files/figure-commonmark/cell-8-output-1.png index a6c133b..feeda24 100644 Binary files a/docs/src/tutorials/inspect_results_files/figure-commonmark/cell-8-output-1.png and b/docs/src/tutorials/inspect_results_files/figure-commonmark/cell-8-output-1.png differ diff --git a/docs/src/tutorials/intermediate/basic_cpu_results.jld2 b/docs/src/tutorials/intermediate/basic_cpu_results.jld2 index 126fa04..4046a37 100644 Binary files a/docs/src/tutorials/intermediate/basic_cpu_results.jld2 and b/docs/src/tutorials/intermediate/basic_cpu_results.jld2 differ diff --git a/docs/src/tutorials/lux_gpu.md b/docs/src/tutorials/lux_gpu.md index ade618c..9e7328c 100644 --- a/docs/src/tutorials/lux_gpu.md +++ b/docs/src/tutorials/lux_gpu.md @@ -27,6 +27,7 @@ using StableRNGs using MLUtils using JLD2 using Random +using MLDataDevices # using CairoMakie # using PairPlots # scatterplot matrices ``` @@ -88,15 +89,16 @@ Hence specify - `gdevs = (; gdev_M=gpu_device(), gdev_P=gpu_device())`: to move both ML model and PBM to GPU - `gdevs = (; gdev_M=gpu_device(), gdev_P=identity)`: to move both ML model to GPU but execute the PBM (and parameter transformation) on CPU +Currently, putting the PBM on gpu is not efficient during inversion, because +prior distribution needs to be evaluated for each sample. +However, sampling and prediction using a fitted model is efficient. + In addition, the libraries of the GPU device need to be activated by importing respective Julia packages. Currently, only CUDA is tested with this `HybridVariationalInference` package. ``` julia import CUDA, cuDNN # so that gpu_device() returns a CUDADevice -#CUDA.device!(4) -gdevs = (; gdev_M=gpu_device(), gdev_P=gpu_device()) -#gdevs = (; gdev_M=gpu_device(), gdev_P=identity) using OptimizationOptimisers import Zygote @@ -105,7 +107,7 @@ solver = HybridPosteriorSolver(; alg=Adam(0.02), n_MC=3) (; probo) = solve(prob_lux, solver; callback = callback_loss(100), epochs = 10, - gdevs, + gdevs = (; gdev_M=gpu_device(), gdev_P=identity) ); probo_lux = probo; ``` @@ -116,7 +118,8 @@ The sampling and prediction methods, also take this `gdevs` keyword argument. ``` julia n_sample_pred = 400 (y_dev, θsP_dev, θsMs_dev) = (; y, θsP, θsMs) = predict_hvi( - rng, probo_lux; n_sample_pred, gdevs); + rng, probo_lux; n_sample_pred, + gdevs = (; gdev_M=gpu_device(), gdev_P=gpu_device())); ``` If `gdev_P` is not an `AbstractGPUDevice` then all the results are on CPU. diff --git a/docs/src/tutorials/lux_gpu.qmd b/docs/src/tutorials/lux_gpu.qmd index 6072d20..44bc774 100644 --- a/docs/src/tutorials/lux_gpu.qmd +++ b/docs/src/tutorials/lux_gpu.qmd @@ -38,6 +38,7 @@ using StableRNGs using MLUtils using JLD2 using Random +using MLDataDevices # using CairoMakie # using PairPlots # scatterplot matrices ``` @@ -98,6 +99,10 @@ Hence specify - `gdevs = (; gdev_M=gpu_device(), gdev_P=gpu_device())`: to move both ML model and PBM to GPU - `gdevs = (; gdev_M=gpu_device(), gdev_P=identity)`: to move both ML model to GPU but execute the PBM (and parameter transformation) on CPU +Currently, putting the PBM on gpu is not efficient during inversion, because +prior distribution needs to be evaluated for each sample. +However, sampling and prediction using a fitted model is efficient. + In addition, the libraries of the GPU device need to be activated by importing respective Julia packages. Currently, only CUDA is tested with this `HybridVariationalInference` package. @@ -112,9 +117,6 @@ end ``` ```{julia} import CUDA, cuDNN # so that gpu_device() returns a CUDADevice -#CUDA.device!(4) -gdevs = (; gdev_M=gpu_device(), gdev_P=gpu_device()) -#gdevs = (; gdev_M=gpu_device(), gdev_P=identity) using OptimizationOptimisers import Zygote @@ -123,7 +125,7 @@ solver = HybridPosteriorSolver(; alg=Adam(0.02), n_MC=3) (; probo) = solve(prob_lux, solver; callback = callback_loss(100), epochs = 10, - gdevs, + gdevs = (; gdev_M=gpu_device(), gdev_P=identity) ); probo_lux = probo; ``` @@ -134,7 +136,8 @@ The sampling and prediction methods, also take this `gdevs` keyword argument. ```{julia} n_sample_pred = 400 (y_dev, θsP_dev, θsMs_dev) = (; y, θsP, θsMs) = predict_hvi( - rng, probo_lux; n_sample_pred, gdevs); + rng, probo_lux; n_sample_pred, + gdevs = (; gdev_M=gpu_device(), gdev_P=gpu_device())); ``` If `gdev_P` is not an `AbstractGPUDevice` then all the results are on CPU. diff --git a/ext/HybridVariationalInferenceCUDAExt.jl b/ext/HybridVariationalInferenceCUDAExt.jl index aa1e2f2..180944c 100644 --- a/ext/HybridVariationalInferenceCUDAExt.jl +++ b/ext/HybridVariationalInferenceCUDAExt.jl @@ -2,6 +2,7 @@ module HybridVariationalInferenceCUDAExt using HybridVariationalInference, CUDA using HybridVariationalInference: HybridVariationalInference as HVI +using ComponentArrays: ComponentArrays as CA using ChainRulesCore # here, really CUDA-specific implementation, in case need to code other GPU devices @@ -84,6 +85,12 @@ function HVI._create_randn(rng, v::CUDA.CuVector{T,M}, dims...) where {T,M} res::CUDA.CuArray{T, length(dims),M} end +function HVI.ones_similar_x(x::CuArray, size_ret = size(x)) + # call CUDA.ones rather than ones for x::CuArray + ChainRulesCore.@ignore_derivatives CUDA.ones(eltype(x), size_ret) +end + + diff --git a/src/AbstractHybridProblem.jl b/src/AbstractHybridProblem.jl index 798ed57..ad0404a 100644 --- a/src/AbstractHybridProblem.jl +++ b/src/AbstractHybridProblem.jl @@ -141,9 +141,8 @@ Setup synthetic data, a NamedTuple of - θP_true: vector global process-model parameters - θMs_true: matrix of site-varying process-model parameters, with - xP: Vector of process-model drivers, with an entry per site -- y_global_true: vector of global observations - y_true: matrix of site-specific observations with one column per site -- y_global_o, y_o: observations with added noise +- y_o: observations with added noise """ function gen_hybridproblem_synthetic end diff --git a/src/DoubleMM/f_doubleMM.jl b/src/DoubleMM/f_doubleMM.jl index 2d8afbb..7c8928a 100644 --- a/src/DoubleMM/f_doubleMM.jl +++ b/src/DoubleMM/f_doubleMM.jl @@ -77,20 +77,23 @@ Returns a matrix `(n_obs x n_site)` of predictions. ```julia function f_doubleMM_sites(θc::ComponentMatrix, xPc::ComponentMatrix) # extract several covariates from xP - ST = typeof(getdata(xPc)[1:1,:]) # workaround for non-type-stable Symbol-indexing - S1 = (getdata(xPc[:S1,:])::ST) - S2 = (getdata(xPc[:S2,:])::ST) + ST = typeof(CA.getdata(xPc)[1:1,:]) # workaround for non-type-stable Symbol-indexing + S1 = (CA.getdata(xPc[:S1,:])::ST) + S2 = (CA.getdata(xPc[:S2,:])::ST) # # extract the parameters as vectors that are row-repeated into a matrix + VT = typeof(CA.getdata(θc)[:,1]) # workaround for non-type-stable Symbol-indexing n_obs = size(S1, 1) - VT = typeof(getdata(θc)[:,1]) # workaround for non-type-stable Symbol-indexing + rep_fac = ones_similar_x(xPc, n_obs) # to reshape into matrix, avoiding repeat (r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par - p1 = getdata(θc[:, par]) ::VT - repeat(p1', n_obs) # matrix: same for each concentration row in S1 + p1 = CA.getdata(θc[:, par]) ::VT + #(r0 .* rep_fac)' # move to computation below to save allocation + #repeat(p1', n_obs) # matrix: same for each concentration row in S1 end # # each variable is a matrix (n_obs x n_site) - r0 .+ r1 .* S1 ./ (K1 .+ S1) .* S2 ./ (K2 .+ S2) + #r0 .+ r1 .* S1 ./ (K1 .+ S1) .* S2 ./ (K2 .+ S2) + (r0 .* rep_fac)' .+ (r1 .* rep_fac)' .* S1 ./ ((K1 .* rep_fac)' .+ S1) .* S2 ./ ((K2 .* rep_fac)' .+ S2) end ``` """ @@ -101,15 +104,18 @@ function f_doubleMM_sites(θc::CA.ComponentMatrix, xPc::CA.ComponentMatrix) S2 = (CA.getdata(xPc[:S2,:])::ST) # # extract the parameters as vectors that are row-repeated into a matrix - n_obs = size(S1, 1) VT = typeof(CA.getdata(θc)[:,1]) # workaround for non-type-stable Symbol-indexing + n_obs = size(S1, 1) + rep_fac = HVI.ones_similar_x(xPc, n_obs) # to reshape into matrix, avoiding repeat (r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par p1 = CA.getdata(θc[:, par]) ::VT - repeat(p1', n_obs) # matrix: same for each concentration row in S1 + #repeat(p1', n_obs) # matrix: same for each concentration row in S1 + #(rep_fac .* p1') # move to computation below to save allocation end # # each variable is a matrix (n_obs x n_site) - r0 .+ r1 .* S1 ./ (K1 .+ S1) .* S2 ./ (K2 .+ S2) + #r0 .+ r1 .* S1 ./ (K1 .+ S1) .* S2 ./ (K2 .+ S2) + (rep_fac .* r0') .+ (rep_fac .* r1') .* S1 ./ ((rep_fac .* K1') .+ S1) .* S2 ./ ((rep_fac .* K2') .+ S2) end # function f_doubleMM_sites(θc::CA.ComponentMatrix, xPc::CA.ComponentMatrix) @@ -204,7 +210,7 @@ end # end function HVI.get_hybridproblem_priors(::DoubleMMCase; scenario::Val{scen}) where {scen} - Dict(keys(θall) .=> fit.(LogNormal, θall, QuantilePoint.(θall .* 3, 0.95))) + Dict(keys(θall) .=> fit.(LogNormal, θall, QuantilePoint.(θall .* 3, 0.95), Val(:mode))) end function HVI.get_hybridproblem_MLapplicator( @@ -371,21 +377,18 @@ function HVI.gen_hybridproblem_synthetic(rng::AbstractRNG, prob::DoubleMMCase; xP = int_xP_sites(vcat(repeat(xP_S1, 1, n_site), repeat(xP_S2, 1, n_site))) #xP[:S1,:] θP = par_templates.θP - y_global_true, y_true = f(θP, θMs_true', xP) + y_true = f(θP, θMs_true', xP) σ_o = FloatType(0.01) #σ_o = FloatType(0.002) logσ2_o = FloatType(2) .* log.(σ_o) #σ_o = 0.002 - y_global_o = y_global_true .+ randn(rng, FloatType, size(y_global_true)) .* σ_o y_o = y_true .+ randn(rng, FloatType, size(y_true)) .* σ_o (; xM, θP_true = θP, θMs_true, xP, - y_global_true, y_true, - y_global_o, y_o, y_unc = fill(logσ2_o, size(y_o)) ) diff --git a/src/HybridSolver.jl b/src/HybridSolver.jl index 88a8aed..87cbaf6 100644 --- a/src/HybridSolver.jl +++ b/src/HybridSolver.jl @@ -32,12 +32,14 @@ function CommonSolve.solve(prob::AbstractHybridProblem, solver::HybridPointSolve train_loader_dev = train_loader end f = get_hybridproblem_PBmodel(prob; scenario, use_all_sites=false) - y_global_o = FT[] # TODO pbm_covars = get_hybridproblem_pbmpar_covars(prob; scenario) n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario) + priors = get_hybridproblem_priors(prob; scenario) + priorsP = [priors[k] for k in keys(par_templates.θP)] + priorsM = [priors[k] for k in keys(par_templates.θM)] #intP = ComponentArrayInterpreter(par_templates.θP) - loss_gf = get_loss_gf(g_dev, transM, transP, f, y_global_o, intϕ; - cdev=infer_cdev(gdevs), pbm_covars, n_site_batch=n_batch) + loss_gf = get_loss_gf(g_dev, transM, transP, f, intϕ; + cdev=infer_cdev(gdevs), pbm_covars, n_site_batch=n_batch, priorsP, priorsM,) # call loss function once l1 = is_infer ? Test.@inferred(loss_gf(ϕ0_dev, first(train_loader_dev)...))[1] : @@ -122,6 +124,9 @@ function CommonSolve.solve(prob::AbstractHybridProblem, solver::HybridPosteriorS int_unc = interpreters.unc int_μP_ϕg_unc = interpreters.μP_ϕg_unc transMs = StackedArray(transM, n_batch) + priors = get_hybridproblem_priors(prob; scenario) + priorsP = [priors[k] for k in keys(par_templates.θP)] + priorsM = [priors[k] for k in keys(par_templates.θM)] # train_loader = get_hybridproblem_train_dataloader(prob; scenario) if gdevs.gdev_M isa MLDataDevices.AbstractGPUDevice @@ -145,12 +150,11 @@ function CommonSolve.solve(prob::AbstractHybridProblem, solver::HybridPosteriorS priors_θP_mean, priors_θMs_mean = construct_priors_θ_mean( prob, ϕ0_dev.ϕg, keys(θM), θP, θmean_quant, g_dev, transM, transP; scenario, get_ca_int_PMs, gdevs, pbm_covars) - y_global_o = Float32[] # TODO loss_elbo = get_loss_elbo( - g_dev, transP, transMs, f_dev, py, y_global_o; + g_dev, transP, transMs, f_dev, py; solver.n_MC, solver.n_MC_cap, cor_ends, priors_θP_mean, priors_θMs_mean, - cdev=infer_cdev(gdevs), pbm_covars, θP, int_unc, int_μP_ϕg_unc) + cdev=infer_cdev(gdevs), pbm_covars, θP, int_unc, int_μP_ϕg_unc, priorsP, priorsM,) # test loss function once # tmp = first(train_loader_dev) # using ShareAdd @@ -193,26 +197,30 @@ The loss function takes in addition to ϕ, data that changes with minibatch - `xP`: drivers for the processmodel: Iterator of size n_site - `y_o`, `y_unc`: matrix of observations and uncertainties, sites in columns """ -function get_loss_elbo(g, transP, transMs, f, py, y_o_global; +function get_loss_elbo(g, transP, transMs, f, py; n_MC, n_MC_mean = max(n_MC,20), n_MC_cap=n_MC, cor_ends, priors_θP_mean, priors_θMs_mean, cdev, pbm_covars, θP, int_unc, int_μP_ϕg_unc, + priorsP, priorsM, floss_penalty = zero_penalty_loss, ) - let g = g, transP = transP, transMs = transMs, f = f, py = py, y_o_global = y_o_global, + let g = g, transP = transP, transMs = transMs, f = f, py = py, n_MC = n_MC, n_MC_cap = n_MC_cap, n_MC_mean = n_MC_mean, cor_ends = cor_ends, int_unc = get_concrete(int_unc), int_μP_ϕg_unc = get_concrete(int_μP_ϕg_unc), priors_θP_mean = priors_θP_mean, priors_θMs_mean = priors_θMs_mean, cdev = cdev, pbm_covar_indices = get_pbm_covar_indices(θP, pbm_covars), trans_mP=StackedArray(transP, n_MC_mean), - trans_mMs=StackedArray(transMs.stacked, n_MC_mean) + trans_mMs=StackedArray(transMs.stacked, n_MC_mean), + priorsP=priorsP, priorsM=priorsM, floss_penalty=floss_penalty function loss_elbo(ϕ, rng, xM, xP, y_o, y_unc, i_sites) + #ϕc = int_μP_ϕg_unc(ϕ) neg_elbo_gtf( rng, ϕ, g, f, py, xM, xP, y_o, y_unc, i_sites; int_unc, int_μP_ϕg_unc, n_MC, n_MC_cap, n_MC_mean, cor_ends, priors_θP_mean, priors_θMs_mean, cdev, pbm_covar_indices, transP, transMs, trans_mP, trans_mMs, + priorsP, priorsM, floss_penalty, #ϕg = ϕc.ϕg, ϕunc = ϕc.unc, ) end end diff --git a/src/HybridVariationalInference.jl b/src/HybridVariationalInference.jl index b982db5..0292776 100644 --- a/src/HybridVariationalInference.jl +++ b/src/HybridVariationalInference.jl @@ -20,9 +20,13 @@ using Distributions, DistributionFits using StaticArrays: StaticArrays as SA using Functors using Test: Test # @inferred +using Missings +using FillArrays export DoubleMM +include("util.jl") + export extend_stacked_nrow, StackedArray #public Exp #julia 1.10 public: https://github.com/JuliaLang/julia/pull/55097 @@ -41,6 +45,7 @@ export NullModelApplicator, MagnitudeModelApplicator, NormalScalingModelApplicat include("ModelApplicator.jl") export AbstractPBMApplicator, NullPBMApplicator, PBMSiteApplicator, PBMPopulationApplicator +export DirectPBMApplicator include("PBMApplicator.jl") # export AbstractGPUDataHandler, NullGPUDataHandler, get_default_GPUHandler @@ -87,13 +92,15 @@ include("util_opt.jl") export cpu_ca, apply_preserve_axes include("util_ca.jl") +include("util_gpu.jl") + export neg_logden_indep_normal, entropy_MvNormal include("logden_normal.jl") export get_ca_starts, get_ca_ends, get_cor_count include("cholesky.jl") -export neg_elbo_gtf, sample_posterior, predict_hvi +export neg_elbo_gtf, sample_posterior, predict_hvi, zero_penalty_loss include("elbo.jl") export init_hybrid_params, init_hybrid_ϕunc diff --git a/src/PBMApplicator.jl b/src/PBMApplicator.jl index 59abeb5..c354d02 100644 --- a/src/PBMApplicator.jl +++ b/src/PBMApplicator.jl @@ -18,7 +18,7 @@ Provided are implementations - `PBMSiteApplicator`: based on a function that computes predictions per site - `PBMPopulationApplicator`: based on a function that computes predictions for entire population - `NullPBMApplicator`: returning its input `θMs` for testing -- `PlainPBMApplicator`: based on a function that takes the same arguments as `apply_model` +- `DirectPBMApplicator`: based on a function that takes the same arguments as `apply_model` """ abstract type AbstractPBMApplicator end @@ -43,21 +43,37 @@ function apply_model(app::AbstractPBMApplicator, θsP::AbstractMatrix, θsMs::Ab # stack does not work on GPU, see specialized method for GPUArrays below y_pred = stack( map(eachcol(CA.getdata(θsP)), eachslice(CA.getdata(θsMs), dims=3)) do θP, θMs - y_global, y_pred_i = app(θP, θMs, xP) - y_pred_i + app(θP, θMs, xP) end) end +# function apply_model(app::AbstractPBMApplicator, θsP::GPUArraysCore.AbstractGPUMatrix, θsMs::GPUArraysCore.AbstractGPUArray{ET,3}, xP) where ET +# # stack does not work on GPU, need to resort to slower mapreduce +# # for type stability, apply f at first iterate to supply init to mapreduce +# P1, Pit = Iterators.peel(eachcol(CA.getdata(θsP))); +# Ms1, Msit = Iterators.peel(eachslice(CA.getdata(θsMs), dims=3)); +# y1 = apply_model(app, P1, Ms1, xP)[2] +# y1a = reshape(y1, size(y1)..., 1) # add one dimension +# y_pred = mapreduce((a,b) -> cat(a,b; dims=3), Pit, Msit; init=y1a) do θP, θMs +# y_pred_i = app(θP, θMs, xP) +# end +# end function apply_model(app::AbstractPBMApplicator, θsP::GPUArraysCore.AbstractGPUMatrix, θsMs::GPUArraysCore.AbstractGPUArray{ET,3}, xP) where ET # stack does not work on GPU, need to resort to slower mapreduce # for type stability, apply f at first iterate to supply init to mapreduce - P1, Pit = Iterators.peel(eachcol(CA.getdata(θsP))); - Ms1, Msit = Iterators.peel(eachslice(CA.getdata(θsMs), dims=3)); - y1 = apply_model(app, P1, Ms1, xP)[2] - y1a = reshape(y1, size(y1)..., 1) # add one dimension - y_pred = mapreduce((a,b) -> cat(a,b; dims=3), Pit, Msit; init=y1a) do θP, θMs - y_global, y_pred_i = app(θP, θMs, xP) - y_pred_i + # avoid Iterators.peel for CUDA + y1 = apply_model(app, CA.getdata(θsP)[:,1], CA.getdata(θsMs)[:,:,1], xP)[2] + y1a = reshape(y1, :, 1) # add one dimension + n_sample = size(θsP,2) + y_pred = if (n_sample == 1) + y1a + else + mapreduce((a,b) -> cat(a,b; dims=3), + eachcol(CA.getdata(θsP)[:,2:end]), eachslice(CA.getdata(θsMs)[:,:,2:end], dims=3); + init=y1a) do θP, θMs + app(θP, θMs, xP) + end end + return(y_pred) end @@ -74,6 +90,21 @@ function apply_model(app::NullPBMApplicator, θP::AbstractVector, θMs::Abstract return CA.getdata(θMs) end +""" + DirectPBMApplicator() + +Process-based-Model applicator that invokes directly given +function `f(θP::AbstractVector, θMs::AbstractMatrix, xP)`. +""" +struct DirectPBMApplicator{F} <: AbstractPBMApplicator + f::F +end + +function apply_model(app::DirectPBMApplicator, θP::AbstractVector, θMs::AbstractMatrix, xP) + return app.f(θP, θMs, xP) +end + + struct PBMSiteApplicator{F, IT, IXT, VFT} <: AbstractPBMApplicator fθ::F @@ -142,27 +173,20 @@ function apply_model(app::PBMSiteApplicator, θP::AbstractVector, θMs::Abstract obs1 = apply_PBMsite(θMs1, xP1) local pred_sites = mapreduce( apply_PBMsite, hcat, it_θMs, it_xP; init=reshape(obs1, :, 1)) - # # special case of mapreduce producing a vector rather than a matrix - # pred_sites = !(pred_sites0 isa AbstractMatrix) ? hcat(pred_sites0) : pred_sites0 - #obs1 = apply_PBMsite(first(eachrow(θMs)), first(eachcol(xP))) - #obs_vecs = map(apply_PBMsite, eachrow(θMs), eachcol(xP)) - #obs_vecs = (apply_PBMsite(θMs1, xP1) for (θMs1, xP1) in zip(eachrow(θMs), eachcol(xP))) - #pred_sites = stack(obs_vecs; dims = 1) - #pred_sites = stack(obs_vecs) # does not work with Zygote - local pred_global = eltype(pred_sites)[] # TODO remove - return pred_global, pred_sites + return pred_sites end -struct PBMPopulationApplicator{MFT, IPT, IT, IXT, F} <: AbstractPBMApplicator +struct PBMPopulationApplicator{MFT, RFT, IT, IXT, F} <: AbstractPBMApplicator fθpop::F θFixm::MFT # may be CuMatrix rather than Matrix - isP::IPT #Matrix{Int} # transferred to CuMatrix? + #isP::IPT #Matrix{Int} # transferred to CuMatrix? + rep_fac::RFT intθ::IT int_xP::IXT end # let fmap not descend into isP, because indexing with isP on cpu is faster -@functor PBMPopulationApplicator (θFixm, ) +@functor PBMPopulationApplicator (θFixm, rep_fac) """ PBMPopulationApplicator(fθpop, n_batch; θP, θM, θFix, xPvec) @@ -195,9 +219,11 @@ function PBMPopulationApplicator(fθpop, n_batch; # intθ = get_concrete(ComponentArrayInterpreter((n_batch,), intθvec)) int_xP = get_concrete(ComponentArrayInterpreter(int_xP_vec, (n_batch,))) - isP = repeat(axes(θP, 1)', n_batch) + #isP = repeat(axes(θP, 1)', n_batch) + # n_site = size(θMs, 1) + rep_fac = ones_similar_x(θP, n_batch) # to reshape into matrix, avoiding repeat θFixm = CA.getdata(θFix[isFix]) - PBMPopulationApplicator(fθpop, θFixm, isP, intθ, int_xP) + PBMPopulationApplicator(fθpop, θFixm, rep_fac, intθ, int_xP) end function apply_model(app::PBMPopulationApplicator, θP::AbstractVector, θMs::AbstractMatrix, xP) @@ -209,17 +235,18 @@ function apply_model(app::PBMPopulationApplicator, θP::AbstractVector, θMs::Ab "or compute PBM on CPU.") end # repeat θP and concatenate with - # Main.@infiltrate_main # repeat is 2x slower for Vector and 100 times slower (with allocation) on GPU # app.isP on CPU is slightly faster than app.isP on GPU + # multiplication has one more allocation on CPU and same speed, but 5x faster on GPU #@benchmark CA.getdata(θP[app.isP]) #@benchmark CA.getdata(repeat(θP', size(θMs,1))) - local θ = hcat(CA.getdata(θP[app.isP]), CA.getdata(θMs), app.θFixm) + #@benchmark rep_fac .* CA.getdata(θP)' # + local θ = hcat(app.rep_fac .* CA.getdata(θP)', CA.getdata(θMs), app.θFixm) + #local θ = hcat(CA.getdata(θP[app.isP]), CA.getdata(θMs), app.θFixm) #local θ = hcat(CA.getdata(repeat(θP', size(θMs,1))), CA.getdata(θMs), app.θFixm) local θc = app.intθ(CA.getdata(θ)) local xPc = app.int_xP(CA.getdata(xP)) local pred_sites = app.fθpop(θc, xPc) - local pred_global = eltype(pred_sites)[] # TODO remove - return pred_global, pred_sites + return pred_sites end diff --git a/src/elbo.jl b/src/elbo.jl index 2c26470..cc26bec 100644 --- a/src/elbo.jl +++ b/src/elbo.jl @@ -28,8 +28,11 @@ expected value of the likelihood of observations. using the mechanistic model f. """ function neg_elbo_gtf(args...; kwargs...) - nLy, entropy_ζ, nLmean_θ = neg_elbo_gtf_components(args...; kwargs...) - nLy - entropy_ζ + nLmean_θ + # TODO prior and penalty loss + (;nLjoint, entropy_ζ, loss_penalty, + nLy, neg_log_prior, neg_log_jac, + nLmean_θ) = neg_elbo_gtf_components(args...; kwargs...) + nLjoint - entropy_ζ + loss_penalty + nLmean_θ end function neg_elbo_gtf_components(rng, ϕ::AbstractVector{FT}, g, f, py, @@ -46,6 +49,8 @@ function neg_elbo_gtf_components(rng, ϕ::AbstractVector{FT}, g, f, py, transP, transMs, trans_mP =StackedArray(transP, n_MC), # provide with creating cost function trans_mMs =StackedArray(transMs.stacked, n_MC), + priorsP, priorsM, + floss_penalty = zero_penalty_loss, ) where {FT} n_MCr = isempty(priors_θP_mean) ? n_MC : max(n_MC, n_MC_mean) ζsP, ζsMs, σ = generate_ζ(rng, g, ϕ, xM; n_MC=n_MCr, cor_ends, pbm_covar_indices, @@ -54,14 +59,16 @@ function neg_elbo_gtf_components(rng, ϕ::AbstractVector{FT}, g, f, py, ζsMs_cpu = cdev(ζsMs) # fetch to CPU, because for <1000 sites (n_batch) this is faster # # maybe: translate ζ once and supply to both neg_elbo and negloglik_meanθ - nLy, entropy_ζ = neg_elbo_ζtf( + ϕc = int_μP_ϕg_unc(ϕ) + loss_comps = neg_elbo_ζtf( ζsP_cpu[:,1:n_MC], ζsMs_cpu[:,:,1:n_MC], σ, f, py, xP, y_ob, y_unc; - n_MC_cap, transP, transMs, ) + n_MC_cap, transP, transMs, priorsP, priorsM, + floss_penalty, ϕg = ϕc.ϕg, ϕunc = ϕc.unc,) # # maybe: provide trans_mP and trans_mMs with creating cost function nLmean_θ = _compute_negloglik_meanθ(ζsP_cpu, ζsMs_cpu; - trans_mP, trans_mMs, priors_θP_mean, priors_θMs_mean, i_sites) - nLy, entropy_ζ, nLmean_θ + trans_mP, trans_mMs, priors_θP_mean, priors_θMs_mean, i_sites, ) + (;loss_comps..., nLmean_θ) end function _compute_negloglik_meanθ(ζsP::AbstractMatrix{FT}, ζsMs; @@ -83,34 +90,60 @@ end Compute the neg_elbo for each sampled parameter vector (last dimension of ζs). - Transform and compute log-jac - call forward model -- compute log-density of predictions +- compute log-density of joint density of predictions and unconstrained parameters, `nLjoint` + and its components + - `nLy`: The likelihood of the data, given the parameters + - `neg_log_prior`: the prior of parameters at constrained scale + - `logjac`, negative logarithm of the absolute value of the determinant of the Jacobian of + the transformation `θ=T(ζ)`. +- `loss_penalty`: additional loss terms from floss_penalty - compute entropy of transformation """ function neg_elbo_ζtf(ζsP, ζsMs, σ, f, py, xP, y_ob, y_unc; n_MC_cap=size(ζsP,2), transP, - transMs=StackedArray(transM, size(ζsMs, 2)) + transMs=StackedArray(transM, size(ζsMs, 2)), + priorsP, priorsM, + floss_penalty, ϕg, ϕunc, ) n_MC = size(ζsP,2) - nLys = map(eachcol(ζsP), eachslice(ζsMs; dims=3)) do ζP, ζMs - θP, θMs, logjac = transform_and_logjac_ζ(ζP, ζMs; transP, transMs) - y_pred_global, y_pred_i = f(θP, θMs, xP) - # TODO nLogDen prior on \theta - #nLy1 = neg_logden_indep_normal(y_ob, y_pred_i, y_unc) - # Main.@infiltrate_main - # Test.@inferred( f(θP, θMs, xP) ) - # using ShareAdd - # @usingany Cthulhu - # @descend_code_warntype f(θP, θMs, xP) - nLy1 = py(y_ob, y_pred_i, y_unc) - nLy1 - logjac - end + cdev = cpu_device() #TODO avoid the cdev + f_sample = (ζP, ζMs) -> begin + θP, θMs, logjac_i = transform_and_logjac_ζ(ζP, ζMs; transP, transMs) + logpdf_t = (prior, θ) -> logpdf(prior, θ)::eltype(θP) + logpdf_tv = (prior, θ::AbstractVector) -> begin + map(Base.Fix1(logpdf, prior), θ)::Vector{eltype(θMs)} + end + #TODO avoid the cdev, but compute prior on GPU because transfer takes very long + neg_log_prior_i = -sum(logpdf_t.(priorsP, cdev(θP))) - sum(map( + (priorMi, θMi) -> sum(logpdf_tv(priorMi, θMi)), priorsM, eachcol(cdev(θMs)))) + y_pred_i = f(θP, θMs, xP) + #nLy1 = neg_logden_indep_normal(y_ob, y_pred_i, y_unc) + # Main.@infiltrate_main + # Test.@inferred( f(θP, θMs, xP) ) + # using ShareAdd + # @usingany Cthulhu + # @descend_code_warntype f(θP, θMs, xP) + nLy_i = py(y_ob, y_pred_i, y_unc) + loss_penalty_i = convert(eltype(ζMs),floss_penalty(y_pred_i, θMs, θP, ϕg, ϕunc)) + # make sure names to not match outer, otherwise Box type instability + (nLy_i, neg_log_prior_i, -logjac_i, loss_penalty_i) + #(nLy_i, 0.0, 0.0, 0.0) + end + # only Vector inferred, need to provide type hint + # make that all components use the same Float type + map_res = map(f_sample, eachcol(ζsP), eachslice(ζsMs; dims=3))::Vector{NTuple{4,eltype(ζsP)}} + nLys, neg_log_priors, neglogjacs, loss_penalties = vectuptotupvec(map_res) # For robustness may compute the expectation only on the n_smallest values # because its very sensitive to few large outliers #nLys_smallest = nsmallest(n_MC_cap, nLys) # does not work with Zygote if n_MC_cap == n_MC nLy = sum(nLys) / n_MC + neg_log_prior = sum(neg_log_priors) / n_MC + neg_log_jac = sum(neglogjacs) / n_MC + loss_penalty = sum(loss_penalties) / n_MC else + @warn "neg_elbo_ζtf: TPDP n_MC_cap: implement for for logjac, loss_penalty, and neg_log_prior not capped" nLys_smallest = partialsort(nLys, 1:n_MC_cap) nLy = sum(nLys_smallest) / n_MC_cap end @@ -130,9 +163,38 @@ function neg_elbo_ζtf(ζsP, ζsMs, σ, f, py, xP, y_ob, y_unc; # @show std(nLys), std(nLys)/abs(nLy) # @show std(nLys_smallest), std(nLys_smallest)/abs(nLy) # end - nLy, entropy_ζ + nLjoint = nLy + neg_log_prior + neg_log_jac + (;nLjoint, entropy_ζ, loss_penalty, nLy, neg_log_prior, neg_log_jac) end +""" + zero_penalty_loss(y_pred, θMs, θP, ϕg, ϕunc) + +Add zero i.e. no additional loss terms during the HVI fit. + +The basic cost in HVI is the negative log of the joint probability, i.e. +the likelihood of the observations given the parameters * prior probability +of the parameters. + +Sometimes there is additional knowledge not encoded in the prior, such as +one parameter must be larger than another, or entropy-weights of the +ML-parameters, and the solver accept a function to add additional loss terms. + +Arguments +- y_pred::AbstractMatrix: Observations +- θMs::AbstractMatrix: site parameters +- θP::AbstractVector: global parameters +- ϕg: ML-model parameters, +- ϕunc::AbstractVector, additional parameters of the posterior +""" +function zero_penalty_loss( + y_pred::AbstractMatrix, θMs::AbstractMatrix, θP::AbstractVector, + ϕg, ϕunc::AbstractVector) + return zero(eltype(θMs)) +end + + + """ predict_hvi([rng], predict_hvi(rng, prob::AbstractHybridProblem) @@ -275,72 +337,9 @@ function sample_posterior(rng, g, ϕ::AbstractVector, xM::AbstractMatrix; θsP, θsMs = is_infer ? Test.@inferred(transform_ζs(ζsP, ζsMs; trans_mP, trans_mMs)) : transform_ζs(ζsP, ζsMs; trans_mP, trans_mMs) - # res_pred = is_infer ? - # apply_f_trans(ζsP_cpu, ζsMs_cpu, f, xP; transP, transM, kwargs...) : - # Test.@inferred apply_f_trans(ζsP_cpu, ζsMs_cpu, f, xP; transP, transM, kwargs...) (; θsP, θsMs, entropy_ζ) end - -# """ -# Compute predictions of the transformation at given -# transformed parameters for each site. -# The number of sites is given by the number of rows in `ζsMs`. - -# Steps: -# - transform the parameters to original constrained space -# - Applies the mechanistic model for each site - -# `ζsP` and `ζsMs` are shaped according to the output of `generate_ζ`. -# Results are of shape `(n_obs x n_site_pred x n_MC)`. -# """ -# function apply_f_trans(ζsP::AbstractMatrix, ζsMs::AbstractArray, f, xP; -# transP, transM::Stacked, -# trans_mP=StackedArray(transP, size(ζsP, 2)), -# trans_mMs=StackedArray(transM, size(ζsMs, 1) * size(ζsMs, 3)) -# ) -# θsP, θsMs = transform_ζs(ζsP, ζsMs; trans_mP, trans_mMs) -# y = apply_process_model(θsP, θsMs, f, xP) -# (; y, θsP, θsMs) -# end - -# function apply_f_trans(ζP::AbstractVector, ζMs::AbstractMatrix, f, xP; -# transP, transM::Stacked, transMs::StackedArray=StackedArray(transM, size(ζMs, 1)), -# ) -# θP = transP(ζP) -# θMs = transMs(ζMs) -# y_global, y = f(θP, θMs, xP) -# (; y, θP, θMs) -# end - -# """ -# apply_process_model(θsP::AbstractMatrix, θsMs::AbstractArray{ET,3}, f, xP) - -# Call a PBM applicator for a sample of parameters of each site, and stack results - -# `θsP` and `θsMs` are shaped according to the output of `generate_ζ`, i.e. -# `(n_site_pred x n_par x n_MC)`. -# Results are of shape `(n_obs x n_site_pred x n_MC)`. -# """ -# function apply_process_model(θsP::AbstractMatrix, θsMs::AbstractArray{ET,3}, f, xP) where ET -# error("deprecated, use f(θsP, θsMs, xP)") -# # stack does not work on GPU -# # y_pred = stack( -# # map(eachcol(CA.getdata(θsP)), eachslice(CA.getdata(θsMs), dims=3)) do θP, θMs -# # y_global, y_pred_i = f(θP, θMs, xP) -# # y_pred_i -# # end) -# # for type stability, apply f at first iterate to supply init to mapreduce -# P1, Pit = Iterators.peel(eachcol(CA.getdata(θsP))); -# Ms1, Msit = Iterators.peel(eachslice(CA.getdata(θsMs), dims=3)); -# y1 = f(P1, Ms1, xP)[2] -# y1a = reshape(y1, size(y1)..., 1) # add one dimension -# y_pred = mapreduce((a,b) -> cat(a,b; dims=3), Pit, Msit; init=y1a) do θP, θMs -# y_global, y_pred_i = f(θP, θMs, xP) -# y_pred_i -# end -# end - """ Generate samples of (inv-transformed) model parameters, ζ, and the vector of standard deviations, σ, i.e. the diagonal of the cholesky-factor. diff --git a/src/gf.jl b/src/gf.jl index 11c552a..4322992 100644 --- a/src/gf.jl +++ b/src/gf.jl @@ -97,6 +97,7 @@ function gf(g::AbstractModelApplicator, transMs, transP, f, xM, xP, ϕg, ζP; end + function gf(g::AbstractModelApplicator, transMs, transP, f, xM, xP, ϕg, ζP, pbm_covar_indices::AbstractVector{<:Integer}; cdev) @@ -112,8 +113,8 @@ function gf(g::AbstractModelApplicator, transMs, transP, f, xM, xP, ϕg, ζP, θMs = gtrans(g, transMs, xMP, ϕg; cdev) θP = transP(CA.getdata(ζP)) θP_cpu = cdev(θP) - y_pred_global, y_pred = f(θP_cpu, θMs, xP) - return y_pred_global, y_pred, θMs, θP_cpu + y_pred = f(θP_cpu, θMs, xP) + return y_pred, θMs, θP_cpu end """ @@ -136,7 +137,6 @@ Create a loss function for given - g(x, ϕ): machine learning model - transM: transforamtion of parameters at unconstrained space - f(θMs, θP): mechanistic model -- y_o_global: site-independent observations - intϕ: interpreter attaching axis with components ϕg and ϕP - intP: interpreter attaching axis to ζP = ϕP with components used by f - kwargs: additional keyword arguments passed to gf, such as gdev or pbm_covars @@ -148,19 +148,30 @@ The loss function `loss_gf(ϕ, xM, xP, y_o, y_unc, i_sites)` takes - y_o: matrix of observations, sites in columns - y_unc: vector of uncertainty information for each observation - i_sites: index of sites in the batch + +and returns a NamedTuple of +- `nLjoint`: the negative-log of the joint parameter probability (Likelihood * prior) +- `y_pred`: predicted values +- `θMs`, `θP`: PBM-parameters +- `nLy`: negative log-Likelihood of y_pred +- `neg_log_prior`: negative log-prior of `θMs` and `θP` +- `neg_log_prior`: negative log-prior of `θMs` and `θP` """ -function get_loss_gf(g, transM, transP, f, y_o_global, +function get_loss_gf(g, transM, transP, f, intϕ::AbstractComponentArrayInterpreter, intP::AbstractComponentArrayInterpreter = ComponentArrayInterpreter( intϕ(1:length(intϕ)).ϕP); cdev=cpu_device(), - pbm_covars, n_site_batch, kwargs...) + pbm_covars, n_site_batch, + priorsP, priorsM, floss_penalty = zero_penalty_loss, + kwargs...) - let g = g, transM = transM, transP = transP, f = f, y_o_global = y_o_global, + let g = g, transM = transM, transP = transP, f = f, intϕ = get_concrete(intϕ), transMs = StackedArray(transM, n_site_batch), cdev = cdev, - pbm_covar_indices = CA.getdata(intP(1:length(intP))[pbm_covars]) + pbm_covar_indices = CA.getdata(intP(1:length(intP))[pbm_covars]), + priorsP = priorsP, priorsM = priorsM, floss_penalty = floss_penalty #, intP = get_concrete(intP) #inv_transP = inverse(transP), kwargs = kwargs function loss_gf(ϕ, xM, xP, y_o, y_unc, i_sites) @@ -173,11 +184,21 @@ function get_loss_gf(g, transM, transP, f, y_o_global, # ζP_cpu = cdev(CA.getdata(μ_ζP)) # ζMs_cpu = cdev(CA.getdata(μ_ζMs)) # y_pred, _, _ = apply_f_trans(ζP_cpu, ζMs_cpu, f, xP; transM, transP) - y_pred_global, y_pred, θMs, θP = gf( + y_pred, θMs_pred, θP_pred = gf( g, transMs, transP, f, xM, xP, CA.getdata(ϕc.ϕg), CA.getdata(ϕc.ϕP), pbm_covar_indices; cdev, kwargs...) - loss = sum(abs2, (y_pred .- y_o) ./ σ) #+ sum(abs2, y_pred_global .- y_o_global) - return loss, y_pred, θMs, θP + nLy = sum(abs2, (y_pred .- y_o) ./ σ) + # logpdf is not typestable for Distribution{Univariate, Continuous} + logpdf_t = (prior, θ) -> logpdf(prior, θ)::eltype(θP_pred) + logpdf_tv = (prior, θ::AbstractVector) -> begin + map(Base.Fix1(logpdf, prior), θ)::Vector{eltype(θP_pred)} + end + neg_log_prior = -sum(logpdf_t.(priorsP, θP_pred)) - sum(map( + (priorMi, θMi) -> sum(logpdf_tv(priorMi, θMi)), priorsM, eachcol(θMs_pred))) + ϕunc = eltype(θP_pred)[] # no uncertainty parameters optimized + loss_penalty = floss_penalty(y_pred, θMs_pred, θP_pred, ϕc.ϕg, ϕunc) + nLjoint_pen = nLy + neg_log_prior + loss_penalty + return (;nLjoint_pen, y_pred, θMs_pred, θP_pred, nLy, neg_log_prior, loss_penalty) end end end diff --git a/src/util.jl b/src/util.jl new file mode 100644 index 0000000..ff898d5 --- /dev/null +++ b/src/util.jl @@ -0,0 +1,36 @@ +""" + vectuptotupvec(vectup) + vectuptotupvec_allowmissing(vectup) + +Typesafe convert from Vector of Tuples to Tuple of Vectors. +The first variant does not allow for `missing` in `vectup`. +The second variant allows for `missing` but has `eltype` of `Union{Missing, ...}` in +all components of the returned Tuple, also when there were not `missing` in `vectup`. + +# Arguments +* `vectup`: A Vector of identical Tuples + +# Examples +```jldoctest; output=false +vectup = [(1,1.01, "string 1"), (2,2.02, "string 2")] +HybridVariationalInference.vectuptotupvec_allowmissing(vectup) == + ([1, 2], [1.01, 2.02], ["string 1", "string 2"]) +# output +true +``` +""" +function vectuptotupvec(vectup::AbstractVector{<:Tuple}) + Ti = eltype(vectup).parameters + npar = length(Ti) + ntuple(i -> + (getindex.(vectup, i))::Vector{Ti[i]}, npar) +end +function vectuptotupvec_allowmissing( + vectup::AbstractVector{<:Union{Missing,Tuple}}) + Ti = nonmissingtype(eltype(vectup)).parameters + npar = length(Ti) + Tim = ntuple(i -> Union{Missing,Ti[i]}, npar) + ntuple(i -> begin + allowmissing(passmissing(getindex).(vectup, i))::Vector{Tim[i]} + end, npar) +end diff --git a/src/util_gpu.jl b/src/util_gpu.jl new file mode 100644 index 0000000..7253cbd --- /dev/null +++ b/src/util_gpu.jl @@ -0,0 +1,22 @@ +""" + ones_similar_x(x::AbstractArray, size_ret = size(x)) + +Return `ones(eltype(x), size_ret)`. +Overload this methods for specific AbstractGPUArrays to return the +correct container type. +See e.g. `HybridVariationalInferenceCUDAExt` +that calls `CUDA.fill` to return a `CuArray` rather than `Array`. +""" +function ones_similar_x(x::AbstractArray, size_ret = size(x)) + #ones(eltype(x), size_ret) + Ones{eltype(x)}(size_ret) +end + +# handle containers and transformations of Arrays +ones_similar_x(x::CA.ComponentArray, s = size(x)) = ones_similar_x(CA.getdata(x), s) +ones_similar_x(x::LinearAlgebra.Adjoint, s = size(x)) = ones_similar_x(parent(x), s) +ones_similar_x(x::SubArray, s = size(x)) = ones_similar_x(parent(x), s) + + + + diff --git a/test/Project.toml b/test/Project.toml index 81c469c..ab63301 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,6 +6,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DistributionFits = "45214091-1ed4-4409-9bcf-fdb48a05e921" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" diff --git a/test/runtests.jl b/test/runtests.jl index 7247a5c..87bbd4b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,8 +5,12 @@ const GROUP = get(ENV, "GROUP", "All") # defined in in CI.yml if GROUP == "All" || GROUP == "Basic" #@safetestset "test" include("test/test_bijectors_utils.jl") @time @safetestset "test_bijectors_utils" include("test_bijectors_utils.jl") + #@safetestset "test" include("test/test_util.jl") + @time @safetestset "test_util" include("test_util.jl") #@safetestset "test" include("test/test_util_ca.jl") @time @safetestset "test_util_ca" include("test_util_ca.jl") + #@safetestset "test" include("test/test_util_gpu.jl") + @time @safetestset "test_util_gpu" include("test_util_gpu.jl") #@safetestset "test" include("test/test_ComponentArrayInterpreter.jl") @time @safetestset "test_ComponentArrayInterpreter" include("test_ComponentArrayInterpreter.jl") #@safetestset "test" include("test/test_hybridprobleminterpreters.jl") diff --git a/test/test_HybridProblem.jl b/test/test_HybridProblem.jl index 48f55aa..ef105bf 100644 --- a/test/test_HybridProblem.jl +++ b/test/test_HybridProblem.jl @@ -47,7 +47,7 @@ function construct_problem(; scenario::Val{scen}) where scen # n_batch = 10 n_site, n_batch = get_hybridproblem_n_site_and_batch(CP.DoubleMM.DoubleMMCase(); scenario) # dependency on DeoubleMMCase -> take care of changes in covariates - (; xM, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc + (; xM, θP_true, θMs_true, xP, y_true, y_o, y_unc ) = gen_hybridproblem_synthetic(rng, DoubleMM.DoubleMMCase(); scenario) n_covar = size(xM,1) n_input = (:covarK2 ∈ scen) ? n_covar +1 : n_covar @@ -73,7 +73,7 @@ function construct_problem(; scenario::Val{scen}) where scen θall = vcat(θP, θM) priors_dict = Dict{Symbol, Distribution}( keys(θall) .=> fit.(LogNormal, θall, QuantilePoint.(θall .* 3, 0.95))) - priors_dict[:r1] = fit(Normal, θall.r1, qp_uu(3 * θall.r1)) # not transformed to log-scale + priors_dict[:r1] = fit(Normal, θall.r1, qp_uu(3 * θall.r1), Val(:mode)) # not transformed to log-scale # scale (0,1) outputs MLmodel to normal distribution fitted to priors translated to ζ priorsM = [priors_dict[k] for k in keys(θM)] lowers, uppers = get_quantile_transformed(priorsM, transM) @@ -147,13 +147,15 @@ test_without_flux = (scenario) -> begin pbm_covars = get_hybridproblem_pbmpar_covars(prob; scenario) intϕ = ComponentArrayInterpreter(CA.ComponentVector( ϕg=1:length(ϕg0), ϕP=par_templates.θP)) + priors = get_hybridproblem_priors(prob; scenario) + priorsP = [priors[k] for k in keys(par_templates.θP)] + priorsM = [priors[k] for k in keys(par_templates.θM)] # slightly disturb θP_true p = p0 = vcat(ϕg0, par_templates.θP .* convert(eltype(ϕg0), 0.8)) # Pass the site-data for the batches as separate vectors wrapped in a tuple - y_global_o = Float64[] - loss_gf = get_loss_gf(g, transM, transP, f, y_global_o, intϕ; - pbm_covars, n_site_batch = n_batch) + loss_gf = get_loss_gf(g, transM, transP, f, intϕ; + pbm_covars, n_site_batch = n_batch, priorsP, priorsM) (_xM, _xP, _y_o, _y_unc, _i_sites) = first(train_loader) l1 = loss_gf(p0, _xM, _xP, _y_o, _y_unc, _i_sites) @@ -173,9 +175,10 @@ test_without_flux = (scenario) -> begin # optprob, Adam(0.02), callback = callback_loss(100), optprob, Adam(0.02), epochs = 150); - loss_gf_sites = get_loss_gf(g, transM, transP, f, y_global_o, intϕ; + loss_gf_sites = get_loss_gf(g, transM, transP, f, intϕ; pbm_covars, n_site_batch = n_site) - l1, y_pred_global, y_pred, θMs_pred = loss_gf_sites(res.u, train_loader.data...) + l1, y_pred, θMs_pred, θP, nLy, neg_log_prior = loss_gf_sites( + res.u, train_loader.data...) @test isapprox(par_templates.θP, transP(intϕ(res.u).ϕP), rtol=0.5) end end @@ -258,8 +261,7 @@ test_with_flux_gpu = (scenario) -> begin epochs = 2, θmean_quant = 0.01, # test constraining mean to initial prediction is_inferred = Val(true), - gdevs = (; gdev_M=gpu_device(), gdev_P=identity), - ); + gdevs = (; gdev_M=gpu_device(), gdev_P=identity),); @test CA.getdata(ϕ) isa GPUArraysCore.AbstractGPUVector #@test cdev(ϕ.unc.ρsM)[1] > 0 # too few iterations in test -> may fail # @@ -309,7 +311,6 @@ test_with_flux_gpu = (scenario) -> begin is = vcat(pos.P, vec(pos.Ms[i_sites,:])) cr[is,is] end - end; @testset "HybridPosteriorSolver also f on gpu $(last(CP._val_value(scenario)))" begin scenf = Val((CP._val_value(scenario)..., :use_Flux, :use_gpu, :omit_r0, :f_on_gpu)) diff --git a/test/test_doubleMM.jl b/test/test_doubleMM.jl index e87aab8..9f16a4b 100644 --- a/test/test_doubleMM.jl +++ b/test/test_doubleMM.jl @@ -32,12 +32,12 @@ par_templates = get_hybridproblem_par_templates(prob; scenario) @testset "get_hybridproblem_priors" begin θall = vcat(par_templates...) priors = get_hybridproblem_priors(prob; scenario) - @test mean(priors[:K2]) == θall.K2 + @test mode(priors[:K2]) == θall.K2 @test quantile(priors[:K2], 0.95) ≈ θall.K2 * 3 # fitted in f_doubleMM end rng = StableRNG(111) # make sure to be the same as when constructing train_dataloader -(; xM, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc +(; xM, θP_true, θMs_true, xP, y_true, y_o, y_unc, ) = gen_hybridproblem_synthetic(rng, prob; scenario); n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario) i_sites = 1:n_site @@ -82,7 +82,7 @@ end f_batch = PBMSiteApplicator(CP.DoubleMM.f_doubleMM; θP = θP_true, θM = θMs_true[:,1], θFix=CA.ComponentVector(), xPvec=xP[:,1]) - y_exp = f_batch(θP_true, θMs_true', xP)[2] + y_exp = f_batch(θP_true, θMs_true', xP) @test y == y_exp ygrad = Zygote.gradient(θv -> sum(fy(θv, xPM)), θvec)[1] if gdev isa MLDataDevices.AbstractGPUDevice @@ -90,10 +90,10 @@ end # xPMg = gdev(xPM) # yg = CP.DoubleMM.f_doubleMM(θg, xPMg, intθ); θvecg = gdev(θvec); # errors without ";" - xPMg = CP.apply_preserve_axes(gdev, xPM) - yg = @inferred fy(θvecg, xPMg) + xPMg = CP.apply_preserve_axes(gdev, xPM); + yg = @inferred fy(θvecg, xPMg); @test cdev(yg) == y_exp - ygradg = Zygote.gradient(θv -> sum(fy(θv, xPMg)), θvecg)[1] + ygradg = Zygote.gradient(θv -> sum(fy(θv, xPMg)), θvecg)[1]; @test ygradg isa CA.ComponentArray @test CA.getdata(ygradg) isa GPUArraysCore.AbstractGPUArray ygradgc = CP.apply_preserve_axes(cdev, ygradg) # can print the cpu version @@ -199,6 +199,9 @@ end n_site, n_site_batch = get_hybridproblem_n_site_and_batch(prob; scenario) f = get_hybridproblem_PBmodel(prob; scenario, use_all_sites = false) f2 = get_hybridproblem_PBmodel(prob; scenario, use_all_sites = true) + priors = get_hybridproblem_priors(prob; scenario) + priorsP = [priors[k] for k in keys(par_templates.θP)] + priorsM = [priors[k] for k in keys(par_templates.θM)] intϕ = ComponentArrayInterpreter(CA.ComponentVector( ϕg = 1:length(ϕg0), ϕP = par_templates.θP)) @@ -214,12 +217,12 @@ end @assert train_loader.data == (xM, xP, y_o, y_unc, i_sites) pbm_covars = get_hybridproblem_pbmpar_covars(prob; scenario) - #loss_gf = get_loss_gf(g, transM, f, y_global_o, intϕ; gdev = identity) - loss_gf = get_loss_gf(g, transM, transP, f, y_global_o, intϕ; - pbm_covars, n_site_batch = n_batch) - loss_gf2 = get_loss_gf(g, transM, transP, f2, y_global_o, intϕ; - pbm_covars, n_site_batch = n_site) - l1 = @inferred first(loss_gf(p0, first(train_loader)...)) + #loss_gf = get_loss_gf(g, transM, f, intϕ; gdev = identity) + loss_gf = get_loss_gf(g, transM, transP, f, intϕ; + pbm_covars, n_site_batch = n_batch, priorsP, priorsM) + loss_gf2 = get_loss_gf(g, transM, transP, f2, intϕ; + pbm_covars, n_site_batch = n_site, priorsP, priorsM) + nLjoint = @inferred first(loss_gf(p0, first(train_loader)...)) (xM_batch, xP_batch, y_o_batch, y_unc_batch, i_sites_batch) = first(train_loader) # @usingany Cthulhu # @descend_code_warntype loss_gf(p0, xM_batch, xP_batch, y_o_batch, y_unc_batch, i_sites_batch) @@ -232,16 +235,19 @@ end res = Optimization.solve( #optprob, Adam(0.02), callback = callback_loss(100), maxiters = 5000); - optprob, Adam(0.02), maxiters = 1000) + optprob, Adam(0.02), maxiters = 2000) - l1, y_pred, θMs_pred, θP_pred = loss_gf2(res.u, train_loader.data...) - #l1, y_pred_global, y_pred, θMs_pred = loss_gf(p0, xM, xP, y_o, y_unc); + (;nLjoint_pen, y_pred, θMs_pred, θP_pred, nLy, neg_log_prior, loss_penalty) = loss_gf2( + res.u, train_loader.data...) + #(nLjoint, y_pred, θMs_pred, θP, nLy, neg_log_prior, loss_penalty) = loss_gf(p0, xM, xP, y_o, y_unc); θMs_pred = CA.ComponentArray(θMs_pred, CA.getaxes(θMs_true')) #TODO @test isapprox(par_templates.θP, intϕ(res.u).ϕP, rtol = 0.15) #@test cor(vec(θMs_true), vec(θMs_pred)) > 0.8 @test cor(θMs_true'[:, 1], θMs_pred[:, 1]) > 0.8 @test cor(θMs_true'[:, 2], θMs_pred[:, 2]) > 0.8 # started from low values -> increased but not too much above true values + # logpdf.(priorsP, θP_pred) + # logpdf.(priorsP, par_templates.θP) @test all(transP(intϕ(p0).ϕP) .< θP_pred .< (1.2 .* par_templates.θP)) () -> begin diff --git a/test/test_elbo.jl b/test/test_elbo.jl index 7a894d3..022d99a 100644 --- a/test/test_elbo.jl +++ b/test/test_elbo.jl @@ -35,13 +35,14 @@ test_scenario = (scenario) -> begin int_P, int_M = map(ComponentArrayInterpreter, par_templates) pbm_covars = get_hybridproblem_pbmpar_covars(probc; scenario) pbm_covar_indices = CP.get_pbm_covar_indices(par_templates.θP, pbm_covars) + #get_hybridproblem_ #θsite_true = get_hybridproblem_par_templates(probc; scenario) n_site, n_batch = get_hybridproblem_n_site_and_batch(probc; scenario) # note: need to use prob rather than probc here, make sure the same rng = StableRNG(111) - (; xM, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, - y_unc) = gen_hybridproblem_synthetic(rng, prob; scenario) + (; xM, θP_true, θMs_true, xP, y_true, y_o, y_unc) = + gen_hybridproblem_synthetic(rng, prob; scenario) tmpf = () -> begin # wrap inside function to not define(pollute) variables in level up _trainloader = get_hybridproblem_train_dataloader(probc; scenario) @@ -60,6 +61,10 @@ test_scenario = (scenario) -> begin py = neg_logden_indep_normal + priors = get_hybridproblem_priors(prob; scenario) + priorsP = [priors[k] for k in keys(par_templates.θP)] + priorsM = [priors[k] for k in keys(par_templates.θM)] + n_MC = 3 (; transP, transM) = get_hybridproblem_transforms(probc; scenario) cor_ends = get_hybridproblem_cor_ends(probc; scenario) @@ -144,7 +149,7 @@ test_scenario = (scenario) -> begin _ϕ = vcat(ϕ_ini.μP, probc.ϕg, probd.ϕunc) #hcat(ϕ_ini, ϕ, _ϕ)[1:4,:] #hcat(ϕ_ini, ϕ, _ϕ)[(end-20):end,:] - n_predict = 8000 + n_predict = 10_000 #8_000 xM_batch = xM[:, 1:n_batch] _ζsP, _ζsMs, _σ = @inferred ( # @descend_code_warntype ( @@ -166,7 +171,7 @@ test_scenario = (scenario) -> begin #scatterplot(ζMs_g[:,1], mMs[:,1]) #scatterplot(ζMs_g[:,2], mMs[:,2]) @test cor(ζMs_g[:,1], mMs[:,1]) > 0.9 - @test cor(ζMs_g[:,2], mMs[:,2]) > 0.8 + @test cor(ζMs_g[:,2], mMs[:,2]) > 0.7 map(axes(mMs,2)) do ipar #@show ipar @test isapprox(mMs[:,ipar], ζMs_g[:,ipar]; rtol=0.1) @@ -208,7 +213,8 @@ test_scenario = (scenario) -> begin # test if uncertainty and reshaping is propagated # here inverse the predicted θs and then test distribution probcu = HybridProblem(probc, ϕunc=ϕunc_true); - n_sample_pred = 2_400 + n_sample_pred = 10_000 #2_400 + #n_sample_pred = 400 (; y, θsP, θsMs, entropy_ζ) = predict_hvi(rng, probcu; scenario, n_sample_pred); #size(_ζsMs), size(θsMs) #size(_ζsP), size(θsP) @@ -332,14 +338,14 @@ test_scenario = (scenario) -> begin neg_elbo_gtf(rng, ϕ_ini, g, f, py, xM[:, i_sites], xP[:, i_sites], y_o[:, i_sites], y_unc[:, i_sites], i_sites; int_unc, int_μP_ϕg_unc, - cor_ends, pbm_covar_indices, transP, transMs) + cor_ends, pbm_covar_indices, transP, transMs, priorsP, priorsM,) ) @test cost isa Float64 gr = Zygote.gradient( ϕ -> neg_elbo_gtf(rng, ϕ, g, f, py, xM[:, i_sites], xP[:, i_sites], y_o[:, i_sites], y_unc[:, i_sites], i_sites; int_unc, int_μP_ϕg_unc, - cor_ends, pbm_covar_indices, transP, transMs), + cor_ends, pbm_covar_indices, transP, transMs, priorsP, priorsM,), CA.getdata(ϕ_ini)) @test gr[1] isa Vector end @@ -356,14 +362,16 @@ test_scenario = (scenario) -> begin neg_elbo_gtf(rng, ϕ, g_gpu, f, py, xMg_batch, xP_batch, y_o[:, i_sites], y_unc[:, i_sites], i_sites; int_unc, int_μP_ϕg_unc, - n_MC=3, cor_ends, pbm_covar_indices, transP, transMs) + n_MC=3, cor_ends, pbm_covar_indices, transP, transMs, priorsP, priorsM, + ) ) @test cost isa Float64 gr = Zygote.gradient( ϕ -> neg_elbo_gtf(rng, ϕ, g_gpu, f, py, xMg_batch, xP_batch, y_o[:, i_sites], y_unc[:, i_sites], i_sites; int_unc, int_μP_ϕg_unc, - n_MC=3, cor_ends, pbm_covar_indices, transP, transMs), + n_MC=3, cor_ends, pbm_covar_indices, transP, transMs, priorsP, priorsM, + ), ϕ) @test gr[1] isa GPUArraysCore.AbstractGPUVector @test eltype(gr[1]) == FT @@ -420,6 +428,7 @@ test_scenario = (scenario) -> begin xP_dev = ggdev(xP); f_pred_dev = fmap(ggdev, f_pred) y = @inferred f_pred_dev(θsP, θsMs, xP_dev) + #@benchmark f_pred_dev(θsP, θsMs, xP_dev) @test y isa GPUArraysCore.AbstractGPUArray @test size(y) == (size(y_o)..., n_sample_pred) end diff --git a/test/test_sample_zeta.jl b/test/test_sample_zeta.jl index e16e6a4..f4f963c 100644 --- a/test/test_sample_zeta.jl +++ b/test/test_sample_zeta.jl @@ -24,7 +24,7 @@ scenario = Val((:default,)) n_θM, n_θP = length.(values(get_hybridproblem_par_templates(prob; scenario))) -(; xM, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o +(; xM, θP_true, θMs_true, xP, y_true, y_o ) = gen_hybridproblem_synthetic(rng, prob; scenario) n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario) diff --git a/test/test_util.jl b/test/test_util.jl new file mode 100644 index 0000000..e273924 --- /dev/null +++ b/test/test_util.jl @@ -0,0 +1,46 @@ +using Test +using HybridVariationalInference: vectuptotupvec_allowmissing, vectuptotupvec +using Zygote + +@testset "vectuptotupvec" begin + vectup = [(1,1.01, "string 1"), (2,2.02, "string 2")] + tupvec = @inferred vectuptotupvec(vectup) + #@code_warntype vectuptotupvec_allowmissing(vectup) + @test tupvec == ([1, 2], [1.01, 2.02], ["string 1", "string 2"]) + @test typeof(first(tupvec)) == Vector{Int} + # empty not allowed + @test_throws Exception tupvec = vectuptotupvec([]) + # do not allow tuples of different types - note the Float64 in first entry + vectupm = [(1.00,1.01, "string 1"), (2,2.02, "string 2",:asymbol)] + @test_throws Exception tupvecm = vectuptotupvec(vectupm) + # + gr = Zygote.gradient(x -> sum(vectuptotupvec(x)[1]), vectup) +end; + + +@testset "vectuptotupvec_allowmissing" begin + vectup = [(1,1.01, "string 1"), (2,2.02, "string 2")] + tupvec = @inferred vectuptotupvec_allowmissing(vectup) + #@code_warntype vectuptotupvec_allowmissing(vectup) + @test tupvec == ([1, 2], [1.01, 2.02], ["string 1", "string 2"]) + @test typeof(first(tupvec)) == Vector{Union{Missing,Int}} + # empty not allowed + @test_throws Exception tupvec = vectuptotupvec_allowmissing([]) + # first missing + vectupm = [missing, (1,1.01, "string 1"), (2,2.02, "string 2")] + vectuptotupvec_allowmissing(vectupm) + tupvecm = @inferred vectuptotupvec_allowmissing(vectupm) + @test ismissing(vectupm[1]) # did not change underlying vector + #@code_warntype vectuptotupvec_allowmissing(vectupm) + @test isequal(tupvecm, ([missing, 1, 2], [missing, 1.01, 2.02], [missing, "string 1", "string 2"])) + # do not allow tuples of different length + vectupm = [(1,1.01, "string 1"), (2,2.02, "string 2",:asymbol)] + @test_throws Exception tupvecm = vectuptotupvec_allowmissing(vectupm) + # do not allow tuples of different types - note the Float64 in first entry + vectupm = [(1.00,1.01, "string 1"), (2,2.02, "string 2",:asymbol)] + @test_throws Exception tupvecm = vectuptotupvec_allowmissing(vectupm) + # + vectupm = [missing, (1,1.01, "string 1"), (2,2.02, "string 2")] + gr = Zygote.gradient(x -> sum(skipmissing(vectuptotupvec_allowmissing(x)[1])), vectupm) +end; + diff --git a/test/test_util_gpu.jl b/test/test_util_gpu.jl new file mode 100644 index 0000000..75162f8 --- /dev/null +++ b/test/test_util_gpu.jl @@ -0,0 +1,25 @@ +using Test +using HybridVariationalInference: HybridVariationalInference as HVI +using ComponentArrays +using MLDataDevices +import CUDA, cuDNN +using FillArrays + +@testset "ones_similar_x" begin + A = rand(Float64, 3, 4); + @test HVI.ones_similar_x(A, 3) isa FillArrays.AbstractFill #Vector + @test HVI.ones_similar_x(A, size(A,1)) isa FillArrays.AbstractFill #Vector#Vector +end + +gdev = gpu_device() +if gdev isa MLDataDevices.CUDADevice + @testset "ones_similar_x" begin + B = CUDA.rand(Float32, 5, 2); # GPU matrix + @test HVI.ones_similar_x(B, size(B,1)) isa CuArray + @test HVI.ones_similar_x(ComponentVector(b=B), size(B,1)) isa CuArray + @test HVI.ones_similar_x(B', size(B,1)) isa CuArray + @test HVI.ones_similar_x(@view(B[:,2]), size(B,1)) isa CuArray + @test HVI.ones_similar_x(ComponentVector(b=B)[:,1], size(B,1)) isa CuArray + end +end +