Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
DistributionFits = "45214091-1ed4-4409-9bcf-fdb48a05e921"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Missings = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Expand Down Expand Up @@ -47,13 +49,15 @@ CommonSolve = "0.2.4"
ComponentArrays = "0.15.19"
DistributionFits = "0.3.9"
Distributions = "0.25.117"
FillArrays = "1.13.0"
Flux = "0.14, 0.15, 0.16"
Functors = "0.4, 0.5"
GPUArraysCore = "0.1, 0.2"
LinearAlgebra = "1.10"
Lux = "1.4.2"
MLDataDevices = "1.5, 1.6"
MLUtils = "0.4.5"
Missings = "1.2.0"
Optimization = "3.19.3, 4"
Random = "1.10.0"
SimpleChains = "0.4"
Expand Down
1 change: 1 addition & 0 deletions _typos.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ extend-exclude = ["docs/src_stash/"]
[default.extend-words]
SOM = "SOM"
negLogLik = "negLogLik"
Missings = "Missings"
14 changes: 7 additions & 7 deletions dev/doubleMM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ cdev = gdev isa MLDataDevices.AbstractGPUDevice ? cpu_device() : identity

#------ setup synthetic data and training data loader
prob0_ = HybridProblem(DoubleMM.DoubleMMCase(); scenario);
(; xM, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc
(; xM, θP_true, θMs_true, xP, y_true, y_o, y_unc
) = gen_hybridproblem_synthetic(rng, DoubleMM.DoubleMMCase(); scenario);
n_site, n_batch = get_hybridproblem_n_site_and_batch(prob0_; scenario)
ζP_true, ζMs_true = log.(θP_true), log.(θMs_true)
Expand All @@ -59,7 +59,7 @@ n_epoch = 80
maxiters = n_batches_in_epoch * n_epoch);
# update the problem with optimized parameters
prob0o = prob1o =probo;
y_pred_global, y_pred, θMs = gf(prob0o; scenario, is_inferred=Val(true));
y_pred, θMs = gf(prob0o; scenario, is_inferred=Val(true));
# @descend_code_warntype gf(prob0o; scenario)
#@usingany UnicodePlots
plt = scatterplot(θMs_true'[:, 1], θMs[:, 1]);
Expand All @@ -77,7 +77,7 @@ histogram(vec(y_pred) - vec(y_true)) # predictions centered around y_o (or y_tru
(; ϕ, resopt) = solve(prob0o, solver1; scenario, rng,
callback = callback_loss(20), maxiters = 400)
prob1o = HybridProblem(prob0o; ϕg = cpu_ca(ϕ).ϕg, θP = cpu_ca(ϕ).θP)
y_pred_global, y_pred, θMs = gf(prob1o, xM, xP; scenario)
y_pred, θMs = gf(prob1o, xM, xP; scenario)
scatterplot(θMs_true[1, :], θMs[1, :])
scatterplot(θMs_true[2, :], θMs[2, :])
prob1o.θP
Expand All @@ -91,7 +91,7 @@ end
(; ϕ, resopt) = solve(prob2, solver1; scenario, rng,
callback = callback_loss(20), maxiters = 600)
prob2o = HybridProblem(prob2; ϕg = collect(ϕ.ϕg), θP = ϕ.θP)
y_pred_global, y_pred, θMs = gf(prob2o, xM, xP)
y_pred, θMs = gf(prob2o, xM, xP)
prob2o.θP
end

Expand Down Expand Up @@ -127,7 +127,7 @@ end
(; ϕ, resopt) = solve(prob3, solver1; scenario, rng,
callback = callback_loss(50), maxiters = 600)
prob3o = HybridProblem(prob3; ϕg = cpu_ca(ϕ).ϕg, θP = cpu_ca(ϕ).θP)
y_pred_global, y_pred, θMs = gf(prob3o, xM, xP; scenario)
y_pred, θMs = gf(prob3o, xM, xP; scenario)
scatterplot(θMs_true[2, :], θMs[2, :])
prob3o.θP
scatterplot(vec(y_true), vec(y_pred))
Expand Down Expand Up @@ -173,7 +173,7 @@ solver_post = HybridPosteriorSolver(; alg = OptimizationOptimisers.Adam(0.01), n
(y1, θsP1, θsMs1) = (y, θsP, θsMs);

() -> begin # prediction with fitted parameters (should be smaller than mean)
y_pred_global, y_pred2, θMs = gf(prob1o, xM, xP; scenario)
y_pred2, θMs = gf(prob1o, xM, xP; scenario)
scatterplot(θMs_true[1, :], θMs[1, :])
scatterplot(θMs_true[2, :], θMs[2, :])
hcat(θP_true, θP) # all parameters overestimated
Expand Down Expand Up @@ -366,7 +366,7 @@ end
# ζMs = invt.transM.(θMs_i)
# _f = get_hybridproblem_PBmodel(probo; scenario)
# y_site = map(eachcol(θPs), θMs_i) do θP, θM
# y_global, y = _f(θP, reshape(θM, (length(θM), 1)), xP[[i_site]])
# y = _f(θP, reshape(θM, (length(θM), 1)), xP[[i_site]])
# y[:,1]
# end |> stack
nLs = get_hybridproblem_neg_logden_obs(
Expand Down
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ makedocs(;
#"Test quarto markdown" => "tutorials/test1.md",
],
"How to" => [
".. use GPU" => "tutorials/lux_gpu.md",
".. model independent parameters" => "tutorials/blocks_corr.md",
".. model site-global corr" => "tutorials/corr_site_global.md",
".. use GPU" => "tutorials/lux_gpu.md",
],
"Explanation" => [
#"Theory" => "explanation/theory_hvi.md", TODO activate when paper is published
Expand Down
1 change: 1 addition & 0 deletions docs/src/tutorials/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ DistributionFits = "45214091-1ed4-4409-9bcf-fdb48a05e921"
HybridVariationalInference = "a108c475-a4e2-4021-9a84-cfa7df242f64"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
PairPlots = "43a3c2be-4208-490b-832a-a21dcd55d7da"
Expand Down
6 changes: 3 additions & 3 deletions docs/src/tutorials/basic_cpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,14 @@ HVI is an approximate bayesian analysis and combines prior information on
the parameters with the model and observed data.

Here, we provide a wide prior by fitting a Lognormal distributions to
- the mean corresponding to the initial value provided above
- the 0.95-quantile 3 times the mean
- the mode corresponding to the initial value provided above
- the 0.95-quantile 3 times the mode
using the `DistributionFits.jl` package.

``` julia
θall = vcat(θP, θM)
priors_dict = Dict{Symbol, Distribution}(
keys(θall) .=> fit.(LogNormal, θall, QuantilePoint.(θall .* 3, 0.95)))
keys(θall) .=> fit.(LogNormal, θall, QuantilePoint.(θall .* 3, 0.95), Val(:mode)))
```

## Observations, model drivers and covariates
Expand Down
8 changes: 4 additions & 4 deletions docs/src/tutorials/basic_cpu.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,14 @@ HVI is an approximate bayesian analysis and combines prior information on
the parameters with the model and observed data.

Here, we provide a wide prior by fitting a Lognormal distributions to
- the mean corresponding to the initial value provided above
- the 0.95-quantile 3 times the mean
- the mode corresponding to the initial value provided above
- the 0.95-quantile 3 times the mode
using the `DistributionFits.jl` package.

```{julia}
θall = vcat(θP, θM)
priors_dict = Dict{Symbol, Distribution}(
keys(θall) .=> fit.(LogNormal, θall, QuantilePoint.(θall .* 3, 0.95)))
keys(θall) .=> fit.(LogNormal, θall, QuantilePoint.(θall .* 3, 0.95), Val(:mode)))
```

## Observations, model drivers and covariates
Expand All @@ -138,7 +138,7 @@ rng = StableRNG(111)
#| echo: false
#| eval: false
() -> begin
(; xM, θP_true, θMs_true, xP, y_global_true, y_true, y_global_o, y_o, y_unc) =
(; xM, θP_true, θMs_true, xP, y_true, y_o, y_unc) =
gen_hybridproblem_synthetic(rng, DoubleMM.DoubleMMCase(); scenario=Val((:omit_r0,)))
end
```
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions docs/src/tutorials/inspect_results.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
CurrentModule = HybridVariationalInference
```

This tutorial leads you through querying relevant information from the
This tutorial leads you through extracting relevant information from the
inversion results and to produce some typical plots.

First load necessary packages.
Expand Down Expand Up @@ -110,7 +110,7 @@ In addition to the uncertainty in parameters, we are also interested in
the uncertainty of predictions, i.e. the predictive posterior.

We cam either run the PBM for all the parameter samples that we obtained already,
using the [`AbstractModelApplicator`](@ref), or use [`predict_hvi`](@ref) which combines
using the [`AbstractPBMApplicator`](@ref), or use [`predict_hvi`](@ref) which combines
sampling the posterior and predictive posterior and returns the additional
`NamedTuple` entry `y`.

Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/src/tutorials/intermediate/basic_cpu_results.jld2
Binary file not shown.
13 changes: 8 additions & 5 deletions docs/src/tutorials/lux_gpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ using StableRNGs
using MLUtils
using JLD2
using Random
using MLDataDevices
# using CairoMakie
# using PairPlots # scatterplot matrices
```
Expand Down Expand Up @@ -88,15 +89,16 @@ Hence specify
- `gdevs = (; gdev_M=gpu_device(), gdev_P=gpu_device())`: to move both ML model and PBM to GPU
- `gdevs = (; gdev_M=gpu_device(), gdev_P=identity)`: to move both ML model to GPU but execute the PBM (and parameter transformation) on CPU

Currently, putting the PBM on gpu is not efficient during inversion, because
prior distribution needs to be evaluated for each sample.
However, sampling and prediction using a fitted model is efficient.

In addition, the libraries of the GPU device need to be activated by
importing respective Julia packages.
Currently, only CUDA is tested with this `HybridVariationalInference` package.

``` julia
import CUDA, cuDNN # so that gpu_device() returns a CUDADevice
#CUDA.device!(4)
gdevs = (; gdev_M=gpu_device(), gdev_P=gpu_device())
#gdevs = (; gdev_M=gpu_device(), gdev_P=identity)

using OptimizationOptimisers
import Zygote
Expand All @@ -105,7 +107,7 @@ solver = HybridPosteriorSolver(; alg=Adam(0.02), n_MC=3)
(; probo) = solve(prob_lux, solver;
callback = callback_loss(100),
epochs = 10,
gdevs,
gdevs = (; gdev_M=gpu_device(), gdev_P=identity)
); probo_lux = probo;
```

Expand All @@ -116,7 +118,8 @@ The sampling and prediction methods, also take this `gdevs` keyword argument.
``` julia
n_sample_pred = 400
(y_dev, θsP_dev, θsMs_dev) = (; y, θsP, θsMs) = predict_hvi(
rng, probo_lux; n_sample_pred, gdevs);
rng, probo_lux; n_sample_pred,
gdevs = (; gdev_M=gpu_device(), gdev_P=gpu_device()));
```

If `gdev_P` is not an `AbstractGPUDevice` then all the results are on CPU.
Expand Down
13 changes: 8 additions & 5 deletions docs/src/tutorials/lux_gpu.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ using StableRNGs
using MLUtils
using JLD2
using Random
using MLDataDevices
# using CairoMakie
# using PairPlots # scatterplot matrices
```
Expand Down Expand Up @@ -98,6 +99,10 @@ Hence specify
- `gdevs = (; gdev_M=gpu_device(), gdev_P=gpu_device())`: to move both ML model and PBM to GPU
- `gdevs = (; gdev_M=gpu_device(), gdev_P=identity)`: to move both ML model to GPU but execute the PBM (and parameter transformation) on CPU

Currently, putting the PBM on gpu is not efficient during inversion, because
prior distribution needs to be evaluated for each sample.
However, sampling and prediction using a fitted model is efficient.

In addition, the libraries of the GPU device need to be activated by
importing respective Julia packages.
Currently, only CUDA is tested with this `HybridVariationalInference` package.
Expand All @@ -112,9 +117,6 @@ end
```
```{julia}
import CUDA, cuDNN # so that gpu_device() returns a CUDADevice
#CUDA.device!(4)
gdevs = (; gdev_M=gpu_device(), gdev_P=gpu_device())
#gdevs = (; gdev_M=gpu_device(), gdev_P=identity)

using OptimizationOptimisers
import Zygote
Expand All @@ -123,7 +125,7 @@ solver = HybridPosteriorSolver(; alg=Adam(0.02), n_MC=3)
(; probo) = solve(prob_lux, solver;
callback = callback_loss(100),
epochs = 10,
gdevs,
gdevs = (; gdev_M=gpu_device(), gdev_P=identity)
); probo_lux = probo;
```

Expand All @@ -134,7 +136,8 @@ The sampling and prediction methods, also take this `gdevs` keyword argument.
```{julia}
n_sample_pred = 400
(y_dev, θsP_dev, θsMs_dev) = (; y, θsP, θsMs) = predict_hvi(
rng, probo_lux; n_sample_pred, gdevs);
rng, probo_lux; n_sample_pred,
gdevs = (; gdev_M=gpu_device(), gdev_P=gpu_device()));
```

If `gdev_P` is not an `AbstractGPUDevice` then all the results are on CPU.
Expand Down
7 changes: 7 additions & 0 deletions ext/HybridVariationalInferenceCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module HybridVariationalInferenceCUDAExt

using HybridVariationalInference, CUDA
using HybridVariationalInference: HybridVariationalInference as HVI
using ComponentArrays: ComponentArrays as CA
using ChainRulesCore

# here, really CUDA-specific implementation, in case need to code other GPU devices
Expand Down Expand Up @@ -84,6 +85,12 @@ function HVI._create_randn(rng, v::CUDA.CuVector{T,M}, dims...) where {T,M}
res::CUDA.CuArray{T, length(dims),M}
end

function HVI.ones_similar_x(x::CuArray, size_ret = size(x))
# call CUDA.ones rather than ones for x::CuArray
ChainRulesCore.@ignore_derivatives CUDA.ones(eltype(x), size_ret)
end





Expand Down
3 changes: 1 addition & 2 deletions src/AbstractHybridProblem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,8 @@ Setup synthetic data, a NamedTuple of
- θP_true: vector global process-model parameters
- θMs_true: matrix of site-varying process-model parameters, with
- xP: Vector of process-model drivers, with an entry per site
- y_global_true: vector of global observations
- y_true: matrix of site-specific observations with one column per site
- y_global_o, y_o: observations with added noise
- y_o: observations with added noise
"""
function gen_hybridproblem_synthetic end

Expand Down
33 changes: 18 additions & 15 deletions src/DoubleMM/f_doubleMM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,20 +77,23 @@ Returns a matrix `(n_obs x n_site)` of predictions.
```julia
function f_doubleMM_sites(θc::ComponentMatrix, xPc::ComponentMatrix)
# extract several covariates from xP
ST = typeof(getdata(xPc)[1:1,:]) # workaround for non-type-stable Symbol-indexing
S1 = (getdata(xPc[:S1,:])::ST)
S2 = (getdata(xPc[:S2,:])::ST)
ST = typeof(CA.getdata(xPc)[1:1,:]) # workaround for non-type-stable Symbol-indexing
S1 = (CA.getdata(xPc[:S1,:])::ST)
S2 = (CA.getdata(xPc[:S2,:])::ST)
#
# extract the parameters as vectors that are row-repeated into a matrix
VT = typeof(CA.getdata(θc)[:,1]) # workaround for non-type-stable Symbol-indexing
n_obs = size(S1, 1)
VT = typeof(getdata(θc)[:,1]) # workaround for non-type-stable Symbol-indexing
rep_fac = ones_similar_x(xPc, n_obs) # to reshape into matrix, avoiding repeat
(r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par
p1 = getdata(θc[:, par]) ::VT
repeat(p1', n_obs) # matrix: same for each concentration row in S1
p1 = CA.getdata(θc[:, par]) ::VT
#(r0 .* rep_fac)' # move to computation below to save allocation
#repeat(p1', n_obs) # matrix: same for each concentration row in S1
end
#
# each variable is a matrix (n_obs x n_site)
r0 .+ r1 .* S1 ./ (K1 .+ S1) .* S2 ./ (K2 .+ S2)
#r0 .+ r1 .* S1 ./ (K1 .+ S1) .* S2 ./ (K2 .+ S2)
(r0 .* rep_fac)' .+ (r1 .* rep_fac)' .* S1 ./ ((K1 .* rep_fac)' .+ S1) .* S2 ./ ((K2 .* rep_fac)' .+ S2)
end
```
"""
Expand All @@ -101,15 +104,18 @@ function f_doubleMM_sites(θc::CA.ComponentMatrix, xPc::CA.ComponentMatrix)
S2 = (CA.getdata(xPc[:S2,:])::ST)
#
# extract the parameters as vectors that are row-repeated into a matrix
n_obs = size(S1, 1)
VT = typeof(CA.getdata(θc)[:,1]) # workaround for non-type-stable Symbol-indexing
n_obs = size(S1, 1)
rep_fac = HVI.ones_similar_x(xPc, n_obs) # to reshape into matrix, avoiding repeat
(r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par
p1 = CA.getdata(θc[:, par]) ::VT
repeat(p1', n_obs) # matrix: same for each concentration row in S1
#repeat(p1', n_obs) # matrix: same for each concentration row in S1
#(rep_fac .* p1') # move to computation below to save allocation
end
#
# each variable is a matrix (n_obs x n_site)
r0 .+ r1 .* S1 ./ (K1 .+ S1) .* S2 ./ (K2 .+ S2)
#r0 .+ r1 .* S1 ./ (K1 .+ S1) .* S2 ./ (K2 .+ S2)
(rep_fac .* r0') .+ (rep_fac .* r1') .* S1 ./ ((rep_fac .* K1') .+ S1) .* S2 ./ ((rep_fac .* K2') .+ S2)
end

# function f_doubleMM_sites(θc::CA.ComponentMatrix, xPc::CA.ComponentMatrix)
Expand Down Expand Up @@ -204,7 +210,7 @@ end
# end

function HVI.get_hybridproblem_priors(::DoubleMMCase; scenario::Val{scen}) where {scen}
Dict(keys(θall) .=> fit.(LogNormal, θall, QuantilePoint.(θall .* 3, 0.95)))
Dict(keys(θall) .=> fit.(LogNormal, θall, QuantilePoint.(θall .* 3, 0.95), Val(:mode)))
end

function HVI.get_hybridproblem_MLapplicator(
Expand Down Expand Up @@ -371,21 +377,18 @@ function HVI.gen_hybridproblem_synthetic(rng::AbstractRNG, prob::DoubleMMCase;
xP = int_xP_sites(vcat(repeat(xP_S1, 1, n_site), repeat(xP_S2, 1, n_site)))
#xP[:S1,:]
θP = par_templates.θP
y_global_true, y_true = f(θP, θMs_true', xP)
y_true = f(θP, θMs_true', xP)
σ_o = FloatType(0.01)
#σ_o = FloatType(0.002)
logσ2_o = FloatType(2) .* log.(σ_o)
#σ_o = 0.002
y_global_o = y_global_true .+ randn(rng, FloatType, size(y_global_true)) .* σ_o
y_o = y_true .+ randn(rng, FloatType, size(y_true)) .* σ_o
(;
xM,
θP_true = θP,
θMs_true,
xP,
y_global_true,
y_true,
y_global_o,
y_o,
y_unc = fill(logσ2_o, size(y_o))
)
Expand Down
Loading