From 772ece395b5b3a00e5d3ce24baef8259a4b8bf69 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Fri, 27 Jun 2025 20:08:04 +0200 Subject: [PATCH 1/6] feat: allow sampling with fixed step size --- src/stepsize_adapt.rs | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/stepsize_adapt.rs b/src/stepsize_adapt.rs index f8323e9..00819e6 100644 --- a/src/stepsize_adapt.rs +++ b/src/stepsize_adapt.rs @@ -39,6 +39,10 @@ impl Strategy { position: &[f64], rng: &mut R, ) -> Result<(), NutsError> { + if let Some(step_size) = self.options.fixed_step_size { + *hamiltonian.step_size_mut() = step_size; + return Ok(()); + } let mut state = hamiltonian.init_state(math, position)?; hamiltonian.initialize_trajectory(math, &mut state, rng)?; @@ -118,13 +122,17 @@ impl Strategy { pub fn update_stepsize( &mut self, - potential: &mut impl Hamiltonian, + hamiltonian: &mut impl Hamiltonian, use_best_guess: bool, ) { + if let Some(step_size) = self.options.fixed_step_size { + *hamiltonian.step_size_mut() = step_size; + return; + } if use_best_guess { - *potential.step_size_mut() = self.step_size_adapt.current_step_size_adapted(); + *hamiltonian.step_size_mut() = self.step_size_adapt.current_step_size_adapted(); } else { - *potential.step_size_mut() = self.step_size_adapt.current_step_size(); + *hamiltonian.step_size_mut() = self.step_size_adapt.current_step_size(); } } @@ -226,6 +234,7 @@ pub struct DualAverageSettings { pub target_accept: f64, pub initial_step: f64, pub params: DualAverageOptions, + pub fixed_step_size: Option, } impl Default for DualAverageSettings { @@ -234,6 +243,7 @@ impl Default for DualAverageSettings { target_accept: 0.8, initial_step: 0.1, params: DualAverageOptions::default(), + fixed_step_size: None, } } } From 57c8ae6e0ea92b633bb00f9b0846fbff5b649d7a Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Fri, 27 Jun 2025 20:08:04 +0200 Subject: [PATCH 2/6] feat: add untested walnuts implementation --- src/adapt_strategy.rs | 1 + src/euclidean_hamiltonian.rs | 3 +- src/hamiltonian.rs | 25 ++++++ src/nuts.rs | 143 +++++++++++++++++++++++++++++++-- src/sampler.rs | 9 ++- src/stepsize_adapt.rs | 5 +- src/transformed_hamiltonian.rs | 3 +- 7 files changed, 178 insertions(+), 11 deletions(-) diff --git a/src/adapt_strategy.rs b/src/adapt_strategy.rs index 5b14935..e891e18 100644 --- a/src/adapt_strategy.rs +++ b/src/adapt_strategy.rs @@ -503,6 +503,7 @@ mod test { store_unconstrained: true, check_turning: true, store_divergences: false, + walnuts_options: None, }; let rng = { diff --git a/src/euclidean_hamiltonian.rs b/src/euclidean_hamiltonian.rs index ae5be9a..9e5d750 100644 --- a/src/euclidean_hamiltonian.rs +++ b/src/euclidean_hamiltonian.rs @@ -309,6 +309,7 @@ impl> Hamiltonian for EuclideanHamiltonian, dir: Direction, + step_size_factor: f64, collector: &mut C, ) -> LeapfrogResult { let mut out = self.pool().new_state(math); @@ -321,7 +322,7 @@ impl> Hamiltonian for EuclideanHamiltonian -1, }; - let epsilon = (sign as f64) * self.step_size; + let epsilon = (sign as f64) * self.step_size * step_size_factor; start .point() diff --git a/src/hamiltonian.rs b/src/hamiltonian.rs index 1584904..7a49979 100644 --- a/src/hamiltonian.rs +++ b/src/hamiltonian.rs @@ -28,12 +28,36 @@ pub struct DivergenceInfo { pub logp_function_error: Option>, } +impl DivergenceInfo { + pub fn new() -> Self { + DivergenceInfo { + start_momentum: None, + start_location: None, + start_gradient: None, + end_location: None, + energy_error: None, + end_idx_in_trajectory: None, + start_idx_in_trajectory: None, + logp_function_error: None, + } + } +} + #[derive(Debug, Copy, Clone)] pub enum Direction { Forward, Backward, } +impl Direction { + pub fn reverse(&self) -> Self { + match self { + Direction::Forward => Direction::Backward, + Direction::Backward => Direction::Forward, + } + } +} + impl Distribution for StandardUniform { fn sample(&self, rng: &mut R) -> Direction { if rng.random::() { @@ -82,6 +106,7 @@ pub trait Hamiltonian: SamplerStats + Sized { math: &mut M, start: &State, dir: Direction, + step_size_factor: f64, collector: &mut C, ) -> LeapfrogResult; diff --git a/src/nuts.rs b/src/nuts.rs index 0072c69..2a75da7 100644 --- a/src/nuts.rs +++ b/src/nuts.rs @@ -120,7 +120,7 @@ impl, C: Collector> NutsTree { H: Hamiltonian, R: rand::Rng + ?Sized, { - let mut other = match self.single_step(math, hamiltonian, direction, collector) { + let mut other = match self.single_step(math, hamiltonian, direction, options, collector) { Ok(Ok(tree)) => tree, Ok(Err(info)) => return ExtendResult::Diverging(self, info), Err(err) => return ExtendResult::Err(err), @@ -213,19 +213,141 @@ impl, C: Collector> NutsTree { math: &mut M, hamiltonian: &mut H, direction: Direction, + options: &NutsOptions, collector: &mut C, ) -> Result, DivergenceInfo>> { let start = match direction { Direction::Forward => &self.right, Direction::Backward => &self.left, }; - let end = match hamiltonian.leapfrog(math, start, direction, collector) { - LeapfrogResult::Divergence(info) => return Ok(Err(info)), - LeapfrogResult::Err(err) => return Err(NutsError::LogpFailure(err.into())), - LeapfrogResult::Ok(end) => end, + + let (log_size, end) = match options.walnuts_options { + Some(ref options) => { + // Walnuts implementation + // TODO: Shouldn't all be in this one big function... + let mut step_size_factor = 1.0; + let mut num_steps = 1; + let mut current = start.clone(); + + let mut success = false; + + 'step_size_search: for _ in 0..options.max_step_size_halvings { + current = start.clone(); + let mut min_energy = current.energy(); + let mut max_energy = min_energy; + + for _ in 0..num_steps { + current = match hamiltonian.leapfrog( + math, + ¤t, + direction, + step_size_factor, + collector, + ) { + LeapfrogResult::Ok(state) => state, + LeapfrogResult::Divergence(_) => { + num_steps *= 2; + step_size_factor *= 0.5; + continue 'step_size_search; + } + LeapfrogResult::Err(err) => { + return Err(NutsError::LogpFailure(err.into())); + } + }; + + // Update min/max energies + let current_energy = current.energy(); + min_energy = min_energy.min(current_energy); + max_energy = max_energy.max(current_energy); + } + + if max_energy - min_energy > options.max_energy_error { + num_steps *= 2; + step_size_factor *= 0.5; + continue 'step_size_search; + } + + success = true; + break 'step_size_search; + } + + if !success { + // TODO: More info + return Ok(Err(DivergenceInfo::new())); + } + + // TODO + let back = direction.reverse(); + let mut current_backward; + + let mut reversible = true; + + 'rev_step_size: while num_steps >= 2 { + num_steps /= 2; + step_size_factor *= 0.5; + + // TODO: Can we share code for the micro steps in the two directions? + current_backward = current.clone(); + + let mut min_energy = current_backward.energy(); + let mut max_energy = min_energy; + + for _ in 0..num_steps { + current_backward = match hamiltonian.leapfrog( + math, + ¤t_backward, + back, + step_size_factor, + collector, + ) { + LeapfrogResult::Ok(state) => state, + LeapfrogResult::Divergence(_) => { + // We also reject in the backward direction, all is good so far... + continue 'rev_step_size; + } + LeapfrogResult::Err(err) => { + return Err(NutsError::LogpFailure(err.into())); + } + }; + + // Update min/max energies + let current_energy = current_backward.energy(); + min_energy = min_energy.min(current_energy); + max_energy = max_energy.max(current_energy); + if max_energy - min_energy > options.max_energy_error { + // We reject also in the backward direction, all good so far... + continue 'rev_step_size; + } + } + + // We did not reject in the backward direction, so we are not reversible + reversible = false; + break; + } + + if reversible { + let log_size = -current.point().energy_error(); + (log_size, current) + } else { + // TODO: More info + return Ok(Err(DivergenceInfo::new())); + } + } + None => { + // Classical NUTS + // + let end = match hamiltonian.leapfrog(math, start, direction, 1.0, collector) { + LeapfrogResult::Divergence(info) => return Ok(Err(info)), + LeapfrogResult::Err(err) => return Err(NutsError::LogpFailure(err.into())), + LeapfrogResult::Ok(end) => end, + }; + + let log_size = -end.point().energy_error(); + + (log_size, end) + } }; - let log_size = -end.point().energy_error(); Ok(Ok(NutsTree { right: end.clone(), left: end.clone(), @@ -248,12 +370,21 @@ impl, C: Collector> NutsTree { } } +#[derive(Debug, Clone, Copy)] +pub struct WalnutsOptions { + pub max_energy_error: f64, + pub max_step_size_halvings: u64, +} + +#[derive(Debug, Clone, Copy)] pub struct NutsOptions { pub maxdepth: u64, pub store_gradient: bool, pub store_unconstrained: bool, pub check_turning: bool, pub store_divergences: bool, + + pub walnuts_options: Option, } pub(crate) fn draw( diff --git a/src/sampler.rs b/src/sampler.rs index 0bb4f7c..e80f629 100644 --- a/src/sampler.rs +++ b/src/sampler.rs @@ -24,7 +24,7 @@ use crate::{ mass_matrix::DiagMassMatrix, mass_matrix_adapt::Strategy as DiagMassMatrixStrategy, math_base::Math, - nuts::NutsOptions, + nuts::{NutsOptions, WalnutsOptions}, sampler_stats::{SamplerStats, StatTraceBuilder}, transform_adapt_strategy::{TransformAdaptation, TransformedSettings}, transformed_hamiltonian::{TransformedHamiltonian, TransformedPointStatsOptions}, @@ -102,6 +102,7 @@ pub struct NutsSettings { pub num_chains: usize, pub seed: u64, + pub walnuts_options: Option, } pub type DiagGradNutsSettings = NutsSettings>; @@ -122,6 +123,7 @@ impl Default for DiagGradNutsSettings { check_turning: true, seed: 0, num_chains: 6, + walnuts_options: None, } } } @@ -140,6 +142,7 @@ impl Default for LowRankNutsSettings { check_turning: true, seed: 0, num_chains: 6, + walnuts_options: None, }; vals.adapt_options.mass_matrix_update_freq = 10; vals @@ -160,6 +163,7 @@ impl Default for TransformedNutsSettings { check_turning: true, seed: 0, num_chains: 1, + walnuts_options: None, } } } @@ -191,6 +195,7 @@ impl Settings for LowRankNutsSettings { store_divergences: self.store_divergences, store_unconstrained: self.store_unconstrained, check_turning: self.check_turning, + walnuts_options: self.walnuts_options, }; let rng = rand::rngs::SmallRng::try_from_rng(&mut rng).expect("Could not seed rng"); @@ -250,6 +255,7 @@ impl Settings for DiagGradNutsSettings { store_divergences: self.store_divergences, store_unconstrained: self.store_unconstrained, check_turning: self.check_turning, + walnuts_options: self.walnuts_options, }; let rng = rand::rngs::SmallRng::try_from_rng(&mut rng).expect("Could not seed rng"); @@ -306,6 +312,7 @@ impl Settings for TransformedNutsSettings { store_divergences: self.store_divergences, store_unconstrained: self.store_unconstrained, check_turning: self.check_turning, + walnuts_options: self.walnuts_options, }; let rng = rand::rngs::SmallRng::try_from_rng(&mut rng).expect("Could not seed rng"); diff --git a/src/stepsize_adapt.rs b/src/stepsize_adapt.rs index 00819e6..d55942b 100644 --- a/src/stepsize_adapt.rs +++ b/src/stepsize_adapt.rs @@ -52,7 +52,8 @@ impl Strategy { *hamiltonian.step_size_mut() = self.options.initial_step; - let state_next = hamiltonian.leapfrog(math, &state, Direction::Forward, &mut collector); + let state_next = + hamiltonian.leapfrog(math, &state, Direction::Forward, 1.0, &mut collector); let LeapfrogResult::Ok(_) = state_next else { return Ok(()); @@ -68,7 +69,7 @@ impl Strategy { for _ in 0..100 { let mut collector = AcceptanceRateCollector::new(); collector.register_init(math, &state, options); - let state_next = hamiltonian.leapfrog(math, &state, dir, &mut collector); + let state_next = hamiltonian.leapfrog(math, &state, dir, 1.0, &mut collector); let LeapfrogResult::Ok(_) = state_next else { *hamiltonian.step_size_mut() = self.options.initial_step; return Ok(()); diff --git a/src/transformed_hamiltonian.rs b/src/transformed_hamiltonian.rs index 93581a7..014c9f7 100644 --- a/src/transformed_hamiltonian.rs +++ b/src/transformed_hamiltonian.rs @@ -456,6 +456,7 @@ impl Hamiltonian for TransformedHamiltonian { math: &mut M, start: &State, dir: Direction, + step_size_factor: f64, collector: &mut C, ) -> LeapfrogResult { let mut out = self.pool().new_state(math); @@ -469,7 +470,7 @@ impl Hamiltonian for TransformedHamiltonian { Direction::Backward => -1, }; - let epsilon = (sign as f64) * self.step_size; + let epsilon = (sign as f64) * self.step_size * step_size_factor; start .point() From 10bd9b9f038d38ca9b9d028ae4be8622152dd75e Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Fri, 27 Jun 2025 20:08:04 +0200 Subject: [PATCH 3/6] doc: add temporary comments for comparison with walnuts c++ --- src/nuts.rs | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/nuts.rs b/src/nuts.rs index 2a75da7..bcc4867 100644 --- a/src/nuts.rs +++ b/src/nuts.rs @@ -59,22 +59,41 @@ pub struct SampleInfo { } /// A part of the trajectory tree during NUTS sampling. +/// +/// Corresponds to SpanW in walnuts C++ code struct NutsTree, C: Collector> { /// The left position of the tree. /// /// The left side always has the smaller index_in_trajectory. /// Leapfrogs in backward direction will replace the left. + /// + /// theta_bk_, rho_bk_, grad_theta_bk_, logp_bk_ in C++ code left: State, + + /// The right position of the tree. + /// + /// theta_fw_, rho_fw_, grad_theta_fw_, logp_fw_ in C++ code right: State, /// A draw from the trajectory between left and right using /// multinomial sampling. + /// + /// theta_select_ in C++ code draw: State, + + /// Constant for acceptance probability + /// + /// logp_ in C++ code log_size: f64, + + /// The depth of the tree depth: u64, /// A tree is the main tree if it contains the initial point /// of the trajectory. + /// + /// This is used to determine whether to use Metropolis + /// accptance or Barker is_main: bool, _phantom2: PhantomData, } @@ -171,6 +190,7 @@ impl, C: Collector> NutsTree { } } + // `combine` in C++ code fn merge_into( &mut self, _math: &mut M, @@ -208,6 +228,7 @@ impl, C: Collector> NutsTree { self.log_size = log_size; } + // Corresponds to `build_leaf` in C++ code fn single_step( &self, math: &mut M, From 3648e00f47ba91bfc073341e1bbb4af6cf77190b Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Sat, 28 Jun 2025 12:27:47 +0200 Subject: [PATCH 4/6] feat: clean up walnuts a little bit --- src/adapt_strategy.rs | 5 +- src/chain.rs | 33 ++++++- src/euclidean_hamiltonian.rs | 44 ++++----- src/hamiltonian.rs | 99 ++++++++++++++++++- src/lib.rs | 2 +- src/nuts.rs | 169 ++++++++++++++------------------ src/stepsize.rs | 1 + src/stepsize_adapt.rs | 6 +- src/transform_adapt_strategy.rs | 1 + src/transformed_hamiltonian.rs | 43 ++++---- 10 files changed, 248 insertions(+), 155 deletions(-) diff --git a/src/adapt_strategy.rs b/src/adapt_strategy.rs index e891e18..eb00e4e 100644 --- a/src/adapt_strategy.rs +++ b/src/adapt_strategy.rs @@ -339,11 +339,12 @@ where start: &State, end: &State, divergence_info: Option<&DivergenceInfo>, + num_substeps: u64, ) { self.collector1 - .register_leapfrog(math, start, end, divergence_info); + .register_leapfrog(math, start, end, divergence_info, num_substeps); self.collector2 - .register_leapfrog(math, start, end, divergence_info); + .register_leapfrog(math, start, end, divergence_info, num_substeps); } fn register_draw(&mut self, math: &mut M, state: &State, info: &crate::nuts::SampleInfo) { diff --git a/src/chain.rs b/src/chain.rs index 441a32e..755531d 100644 --- a/src/chain.rs +++ b/src/chain.rs @@ -183,6 +183,7 @@ where &mut self.hamiltonian, &self.options, &mut self.collector, + self.draw_count < 70, )?; let mut position: Box<[f64]> = vec![0f64; math.dim()].into(); state.write_position(math, &mut position); @@ -237,6 +238,7 @@ pub struct NutsStatsBuilder> { divergence_start_grad: Option>>, divergence_end: Option>>, divergence_momentum: Option>>, + non_reversible: Option, divergence_msg: Option, } @@ -274,7 +276,9 @@ impl> NutsStatsBuilder { None }; - let (div_start, div_start_grad, div_end, div_mom, div_msg) = if options.store_divergences { + let (div_start, div_start_grad, div_end, div_mom, non_rev, div_msg) = if options + .store_divergences + { let start_location_prim = PrimitiveBuilder::new(); let start_location_list = FixedSizeListBuilder::new(start_location_prim, dim as i32); @@ -288,6 +292,8 @@ impl> NutsStatsBuilder { let momentum_location_list = FixedSizeListBuilder::new(momentum_location_prim, dim as i32); + let non_reversible = BooleanBuilder::new(); + let msg_list = StringBuilder::new(); ( @@ -295,10 +301,11 @@ impl> NutsStatsBuilder { Some(start_grad_list), Some(end_location_list), Some(momentum_location_list), + Some(non_reversible), Some(msg_list), ) } else { - (None, None, None, None, None) + (None, None, None, None, None, None) }; Self { @@ -320,6 +327,7 @@ impl> NutsStatsBuilder { divergence_start_grad: div_start_grad, divergence_end: div_end, divergence_momentum: div_mom, + non_reversible: non_rev, divergence_msg: div_msg, } } @@ -350,6 +358,7 @@ impl> StatTraceBuilder> StatTraceBuilder> StatTraceBuilder> StatTraceBuilder> StatTraceBuilder> StatTraceBuilder> Hamiltonian for EuclideanHamiltonian, dir: Direction, - step_size_factor: f64, + step_size_splits: u64, collector: &mut C, ) -> LeapfrogResult { let mut out = self.pool().new_state(math); @@ -322,7 +322,7 @@ impl> Hamiltonian for EuclideanHamiltonian -1, }; - let epsilon = (sign as f64) * self.step_size * step_size_factor; + let epsilon = (sign as f64) * self.step_size / (step_size_splits as f64); start .point() @@ -334,17 +334,9 @@ impl> Hamiltonian for EuclideanHamiltonian> Hamiltonian for EuclideanHamiltonian self.max_energy_error) | !energy_error.is_finite() { - let divergence_info = DivergenceInfo { - logp_function_error: None, - start_location: Some(math.box_array(start.point().position())), - start_gradient: Some(math.box_array(start.point().gradient())), - end_location: Some(math.box_array(&out_point.position)), - start_momentum: Some(math.box_array(&out_point.momentum)), - start_idx_in_trajectory: Some(start.index_in_trajectory()), - end_idx_in_trajectory: Some(out.index_in_trajectory()), - energy_error: Some(energy_error), - }; - collector.register_leapfrog(math, start, &out, Some(&divergence_info)); + let divergence_info = DivergenceInfo::new_energy_error_too_large(math, start, &out); + collector.register_leapfrog( + math, + start, + &out, + Some(&divergence_info), + step_size_splits, + ); return LeapfrogResult::Divergence(divergence_info); } - collector.register_leapfrog(math, start, &out, None); + collector.register_leapfrog(math, start, &out, None, step_size_splits); LeapfrogResult::Ok(out) } @@ -447,4 +437,8 @@ impl> Hamiltonian for EuclideanHamiltonian &mut f64 { &mut self.step_size } + + fn max_energy_error(&self) -> f64 { + self.max_energy_error + } } diff --git a/src/hamiltonian.rs b/src/hamiltonian.rs index 7a49979..d6071d8 100644 --- a/src/hamiltonian.rs +++ b/src/hamiltonian.rs @@ -16,6 +16,7 @@ use crate::{ /// a cutoff value or nan. /// - The logp function caused a recoverable error (eg if an ODE solver /// failed) +#[non_exhaustive] #[derive(Debug, Clone)] pub struct DivergenceInfo { pub start_momentum: Option>, @@ -26,6 +27,7 @@ pub struct DivergenceInfo { pub end_idx_in_trajectory: Option, pub start_idx_in_trajectory: Option, pub logp_function_error: Option>, + pub non_reversible: bool, } impl DivergenceInfo { @@ -39,8 +41,67 @@ impl DivergenceInfo { end_idx_in_trajectory: None, start_idx_in_trajectory: None, logp_function_error: None, + non_reversible: false, } } + + pub fn new_energy_error_too_large( + math: &mut M, + start: &State>, + stop: &State>, + ) -> Self { + DivergenceInfo { + logp_function_error: None, + start_location: Some(math.box_array(start.point().position())), + start_gradient: Some(math.box_array(start.point().gradient())), + // TODO + start_momentum: None, + start_idx_in_trajectory: Some(start.index_in_trajectory()), + end_location: Some(math.box_array(&stop.point().position())), + end_idx_in_trajectory: Some(stop.index_in_trajectory()), + // TODO + energy_error: None, + non_reversible: false, + } + } + + pub fn new_logp_function_error( + math: &mut M, + start: &State>, + logp_function_error: Arc, + ) -> Self { + DivergenceInfo { + logp_function_error: Some(logp_function_error), + start_location: Some(math.box_array(start.point().position())), + start_gradient: Some(math.box_array(start.point().gradient())), + // TODO + start_momentum: None, + start_idx_in_trajectory: Some(start.index_in_trajectory()), + end_location: None, + end_idx_in_trajectory: None, + energy_error: None, + non_reversible: false, + } + } + + pub fn new_not_reversible(math: &mut M, start: &State>) -> Self { + // TODO add info about what went wrong + DivergenceInfo { + logp_function_error: None, + start_location: Some(math.box_array(start.point().position())), + start_gradient: Some(math.box_array(start.point().gradient())), + // TODO + start_momentum: None, + start_idx_in_trajectory: Some(start.index_in_trajectory()), + end_location: None, + end_idx_in_trajectory: None, + energy_error: None, + non_reversible: true, + } + } + pub fn new_max_step_size_halvings(math: &mut M, num_steps: u64, info: Self) -> Self { + info // TODO + } } #[derive(Debug, Copy, Clone)] @@ -106,10 +167,44 @@ pub trait Hamiltonian: SamplerStats + Sized { math: &mut M, start: &State, dir: Direction, - step_size_factor: f64, + step_size_splits: u64, collector: &mut C, ) -> LeapfrogResult; + fn split_leapfrog>( + &mut self, + math: &mut M, + start: &State, + dir: Direction, + num_steps: u64, + collector: &mut C, + max_error: f64, + ) -> LeapfrogResult { + let mut state = start.clone(); + + let mut min_energy = start.energy(); + let mut max_energy = min_energy; + + for _ in 0..num_steps { + state = match self.leapfrog(math, &state, dir, num_steps, collector) { + LeapfrogResult::Ok(state) => state, + LeapfrogResult::Divergence(info) => return LeapfrogResult::Divergence(info), + LeapfrogResult::Err(err) => return LeapfrogResult::Err(err), + }; + let energy = state.energy(); + min_energy = min_energy.min(energy); + max_energy = max_energy.max(energy); + + // TODO: walnuts papers says to use abs, but c++ code doesn't? + if max_energy - min_energy > max_error { + let info = DivergenceInfo::new_energy_error_too_large(math, start, &state); + return LeapfrogResult::Divergence(info); + } + } + + LeapfrogResult::Ok(state) + } + fn is_turning( &self, math: &mut M, @@ -141,4 +236,6 @@ pub trait Hamiltonian: SamplerStats + Sized { fn step_size(&self) -> f64; fn step_size_mut(&mut self) -> &mut f64; + + fn max_energy_error(&self) -> f64; } diff --git a/src/lib.rs b/src/lib.rs index b4798a0..ca94d37 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -108,7 +108,7 @@ pub use chain::Chain; pub use cpu_math::{CpuLogpFunc, CpuMath}; pub use hamiltonian::DivergenceInfo; pub use math_base::{LogpError, Math}; -pub use nuts::NutsError; +pub use nuts::{NutsError, WalnutsOptions}; pub use sampler::{ sample_sequentially, ChainOutput, ChainProgress, DiagGradNutsSettings, DrawStorage, LowRankNutsSettings, Model, NutsSettings, Progress, ProgressCallback, Sampler, diff --git a/src/nuts.rs b/src/nuts.rs index bcc4867..69b9a69 100644 --- a/src/nuts.rs +++ b/src/nuts.rs @@ -34,6 +34,7 @@ pub trait Collector> { _start: &State, _end: &State, _divergence_info: Option<&DivergenceInfo>, + _num_substeps: u64, ) { } fn register_draw(&mut self, _math: &mut M, _state: &State, _info: &SampleInfo) {} @@ -134,20 +135,23 @@ impl, C: Collector> NutsTree { direction: Direction, collector: &mut C, options: &NutsOptions, + early: bool, ) -> ExtendResult where H: Hamiltonian, R: rand::Rng + ?Sized, { - let mut other = match self.single_step(math, hamiltonian, direction, options, collector) { - Ok(Ok(tree)) => tree, - Ok(Err(info)) => return ExtendResult::Diverging(self, info), - Err(err) => return ExtendResult::Err(err), - }; + let mut other = + match self.single_step(math, hamiltonian, direction, options, collector, early) { + Ok(Ok(tree)) => tree, + Ok(Err(info)) => return ExtendResult::Diverging(self, info), + Err(err) => return ExtendResult::Err(err), + }; while other.depth < self.depth { use ExtendResult::*; - other = match other.extend(math, rng, hamiltonian, direction, collector, options) { + other = match other.extend(math, rng, hamiltonian, direction, collector, options, early) + { Ok(tree) => tree, Turning(_) => { return Turning(self); @@ -236,6 +240,7 @@ impl, C: Collector> NutsTree { direction: Direction, options: &NutsOptions, collector: &mut C, + early: bool, ) -> Result, DivergenceInfo>> { let start = match direction { Direction::Forward => &self.right, @@ -246,118 +251,79 @@ impl, C: Collector> NutsTree { Some(ref options) => { // Walnuts implementation // TODO: Shouldn't all be in this one big function... - let mut step_size_factor = 1.0; let mut num_steps = 1; let mut current = start.clone(); - let mut success = false; - - 'step_size_search: for _ in 0..options.max_step_size_halvings { - current = start.clone(); - let mut min_energy = current.energy(); - let mut max_energy = min_energy; - - for _ in 0..num_steps { - current = match hamiltonian.leapfrog( - math, - ¤t, - direction, - step_size_factor, - collector, - ) { - LeapfrogResult::Ok(state) => state, - LeapfrogResult::Divergence(_) => { - num_steps *= 2; - step_size_factor *= 0.5; - continue 'step_size_search; - } - LeapfrogResult::Err(err) => { - return Err(NutsError::LogpFailure(err.into())); - } - }; - - // Update min/max energies - let current_energy = current.energy(); - min_energy = min_energy.min(current_energy); - max_energy = max_energy.max(current_energy); - } - - if max_energy - min_energy > options.max_energy_error { - num_steps *= 2; - step_size_factor *= 0.5; - continue 'step_size_search; - } - - success = true; - break 'step_size_search; + let mut last_divergence = None; + + for _ in 0..options.max_step_size_halvings { + current = match hamiltonian.split_leapfrog( + math, + start, + direction, + num_steps, + collector, + options.max_energy_error, + ) { + LeapfrogResult::Ok(state) => { + last_divergence = None; + state + } + LeapfrogResult::Err(err) => return Err(NutsError::LogpFailure(err.into())), + LeapfrogResult::Divergence(info) => { + num_steps *= 2; + last_divergence = Some(info); + continue; + } + }; + break; } - if !success { - // TODO: More info - return Ok(Err(DivergenceInfo::new())); + if let Some(info) = last_divergence { + let info = DivergenceInfo::new_max_step_size_halvings(math, num_steps, info); + return Ok(Err(info)); } - // TODO let back = direction.reverse(); - let mut current_backward; - let mut reversible = true; - 'rev_step_size: while num_steps >= 2 { + while num_steps >= 2 { num_steps /= 2; - step_size_factor *= 0.5; - - // TODO: Can we share code for the micro steps in the two directions? - current_backward = current.clone(); - - let mut min_energy = current_backward.energy(); - let mut max_energy = min_energy; - - for _ in 0..num_steps { - current_backward = match hamiltonian.leapfrog( - math, - ¤t_backward, - back, - step_size_factor, - collector, - ) { - LeapfrogResult::Ok(state) => state, - LeapfrogResult::Divergence(_) => { - // We also reject in the backward direction, all is good so far... - continue 'rev_step_size; - } - LeapfrogResult::Err(err) => { - return Err(NutsError::LogpFailure(err.into())); - } - }; - - // Update min/max energies - let current_energy = current_backward.energy(); - min_energy = min_energy.min(current_energy); - max_energy = max_energy.max(current_energy); - if max_energy - min_energy > options.max_energy_error { - // We reject also in the backward direction, all good so far... - continue 'rev_step_size; + + match hamiltonian.split_leapfrog( + math, + ¤t, + back, + num_steps, + collector, + options.max_energy_error, + ) { + LeapfrogResult::Ok(_) => (), + LeapfrogResult::Divergence(_) => { + // We also reject in the backward direction, all is good so far... + continue; + } + LeapfrogResult::Err(err) => { + return Err(NutsError::LogpFailure(err.into())); } - } + }; // We did not reject in the backward direction, so we are not reversible reversible = false; break; } - if reversible { + if reversible || early { let log_size = -current.point().energy_error(); (log_size, current) } else { - // TODO: More info - return Ok(Err(DivergenceInfo::new())); + return Ok(Err(DivergenceInfo::new_not_reversible(math, start))); } } None => { - // Classical NUTS - // - let end = match hamiltonian.leapfrog(math, start, direction, 1.0, collector) { + // Classical NUTS. + // TODO Is equivalent to walnuts with max_step_size_halvings = 0? + let end = match hamiltonian.leapfrog(math, start, direction, 1, collector) { LeapfrogResult::Divergence(info) => return Ok(Err(info)), LeapfrogResult::Err(err) => return Err(NutsError::LogpFailure(err.into())), LeapfrogResult::Ok(end) => end, @@ -391,10 +357,20 @@ impl, C: Collector> NutsTree { } } +#[non_exhaustive] #[derive(Debug, Clone, Copy)] pub struct WalnutsOptions { - pub max_energy_error: f64, pub max_step_size_halvings: u64, + pub max_energy_error: f64, +} + +impl Default for WalnutsOptions { + fn default() -> Self { + WalnutsOptions { + max_step_size_halvings: 10, + max_energy_error: 5.0, + } + } } #[derive(Debug, Clone, Copy)] @@ -415,6 +391,7 @@ pub(crate) fn draw( hamiltonian: &mut H, options: &NutsOptions, collector: &mut C, + early: bool, ) -> Result<(State, SampleInfo)> where M: Math, @@ -435,7 +412,7 @@ where while tree.depth < options.maxdepth { let direction: Direction = rng.random(); - tree = match tree.extend(math, rng, hamiltonian, direction, collector, options) { + tree = match tree.extend(math, rng, hamiltonian, direction, collector, options, early) { ExtendResult::Ok(tree) => tree, ExtendResult::Turning(tree) => { let info = tree.info(false, None); diff --git a/src/stepsize.rs b/src/stepsize.rs index 2556d8f..5ce9cd8 100644 --- a/src/stepsize.rs +++ b/src/stepsize.rs @@ -124,6 +124,7 @@ impl> Collector for AcceptanceRateCollector { _start: &State, end: &State, divergence_info: Option<&DivergenceInfo>, + num_substeps: u64, ) { match divergence_info { Some(_) => { diff --git a/src/stepsize_adapt.rs b/src/stepsize_adapt.rs index d55942b..0c867cd 100644 --- a/src/stepsize_adapt.rs +++ b/src/stepsize_adapt.rs @@ -3,6 +3,7 @@ use arrow::{ datatypes::{DataType, Field, Float64Type, UInt64Type}, }; use rand::Rng; +use rand_distr::Uniform; use crate::{ hamiltonian::{Direction, Hamiltonian, LeapfrogResult, Point}, @@ -52,8 +53,7 @@ impl Strategy { *hamiltonian.step_size_mut() = self.options.initial_step; - let state_next = - hamiltonian.leapfrog(math, &state, Direction::Forward, 1.0, &mut collector); + let state_next = hamiltonian.leapfrog(math, &state, Direction::Forward, 1, &mut collector); let LeapfrogResult::Ok(_) = state_next else { return Ok(()); @@ -69,7 +69,7 @@ impl Strategy { for _ in 0..100 { let mut collector = AcceptanceRateCollector::new(); collector.register_init(math, &state, options); - let state_next = hamiltonian.leapfrog(math, &state, dir, 1.0, &mut collector); + let state_next = hamiltonian.leapfrog(math, &state, dir, 1, &mut collector); let LeapfrogResult::Ok(_) = state_next else { *hamiltonian.step_size_mut() = self.options.initial_step; return Ok(()); diff --git a/src/transform_adapt_strategy.rs b/src/transform_adapt_strategy.rs index b1fa2b0..c7b3132 100644 --- a/src/transform_adapt_strategy.rs +++ b/src/transform_adapt_strategy.rs @@ -104,6 +104,7 @@ impl> Collector for DrawCollector { _start: &State, end: &State, divergence_info: Option<&crate::DivergenceInfo>, + num_substeps: u64, ) { if divergence_info.is_some() { return; diff --git a/src/transformed_hamiltonian.rs b/src/transformed_hamiltonian.rs index 014c9f7..6a347b7 100644 --- a/src/transformed_hamiltonian.rs +++ b/src/transformed_hamiltonian.rs @@ -456,7 +456,7 @@ impl Hamiltonian for TransformedHamiltonian { math: &mut M, start: &State, dir: Direction, - step_size_factor: f64, + step_size_splits: u64, collector: &mut C, ) -> LeapfrogResult { let mut out = self.pool().new_state(math); @@ -470,7 +470,7 @@ impl Hamiltonian for TransformedHamiltonian { Direction::Backward => -1, }; - let epsilon = (sign as f64) * self.step_size * step_size_factor; + let epsilon = (sign as f64) * self.step_size / (step_size_splits as f64); start .point() @@ -481,17 +481,9 @@ impl Hamiltonian for TransformedHamiltonian { if !logp_error.is_recoverable() { return LeapfrogResult::Err(logp_error); } - let div_info = DivergenceInfo { - logp_function_error: Some(Arc::new(Box::new(logp_error))), - start_location: Some(math.box_array(start.point().position())), - start_gradient: Some(math.box_array(start.point().gradient())), - start_momentum: None, - end_location: None, - start_idx_in_trajectory: Some(start.point().index_in_trajectory()), - end_idx_in_trajectory: None, - energy_error: None, - }; - collector.register_leapfrog(math, start, &out, Some(&div_info)); + let logp_error = Arc::new(Box::new(logp_error)); + let div_info = DivergenceInfo::new_logp_function_error(math, start, logp_error); + collector.register_leapfrog(math, start, &out, Some(&div_info), step_size_splits); return LeapfrogResult::Divergence(div_info); } @@ -502,21 +494,18 @@ impl Hamiltonian for TransformedHamiltonian { let energy_error = out_point.energy_error(); if (energy_error > self.max_energy_error) | !energy_error.is_finite() { - let divergence_info = DivergenceInfo { - logp_function_error: None, - start_location: Some(math.box_array(start.point().position())), - start_gradient: Some(math.box_array(start.point().gradient())), - end_location: Some(math.box_array(out_point.position())), - start_momentum: None, - start_idx_in_trajectory: Some(start.index_in_trajectory()), - end_idx_in_trajectory: Some(out.index_in_trajectory()), - energy_error: Some(energy_error), - }; - collector.register_leapfrog(math, start, &out, Some(&divergence_info)); + let divergence_info = DivergenceInfo::new_energy_error_too_large(math, start, &out); + collector.register_leapfrog( + math, + start, + &out, + Some(&divergence_info), + step_size_splits, + ); return LeapfrogResult::Divergence(divergence_info); } - collector.register_leapfrog(math, start, &out, None); + collector.register_leapfrog(math, start, &out, None, step_size_splits); LeapfrogResult::Ok(out) } @@ -618,4 +607,8 @@ impl Hamiltonian for TransformedHamiltonian { fn step_size_mut(&mut self) -> &mut f64 { &mut self.step_size } + + fn max_energy_error(&self) -> f64 { + self.max_energy_error + } } From 16d2235e03ee359fdee9d8622eec335d02a7b6f5 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Mon, 30 Jun 2025 09:45:37 +0200 Subject: [PATCH 5/6] feat: add step size jitter --- src/adapt_strategy.rs | 6 ++++-- src/stepsize_adapt.rs | 26 ++++++++++++++++++-------- src/transform_adapt_strategy.rs | 6 ++++-- 3 files changed, 26 insertions(+), 12 deletions(-) diff --git a/src/adapt_strategy.rs b/src/adapt_strategy.rs index eb00e4e..e9155eb 100644 --- a/src/adapt_strategy.rs +++ b/src/adapt_strategy.rs @@ -143,6 +143,8 @@ impl> AdaptStrategy for GlobalStrategy self.step_size.update(&collector.collector1); if draw >= self.num_tune { + // Needed for step size jitter + self.step_size.update_stepsize(rng, hamiltonian, true); self.tuning = false; return Ok(()); } @@ -194,14 +196,14 @@ impl> AdaptStrategy for GlobalStrategy self.step_size .init(math, options, hamiltonian, &position, rng)?; } else { - self.step_size.update_stepsize(hamiltonian, false) + self.step_size.update_stepsize(rng, hamiltonian, false) } return Ok(()); } self.step_size.update_estimator_late(); let is_last = draw == self.num_tune - 1; - self.step_size.update_stepsize(hamiltonian, is_last); + self.step_size.update_stepsize(rng, hamiltonian, is_last); Ok(()) } diff --git a/src/stepsize_adapt.rs b/src/stepsize_adapt.rs index 0c867cd..91224b2 100644 --- a/src/stepsize_adapt.rs +++ b/src/stepsize_adapt.rs @@ -121,19 +121,27 @@ impl Strategy { .advance(self.last_sym_mean_tree_accept, self.options.target_accept); } - pub fn update_stepsize( + pub fn update_stepsize( &mut self, + rng: &mut R, hamiltonian: &mut impl Hamiltonian, use_best_guess: bool, ) { - if let Some(step_size) = self.options.fixed_step_size { - *hamiltonian.step_size_mut() = step_size; - return; - } - if use_best_guess { - *hamiltonian.step_size_mut() = self.step_size_adapt.current_step_size_adapted(); + let step_size = if let Some(step_size) = self.options.fixed_step_size { + step_size + } else if use_best_guess { + self.step_size_adapt.current_step_size_adapted() } else { - *hamiltonian.step_size_mut() = self.step_size_adapt.current_step_size(); + self.step_size_adapt.current_step_size() + }; + + if let Some(jitter) = self.options.jitter { + let jitter = + rng.sample(Uniform::new(1.0 - jitter, 1.0 + jitter).expect("Invalid jitter")); + let jittered_step_size = step_size * jitter; + *hamiltonian.step_size_mut() = jittered_step_size; + } else { + *hamiltonian.step_size_mut() = step_size; } } @@ -236,6 +244,7 @@ pub struct DualAverageSettings { pub initial_step: f64, pub params: DualAverageOptions, pub fixed_step_size: Option, + pub jitter: Option, } impl Default for DualAverageSettings { @@ -245,6 +254,7 @@ impl Default for DualAverageSettings { initial_step: 0.1, params: DualAverageOptions::default(), fixed_step_size: None, + jitter: None, } } } diff --git a/src/transform_adapt_strategy.rs b/src/transform_adapt_strategy.rs index c7b3132..5738e75 100644 --- a/src/transform_adapt_strategy.rs +++ b/src/transform_adapt_strategy.rs @@ -213,6 +213,8 @@ impl AdaptStrategy for TransformAdaptation { self.step_size.update(&collector.collector1); if draw >= self.num_tune { + // Needed for step size jitter + self.step_size.update_stepsize(rng, hamiltonian, true); self.tuning = false; return Ok(()); } @@ -238,13 +240,13 @@ impl AdaptStrategy for TransformAdaptation { )?; } self.step_size.update_estimator_early(); - self.step_size.update_stepsize(hamiltonian, false); + self.step_size.update_stepsize(rng, hamiltonian, false); return Ok(()); } self.step_size.update_estimator_late(); let is_last = draw == self.num_tune - 1; - self.step_size.update_stepsize(hamiltonian, is_last); + self.step_size.update_stepsize(rng, hamiltonian, is_last); Ok(()) } From b2366c14452f2f6954e8a434f8227e675c959659 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Mon, 30 Jun 2025 09:45:37 +0200 Subject: [PATCH 6/6] style: Some clippy fixes --- src/hamiltonian.rs | 8 +++++++- src/mass_matrix.rs | 4 ---- src/nuts.rs | 2 +- src/transform_adapt_strategy.rs | 4 ++-- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/hamiltonian.rs b/src/hamiltonian.rs index d6071d8..94bb537 100644 --- a/src/hamiltonian.rs +++ b/src/hamiltonian.rs @@ -30,6 +30,12 @@ pub struct DivergenceInfo { pub non_reversible: bool, } +impl Default for DivergenceInfo { + fn default() -> Self { + Self::new() + } +} + impl DivergenceInfo { pub fn new() -> Self { DivergenceInfo { @@ -57,7 +63,7 @@ impl DivergenceInfo { // TODO start_momentum: None, start_idx_in_trajectory: Some(start.index_in_trajectory()), - end_location: Some(math.box_array(&stop.point().position())), + end_location: Some(math.box_array(stop.point().position())), end_idx_in_trajectory: Some(stop.index_in_trajectory()), // TODO energy_error: None, diff --git a/src/mass_matrix.rs b/src/mass_matrix.rs index 2f0219c..e3b27c6 100644 --- a/src/mass_matrix.rs +++ b/src/mass_matrix.rs @@ -24,10 +24,6 @@ pub trait MassMatrix: SamplerStats { ); } -pub struct NullCollector {} - -impl> Collector for NullCollector {} - #[derive(Debug)] pub struct DiagMassMatrix { inv_stds: M::Vector, diff --git a/src/nuts.rs b/src/nuts.rs index 69b9a69..dfaeaea 100644 --- a/src/nuts.rs +++ b/src/nuts.rs @@ -306,7 +306,7 @@ impl, C: Collector> NutsTree { LeapfrogResult::Err(err) => { return Err(NutsError::LogpFailure(err.into())); } - }; + } // We did not reject in the backward direction, so we are not reversible reversible = false; diff --git a/src/transform_adapt_strategy.rs b/src/transform_adapt_strategy.rs index 5738e75..5fbe915 100644 --- a/src/transform_adapt_strategy.rs +++ b/src/transform_adapt_strategy.rs @@ -221,7 +221,7 @@ impl AdaptStrategy for TransformAdaptation { if draw < self.final_window_size { if draw < 100 { - if (draw > 0) & (draw % 10 == 0) { + if (draw > 0) & draw.is_multiple_of(10) { hamiltonian.update_params( math, rng, @@ -230,7 +230,7 @@ impl AdaptStrategy for TransformAdaptation { collector.collector2.logps.iter(), )?; } - } else if (draw > 0) & (draw % self.options.transform_update_freq == 0) { + } else if (draw > 0) & draw.is_multiple_of(self.options.transform_update_freq) { hamiltonian.update_params( math, rng,