From b2f3ac1e4aeea8350005113fb8dd9f16bdc06421 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Thu, 19 Oct 2023 17:48:04 -0400 Subject: [PATCH 1/3] Put new loglikelihood behind a conditional --- src/NeuralPDE.jl | 4 +- src/{ => bayesian}/BPINN_ode.jl | 6 +- src/{ => bayesian}/advancedHMC_MCMC.jl | 22 ++- src/bayesian/collocated_estim.jl | 194 +++++++++++++++++++++++++ test/bpinnexperimental.jl | 66 +++++++++ 5 files changed, 281 insertions(+), 11 deletions(-) rename src/{ => bayesian}/BPINN_ode.jl (98%) rename src/{ => bayesian}/advancedHMC_MCMC.jl (97%) create mode 100644 src/bayesian/collocated_estim.jl create mode 100644 test/bpinnexperimental.jl diff --git a/src/NeuralPDE.jl b/src/NeuralPDE.jl index 945093ea04..e38fca98d4 100644 --- a/src/NeuralPDE.jl +++ b/src/NeuralPDE.jl @@ -50,8 +50,8 @@ include("rode_solve.jl") include("transform_inf_integral.jl") include("discretize.jl") include("neural_adapter.jl") -include("advancedHMC_MCMC.jl") -include("BPINN_ode.jl") +include("bayesian/advancedHMC_MCMC.jl") +include("bayesian/BPINN_ode.jl") export NNODE, TerminalPDEProblem, NNPDEHan, NNPDENS, NNRODE, KolmogorovPDEProblem, NNKolmogorov, NNStopping, ParamKolmogorovPDEProblem, diff --git a/src/BPINN_ode.jl b/src/bayesian/BPINN_ode.jl similarity index 98% rename from src/BPINN_ode.jl rename to src/bayesian/BPINN_ode.jl index da49640314..5c26329f14 100644 --- a/src/BPINN_ode.jl +++ b/src/bayesian/BPINN_ode.jl @@ -178,7 +178,8 @@ function DiffEqBase.__solve(prob::DiffEqBase.ODEProblem, verbose = false, saveat = 1 / 50.0, maxiters = nothing, - numensemble = floor(Int, alg.draw_samples / 3)) + numensemble = floor(Int, alg.draw_samples / 3), + estim_collocate = false) @unpack chain, l2std, phystd, param, priorsNNw, Kernel, strategy, draw_samples, dataset, init_params, nchains, physdt, Adaptorkwargs, Integratorkwargs, @@ -207,7 +208,8 @@ function DiffEqBase.__solve(prob::DiffEqBase.ODEProblem, Integratorkwargs = Integratorkwargs, MCMCkwargs = MCMCkwargs, progress = progress, - verbose = verbose) + verbose = verbose, + estim_collocate = estim_collocate) fullsolution = BPINNstats(mcmcchain, samples, statistics) ninv = length(param) diff --git a/src/advancedHMC_MCMC.jl b/src/bayesian/advancedHMC_MCMC.jl similarity index 97% rename from src/advancedHMC_MCMC.jl rename to src/bayesian/advancedHMC_MCMC.jl index 6032c7ca21..6b6b3303e7 100644 --- a/src/advancedHMC_MCMC.jl +++ b/src/bayesian/advancedHMC_MCMC.jl @@ -16,11 +16,12 @@ mutable struct LogTargetDensity{C, S, ST <: AbstractTrainingStrategy, I, physdt::Float64 extraparams::Int init_params::I + estim_collocate::Bool function LogTargetDensity(dim, prob, chain::Optimisers.Restructure, st, strategy, dataset, priors, phystd, l2std, autodiff, physdt, extraparams, - init_params::AbstractVector) + init_params::AbstractVector, estim_collocate) new{ typeof(chain), Nothing, @@ -39,12 +40,13 @@ mutable struct LogTargetDensity{C, S, ST <: AbstractTrainingStrategy, I, autodiff, physdt, extraparams, - init_params) + init_params, + estim_collocate) end function LogTargetDensity(dim, prob, chain::Lux.AbstractExplicitLayer, st, strategy, dataset, priors, phystd, l2std, autodiff, physdt, extraparams, - init_params::NamedTuple) + init_params::NamedTuple, estim_collocate) new{ typeof(chain), typeof(st), @@ -60,7 +62,8 @@ mutable struct LogTargetDensity{C, S, ST <: AbstractTrainingStrategy, I, autodiff, physdt, extraparams, - init_params) + init_params, + estim_collocate) end end @@ -79,7 +82,11 @@ function vector_to_parameters(ps_new::AbstractVector, ps::NamedTuple) end function LogDensityProblems.logdensity(Tar::LogTargetDensity, θ) - return physloglikelihood(Tar, θ) + priorweights(Tar, θ) + L2LossData(Tar, θ) + if Tar.estim_collocate + return physloglikelihood(Tar, θ) + priorweights(Tar, θ) + L2LossData(Tar, θ) + L2loss2(Tar, θ) + else + return physloglikelihood(Tar, θ) + priorweights(Tar, θ) + L2LossData(Tar, θ) + end end LogDensityProblems.dimension(Tar::LogTargetDensity) = Tar.dim @@ -481,7 +488,8 @@ function ahmc_bayesian_pinn_ode(prob::DiffEqBase.ODEProblem, chain; Metric = DiagEuclideanMetric, targetacceptancerate = 0.8), Integratorkwargs = (Integrator = Leapfrog,), MCMCkwargs = (n_leapfrog = 30,), - progress = false, verbose = false) + progress = false, verbose = false, + estim_collocate = false) # NN parameter prior mean and variance(PriorsNN must be a tuple) if isinplace(prob) @@ -542,7 +550,7 @@ function ahmc_bayesian_pinn_ode(prob::DiffEqBase.ODEProblem, chain; t0 = prob.tspan[1] # dimensions would be total no of params,initial_nnθ for Lux namedTuples ℓπ = LogTargetDensity(nparameters, prob, recon, st, strategy, dataset, priors, - phystd, l2std, autodiff, physdt, ninv, initial_nnθ) + phystd, l2std, autodiff, physdt, ninv, initial_nnθ, estim_collocate) try ℓπ(t0, initial_θ[1:(nparameters - ninv)]) diff --git a/src/bayesian/collocated_estim.jl b/src/bayesian/collocated_estim.jl new file mode 100644 index 0000000000..157388194e --- /dev/null +++ b/src/bayesian/collocated_estim.jl @@ -0,0 +1,194 @@ +# suggested extra loss function +function L2loss2(Tar::LogTargetDensity, θ) + f = Tar.prob.f + + # parameter estimation chosen or not + if Tar.extraparams > 0 + dataset, deri_sol = Tar.dataset + # deri_sol = deri_sol' + autodiff = Tar.autodiff + + # # Timepoints to enforce Physics + # dataset = Array(reduce(hcat, dataset)') + # t = dataset[end, :] + # û = dataset[1:(end - 1), :] + + # ode_params = Tar.extraparams == 1 ? + # θ[((length(θ) - Tar.extraparams) + 1):length(θ)][1] : + # θ[((length(θ) - Tar.extraparams) + 1):length(θ)] + + # if length(û[:, 1]) == 1 + # physsol = [f(û[:, i][1], + # ode_params, + # t[i]) + # for i in 1:length(û[1, :])] + # else + # physsol = [f(û[:, i], + # ode_params, + # t[i]) + # for i in 1:length(û[1, :])] + # end + # #form of NN output matrix output dim x n + # deri_physsol = reduce(hcat, physsol) + + # > for perfect deriv(basically gradient matching in case of an ODEFunction) + # in case of PDE or general ODE we would want to reduce residue of f(du,u,p,t) + # if length(û[:, 1]) == 1 + # deri_sol = [f(û[:, i][1], + # Tar.prob.p, + # t[i]) + # for i in 1:length(û[1, :])] + # else + # deri_sol = [f(û[:, i], + # Tar.prob.p, + # t[i]) + # for i in 1:length(û[1, :])] + # end + # deri_sol = reduce(hcat, deri_sol) + # deri_sol = reduce(hcat, derivatives) + + # Timepoints to enforce Physics + t = dataset[end] + u1 = dataset[2] + û = dataset[1] + # Tar(t, θ[1:(length(θ) - Tar.extraparams)])' + # + + nnsol = NNodederi(Tar, t, θ[1:(length(θ) - Tar.extraparams)], autodiff) + + ode_params = Tar.extraparams == 1 ? + θ[((length(θ) - Tar.extraparams) + 1):length(θ)][1] : + θ[((length(θ) - Tar.extraparams) + 1):length(θ)] + + if length(Tar.prob.u0) == 1 + physsol = [f(û[i], + ode_params, + t[i]) + for i in 1:length(û[:, 1])] + else + physsol = [f([û[i], u1[i]], + ode_params, + t[i]) + for i in 1:length(û[:, 1])] + end + #form of NN output matrix output dim x n + deri_physsol = reduce(hcat, physsol) + + # if length(Tar.prob.u0) == 1 + # nnsol = [f(û[i], + # Tar.prob.p, + # t[i]) + # for i in 1:length(û[:, 1])] + # else + # nnsol = [f([û[i], u1[i]], + # Tar.prob.p, + # t[i]) + # for i in 1:length(û[:, 1])] + # end + # form of NN output matrix output dim x n + # nnsol = reduce(hcat, nnsol) + + # > Instead of dataset gradients trying NN derivatives with dataset collocation + # # convert to matrix as nnsol + + physlogprob = 0 + for i in 1:length(Tar.prob.u0) + # can add phystd[i] for u[i] + physlogprob += logpdf(MvNormal(deri_physsol[i, :], + LinearAlgebra.Diagonal(map(abs2, + (Tar.l2std[i] * 4.0) .* + ones(length(nnsol[i, :]))))), + nnsol[i, :]) + end + return physlogprob + else + return 0 + end +end + +# PDE(DU,U,P,T)=0 + +# Derivated via Central Diff +# function calculate_derivatives2(dataset) +# x̂, time = dataset +# num_points = length(x̂) +# # Initialize an array to store the derivative values. +# derivatives = similar(x̂) + +# for i in 2:(num_points - 1) +# # Calculate the first-order derivative using central differences. +# Δt_forward = time[i + 1] - time[i] +# Δt_backward = time[i] - time[i - 1] + +# derivative = (x̂[i + 1] - x̂[i - 1]) / (Δt_forward + Δt_backward) + +# derivatives[i] = derivative +# end + +# # Derivatives at the endpoints can be calculated using forward or backward differences. +# derivatives[1] = (x̂[2] - x̂[1]) / (time[2] - time[1]) +# derivatives[end] = (x̂[end] - x̂[end - 1]) / (time[end] - time[end - 1]) +# return derivatives +# end + +function calderivatives(prob, dataset) + chainflux = Flux.Chain(Flux.Dense(1, 8, tanh), Flux.Dense(8, 8, tanh), + Flux.Dense(8, 2)) |> Flux.f64 + # chainflux = Flux.Chain(Flux.Dense(1, 7, tanh), Flux.Dense(7, 1)) |> Flux.f64 + function loss(x, y) + # sum(Flux.mse.(prob.u0[1] .+ (prob.tspan[2] .- x)' .* chainflux(x)[1, :], y[1]) + + # Flux.mse.(prob.u0[2] .+ (prob.tspan[2] .- x)' .* chainflux(x)[2, :], y[2])) + # sum(Flux.mse.(prob.u0[1] .+ (prob.tspan[2] .- x)' .* chainflux(x)[1, :], y[1])) + sum(Flux.mse.(chainflux(x), y)) + end + optimizer = Flux.Optimise.ADAM(0.01) + epochs = 3000 + for epoch in 1:epochs + Flux.train!(loss, + Flux.params(chainflux), + [(dataset[end]', dataset[1:(end - 1)])], + optimizer) + end + + # A1 = (prob.u0' .+ + # (prob.tspan[2] .- (dataset[end]' .+ sqrt(eps(eltype(Float64)))))' .* + # chainflux(dataset[end]' .+ sqrt(eps(eltype(Float64))))') + + # A2 = (prob.u0' .+ + # (prob.tspan[2] .- (dataset[end]'))' .* + # chainflux(dataset[end]')') + + A1 = chainflux(dataset[end]' .+ sqrt(eps(eltype(dataset[end][1])))) + A2 = chainflux(dataset[end]') + + gradients = (A2 .- A1) ./ sqrt(eps(eltype(dataset[end][1]))) + + return gradients +end + +function calculate_derivatives(dataset) + + # u = dataset[1] + # u1 = dataset[2] + # t = dataset[end] + # # control points + # n = Int(floor(length(t) / 10)) + # # spline for datasetvalues(solution) + # # interp = BSplineApprox(u, t, 4, 10, :Uniform, :Uniform) + # interp = CubicSpline(u, t) + # interp1 = CubicSpline(u1, t) + # # derrivatives interpolation + # dx = t[2] - t[1] + # time = collect(t[1]:dx:t[end]) + # smoothu = [interp(i) for i in time] + # smoothu1 = [interp1(i) for i in time] + # # derivative of the spline (must match function derivative) + # û = tvdiff(smoothu, 20, 0.5, dx = dx, ε = 1) + # û1 = tvdiff(smoothu1, 20, 0.5, dx = dx, ε = 1) + # # tvdiff(smoothu, 100, 0.035, dx = dx, ε = 1) + # # FDM + # # û1 = diff(u) / dx + # # dataset[1] and smoothu are almost equal(rounding errors) + # return [û, û1] + +end \ No newline at end of file diff --git a/test/bpinnexperimental.jl b/test/bpinnexperimental.jl new file mode 100644 index 0000000000..153124b069 --- /dev/null +++ b/test/bpinnexperimental.jl @@ -0,0 +1,66 @@ +using Test, MCMCChains +using ForwardDiff, Distributions, OrdinaryDiffEq +using Flux, OptimizationOptimisers, AdvancedHMC, Lux +using Statistics, Random, Functors, ComponentArrays +using NeuralPDE, MonteCarloMeasurements + +Random.seed!(110) + +using NeuralPDE, Lux, Plots, OrdinaryDiffEq, Distributions, Random + +function lotka_volterra(u, p, t) + # Model parameters. + α, β, γ, δ = p + # Current state. + x, y = u + + # Evaluate differential equations. + dx = (α - β * y) * x # prey + dy = (δ * x - γ) * y # predator + + return [dx, dy] +end + +# initial-value problem. +u0 = [1.0, 1.0] +p = [1.5, 1.0, 3.0, 1.0] +tspan = (0.0, 4.0) +prob = ODEProblem(lotka_volterra, u0, tspan, p) + +# Solve using OrdinaryDiffEq.jl solver +dt = 0.01 +solution = solve(prob, Tsit5(); saveat = dt) + +times = solution.t +u = hcat(solution.u...) +x = u[1, :] + (u[1, :]) .* (0.05 .* randn(length(u[1, :]))) +y = u[2, :] + (u[2, :]) .* (0.05 .* randn(length(u[2, :]))) +dataset = [x, y, times] + +plot(times, x, label = "noisy x") +plot!(times, y, label = "noisy y") +plot!(solution, labels = ["x" "y"]) + +chain = Lux.Chain(Lux.Dense(1, 6, tanh), Lux.Dense(6, 6, tanh), + Lux.Dense(6, 2)) + +alg = BNNODE(chain; +dataset = dataset, +draw_samples = 1000, +l2std = [0.1, 0.1], +phystd = [0.1, 0.1], +priorsNNw = (0.0, 3.0), +param = [ + Normal(1, 2), + Normal(2, 2), + Normal(2, 2), + Normal(0, 2)], progress = false) + +sol_pestim = solve(prob, alg; saveat = dt) +plot(times, sol_pestim.ensemblesol[1], label = "estimated x") +plot!(times, sol_pestim.ensemblesol[2], label = "estimated y") + +# comparing it with the original solution +plot!(solution, labels = ["true x" "true y"]) + +sol_pestim.estimated_ode_params \ No newline at end of file From 058aa05eeb5dc434c825c30115b8ea7fd2d733ca Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Fri, 27 Oct 2023 16:58:42 -0400 Subject: [PATCH 2/3] fitzhughnagumo experiment and some edits --- src/NeuralPDE.jl | 1 + src/bayesian/advancedHMC_MCMC.jl | 6 +-- src/bayesian/collocated_estim.jl | 10 ++--- test/bpinnexperimental.jl | 68 ++++++++++++++++++++++++++++---- 4 files changed, 68 insertions(+), 17 deletions(-) diff --git a/src/NeuralPDE.jl b/src/NeuralPDE.jl index e38fca98d4..edfaf9664a 100644 --- a/src/NeuralPDE.jl +++ b/src/NeuralPDE.jl @@ -52,6 +52,7 @@ include("discretize.jl") include("neural_adapter.jl") include("bayesian/advancedHMC_MCMC.jl") include("bayesian/BPINN_ode.jl") +include("bayesian/collocated_estim.jl") export NNODE, TerminalPDEProblem, NNPDEHan, NNPDENS, NNRODE, KolmogorovPDEProblem, NNKolmogorov, NNStopping, ParamKolmogorovPDEProblem, diff --git a/src/bayesian/advancedHMC_MCMC.jl b/src/bayesian/advancedHMC_MCMC.jl index 6b6b3303e7..740bb344a3 100644 --- a/src/bayesian/advancedHMC_MCMC.jl +++ b/src/bayesian/advancedHMC_MCMC.jl @@ -587,8 +587,8 @@ function ahmc_bayesian_pinn_ode(prob::DiffEqBase.ODEProblem, chain; MCMC_alg = kernelchoice(Kernel, MCMCkwargs) Kernel = AdvancedHMC.make_kernel(MCMC_alg, integrator) - samples, stats = sample(hamiltonian, Kernel, initial_θ, draw_samples, adaptor; - progress = progress, verbose = verbose) + samples, stats = sample(hamiltonian, Kernel, initial_θ, draw_samples, adaptor, draw_samples; + progress = progress, verbose = verbose, drop_warmup = true) samplesc[i] = samples statsc[i] = stats @@ -606,7 +606,7 @@ function ahmc_bayesian_pinn_ode(prob::DiffEqBase.ODEProblem, chain; MCMC_alg = kernelchoice(Kernel, MCMCkwargs) Kernel = AdvancedHMC.make_kernel(MCMC_alg, integrator) samples, stats = sample(hamiltonian, Kernel, initial_θ, draw_samples, - adaptor; progress = progress, verbose = verbose) + adaptor, draw_samples; progress = progress, verbose = verbose, drop_warmup = true) # return a chain(basic chain),samples and stats matrix_samples = hcat(samples...) diff --git a/src/bayesian/collocated_estim.jl b/src/bayesian/collocated_estim.jl index 157388194e..b113b76f12 100644 --- a/src/bayesian/collocated_estim.jl +++ b/src/bayesian/collocated_estim.jl @@ -4,10 +4,8 @@ function L2loss2(Tar::LogTargetDensity, θ) # parameter estimation chosen or not if Tar.extraparams > 0 - dataset, deri_sol = Tar.dataset # deri_sol = deri_sol' autodiff = Tar.autodiff - # # Timepoints to enforce Physics # dataset = Array(reduce(hcat, dataset)') # t = dataset[end, :] @@ -48,9 +46,9 @@ function L2loss2(Tar::LogTargetDensity, θ) # deri_sol = reduce(hcat, derivatives) # Timepoints to enforce Physics - t = dataset[end] - u1 = dataset[2] - û = dataset[1] + t = Tar.dataset[end] + u1 = Tar.dataset[2] + û = Tar.dataset[1] # Tar(t, θ[1:(length(θ) - Tar.extraparams)])' # @@ -69,7 +67,7 @@ function L2loss2(Tar::LogTargetDensity, θ) physsol = [f([û[i], u1[i]], ode_params, t[i]) - for i in 1:length(û[:, 1])] + for i in 1:length(û)] end #form of NN output matrix output dim x n deri_physsol = reduce(hcat, physsol) diff --git a/test/bpinnexperimental.jl b/test/bpinnexperimental.jl index 153124b069..ffe7fcf0f8 100644 --- a/test/bpinnexperimental.jl +++ b/test/bpinnexperimental.jl @@ -28,13 +28,13 @@ tspan = (0.0, 4.0) prob = ODEProblem(lotka_volterra, u0, tspan, p) # Solve using OrdinaryDiffEq.jl solver -dt = 0.01 +dt = 0.2 solution = solve(prob, Tsit5(); saveat = dt) times = solution.t u = hcat(solution.u...) -x = u[1, :] + (u[1, :]) .* (0.05 .* randn(length(u[1, :]))) -y = u[2, :] + (u[2, :]) .* (0.05 .* randn(length(u[2, :]))) +x = u[1, :] + (u[1, :]) .* (0.3 .* randn(length(u[1, :]))) +y = u[2, :] + (u[2, :]) .* (0.3 .* randn(length(u[2, :]))) dataset = [x, y, times] plot(times, x, label = "noisy x") @@ -54,13 +54,65 @@ param = [ Normal(1, 2), Normal(2, 2), Normal(2, 2), - Normal(0, 2)], progress = false) + Normal(0, 2)], progress = true) -sol_pestim = solve(prob, alg; saveat = dt) -plot(times, sol_pestim.ensemblesol[1], label = "estimated x") -plot!(times, sol_pestim.ensemblesol[2], label = "estimated y") +@time sol_pestim1 = solve(prob, alg; saveat = dt,) +@time sol_pestim2 = solve(prob, alg; estim_collocate = true, saveat = dt) +plot(times, sol_pestim1.ensemblesol[1], label = "estimated x1") +plot!(times, sol_pestim2.ensemblesol[1], label = "estimated x2") +plot!(times, sol_pestim1.ensemblesol[2], label = "estimated y1") +plot!(times, sol_pestim2.ensemblesol[2], label = "estimated y2") # comparing it with the original solution plot!(solution, labels = ["true x" "true y"]) -sol_pestim.estimated_ode_params \ No newline at end of file +@show sol_pestim1.estimated_ode_params +@show sol_pestim2.estimated_ode_params + +function fitz(u, p , t) + v, w = u[1], u[2] + a,b,τinv,l = p[1], p[2], p[3], p[4] + + dv = v - 0.33*v^3 -w + l + dw = τinv*(v + a - b*w) + + return [dv, dw] +end + +prob_ode_fitzhughnagumo = ODEProblem(fitz, [1.0,1.0], (0.0,10.0), [0.7,0.8,1/12.5,0.5]) +dt = 0.5 +sol = solve(prob_ode_fitzhughnagumo, Tsit5(), saveat = dt) + +sig = 0.20 +data = Array(sol) +dataset = [data[1,:] .+ (sig .* rand(length(sol.t))), data[2, :] .+ (sig .* rand(length(sol.t))), sol.t] +priors = [truncated(Normal(0.5,1.0),0,1.5), truncated(Normal(0.5,1.0),0,1.5), truncated(Normal(0.0,0.5),0.0,0.5), truncated(Normal(0.5,1.0),0,1)] + + +plot(sol.t, dataset[1], label = "noisy x") +plot!(sol.t, dataset[2], label = "noisy y") +plot!(sol, labels = ["x" "y"]) + +chain = Lux.Chain(Lux.Dense(1, 10, tanh), Lux.Dense(10, 10, tanh), + Lux.Dense(10, 2)) + +Adaptorkwargs = (Adaptor = AdvancedHMC.StanHMCAdaptor, + Metric = AdvancedHMC.DiagEuclideanMetric, targetacceptancerate = 0.65) +alg = BNNODE(chain; +dataset = dataset, +draw_samples = 10000, +l2std = [0.1, 0.1], +phystd = [0.1, 0.1], +priorsNNw = (0.01, 3.0), +Adaptorkwargs = Adaptorkwargs, +param = priors, progress = true) + +@time sol_pestim1 = solve(prob_ode_fitzhughnagumo, alg; saveat = dt) +@time sol_pestim2 = solve(prob_ode_fitzhughnagumo, alg; estim_collocate = true, saveat = dt) +plot!(sol.t, sol_pestim1.ensemblesol[1], label = "estimated x1") +plot!(sol.t, sol_pestim2.ensemblesol[1], label = "estimated x2") +plot!(sol.t, sol_pestim1.ensemblesol[2], label = "estimated y1") +plot!(sol.t, sol_pestim2.ensemblesol[2], label = "estimated y2") + +@show sol_pestim1.estimated_ode_params +@show sol_pestim2.estimated_ode_params \ No newline at end of file From 103e1febf7f0d4153b560ad8a962106c6bf92cde Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Sat, 28 Oct 2023 15:14:31 -0400 Subject: [PATCH 3/3] Scale logpdfs and fix chain creation --- src/bayesian/BPINN_ode.jl | 6 +++++- src/bayesian/advancedHMC_MCMC.jl | 13 ++++++------- test/bpinnexperimental.jl | 22 +++++++++++----------- 3 files changed, 22 insertions(+), 19 deletions(-) diff --git a/src/bayesian/BPINN_ode.jl b/src/bayesian/BPINN_ode.jl index 5c26329f14..a2cce9db34 100644 --- a/src/bayesian/BPINN_ode.jl +++ b/src/bayesian/BPINN_ode.jl @@ -217,8 +217,12 @@ function DiffEqBase.__solve(prob::DiffEqBase.ODEProblem, if chain isa Lux.AbstractExplicitLayer θinit, st = Lux.setup(Random.default_rng(), chain) + println(length(θinit)) + println(length(samples[1])) + println(draw_samples) θ = [vector_to_parameters(samples[i][1:(end - ninv)], θinit) - for i in (draw_samples - numensemble):draw_samples] + for i in 1:max(draw_samples - draw_samples ÷ 10, draw_samples - 1000)] + luxar = [chain(t', θ[i], st)[1] for i in 1:numensemble] # only need for size θinit = collect(ComponentArrays.ComponentArray(θinit)) diff --git a/src/bayesian/advancedHMC_MCMC.jl b/src/bayesian/advancedHMC_MCMC.jl index 740bb344a3..5e995ebfdb 100644 --- a/src/bayesian/advancedHMC_MCMC.jl +++ b/src/bayesian/advancedHMC_MCMC.jl @@ -83,9 +83,9 @@ end function LogDensityProblems.logdensity(Tar::LogTargetDensity, θ) if Tar.estim_collocate - return physloglikelihood(Tar, θ) + priorweights(Tar, θ) + L2LossData(Tar, θ) + L2loss2(Tar, θ) + return physloglikelihood(Tar, θ)/length(Tar.dataset[1]) + priorweights(Tar, θ) + L2LossData(Tar, θ)/length(Tar.dataset[1]) + L2loss2(Tar, θ)/length(Tar.dataset[1]) else - return physloglikelihood(Tar, θ) + priorweights(Tar, θ) + L2LossData(Tar, θ) + return physloglikelihood(Tar, θ)/length(Tar.dataset[1]) + priorweights(Tar, θ) + L2LossData(Tar, θ)/length(Tar.dataset[1]) end end @@ -587,7 +587,7 @@ function ahmc_bayesian_pinn_ode(prob::DiffEqBase.ODEProblem, chain; MCMC_alg = kernelchoice(Kernel, MCMCkwargs) Kernel = AdvancedHMC.make_kernel(MCMC_alg, integrator) - samples, stats = sample(hamiltonian, Kernel, initial_θ, draw_samples, adaptor, draw_samples; + samples, stats = sample(hamiltonian, Kernel, initial_θ, draw_samples, adaptor; progress = progress, verbose = verbose, drop_warmup = true) samplesc[i] = samples @@ -606,11 +606,10 @@ function ahmc_bayesian_pinn_ode(prob::DiffEqBase.ODEProblem, chain; MCMC_alg = kernelchoice(Kernel, MCMCkwargs) Kernel = AdvancedHMC.make_kernel(MCMC_alg, integrator) samples, stats = sample(hamiltonian, Kernel, initial_θ, draw_samples, - adaptor, draw_samples; progress = progress, verbose = verbose, drop_warmup = true) - + adaptor; progress = progress, verbose = verbose, drop_warmup = true) # return a chain(basic chain),samples and stats - matrix_samples = hcat(samples...) - mcmc_chain = MCMCChains.Chains(matrix_samples') + matrix_samples = reshape(hcat(samples...), (length(samples[1]), length(samples), 1)) + mcmc_chain = MCMCChains.Chains(matrix_samples) return mcmc_chain, samples, stats end end \ No newline at end of file diff --git a/test/bpinnexperimental.jl b/test/bpinnexperimental.jl index ffe7fcf0f8..3de049bf58 100644 --- a/test/bpinnexperimental.jl +++ b/test/bpinnexperimental.jl @@ -86,7 +86,7 @@ sol = solve(prob_ode_fitzhughnagumo, Tsit5(), saveat = dt) sig = 0.20 data = Array(sol) dataset = [data[1,:] .+ (sig .* rand(length(sol.t))), data[2, :] .+ (sig .* rand(length(sol.t))), sol.t] -priors = [truncated(Normal(0.5,1.0),0,1.5), truncated(Normal(0.5,1.0),0,1.5), truncated(Normal(0.0,0.5),0.0,0.5), truncated(Normal(0.5,1.0),0,1)] +priors = [Normal(0.5,1.0), Normal(0.5,1.0), Normal(0.0,0.5), Normal(0.5,1.0)] plot(sol.t, dataset[1], label = "noisy x") @@ -97,22 +97,22 @@ chain = Lux.Chain(Lux.Dense(1, 10, tanh), Lux.Dense(10, 10, tanh), Lux.Dense(10, 2)) Adaptorkwargs = (Adaptor = AdvancedHMC.StanHMCAdaptor, - Metric = AdvancedHMC.DiagEuclideanMetric, targetacceptancerate = 0.65) + Metric = AdvancedHMC.DiagEuclideanMetric, targetacceptancerate = 0.8) alg = BNNODE(chain; dataset = dataset, -draw_samples = 10000, +draw_samples = 1000, l2std = [0.1, 0.1], phystd = [0.1, 0.1], priorsNNw = (0.01, 3.0), Adaptorkwargs = Adaptorkwargs, param = priors, progress = true) -@time sol_pestim1 = solve(prob_ode_fitzhughnagumo, alg; saveat = dt) -@time sol_pestim2 = solve(prob_ode_fitzhughnagumo, alg; estim_collocate = true, saveat = dt) -plot!(sol.t, sol_pestim1.ensemblesol[1], label = "estimated x1") -plot!(sol.t, sol_pestim2.ensemblesol[1], label = "estimated x2") -plot!(sol.t, sol_pestim1.ensemblesol[2], label = "estimated y1") -plot!(sol.t, sol_pestim2.ensemblesol[2], label = "estimated y2") +@time sol_pestim3 = solve(prob_ode_fitzhughnagumo, alg; saveat = dt) +@time sol_pestim4 = solve(prob_ode_fitzhughnagumo, alg; estim_collocate = true, saveat = dt) +plot!(sol.t, sol_pestim3.ensemblesol[1], label = "estimated x1") +plot!(sol.t, sol_pestim4.ensemblesol[1], label = "estimated x2") +plot!(sol.t, sol_pestim3.ensemblesol[2], label = "estimated y1") +plot!(sol.t, sol_pestim4.ensemblesol[2], label = "estimated y2") -@show sol_pestim1.estimated_ode_params -@show sol_pestim2.estimated_ode_params \ No newline at end of file +@show sol_pestim3.estimated_ode_params +@show sol_pestim4.estimated_ode_params