From c0158ead74a35b620781ae691f73e1c98550e5c6 Mon Sep 17 00:00:00 2001 From: AoifeHughes Date: Thu, 7 Aug 2025 09:08:25 +0100 Subject: [PATCH 1/8] Add GibbsConditional sampler and corresponding tests --- src/mcmc/Inference.jl | 2 + src/mcmc/gibbs_conditional.jl | 245 ++++++++++++++++++++++++++++++++++ test/mcmc/gibbs.jl | 114 ++++++++++++++++ test_gibbs_conditional.jl | 78 +++++++++++ 4 files changed, 439 insertions(+) create mode 100644 src/mcmc/gibbs_conditional.jl create mode 100644 test_gibbs_conditional.jl diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index 0951026aa..35cdc46b5 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -67,6 +67,7 @@ export InferenceAlgorithm, ESS, Emcee, Gibbs, # classic sampling + GibbsConditional, # conditional sampling HMC, SGLD, PolynomialStepsize, @@ -392,6 +393,7 @@ include("mh.jl") include("is.jl") include("particle_mcmc.jl") include("gibbs.jl") +include("gibbs_conditional.jl") include("sghmc.jl") include("emcee.jl") include("prior.jl") diff --git a/src/mcmc/gibbs_conditional.jl b/src/mcmc/gibbs_conditional.jl new file mode 100644 index 000000000..e01e17aff --- /dev/null +++ b/src/mcmc/gibbs_conditional.jl @@ -0,0 +1,245 @@ +using DynamicPPL: VarName +using Random: Random +import AbstractMCMC + +# These functions are defined in gibbs.jl which is loaded before this file + +""" + GibbsConditional(sym::Symbol, conditional) + +A Gibbs sampler component that samples a variable according to a user-provided +analytical conditional distribution. + +The `conditional` function should take a `NamedTuple` of conditioned variables and return +a `Distribution` from which to sample the variable `sym`. + +# Examples + +```julia +# Define a model +@model function inverse_gdemo(x) + λ ~ Gamma(2, 3) + m ~ Normal(0, sqrt(1 / λ)) + for i in 1:length(x) + x[i] ~ Normal(m, sqrt(1 / λ)) + end +end + +# Define analytical conditionals +function cond_λ(c::NamedTuple) + a = 2.0 + b = 3.0 + m = c.m + x = c.x + n = length(x) + a_new = a + (n + 1) / 2 + b_new = b + sum((x[i] - m)^2 for i in 1:n) / 2 + m^2 / 2 + return Gamma(a_new, 1 / b_new) +end + +function cond_m(c::NamedTuple) + λ = c.λ + x = c.x + n = length(x) + m_mean = sum(x) / (n + 1) + m_var = 1 / (λ * (n + 1)) + return Normal(m_mean, sqrt(m_var)) +end + +# Sample using GibbsConditional +model = inverse_gdemo([1.0, 2.0, 3.0]) +chain = sample(model, Gibbs( + :λ => GibbsConditional(:λ, cond_λ), + :m => GibbsConditional(:m, cond_m) +), 1000) +``` +""" +struct GibbsConditional{S,C} <: InferenceAlgorithm + conditional::C + + function GibbsConditional(sym::Symbol, conditional::C) where {C} + return new{sym,C}(conditional) + end +end + +# Mark GibbsConditional as a valid Gibbs component +isgibbscomponent(::GibbsConditional) = true + +""" + DynamicPPL.initialstep(rng, model, sampler::GibbsConditional, vi) + +Initialize the GibbsConditional sampler. +""" +function DynamicPPL.initialstep( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + sampler::DynamicPPL.Sampler{<:GibbsConditional}, + vi::DynamicPPL.AbstractVarInfo; + kwargs..., +) + # GibbsConditional doesn't need any special initialization + # Just return the initial state + return nothing, vi +end + +""" + AbstractMCMC.step(rng, model, sampler::GibbsConditional, state) + +Perform a step of GibbsConditional sampling. +""" +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + sampler::DynamicPPL.Sampler{<:GibbsConditional{S}}, + state::DynamicPPL.AbstractVarInfo; + kwargs..., +) where {S} + alg = sampler.alg + + # For GibbsConditional within Gibbs, we need to get all variable values + # Check if we're in a Gibbs context + global_vi = if hasproperty(model, :context) && model.context isa GibbsContext + # We're in a Gibbs context, get the global varinfo + get_global_varinfo(model.context) + else + # We're not in a Gibbs context, use the current state + state + end + + # Extract conditioned values as a NamedTuple + # Include both random variables and observed data + condvals_vars = DynamicPPL.values_as(DynamicPPL.invlink(global_vi, model), NamedTuple) + condvals_obs = NamedTuple{keys(model.args)}(model.args) + condvals = merge(condvals_vars, condvals_obs) + + # Get the conditional distribution + conddist = alg.conditional(condvals) + + # Sample from the conditional distribution + updated = rand(rng, conddist) + + # Update the variable in state + # We need to get the actual VarName for this variable + # The symbol S tells us which variable to update + vn = VarName{S}() + + # Check if the variable needs to be a vector + new_vi = if haskey(state, vn) + # Update the existing variable + DynamicPPL.setindex!!(state, updated, vn) + else + # Try to find the variable with indices + # This handles cases where the variable might have indices + local updated_vi = state + found = false + for key in keys(state) + if DynamicPPL.getsym(key) == S + updated_vi = DynamicPPL.setindex!!(state, updated, key) + found = true + break + end + end + if !found + error("Could not find variable $S in VarInfo") + end + updated_vi + end + + # Update log joint probability + new_vi = last(DynamicPPL.evaluate!!(model, new_vi, DynamicPPL.DefaultContext())) + + return nothing, new_vi +end + +""" + setparams_varinfo!!(model, sampler::GibbsConditional, state, params::AbstractVarInfo) + +Update the variable info with new parameters for GibbsConditional. +""" +function setparams_varinfo!!( + model::DynamicPPL.Model, + sampler::DynamicPPL.Sampler{<:GibbsConditional}, + state, + params::DynamicPPL.AbstractVarInfo, +) + # For GibbsConditional, we just return the params as-is since + # the state is nothing and we don't need to update anything + return params +end + +""" + gibbs_initialstep_recursive( + rng, model, sampler::GibbsConditional, target_varnames, global_vi, prev_state + ) + +Initialize the GibbsConditional sampler. +""" +function gibbs_initialstep_recursive( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + sampler_wrapped::DynamicPPL.Sampler{<:GibbsConditional}, + target_varnames::AbstractVector{<:VarName}, + global_vi::DynamicPPL.AbstractVarInfo, + prev_state, +) + # GibbsConditional doesn't need any special initialization + # Just perform one sampling step + return gibbs_step_recursive( + rng, model, sampler_wrapped, target_varnames, global_vi, nothing + ) +end + +""" + gibbs_step_recursive( + rng, model, sampler::GibbsConditional, target_varnames, global_vi, state + ) + +Perform a single step of GibbsConditional sampling. +""" +function gibbs_step_recursive( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + sampler_wrapped::DynamicPPL.Sampler{<:GibbsConditional{S}}, + target_varnames::AbstractVector{<:VarName}, + global_vi::DynamicPPL.AbstractVarInfo, + state, +) where {S} + sampler = sampler_wrapped.alg + + # Extract conditioned values as a NamedTuple + # Include both random variables and observed data + condvals_vars = DynamicPPL.values_as(DynamicPPL.invlink(global_vi, model), NamedTuple) + condvals_obs = NamedTuple{keys(model.args)}(model.args) + condvals = merge(condvals_vars, condvals_obs) + + # Get the conditional distribution + conddist = sampler.conditional(condvals) + + # Sample from the conditional distribution + updated = rand(rng, conddist) + + # Update the variable in global_vi + # We need to get the actual VarName for this variable + # The symbol S tells us which variable to update + vn = VarName{S}() + + # Check if the variable needs to be a vector + if haskey(global_vi, vn) + # Update the existing variable + global_vi = DynamicPPL.setindex!!(global_vi, updated, vn) + else + # Try to find the variable with indices + # This handles cases where the variable might have indices + for key in keys(global_vi) + if DynamicPPL.getsym(key) == S + global_vi = DynamicPPL.setindex!!(global_vi, updated, key) + break + end + end + end + + # Update log joint probability + global_vi = last(DynamicPPL.evaluate!!(model, global_vi, DynamicPPL.DefaultContext())) + + return nothing, global_vi +end diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index f44a9fefc..a7884bb7e 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -882,6 +882,120 @@ end sampler = Gibbs(:w => HMC(0.05, 10)) @test (sample(model, sampler, 10); true) end + + @testset "GibbsConditional" begin + # Test with the inverse gamma example from the issue + @model function inverse_gdemo(x) + λ ~ Gamma(2, 3) + m ~ Normal(0, sqrt(1 / λ)) + for i in 1:length(x) + x[i] ~ Normal(m, sqrt(1 / λ)) + end + end + + # Define analytical conditionals + function cond_λ(c::NamedTuple) + a = 2.0 + b = 3.0 + m = c.m + x = c.x + n = length(x) + a_new = a + (n + 1) / 2 + b_new = b + sum((x[i] - m)^2 for i in 1:n) / 2 + m^2 / 2 + return Gamma(a_new, 1 / b_new) + end + + function cond_m(c::NamedTuple) + λ = c.λ + x = c.x + n = length(x) + m_mean = sum(x) / (n + 1) + m_var = 1 / (λ * (n + 1)) + return Normal(m_mean, sqrt(m_var)) + end + + # Test basic functionality + @testset "basic sampling" begin + Random.seed!(42) + x_obs = [1.0, 2.0, 3.0, 2.5, 1.5] + model = inverse_gdemo(x_obs) + + # Test that GibbsConditional works + sampler = Gibbs(GibbsConditional(:λ, cond_λ), GibbsConditional(:m, cond_m)) + chain = sample(model, sampler, 1000) + + # Check that we got the expected variables + @test :λ in names(chain) + @test :m in names(chain) + + # Check that the values are reasonable + λ_samples = vec(chain[:λ]) + m_samples = vec(chain[:m]) + + # Given the observed data, we expect certain behavior + @test mean(λ_samples) > 0 # λ should be positive + @test minimum(λ_samples) > 0 + @test std(m_samples) < 2.0 # m should be relatively well-constrained + end + + # Test mixing with other samplers + @testset "mixed samplers" begin + Random.seed!(42) + x_obs = [1.0, 2.0, 3.0] + model = inverse_gdemo(x_obs) + + # Mix GibbsConditional with standard samplers + sampler = Gibbs(GibbsConditional(:λ, cond_λ), :m => MH()) + chain = sample(model, sampler, 500) + + @test :λ in names(chain) + @test :m in names(chain) + @test size(chain, 1) == 500 + end + + # Test with a simpler model + @testset "simple normal model" begin + @model function simple_normal(x) + μ ~ Normal(0, 10) + σ ~ truncated(Normal(1, 1); lower=0.01) + for i in 1:length(x) + x[i] ~ Normal(μ, σ) + end + end + + # Conditional for μ given σ and x + function cond_μ(c::NamedTuple) + σ = c.σ + x = c.x + n = length(x) + # Prior: μ ~ Normal(0, 10) + # Likelihood: x[i] ~ Normal(μ, σ) + # Posterior: μ ~ Normal(μ_post, σ_post) + prior_var = 100.0 # 10^2 + likelihood_var = σ^2 / n + post_var = 1 / (1 / prior_var + n / σ^2) + post_mean = post_var * (0 / prior_var + sum(x) / σ^2) + return Normal(post_mean, sqrt(post_var)) + end + + Random.seed!(42) + x_obs = randn(10) .+ 2.0 # Data centered around 2 + model = simple_normal(x_obs) + + sampler = Gibbs(GibbsConditional(:μ, cond_μ), :σ => MH()) + + chain = sample(model, sampler, 1000) + + μ_samples = vec(chain[:μ]) + @test abs(mean(μ_samples) - 2.0) < 0.5 # Should be close to true mean + end + + # Test that GibbsConditional is marked as a valid component + @testset "isgibbscomponent" begin + gc = GibbsConditional(:x, c -> Normal(0, 1)) + @test Turing.Inference.isgibbscomponent(gc) + end + end end end diff --git a/test_gibbs_conditional.jl b/test_gibbs_conditional.jl new file mode 100644 index 000000000..d6466e537 --- /dev/null +++ b/test_gibbs_conditional.jl @@ -0,0 +1,78 @@ +using Turing +using Turing.Inference: GibbsConditional +using Distributions +using Random +using Statistics + +# Test with the inverse gamma example from the issue +@model function inverse_gdemo(x) + λ ~ Gamma(2, 3) + m ~ Normal(0, sqrt(1 / λ)) + for i in 1:length(x) + x[i] ~ Normal(m, sqrt(1 / λ)) + end +end + +# Define analytical conditionals +function cond_λ(c::NamedTuple) + a = 2.0 + b = 3.0 + m = c.m + x = c.x + n = length(x) + a_new = a + (n + 1) / 2 + b_new = b + sum((x[i] - m)^2 for i in 1:n) / 2 + m^2 / 2 + return Gamma(a_new, 1 / b_new) +end + +function cond_m(c::NamedTuple) + λ = c.λ + x = c.x + n = length(x) + m_mean = sum(x) / (n + 1) + m_var = 1 / (λ * (n + 1)) + return Normal(m_mean, sqrt(m_var)) +end + +# Generate some observed data +Random.seed!(42) +x_obs = [1.0, 2.0, 3.0, 2.5, 1.5] + +# Create the model +model = inverse_gdemo(x_obs) + +# Sample using GibbsConditional +println("Testing GibbsConditional sampler...") +sampler = Gibbs(:λ => GibbsConditional(:λ, cond_λ), :m => GibbsConditional(:m, cond_m)) + +# Run a short chain to test +chain = sample(model, sampler, 100) + +println("Sampling completed successfully!") +println("\nChain summary:") +println(chain) + +# Extract samples +λ_samples = vec(chain[:λ]) +m_samples = vec(chain[:m]) + +println("\nλ statistics:") +println(" Mean: ", mean(λ_samples)) +println(" Std: ", std(λ_samples)) +println(" Min: ", minimum(λ_samples)) +println(" Max: ", maximum(λ_samples)) + +println("\nm statistics:") +println(" Mean: ", mean(m_samples)) +println(" Std: ", std(m_samples)) +println(" Min: ", minimum(m_samples)) +println(" Max: ", maximum(m_samples)) + +# Test mixing with other samplers +println("\n\nTesting mixed samplers...") +sampler2 = Gibbs(:λ => GibbsConditional(:λ, cond_λ), :m => MH()) + +chain2 = sample(model, sampler2, 100) +println("Mixed sampling completed successfully!") +println("\nMixed chain summary:") +println(chain2) From a972b5a2b54216b2dfb5b5ee2fa807229be081bb Mon Sep 17 00:00:00 2001 From: AoifeHughes Date: Thu, 7 Aug 2025 09:13:15 +0100 Subject: [PATCH 2/8] clarified comment --- src/mcmc/gibbs_conditional.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcmc/gibbs_conditional.jl b/src/mcmc/gibbs_conditional.jl index e01e17aff..2bf7a7bb5 100644 --- a/src/mcmc/gibbs_conditional.jl +++ b/src/mcmc/gibbs_conditional.jl @@ -2,7 +2,7 @@ using DynamicPPL: VarName using Random: Random import AbstractMCMC -# These functions are defined in gibbs.jl which is loaded before this file +# These functions provide specialized methods for GibbsConditional that extend the generic implementations in gibbs.jl """ GibbsConditional(sym::Symbol, conditional) From c3cc7739cbcae0675ed50bff85dcdd34ddd29c3a Mon Sep 17 00:00:00 2001 From: AoifeHughes Date: Fri, 8 Aug 2025 11:11:32 +0100 Subject: [PATCH 3/8] add MHs suggestions --- src/mcmc/gibbs_conditional.jl | 148 +++++++++------------------------- 1 file changed, 37 insertions(+), 111 deletions(-) diff --git a/src/mcmc/gibbs_conditional.jl b/src/mcmc/gibbs_conditional.jl index 2bf7a7bb5..c2eba05ba 100644 --- a/src/mcmc/gibbs_conditional.jl +++ b/src/mcmc/gibbs_conditional.jl @@ -54,17 +54,45 @@ chain = sample(model, Gibbs( ), 1000) ``` """ -struct GibbsConditional{S,C} <: InferenceAlgorithm +struct GibbsConditional{C} <: InferenceAlgorithm conditional::C function GibbsConditional(sym::Symbol, conditional::C) where {C} - return new{sym,C}(conditional) + return new{C}(conditional) end end # Mark GibbsConditional as a valid Gibbs component isgibbscomponent(::GibbsConditional) = true +""" + find_global_varinfo(context, fallback_vi) + +Traverse the context stack to find global variable information from +GibbsContext, ConditionContext, FixedContext, etc. +""" +function find_global_varinfo(context, fallback_vi) + # Start with the given context and traverse down + current_context = context + + while current_context !== nothing + if current_context isa GibbsContext + # Found GibbsContext, return its global varinfo + return get_global_varinfo(current_context) + elseif hasproperty(current_context, :childcontext) && + isdefined(DynamicPPL, :childcontext) + # Move to child context if it exists + current_context = DynamicPPL.childcontext(current_context) + else + # No more child contexts + break + end + end + + # If no GibbsContext found, use the fallback + return fallback_vi +end + """ DynamicPPL.initialstep(rng, model, sampler::GibbsConditional, vi) @@ -97,12 +125,10 @@ function AbstractMCMC.step( alg = sampler.alg # For GibbsConditional within Gibbs, we need to get all variable values - # Check if we're in a Gibbs context - global_vi = if hasproperty(model, :context) && model.context isa GibbsContext - # We're in a Gibbs context, get the global varinfo - get_global_varinfo(model.context) + # Traverse the context stack to find all conditioned/fixed/Gibbs variables + global_vi = if hasproperty(model, :context) + find_global_varinfo(model.context, state) else - # We're not in a Gibbs context, use the current state state end @@ -119,34 +145,10 @@ function AbstractMCMC.step( updated = rand(rng, conddist) # Update the variable in state - # We need to get the actual VarName for this variable - # The symbol S tells us which variable to update - vn = VarName{S}() - - # Check if the variable needs to be a vector - new_vi = if haskey(state, vn) - # Update the existing variable - DynamicPPL.setindex!!(state, updated, vn) - else - # Try to find the variable with indices - # This handles cases where the variable might have indices - local updated_vi = state - found = false - for key in keys(state) - if DynamicPPL.getsym(key) == S - updated_vi = DynamicPPL.setindex!!(state, updated, key) - found = true - break - end - end - if !found - error("Could not find variable $S in VarInfo") - end - updated_vi - end - - # Update log joint probability - new_vi = last(DynamicPPL.evaluate!!(model, new_vi, DynamicPPL.DefaultContext())) + # The Gibbs sampler ensures that state only contains one variable + # Get the variable name from the keys + varname = first(keys(state)) + new_vi = DynamicPPL.setindex!!(state, updated, varname) return nothing, new_vi end @@ -167,79 +169,3 @@ function setparams_varinfo!!( return params end -""" - gibbs_initialstep_recursive( - rng, model, sampler::GibbsConditional, target_varnames, global_vi, prev_state - ) - -Initialize the GibbsConditional sampler. -""" -function gibbs_initialstep_recursive( - rng::Random.AbstractRNG, - model::DynamicPPL.Model, - sampler_wrapped::DynamicPPL.Sampler{<:GibbsConditional}, - target_varnames::AbstractVector{<:VarName}, - global_vi::DynamicPPL.AbstractVarInfo, - prev_state, -) - # GibbsConditional doesn't need any special initialization - # Just perform one sampling step - return gibbs_step_recursive( - rng, model, sampler_wrapped, target_varnames, global_vi, nothing - ) -end - -""" - gibbs_step_recursive( - rng, model, sampler::GibbsConditional, target_varnames, global_vi, state - ) - -Perform a single step of GibbsConditional sampling. -""" -function gibbs_step_recursive( - rng::Random.AbstractRNG, - model::DynamicPPL.Model, - sampler_wrapped::DynamicPPL.Sampler{<:GibbsConditional{S}}, - target_varnames::AbstractVector{<:VarName}, - global_vi::DynamicPPL.AbstractVarInfo, - state, -) where {S} - sampler = sampler_wrapped.alg - - # Extract conditioned values as a NamedTuple - # Include both random variables and observed data - condvals_vars = DynamicPPL.values_as(DynamicPPL.invlink(global_vi, model), NamedTuple) - condvals_obs = NamedTuple{keys(model.args)}(model.args) - condvals = merge(condvals_vars, condvals_obs) - - # Get the conditional distribution - conddist = sampler.conditional(condvals) - - # Sample from the conditional distribution - updated = rand(rng, conddist) - - # Update the variable in global_vi - # We need to get the actual VarName for this variable - # The symbol S tells us which variable to update - vn = VarName{S}() - - # Check if the variable needs to be a vector - if haskey(global_vi, vn) - # Update the existing variable - global_vi = DynamicPPL.setindex!!(global_vi, updated, vn) - else - # Try to find the variable with indices - # This handles cases where the variable might have indices - for key in keys(global_vi) - if DynamicPPL.getsym(key) == S - global_vi = DynamicPPL.setindex!!(global_vi, updated, key) - break - end - end - end - - # Update log joint probability - global_vi = last(DynamicPPL.evaluate!!(model, global_vi, DynamicPPL.DefaultContext())) - - return nothing, global_vi -end From 714c1e82979e5daa6cb1d005c113c967e9d4647a Mon Sep 17 00:00:00 2001 From: AoifeHughes Date: Fri, 8 Aug 2025 11:11:52 +0100 Subject: [PATCH 4/8] formatter --- src/mcmc/gibbs_conditional.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/mcmc/gibbs_conditional.jl b/src/mcmc/gibbs_conditional.jl index c2eba05ba..fe04b048d 100644 --- a/src/mcmc/gibbs_conditional.jl +++ b/src/mcmc/gibbs_conditional.jl @@ -74,13 +74,13 @@ GibbsContext, ConditionContext, FixedContext, etc. function find_global_varinfo(context, fallback_vi) # Start with the given context and traverse down current_context = context - + while current_context !== nothing if current_context isa GibbsContext # Found GibbsContext, return its global varinfo return get_global_varinfo(current_context) - elseif hasproperty(current_context, :childcontext) && - isdefined(DynamicPPL, :childcontext) + elseif hasproperty(current_context, :childcontext) && + isdefined(DynamicPPL, :childcontext) # Move to child context if it exists current_context = DynamicPPL.childcontext(current_context) else @@ -88,7 +88,7 @@ function find_global_varinfo(context, fallback_vi) break end end - + # If no GibbsContext found, use the fallback return fallback_vi end @@ -168,4 +168,3 @@ function setparams_varinfo!!( # the state is nothing and we don't need to update anything return params end - From 94b723da263927edfef7c20d8e56e543d0d84fc3 Mon Sep 17 00:00:00 2001 From: AoifeHughes Date: Fri, 8 Aug 2025 14:37:27 +0100 Subject: [PATCH 5/8] fixed exporting thing --- src/mcmc/gibbs_conditional.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/mcmc/gibbs_conditional.jl b/src/mcmc/gibbs_conditional.jl index fe04b048d..7415c5f3f 100644 --- a/src/mcmc/gibbs_conditional.jl +++ b/src/mcmc/gibbs_conditional.jl @@ -65,6 +65,9 @@ end # Mark GibbsConditional as a valid Gibbs component isgibbscomponent(::GibbsConditional) = true +# Required methods for Gibbs constructor +Base.length(::GibbsConditional) = 1 # Each GibbsConditional handles one variable + """ find_global_varinfo(context, fallback_vi) From 2058ae54e34111e17441c60ab001ba929284646c Mon Sep 17 00:00:00 2001 From: Aoife Date: Tue, 23 Sep 2025 13:09:12 +0100 Subject: [PATCH 6/8] Refactor Gibbs sampler to use inverse of parameters for Gamma distribution and improve context variable retrieval --- src/mcmc/gibbs_conditional.jl | 56 ++++++++++++--------- test/mcmc/gibbs.jl | 4 +- test_gibbs_conditional.jl | 93 ++++++++++++++++++++++------------- 3 files changed, 92 insertions(+), 61 deletions(-) diff --git a/src/mcmc/gibbs_conditional.jl b/src/mcmc/gibbs_conditional.jl index 7415c5f3f..74c0686b1 100644 --- a/src/mcmc/gibbs_conditional.jl +++ b/src/mcmc/gibbs_conditional.jl @@ -18,7 +18,7 @@ a `Distribution` from which to sample the variable `sym`. ```julia # Define a model @model function inverse_gdemo(x) - λ ~ Gamma(2, 3) + λ ~ Gamma(2, inv(3)) m ~ Normal(0, sqrt(1 / λ)) for i in 1:length(x) x[i] ~ Normal(m, sqrt(1 / λ)) @@ -28,7 +28,7 @@ end # Define analytical conditionals function cond_λ(c::NamedTuple) a = 2.0 - b = 3.0 + b = inv(3) m = c.m x = c.x n = length(x) @@ -75,25 +75,39 @@ Traverse the context stack to find global variable information from GibbsContext, ConditionContext, FixedContext, etc. """ function find_global_varinfo(context, fallback_vi) - # Start with the given context and traverse down + # Traverse the entire context stack to find relevant contexts current_context = context + gibbs_context = nothing + condition_context = nothing + fixed_context = nothing while current_context !== nothing - if current_context isa GibbsContext - # Found GibbsContext, return its global varinfo - return get_global_varinfo(current_context) - elseif hasproperty(current_context, :childcontext) && - isdefined(DynamicPPL, :childcontext) - # Move to child context if it exists + # Use NodeTrait for robust context checking + if DynamicPPL.NodeTrait(current_context) isa DynamicPPL.IsParent + if current_context isa GibbsContext + gibbs_context = current_context + elseif current_context isa DynamicPPL.ConditionContext + condition_context = current_context + elseif current_context isa DynamicPPL.FixedContext + fixed_context = current_context + end + # Move to child context current_context = DynamicPPL.childcontext(current_context) else - # No more child contexts break end end - # If no GibbsContext found, use the fallback - return fallback_vi + # Return the most relevant context's varinfo + if gibbs_context !== nothing + return get_global_varinfo(gibbs_context) + elseif condition_context !== nothing + return DynamicPPL.getvarinfo(condition_context) + elseif fixed_context !== nothing + return DynamicPPL.getvarinfo(fixed_context) + else + return fallback_vi + end end """ @@ -121,19 +135,15 @@ Perform a step of GibbsConditional sampling. function AbstractMCMC.step( rng::Random.AbstractRNG, model::DynamicPPL.Model, - sampler::DynamicPPL.Sampler{<:GibbsConditional{S}}, + sampler::DynamicPPL.Sampler{<:GibbsConditional}, state::DynamicPPL.AbstractVarInfo; kwargs..., -) where {S} +) alg = sampler.alg # For GibbsConditional within Gibbs, we need to get all variable values - # Traverse the context stack to find all conditioned/fixed/Gibbs variables - global_vi = if hasproperty(model, :context) - find_global_varinfo(model.context, state) - else - state - end + # Model always has a context field, so we can traverse it directly + global_vi = find_global_varinfo(model.context, state) # Extract conditioned values as a NamedTuple # Include both random variables and observed data @@ -147,11 +157,9 @@ function AbstractMCMC.step( # Sample from the conditional distribution updated = rand(rng, conddist) - # Update the variable in state + # Update the variable in state using unflatten for simplicity # The Gibbs sampler ensures that state only contains one variable - # Get the variable name from the keys - varname = first(keys(state)) - new_vi = DynamicPPL.setindex!!(state, updated, varname) + new_vi = DynamicPPL.unflatten(state, [updated]) return nothing, new_vi end diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index d7c41d70d..a825401a1 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -946,7 +946,7 @@ end @testset "GibbsConditional" begin # Test with the inverse gamma example from the issue @model function inverse_gdemo(x) - λ ~ Gamma(2, 3) + λ ~ Gamma(2, inv(3)) m ~ Normal(0, sqrt(1 / λ)) for i in 1:length(x) x[i] ~ Normal(m, sqrt(1 / λ)) @@ -956,7 +956,7 @@ end # Define analytical conditionals function cond_λ(c::NamedTuple) a = 2.0 - b = 3.0 + b = inv(3) m = c.m x = c.x n = length(x) diff --git a/test_gibbs_conditional.jl b/test_gibbs_conditional.jl index d6466e537..1a01fa9b2 100644 --- a/test_gibbs_conditional.jl +++ b/test_gibbs_conditional.jl @@ -3,10 +3,11 @@ using Turing.Inference: GibbsConditional using Distributions using Random using Statistics +using Test # Test with the inverse gamma example from the issue @model function inverse_gdemo(x) - λ ~ Gamma(2, 3) + λ ~ Gamma(2, inv(3)) m ~ Normal(0, sqrt(1 / λ)) for i in 1:length(x) x[i] ~ Normal(m, sqrt(1 / λ)) @@ -16,7 +17,7 @@ end # Define analytical conditionals function cond_λ(c::NamedTuple) a = 2.0 - b = 3.0 + b = inv(3) m = c.m x = c.x n = length(x) @@ -34,45 +35,67 @@ function cond_m(c::NamedTuple) return Normal(m_mean, sqrt(m_var)) end -# Generate some observed data -Random.seed!(42) -x_obs = [1.0, 2.0, 3.0, 2.5, 1.5] +@testset "GibbsConditional Integration Tests" begin + # Generate some observed data + Random.seed!(42) + x_obs = [1.0, 2.0, 3.0, 2.5, 1.5] -# Create the model -model = inverse_gdemo(x_obs) + # Create the model + model = inverse_gdemo(x_obs) -# Sample using GibbsConditional -println("Testing GibbsConditional sampler...") -sampler = Gibbs(:λ => GibbsConditional(:λ, cond_λ), :m => GibbsConditional(:m, cond_m)) + @testset "Basic GibbsConditional sampling" begin + # Sample using GibbsConditional + sampler = Gibbs(:λ => GibbsConditional(:λ, cond_λ), :m => GibbsConditional(:m, cond_m)) -# Run a short chain to test -chain = sample(model, sampler, 100) + # Run a short chain to test + chain = sample(model, sampler, 100) -println("Sampling completed successfully!") -println("\nChain summary:") -println(chain) + # Test that sampling completed successfully + @test chain isa MCMCChains.Chains + @test size(chain, 1) == 100 + @test :λ in names(chain) + @test :m in names(chain) + end + + @testset "Sample statistics" begin + # Generate samples for statistics testing + sampler = Gibbs(:λ => GibbsConditional(:λ, cond_λ), :m => GibbsConditional(:m, cond_m)) + chain = sample(model, sampler, 100) -# Extract samples -λ_samples = vec(chain[:λ]) -m_samples = vec(chain[:m]) + # Extract samples + λ_samples = vec(chain[:λ]) + m_samples = vec(chain[:m]) -println("\nλ statistics:") -println(" Mean: ", mean(λ_samples)) -println(" Std: ", std(λ_samples)) -println(" Min: ", minimum(λ_samples)) -println(" Max: ", maximum(λ_samples)) + # Test λ statistics + @test mean(λ_samples) > 0 # λ should be positive + @test minimum(λ_samples) > 0 # All λ samples should be positive + @test std(λ_samples) > 0 # Should have some variability + @test isfinite(mean(λ_samples)) + @test isfinite(std(λ_samples)) + + # Test m statistics + @test isfinite(mean(m_samples)) + @test isfinite(std(m_samples)) + @test std(m_samples) > 0 # Should have some variability + end -println("\nm statistics:") -println(" Mean: ", mean(m_samples)) -println(" Std: ", std(m_samples)) -println(" Min: ", minimum(m_samples)) -println(" Max: ", maximum(m_samples)) + @testset "Mixed samplers" begin + # Test mixing with other samplers + sampler2 = Gibbs(:λ => GibbsConditional(:λ, cond_λ), :m => MH()) -# Test mixing with other samplers -println("\n\nTesting mixed samplers...") -sampler2 = Gibbs(:λ => GibbsConditional(:λ, cond_λ), :m => MH()) + chain2 = sample(model, sampler2, 100) -chain2 = sample(model, sampler2, 100) -println("Mixed sampling completed successfully!") -println("\nMixed chain summary:") -println(chain2) + # Test that mixed sampling completed successfully + @test chain2 isa MCMCChains.Chains + @test size(chain2, 1) == 100 + @test :λ in names(chain2) + @test :m in names(chain2) + + # Test that values are reasonable + λ_samples2 = vec(chain2[:λ]) + m_samples2 = vec(chain2[:m]) + @test all(λ_samples2 .> 0) # All λ should be positive + @test all(isfinite.(λ_samples2)) + @test all(isfinite.(m_samples2)) + end +end From b0812a3bfc3be6e3be11fdcdb06b34a324b4c26c Mon Sep 17 00:00:00 2001 From: AoifeHughes Date: Thu, 25 Sep 2025 09:31:31 +0100 Subject: [PATCH 7/8] removed file added by mistake --- test_gibbs_conditional.jl | 101 -------------------------------------- 1 file changed, 101 deletions(-) delete mode 100644 test_gibbs_conditional.jl diff --git a/test_gibbs_conditional.jl b/test_gibbs_conditional.jl deleted file mode 100644 index 1a01fa9b2..000000000 --- a/test_gibbs_conditional.jl +++ /dev/null @@ -1,101 +0,0 @@ -using Turing -using Turing.Inference: GibbsConditional -using Distributions -using Random -using Statistics -using Test - -# Test with the inverse gamma example from the issue -@model function inverse_gdemo(x) - λ ~ Gamma(2, inv(3)) - m ~ Normal(0, sqrt(1 / λ)) - for i in 1:length(x) - x[i] ~ Normal(m, sqrt(1 / λ)) - end -end - -# Define analytical conditionals -function cond_λ(c::NamedTuple) - a = 2.0 - b = inv(3) - m = c.m - x = c.x - n = length(x) - a_new = a + (n + 1) / 2 - b_new = b + sum((x[i] - m)^2 for i in 1:n) / 2 + m^2 / 2 - return Gamma(a_new, 1 / b_new) -end - -function cond_m(c::NamedTuple) - λ = c.λ - x = c.x - n = length(x) - m_mean = sum(x) / (n + 1) - m_var = 1 / (λ * (n + 1)) - return Normal(m_mean, sqrt(m_var)) -end - -@testset "GibbsConditional Integration Tests" begin - # Generate some observed data - Random.seed!(42) - x_obs = [1.0, 2.0, 3.0, 2.5, 1.5] - - # Create the model - model = inverse_gdemo(x_obs) - - @testset "Basic GibbsConditional sampling" begin - # Sample using GibbsConditional - sampler = Gibbs(:λ => GibbsConditional(:λ, cond_λ), :m => GibbsConditional(:m, cond_m)) - - # Run a short chain to test - chain = sample(model, sampler, 100) - - # Test that sampling completed successfully - @test chain isa MCMCChains.Chains - @test size(chain, 1) == 100 - @test :λ in names(chain) - @test :m in names(chain) - end - - @testset "Sample statistics" begin - # Generate samples for statistics testing - sampler = Gibbs(:λ => GibbsConditional(:λ, cond_λ), :m => GibbsConditional(:m, cond_m)) - chain = sample(model, sampler, 100) - - # Extract samples - λ_samples = vec(chain[:λ]) - m_samples = vec(chain[:m]) - - # Test λ statistics - @test mean(λ_samples) > 0 # λ should be positive - @test minimum(λ_samples) > 0 # All λ samples should be positive - @test std(λ_samples) > 0 # Should have some variability - @test isfinite(mean(λ_samples)) - @test isfinite(std(λ_samples)) - - # Test m statistics - @test isfinite(mean(m_samples)) - @test isfinite(std(m_samples)) - @test std(m_samples) > 0 # Should have some variability - end - - @testset "Mixed samplers" begin - # Test mixing with other samplers - sampler2 = Gibbs(:λ => GibbsConditional(:λ, cond_λ), :m => MH()) - - chain2 = sample(model, sampler2, 100) - - # Test that mixed sampling completed successfully - @test chain2 isa MCMCChains.Chains - @test size(chain2, 1) == 100 - @test :λ in names(chain2) - @test :m in names(chain2) - - # Test that values are reasonable - λ_samples2 = vec(chain2[:λ]) - m_samples2 = vec(chain2[:m]) - @test all(λ_samples2 .> 0) # All λ should be positive - @test all(isfinite.(λ_samples2)) - @test all(isfinite.(m_samples2)) - end -end From d91031205ce308881f652f4f135734243d06eaa9 Mon Sep 17 00:00:00 2001 From: AoifeHughes Date: Mon, 29 Sep 2025 12:48:39 +0100 Subject: [PATCH 8/8] Add safety checks and error handling in find_global_varinfo and AbstractMCMC.step functions --- src/mcmc/gibbs_conditional.jl | 132 ++++++++++++++++++++++++---------- 1 file changed, 95 insertions(+), 37 deletions(-) diff --git a/src/mcmc/gibbs_conditional.jl b/src/mcmc/gibbs_conditional.jl index 74c0686b1..8401d3405 100644 --- a/src/mcmc/gibbs_conditional.jl +++ b/src/mcmc/gibbs_conditional.jl @@ -81,33 +81,56 @@ function find_global_varinfo(context, fallback_vi) condition_context = nothing fixed_context = nothing - while current_context !== nothing - # Use NodeTrait for robust context checking - if DynamicPPL.NodeTrait(current_context) isa DynamicPPL.IsParent - if current_context isa GibbsContext - gibbs_context = current_context - elseif current_context isa DynamicPPL.ConditionContext - condition_context = current_context - elseif current_context isa DynamicPPL.FixedContext - fixed_context = current_context + # Safety check: avoid infinite loops with a maximum depth + max_depth = 20 + depth = 0 + + while current_context !== nothing && depth < max_depth + depth += 1 + + try + # Use NodeTrait for robust context checking + if DynamicPPL.NodeTrait(current_context) isa DynamicPPL.IsParent + if current_context isa GibbsContext + gibbs_context = current_context + elseif current_context isa DynamicPPL.ConditionContext + condition_context = current_context + elseif current_context isa DynamicPPL.FixedContext + fixed_context = current_context + end + # Move to child context + current_context = DynamicPPL.childcontext(current_context) + else + break end - # Move to child context - current_context = DynamicPPL.childcontext(current_context) - else + catch e + # If there's an error traversing contexts, break and use fallback + @debug "Error traversing context at depth $depth: $e" break end end - # Return the most relevant context's varinfo - if gibbs_context !== nothing - return get_global_varinfo(gibbs_context) - elseif condition_context !== nothing - return DynamicPPL.getvarinfo(condition_context) - elseif fixed_context !== nothing - return DynamicPPL.getvarinfo(fixed_context) - else - return fallback_vi + # Return the most relevant context's varinfo with error handling + try + if gibbs_context !== nothing + return get_global_varinfo(gibbs_context) + elseif condition_context !== nothing + # Check if getvarinfo method exists for ConditionContext + if hasmethod(DynamicPPL.getvarinfo, (typeof(condition_context),)) + return DynamicPPL.getvarinfo(condition_context) + end + elseif fixed_context !== nothing + # Check if getvarinfo method exists for FixedContext + if hasmethod(DynamicPPL.getvarinfo, (typeof(fixed_context),)) + return DynamicPPL.getvarinfo(fixed_context) + end + end + catch e + @debug "Error accessing varinfo from context: $e" end + + # Fall back to the provided fallback_vi + return fallback_vi end """ @@ -141,27 +164,62 @@ function AbstractMCMC.step( ) alg = sampler.alg - # For GibbsConditional within Gibbs, we need to get all variable values - # Model always has a context field, so we can traverse it directly - global_vi = find_global_varinfo(model.context, state) + try + # For GibbsConditional within Gibbs, we need to get all variable values + # Model always has a context field, so we can traverse it directly + global_vi = find_global_varinfo(model.context, state) + + # Extract conditioned values as a NamedTuple + # Include both random variables and observed data + # Use a safe approach for invlink to avoid linking conflicts + invlinked_global_vi = try + DynamicPPL.invlink(global_vi, model) + catch e + @debug "Failed to invlink global_vi, using as-is: $e" + global_vi + end - # Extract conditioned values as a NamedTuple - # Include both random variables and observed data - condvals_vars = DynamicPPL.values_as(DynamicPPL.invlink(global_vi, model), NamedTuple) - condvals_obs = NamedTuple{keys(model.args)}(model.args) - condvals = merge(condvals_vars, condvals_obs) + condvals_vars = DynamicPPL.values_as(invlinked_global_vi, NamedTuple) + condvals_obs = NamedTuple{keys(model.args)}(model.args) + condvals = merge(condvals_vars, condvals_obs) - # Get the conditional distribution - conddist = alg.conditional(condvals) + # Get the conditional distribution + conddist = alg.conditional(condvals) - # Sample from the conditional distribution - updated = rand(rng, conddist) + # Sample from the conditional distribution + updated = rand(rng, conddist) - # Update the variable in state using unflatten for simplicity - # The Gibbs sampler ensures that state only contains one variable - new_vi = DynamicPPL.unflatten(state, [updated]) + # Update the variable in state, handling linking properly + # The Gibbs sampler ensures that state only contains one variable + state_is_linked = try + DynamicPPL.islinked(state, model) + catch e + @debug "Error checking if state is linked: $e" + false + end - return nothing, new_vi + if state_is_linked + # If state is linked, we need to unlink, update, then relink + try + unlinked_state = DynamicPPL.invlink(state, model) + updated_state = DynamicPPL.unflatten(unlinked_state, [updated]) + new_vi = DynamicPPL.link(updated_state, model) + catch e + @debug "Error in linked state update path: $e, falling back to direct update" + new_vi = DynamicPPL.unflatten(state, [updated]) + end + else + # State is not linked, we can update directly + new_vi = DynamicPPL.unflatten(state, [updated]) + end + + return nothing, new_vi + + catch e + # If there's any error in the step, log it and rethrow + @error "Error in GibbsConditional step: $e" + rethrow(e) + end end """